Focal Loss的Pytorch实现

1 公式推导

1.1 交叉熵(cross entropy)

  1. 信息论中认为事件$X$中概率小的可能性$x_i$如果发生了将会包含了更多的信息量。假设$x_i$就是指的某个可能性,而$P(X=x_i)=p(x_i)$是该可能性发生的概率,所以对信息量的定义就是
    $$
    I(x_i) = \log{\frac{1}{p(x_i)}}
    $$

  2. 而熵的概念就是对于事件的所有可能性的期望,$N$指的是样本数量
    $$
    H(X) = -\sum_{i=1}^{N}{p(x_i)\log{p(x_i)}}
    $$

  3. 交叉熵为我们提供了一种表达两种概率分布的差异的方法。
    $X$和$Y$的分布越不相同, $X$相对于$Y$的交叉熵将越大于$Y$的熵
    $$
    H_{Y}(X) = -\sum_{i=1}^{N}{p(y_i)\log{p(x_i)}}
    $$

  4. 多分类的交叉熵损失函数
    假设$N$个样本,$K$个分类,$I(y_i=k)$记作$y_{i,k}$,这一般是gt,要么为1,要么为0
    $$
    l(X)=-\frac{1}{N}\sum_{i=1}^{N}{\sum_{k=1}^{K}{y_{i,k}\log{x_{i,k}}}}
    $$

1.2 平衡交叉熵函数(balanced cross entropy)

$$
l(X)=-\frac{1}{N}\sum_{i=1}^{N}{\sum_{k=1}^{K}{\alpha_{k}y_{i,k}\log{x_{i,k}}}}
$$
$\alpha_{k}$为样本分布比例

1.3 Focal Loss

如果数据集中的分类样本不均匀,会导致损失函数中多数类别的权重会提高,少数样本的参数学习会很困难。
$$
l(X)=-\frac{1}{N}\sum_{i=1}^{N}{\sum_{k=1}^{K}{\alpha_{k}(1-x_{i,k})^{\gamma}y_{i,k}\log{x_{i,k}}}}
$$

focal loss相比balanced cross entropy而言,二者都是试图解决样本不平衡带来的模型训练问题,后者从样本分布角度对损失函数添加权重因子,前者从样本分类难易程度出发,使loss聚焦于难分样本。

2 代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# pytorch实现
class FocalLoss(nn.Module):
def __init__(self, gamma = 2, alpha = 1, size_average = True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.size_average = size_average
self.elipson = 0.000001

def forward(self, outputs, labels):
# 先计算CE Loss
ce_loss = torch.nn.functional.cross_entropy(outputs, labels, reduction='none')
# 消掉log
pt = torch.exp(-ce_loss)
# mean over the batch
focal_loss = (self.alpha * (1-pt)**self.gamma * ce_loss).mean()
return focal_loss

3 参考资料

  1. 损失函数:交叉熵详解
  2. 交叉熵的原理
  3. Focal loss及多分类任务实现
Author

Huan Yang

Posted on

2023-05-21

Updated on

2023-10-16

Licensed under

Your browser is out-of-date!

Update your browser to view this website correctly.&npsb;Update my browser now

×