import torch
import torch.nn as nn
[文档]class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
[文档] def forward(self, x, target):
loss = -self.alpha * (1 - x)**self.gamma * torch.log(
x + 1e-8) * target - (1 - self.alpha) * x**self.gamma * torch.log(
1 - x + 1e-8) * (1 - target)
return loss.mean()
# class FocalLoss(nn.Module):
# r"""
# This criterion is a implemenation of Focal Loss, which is proposed in
# Focal Loss for Dense Object Detection.
# Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
# The losses are averaged across observations for each minibatch.
# Args:
# alpha(1D Tensor, Variable) : the scalar factor for this criterion
# gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
# putting more focus on hard, misclassified examples
# size_average(bool): By default, the losses are averaged over observations for each minibatch.
# However, if the field size_average is set to False, the losses are
# instead summed for each minibatch.
# """
#
# def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
# super(FocalLoss, self).__init__()
# if alpha is None:
# self.alpha = Variable(torch.ones(class_num, 1))
# else:
# if isinstance(alpha, Variable):
# self.alpha = alpha
# else:
# self.alpha = Variable(alpha)
# self.gamma = gamma
# self.class_num = class_num
# self.size_average = size_average
#
# def forward(self, inputs, targets):
# N = inputs.size(0)
# C = inputs.size(1)
# P = F.softmax(inputs)
#
# class_mask = inputs.data.new(N, C).fill_(0)
# class_mask = Variable(class_mask)
# ids = targets.view(-1, 1)
# class_mask.scatter_(1, ids.data, 1.)
# # print(class_mask)
#
# if inputs.is_cuda and not self.alpha.is_cuda:
# self.alpha = self.alpha.cuda()
# alpha = self.alpha[ids.data.view(-1)]
#
# probs = (P * class_mask).sum(1).view(-1, 1)
#
# log_p = probs.log()
# # print('probs size= {}'.format(probs.size()))
# # print(probs)
#
# batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
# # print('-----bacth_loss------')
# # print(batch_loss)
#
# if self.size_average:
# loss = batch_loss.mean()
# else:
# loss = batch_loss.sum()
# return loss