layout_data.models.focalloss 源代码

import torch
import torch.nn as nn


[文档]class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=0.25): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha
[文档] def forward(self, x, target): loss = -self.alpha * (1 - x)**self.gamma * torch.log( x + 1e-8) * target - (1 - self.alpha) * x**self.gamma * torch.log( 1 - x + 1e-8) * (1 - target) return loss.mean()
# class FocalLoss(nn.Module): # r""" # This criterion is a implemenation of Focal Loss, which is proposed in # Focal Loss for Dense Object Detection. # Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) # The losses are averaged across observations for each minibatch. # Args: # alpha(1D Tensor, Variable) : the scalar factor for this criterion # gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), # putting more focus on hard, misclassified examples # size_average(bool): By default, the losses are averaged over observations for each minibatch. # However, if the field size_average is set to False, the losses are # instead summed for each minibatch. # """ # # def __init__(self, class_num, alpha=None, gamma=2, size_average=True): # super(FocalLoss, self).__init__() # if alpha is None: # self.alpha = Variable(torch.ones(class_num, 1)) # else: # if isinstance(alpha, Variable): # self.alpha = alpha # else: # self.alpha = Variable(alpha) # self.gamma = gamma # self.class_num = class_num # self.size_average = size_average # # def forward(self, inputs, targets): # N = inputs.size(0) # C = inputs.size(1) # P = F.softmax(inputs) # # class_mask = inputs.data.new(N, C).fill_(0) # class_mask = Variable(class_mask) # ids = targets.view(-1, 1) # class_mask.scatter_(1, ids.data, 1.) # # print(class_mask) # # if inputs.is_cuda and not self.alpha.is_cuda: # self.alpha = self.alpha.cuda() # alpha = self.alpha[ids.data.view(-1)] # # probs = (P * class_mask).sum(1).view(-1, 1) # # log_p = probs.log() # # print('probs size= {}'.format(probs.size())) # # print(probs) # # batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p # # print('-----bacth_loss------') # # print(batch_loss) # # if self.size_average: # loss = batch_loss.mean() # else: # loss = batch_loss.sum() # return loss