当前位置:网站首页>Poly1CrossEntropyLoss的pytorch实现
Poly1CrossEntropyLoss的pytorch实现
2022-08-09 04:13:00 【Aldersw】
代码改自github
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "none"):
""" Create instance of Poly1CrossEntropyLoss :param num_classes: :param epsilon: :param reduction: one of none|sum|mean, apply reduction to final loss tensor """
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
return
def forward(self, output, labels):
""" Forward pass :param output: tensor of shape [N, num_classes] :param labels: tensor of shape [N] or [N, num_classes] :return: poly cross-entropy loss """
#timm如果有使用cutmix或者mixup,那么就会得到的labels是[batchsize,classnums]维度的矩阵,因为获取的label不是hard label,但也不算是smooth label,是每个数据对应的label类似于这种[0,0,0.65,0,0,0.35,0,...]
if labels.ndim == 1:
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=output.device,dtype=output.dtype)
else:
labels=labels.to(device=output.device)
pt = torch.sum(labels * F.softmax(output, dim=-1), dim=-1)#64,1
CE = F.cross_entropy(input=output, target=labels, reduction='none')
# print(CE.shape)#64
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
#加上这个平均操作是为了获取每个数据平均的loss,如果不加就是长度为batchsize的一维矩阵,分别存放每个数据对应的loss
#poly1 = poly1.mean()
return poly1
边栏推荐
- npm package.json
- A few words about Microsoft's 2022/2023 autumn recruits
- leetcode 5705. 判断国际象棋棋盘中一个格子的颜色
- 【 21 based texture (2, bump mapping theory) 】
- 【21 基础纹理(二、凹凸映射的理论)】
- 【数学】点积与叉积
- NanoDet代码逐行精读与修改(五.2)计算Loss
- 【二叉树】重建二叉树
- “error“: { “root_cause“: [{ “type“: “circuit_breaking_exception“, “reason“: “[parent] D [solved]
- 欧拉22.02系统 mysql5.7 arm版本的安装包, 哪里能下载到?
猜你喜欢

【周赛复盘】力扣第 305 场单周赛

【精品向】你真的会写测试用例么?全网超大型测试用例攻略

电脑重装系统如何在 Win11查看显卡型号信息

EventLoop同步异步,宏任务微任务笔记

了解CV和RoboMaster视觉组(五)滤波器、观测器和预测方法:粒子滤波器Particle Filter

了解CV和RoboMaster视觉组(五)滤波器、观测器和预测方法:维纳滤波器Wiener Filter,LMS

One Pass 1258 - Digital Pyramid (Dynamic Programming)

If A, B, C, and D process parts, the total number of processed parts is 370. If the number of parts processed by A is 10 more, if the number of parts processed by B is 20 less, if the number of parts

static成员及代码块

了解CV和RoboMaster视觉组(五)滤波器、观测器和预测方法:卡尔曼滤波器
随机推荐
NanoDet代码逐行精读与修改(零)Architecture
etcd学习笔记 - 入门
自动化测试的生命周期是什么?
了解CV和RoboMaster视觉组(五)目标跟踪:基于深度学习的方法
软件质效领航者 | 优秀案例•东风集团DevOps改革项目
链接脚本-变量使用中遇到一个问题
STM32串口通信不停接受到垃圾数据的问题及其解决
gopacket源码分析
etcd Study Notes - Getting Started
了解CV和RoboMaster视觉组(五)统计特征和global-based方法
项目管理-挣值分析方法学习总结
07.1 类的的补充
NanoDet代码逐行精读与修改(三)辅助训练模块AGM
【数学建模绘图系列教程】绘图模板总结
32 基本统计知识——假设检验
5.索引优化实战
关于微软2022/2023秋招内推的几句
29 机器学习中常常提到的正则化到底是什么意思
5824. 子字符串突变后可能得到的最大整数
driftingblues靶机wp