您的位置:

focal loss的多个方面详解

一、focal_loss代码

def binary_focal_loss(gamma=2.0, alpha=0.25):
    def binary_focal_loss_fixed(y_true, y_pred):
        """
        y_true shape need same as y_pred shape
        """
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        modulating_factor = K.pow(1.0 - p_t, gamma)
        return -K.sum(alpha_factor * modulating_factor * K.log(p_t), axis=-1)
    return binary_focal_loss_fixed

focal loss代码通过使用keras库来创建一个二分类的focal loss函数。在pseudo-Huber损失函数的基础上,利用指数函数来加强焦点。对于y_true=1和y_true=0,alpha参数会对真正和假正误差进行赋值。同样的,gamma参数会调整损失函数的几何形状。

二、focal loss实际并不好用

focal loss实际上并不如预期那样好用。一些研究人员在实验中发现,虽然focal loss在框架用例数据集上的结果要优于标准的交叉熵损失,但在其他数据集上可能会产生比标准交叉熵损失更差的结果。这个原因主要是focal loss只关注于未被正确分类的样本,忽略掉了其他已被正确分类的样本,因此会产生过拟合的问题。

三、focalloss缺点

focal loss最大的缺点之一就是需要经过不断的实验才能确定最优的gamma和alpha参数,而这对于很多工程师或者是研究人员来说是一件非常耗时的过程。此外,focal loss在样本不平衡和分布移位(distribution shift)的时候也会出现问题,这是因为gamma和alpha参数不稳定,它们往往取决于数据的分布情况。

四、focalloss改进

为了解决focal loss所面临的问题,学者们提出了一些改进。比如在目标检测中,RetinaNet提出的focal loss可以通过多层监督来加强难分类样本的训练,而自适应分类的半监督focal loss则可以根据每个类别数据的分布自适应性地进行alpha和gamma参数的调整。

五、focal loss函数

根据最初的论文,focal loss函数可以表示如下:

            FL(p_t)=-α(1−p_t)^γ * log(p_t)

其中,p_t是正确分类的概率,α是正向权重,γ是焦距参数。可以看到,当γ=0并且α=0.25时,公式会退化成标准的二元交叉熵损失函数。

六、focalloss损失函数

focal loss应用在分类任务中时,可以通过将其作为损失函数来优化模型。下面是一个图像分类的focal loss示例:

from keras import backend as K

def categorical_focal_loss(gamma=2.0, alpha=0.25):
    def focal_loss_fixed(y_true, y_pred):
        """
        Multi-class Focal loss for imbalanced data
        """
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1.0 - epsilon)
        y_true = K.one_hot(tf.cast(y_true, tf.int32), y_pred.shape[1])
        pt = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
        loss = -K.sum(alpha_t * K.pow(1.0 - pt, gamma) * K.log(pt),axis=-1)
        return loss
    
    return focal_loss_fixed
    
model = Sequential()
model.add(Dense(num_classes, activation='softmax', input_shape=input_shape))
model.compile(optimizer='adam', 
              loss=categorical_focal_loss(gamma=2., alpha=.25),
              metrics=['accuracy'])

七、focal选取

最后,被称为有效的优化方案之一的样本调整技术可以用来解决focal loss在样本不均衡的情况下产生过拟合的问题。样本调整可以通过减轻训练数据中的类别不平衡性来消除过拟合。