当前位置:网站首页>paddlepaddle实现CS_CE Loss且并入PaddleClas
paddlepaddle实现CS_CE Loss且并入PaddleClas
2022-08-11 04:44:00 【shier_smile】
文章目录
1、 环境
1.1 paddlepaddle版本
paddlepaddle-gpu 2.2.2
1.2 PaddleClas版本
* release/2.4
2、 CS_CE Loss
CS_CE Loss在CEloss上加入了对于类别数量的权重系数,增加模型对于样本数量少类别的回归能力
计算方式:
L o s s ( z , c ) = − ( N m i n N c ) γ ∗ C r o s s E n t r o p y ( z , c ) Loss(z, c) = - (\frac{N_{min}}{N_c})^\gamma * CrossEntropy(z, c) Loss(z,c)=−(NcNmin)γ∗CrossEntropy(z,c)
其中:
- γ \gamma γ为控制权重的超参数
- N m i n N_{min} Nmin为存在样本最少的类别的样本数量
- N c N_c Nc为c类别样本数量
3、paddle代码
import paddle.nn as nn
import paddle
import numpy as np
import paddle.nn.functional as F
class CostSensitiveCE(nn.Layer):
r""" Equation: Loss(z, c) = - (\frac{N_min}{N_c})^\gamma * CrossEntropy(z, c), where gamma is a hyper-parameter to control the weights, N_min is the number of images in the smallest class, and N_c is the number of images in the class c. The representative re-weighting methods, which assigns class-dependent weights to the loss function Args: gamma (float or double): to control the loss weights: (N_min/N_i)^gamma """
def __init__(self, num_class_list, gamma):
super(CostSensitiveCE, self).__init__()
self.num_class_list = num_class_list
self.csce_weight = paddle.to_tensor(np.array([(min(self.num_class_list) / N)**gamma for N in self.num_class_list], dtype=np.float32))
def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
cs_ce_loss = F.cross_entropy(x, label=label, soft_label=soft_label, weight=self.csce_weight)
cs_ce_loss = cs_ce_loss.mean()
return {
"CS_CELoss":cs_ce_loss}
4、并入PaddleClas
4.1 loss代码
在PaddeClas/ppcls/loss目录下添加文件cs_celoss.py文件,并在PaddeClas/ppcls/loss/__init__.py中添加
from .cs_celoss import CostSensitiveCE
4.2 config 文件
Loss:
Train:
- CostSensitiveCE:
gamma: 1
num_class_list: [4400, 1520, 560, 1680] # 每个类别样本数量
weight: 1 # 一个Classfier默认为1
Eval:
- CostSensitiveCE:
gamma: 1
num_class_list: [4400, 1520, 560, 1680] # 每个类别样本数量
weight: 1
5、参考文献
- 论文:Bag of Tricks for Long-Tailed Visual Recognition with Deep Convolutional Neural Networks
- 官方代码:https://github.com/zhangyongshun/BagofTricks-LT
本帖写于:2022年8月6号, 未经本人允许,禁止转载。
边栏推荐
- 0 Basic software test for career change, self-study for 3 months, 12k*13 salary offer
- .NET Custom Middleware
- Jetson Orin platform 4-16 channel GMSL2/GSML1 camera acquisition kit recommended
- Research on a Consensus Mechanism-Based Anti-Runaway Scheme for Digital Trunking Terminals
- 增加PRODUCT_BOOT_JARS及类 提供jar包给应用
- 利用Navicat Premium导出数据库表结构信息至Excel
- 力扣——旋转数组的最小数字
- 【服务器安装Redis】Centos7离线安装redis
- 直播软件搭建,流式布局,支持单选、多选等
- 破解事务性工作瓶颈,君子签电子合同释放HR“源动力”!
猜你喜欢
随机推荐
视觉任务种常用的类别文件之一json文件
[Web3 series development tutorial - create your first NFT (9)] How to view your NFT in the mobile wallet
干货:服务器网卡组技术原理与实践
"3 Longest Substring Without Repeating Characters" on the 17th day of LeetCode brushing
监听U盘插入 拔出 消息,获得U盘盘符
Use jackson to parse json data in detail
Redis:解决分布式高并发修改同一个Key的问题
一起Talk编程语言吧
洛谷P2370 yyy2015c01 的 U 盘
使用百度EasyDL实现施工人员安全装备检测
洛谷P1196 银河英雄传说
"239 Sliding Window Maximum Value" on the 16th day of LeetCode brushing
Mysql: set the primary key to automatically increase the starting value
Three 】 【 yolov7 series of actual combat from 0 to build training data sets
分层架构&SOA架构
Bubble sort and heap sort
(转)JVM中那些区域会发生OOM?
1815. Get the maximum number of groups of fresh donuts state compression
"125 Palindrome Verification" of the 10th day string series of LeetCode brushing questions
如何进行AI业务诊断,快速识别降本提效增长点?









