您的位置:

Dice Loss详解

一、Dice Loss 代码

import torch

def dice_loss(pred, target, smooth=1.):
    num = pred.size(0)
    m1 = pred.view(num, -1)
    m2 = target.view(num, -1)
    intersection = (m1 * m2).sum()
    score = (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
    return 1. - score

Dice Loss(Dice Coefficient Loss)是一种二分类分割的监督学习方法,最早被用于医学图像分割。

二、Dice Loss 计算多分类问题

在处理多分类问题时,我们可以将 Dice Loss 用于每个动态二分类问题:对于每个类别,以该类别的输出值为二分类中的正类,其它类别组成的集合为负类,即把多分类的问题转化为多个二分类问题,分别使用 Dice Loss 计算,最后取平均得到多分类问题的 Dice Loss。

三、Dice Loss 不收敛

在实际应用中,我们可能会发现 Dice Loss 不收敛的情况。一个常见的解决方法是利用交叉熵损失 (Cross Entropy Loss)作为惩罚项进行DICE Loss优化。

四、Dice Loss 多分类

在多分类场景下,我们可以通过将 Dice Loss 与交叉熵 Loss 结合,得到用于多分类问题的 Dice Loss。

五、Dice Loss 不下降

在训练中,我们可能会发现 Dice Loss 不下降,这通常是由于数据不平衡造成的。解决方法是加权,即乘以各自的权重因子来平衡损失,这个方法也常用于解决交叉熵不平衡问题。

六、Dice Loss 出现负数

有时候 Dice Loss 会出现负数,这是因为两个图像之间无法对应,最后得到一个负的 Intersection。解决方法是加上一个平滑项,并保证 Intersection 为正数,比如将 1e-5 置于分母中。

七、Dice Loss 多分类分割

对于多分类分割问题,我们可以使用 Dice Loss 计算每个类别与非该类别的分割情况,在所有类别上取平均得到 Dice Loss。

八、Dice Loss 和 BCE Loss 进行组合

在一些用途中,我们需要同时考虑分类准确性和分割精度,这时可以将 BCE Loss 和 Dice Loss 进行加权组合,如下所示:

import torch.nn.functional as F

def dice_bce_loss(pred, target, smooth=1.):
    bce_loss = F.binary_cross_entropy_with_logits(pred, target)
    pred = torch.sigmoid(pred)
    num = pred.size(0)
    m1 = pred.view(num, -1)
    m2 = target.view(num, -1)
    intersection = (m1 * m2).sum()
    score = (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
    dice_loss = 1. - score
    return bce_loss + dice_loss