您的位置:

Focal Loss损失函数详解

一、Focal Loss代码

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C

        target = target.view(-1, 1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.data.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

Focal Loss(FL)是一种针对类不平衡的分类问题的一种有效方法,成功的应用于目标检测任务中,FL损失函数可通过调整$\gamma$和$\alpha$来适应不同类别正负样本的分布。下面将从多个方面详细讲解FL损失函数的基本原理、优点、缺点和改进方向。

二、Focal Loss实际并不好用

FL是一种比较新的损失函数,目前在一些高档的目标检测模型中得到了应用。但是,在实际应用中,FL有时候并不能很好的提高模型的性能。比如,FL存在一些缺点:

1. 当$\gamma$设定不合理时,可能会使训练过程中模型的表现变得更差;

2. $\alpha$需要事先设定,不同的数据集需要不同的设定,这种设定需要基于模型的训练数据的经验,并不能很好的动态调整;

3. 精调每一个参数都比较复杂,特别是$\gamma$和$\alpha$的精细调整;

综上所述,FL在实际应用中,并不是一种非常好的损失函数方法。

三、Focal Loss缺点

FL的主要缺点体现在如下几个方面:

1. 不同数据集下的最佳损失函数参数需要重新调试;

# 想要从各个角度看到loss函数的效果
class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = (self.alpha * (1-pt)**self.gamma * BCE_loss).mean()

        if self.reduction == 'none':
            return F_loss
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss.mean()

2. 确定合理的$\gamma$并不容易;

3. 当类别数量较大时,缺少包容性和良好的可视化效果。

四、Focal Loss改进

1. Focal Loss With Label Smoothing(FL-LS):此方法考虑了标签过度自信的问题,因此在目标检测中提出了FL-LS方法,以降低标签噪声的影响。

class Focal_Loss_with_LS(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, ls_epsilon=0.1):
        super(Focal_Loss_with_LS, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ls_epsilon = ls_epsilon
        self.class_num = class_num

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        p = F.sigmoid(inputs)
        p_smooth = (1 - self.ls_epsilon) * p + self.ls_epsilon / self.class_num
        pt = p * targets + (1 - p) * (1 - targets)
        pt_smooth = p_smooth * targets + (1 - p_smooth) * (1 - targets)
        FL_loss = - self.alpha * (1 - pt_smooth) ** self.gamma * torch.log(pt_smooth)
        F_loss = (FL_loss * (1 - pt) ** 2).mean()

        return F_loss

2. Focal Cosine Loss(FCOS):此方法针对第一个问题,采用了余弦相似度来代替 softmax1。FCOS方法如下:

class FocalCosineLoss(nn.Module):
    def __init__(self, gamma=1.0, eps=1e-7):
        super(FocalCosineLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps

    def forward(self, inputs, targets):
        cosine_loss = F.cosine_embedding_loss(inputs.view(-1), F.one_hot(targets, num_classes=inputs.size(-1)).float().view(-1, inputs.size(-1)), torch.tensor([1.0], device=inputs.device), reduction="none")
        sine_loss = F.sin_embedding_loss(inputs.view(-1), F.one_hot(targets, num_classes=inputs.size(-1)).float().view(-1, inputs.size(-1)), torch.tensor([1.0], device=inputs.device), reduction="none")
        one_hot = torch.zeros_like(inputs)
        one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
        pt = (one_hot * inputs).sum(1) + self.eps
        focal_loss = -((1 - pt) ** self.gamma) * cosine_loss
        return focal_loss.mean()

3. Focal Loss with Dynamic Number of Objects(FL-DoN): FL-DoN是针对第三个问题的改进,其模型结构如下:

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma

    def forward(self, input, target, num_objs=None):
        target = target.type(input.type()).unsqueeze(1)
        logpt = F.logsigmoid(input * (target * 2 - 1))
        pt = logpt.exp()

        if num_objs is not None:  # dynamic focal loss
            alpha = num_objs / num_objs.mean()
            alpha = alpha.unsqueeze(-1)
            at = (target * alpha + (1 - target) * (1 - alpha))
            f_loss = -at * ((1 - pt) ** self.gamma) * logpt
        else:
            f_loss = -((1 - pt) ** self.gamma) * logpt

        return f_loss.mean()

五、Focal Loss函数

Focal Loss函数可视化如图所示:

def focal_loss(p, t, alpha = 0.25, gamma = 2.0):
    t = t.view(-1, 1)
    p = p.view(-1, 1)
    alpha_t = alpha*(2.0*t-1.0)
    modulating_factor = (1.0-p).pow(gamma)
    FL = -1.0 * alpha_t * modulating_factor * p.log()-(1.0-alpha_t) * modulating_factor * (1.0-p).log()
    return FL.mean()

六、Focal Loss损失函数选取

在多个数据集上对比使用不同的损失函数,如交叉熵损失函数(CE)、平衡交叉熵(BCE)、focal损失函数(FL)和Direct Cross Entropy (CEDE)损失函数4种常用的损失函数,实验结果显示 FL损失函数具有最佳性能。

七、小结

总的来说,FL是一种解决类不平衡问题的有效方法,但是由于其固有的限制,实际使用中可能会遇到瓶颈。因此,研究改进FL的方法和更有效的替代方法是值得我们继续深入研究的问题。