当前位置:网站首页>pytorch implementation of Poly1CrossEntropyLoss
pytorch implementation of Poly1CrossEntropyLoss
2022-08-09 04:19:00 【Aldersw】
代码改自github
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "none"):
""" Create instance of Poly1CrossEntropyLoss :param num_classes: :param epsilon: :param reduction: one of none|sum|mean, apply reduction to final loss tensor """
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
return
def forward(self, output, labels):
""" Forward pass :param output: tensor of shape [N, num_classes] :param labels: tensor of shape [N] or [N, num_classes] :return: poly cross-entropy loss """
#timm如果有使用cutmix或者mixup,then you will getlabels是[batchsize,classnums]维度的矩阵,因为获取的label不是hard label,但也不算是smooth label,corresponding to each datalabel类似于这种[0,0,0.65,0,0,0.35,0,...]
if labels.ndim == 1:
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=output.device,dtype=output.dtype)
else:
labels=labels.to(device=output.device)
pt = torch.sum(labels * F.softmax(output, dim=-1), dim=-1)#64,1
CE = F.cross_entropy(input=output, target=labels, reduction='none')
# print(CE.shape)#64
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
#This averaging operation is added to obtain the average of each dataloss,If not added, the length is batchsize的一维矩阵,Each data is stored separatelyloss
#poly1 = poly1.mean()
return poly1
边栏推荐
猜你喜欢
随机推荐
etcd学习笔记 - 入门
Polygon zkEVM Prover
单根k线图知识别以为自己都懂了
阿里云天池大赛赛题(机器学习)——阿里云安全恶意程序检测(完整代码)
软件质效领航者 | 优秀案例•国金证券DevOps建设项目
整数倍数数列
容易混淆的指针知识点
笔记本电脑重装系统后开机蓝屏要怎么办
Crosstalk and Protection
维护RAC日志轮转
电脑重装系统如何在 Win11查看显卡型号信息
简单的数学公式计算
了解CV和RoboMaster视觉组(五)滤波器、观测器和预测方法:维纳滤波器Wiener Filter,LMS
MySQL:redo log日志——笔记自用
给一时兴起想要学习 “ 测试 ” 的同学的几条建议.....
[Server data recovery] A case of data recovery when the Ext4 file system cannot be mounted and an error is reported after fsck
两种K线形态预示今日伦敦银走向
新一代CMDB构建方法,是能够给企业带来收益的
BaseDexClassLoader的正确使用方式
阿里云天池大赛赛题(深度学习)——视频增强(完整代码)









