当前位置:网站首页>paddlepaddle实现CS_CE Loss且并入PaddleClas
paddlepaddle实现CS_CE Loss且并入PaddleClas
2022-08-11 04:44:00 【shier_smile】
文章目录
1、 环境
1.1 paddlepaddle版本
paddlepaddle-gpu 2.2.2
1.2 PaddleClas版本
* release/2.4
2、 CS_CE Loss
CS_CE Loss在CEloss上加入了对于类别数量的权重系数,增加模型对于样本数量少类别的回归能力
计算方式:
L o s s ( z , c ) = − ( N m i n N c ) γ ∗ C r o s s E n t r o p y ( z , c ) Loss(z, c) = - (\frac{N_{min}}{N_c})^\gamma * CrossEntropy(z, c) Loss(z,c)=−(NcNmin)γ∗CrossEntropy(z,c)
其中:
- γ \gamma γ为控制权重的超参数
- N m i n N_{min} Nmin为存在样本最少的类别的样本数量
- N c N_c Nc为c类别样本数量
3、paddle代码
import paddle.nn as nn
import paddle
import numpy as np
import paddle.nn.functional as F
class CostSensitiveCE(nn.Layer):
r""" Equation: Loss(z, c) = - (\frac{N_min}{N_c})^\gamma * CrossEntropy(z, c), where gamma is a hyper-parameter to control the weights, N_min is the number of images in the smallest class, and N_c is the number of images in the class c. The representative re-weighting methods, which assigns class-dependent weights to the loss function Args: gamma (float or double): to control the loss weights: (N_min/N_i)^gamma """
def __init__(self, num_class_list, gamma):
super(CostSensitiveCE, self).__init__()
self.num_class_list = num_class_list
self.csce_weight = paddle.to_tensor(np.array([(min(self.num_class_list) / N)**gamma for N in self.num_class_list], dtype=np.float32))
def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
cs_ce_loss = F.cross_entropy(x, label=label, soft_label=soft_label, weight=self.csce_weight)
cs_ce_loss = cs_ce_loss.mean()
return {
"CS_CELoss":cs_ce_loss}
4、并入PaddleClas
4.1 loss代码
在PaddeClas/ppcls/loss
目录下添加文件cs_celoss.py
文件,并在PaddeClas/ppcls/loss/__init__.py
中添加
from .cs_celoss import CostSensitiveCE
4.2 config 文件
Loss:
Train:
- CostSensitiveCE:
gamma: 1
num_class_list: [4400, 1520, 560, 1680] # 每个类别样本数量
weight: 1 # 一个Classfier默认为1
Eval:
- CostSensitiveCE:
gamma: 1
num_class_list: [4400, 1520, 560, 1680] # 每个类别样本数量
weight: 1
5、参考文献
- 论文:Bag of Tricks for Long-Tailed Visual Recognition with Deep Convolutional Neural Networks
- 官方代码:https://github.com/zhangyongshun/BagofTricks-LT
本帖写于:2022年8月6号, 未经本人允许,禁止转载。
边栏推荐
- "3 Longest Substring Without Repeating Characters" on the 17th day of LeetCode brushing
- 洛谷P5139 z小f的函数
- CAN/以太网转换器 CAN与以太网互联互通
- 无线电射频能量的收集
- Research on a Consensus Mechanism-Based Anti-Runaway Scheme for Digital Trunking Terminals
- 1815. 得到新鲜甜甜圈的最多组数 状态压缩
- ALSA音频架构
- 如何给网页添加icon图标?
- [Note] Is the value of BatchSize the bigger the better?
- Jetson Orin平台4-16路 GMSL2/GSML1相机采集套件推荐
猜你喜欢
随机推荐
自研能力再获认可,腾讯云数据库入选 Forrester Translytical 报告
洛谷P4324 扭动的回文串
Embedded Sharing Collection 33
The sword refers to offer_abstract modeling capabilities
[Likou] 22. Bracket generation
-填涂颜色-
About data paging display
[Web3 series development tutorial - create your first NFT (9)] How to view your NFT in the mobile wallet
洛谷P5139 z小f的函数
Common layout effect realization scheme
Alibaba Cloud releases 3 high-performance computing solutions
交换机和路由器技术-32-命名ACL
如何进行AI业务诊断,快速识别降本提效增长点?
【小记】BatchSize的数值是设置的越大越好吗
ALSA音频架构 -- aplay播放流程分析
监听U盘插入 拔出 消息,获得U盘盘符
【yolov7系列三】实战从0构建训练自己的数据集
Use jackson to parse json data in detail
MySQL database storage engine and database creation, modification and deletion
澳大利亚网络空间安全体系建设论析