Focal Loss的Pytorch实现
1 公式推导
1.1 交叉熵(cross entropy)
信息论中认为事件$X$中概率小的可能性$x_i$如果发生了将会包含了更多的信息量。假设$x_i$就是指的某个可能性,而$P(X=x_i)=p(x_i)$是该可能性发生的概率,所以对信息量的定义就是
$$
I(x_i) = \log{\frac{1}{p(x_i)}}
$$而熵的概念就是对于事件的所有可能性的期望,$N$指的是样本数量
$$
H(X) = -\sum_{i=1}^{N}{p(x_i)\log{p(x_i)}}
$$交叉熵为我们提供了一种表达两种概率分布的差异的方法。
$X$和$Y$的分布越不相同, $X$相对于$Y$的交叉熵将越大于$Y$的熵
$$
H_{Y}(X) = -\sum_{i=1}^{N}{p(y_i)\log{p(x_i)}}
$$多分类的交叉熵损失函数
假设$N$个样本,$K$个分类,$I(y_i=k)$记作$y_{i,k}$,这一般是gt,要么为1,要么为0
$$
l(X)=-\frac{1}{N}\sum_{i=1}^{N}{\sum_{k=1}^{K}{y_{i,k}\log{x_{i,k}}}}
$$
1.2 平衡交叉熵函数(balanced cross entropy)
$$
l(X)=-\frac{1}{N}\sum_{i=1}^{N}{\sum_{k=1}^{K}{\alpha_{k}y_{i,k}\log{x_{i,k}}}}
$$
$\alpha_{k}$为样本分布比例
1.3 Focal Loss
如果数据集中的分类样本不均匀,会导致损失函数中多数类别的权重会提高,少数样本的参数学习会很困难。
$$
l(X)=-\frac{1}{N}\sum_{i=1}^{N}{\sum_{k=1}^{K}{\alpha_{k}(1-x_{i,k})^{\gamma}y_{i,k}\log{x_{i,k}}}}
$$
focal loss相比balanced cross entropy而言,二者都是试图解决样本不平衡带来的模型训练问题,后者从样本分布角度对损失函数添加权重因子,前者从样本分类难易程度出发,使loss聚焦于难分样本。
2 代码实现
1 | # pytorch实现 |