当前位置:网站首页>Poly1CrossEntropyLoss的pytorch实现
Poly1CrossEntropyLoss的pytorch实现
2022-08-09 04:13:00 【Aldersw】
代码改自github
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "none"):
""" Create instance of Poly1CrossEntropyLoss :param num_classes: :param epsilon: :param reduction: one of none|sum|mean, apply reduction to final loss tensor """
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
return
def forward(self, output, labels):
""" Forward pass :param output: tensor of shape [N, num_classes] :param labels: tensor of shape [N] or [N, num_classes] :return: poly cross-entropy loss """
#timm如果有使用cutmix或者mixup,那么就会得到的labels是[batchsize,classnums]维度的矩阵,因为获取的label不是hard label,但也不算是smooth label,是每个数据对应的label类似于这种[0,0,0.65,0,0,0.35,0,...]
if labels.ndim == 1:
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=output.device,dtype=output.dtype)
else:
labels=labels.to(device=output.device)
pt = torch.sum(labels * F.softmax(output, dim=-1), dim=-1)#64,1
CE = F.cross_entropy(input=output, target=labels, reduction='none')
# print(CE.shape)#64
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
#加上这个平均操作是为了获取每个数据平均的loss,如果不加就是长度为batchsize的一维矩阵,分别存放每个数据对应的loss
#poly1 = poly1.mean()
return poly1
边栏推荐
猜你喜欢
随机推荐
助力To B业务,这类企业端数据值得风控童鞋关注
driftingblues靶机wp
浅谈进程与其创建方式
荣耀路由(WS831)做无线中继时LAN网段与WAN网段冲突解决方法
07.1 Supplements to the class
3年半测试经验,20K我都没有,看来是时候跳槽了...
wift3.0 set the navigation bar, title, font, item color and font size
Polygon zkEVM Prover
2022R1快开门式压力容器操作考试模拟100题及在线模拟考试
leetcode 1805. 字符串中不同整数的数目
了解CV和RoboMaster视觉组(五)滤波器、观测器和预测方法:自适应滤波器的应用
提升用户体验,给你的模态弹窗加个小细节
NanoDet代码逐行精读与修改(三)辅助训练模块AGM
(a) 7 classes and objects
Device Reliability vs. Temperature
了解CV和RoboMaster视觉组(五)滤波器、观测器和预测方法
32 Basic Statistics - Hypothesis Testing
了解CV和RoboMaster视觉组(五)运动建模与预测
FFmpeg编译支持x264/openH264/dash
为什么有的时间函数在同一事务内返回的都是同一值?