当前位置:网站首页>NLLLoss+log_SoftMax=CE_Loss
NLLLoss+log_SoftMax=CE_Loss
2022-04-23 06:34:00 【365JHWZGo】
题目
NLLLoss+log_SoftMax=CE_Loss
前言
差不多好几天没有更新了,唉,最近有点忙,请见谅!今天终于忙里偷闲看了B站视频发现NLLLoss+log_SoftMax=CE_Loss,哈哈,又长见识了,那么今天来深入了解一下吧!
NLLLoss
中文
负对数似然损失
函数
torch.nn.NLLLoss(
weight=None,
size_average=None,
ignore_index=- 100,
reduce=None,
reduction='mean'
)
参数详解
| 参数 | 类型 | 含义 |
|---|---|---|
| weight | tensor(optional) | 手动给每一个类重新调整权重比例 |
| size_average | bool(optional)已弃用 | 损失是一个batch中每一个损失元素的平均值 |
| ignore_index | int(optional) | 忽略一个目标值后它不会对输入的梯度产生影响 |
| reduce | bool(optional)已弃用 | 根据每个mini-batch的平均尺寸对损失进行均分或汇总 |
| reduction | string(optional) | reuduction将会被应用于输出 |
| reduction参数 | 解释 |
|---|---|
| none | 全部展开 |
| mean | 累加/个数 |
| sum | 把none结果累加 |
函数输入输出
| 输入 | 目标 | 输出 | |
|---|---|---|---|
| 类型 | tensor | tensor | tensor |
| 维度 | ( N , C N,C N,C)/( C C C) | ( N N N)/() | ( N N N)/ ( N , d 1 , d 2 , . . . , d K ) (N, d_1, d_2, ..., d_K) (N,d1,d2,...,dK) |
代码
import torch
# 模仿经过模型之后的输出结果
preds = torch.tensor([[1.5,2.5,3.0]])
# 真实标签
target = torch.tensor([1])
nllloss = torch.nn.NLLLoss()
print('nllloss:',nllloss(preds,target))

与CE_Loss的区别和联系
区别:
- CE_Loss是先经过log_softmax再经过NLLLoss步骤的损失
- NLLLoss仅仅是将对应index的target值变为相反数
联系:
- 存在着一定的相关性,经过softmax之后的值大都在[0,1]之间,在进行log取对数之后值分布在( − ∞ -\infty −∞,0]之间,而loss又不能为负,所以需要NLLLoss来将其变为正数
import torch
# 模仿经过模型之后的输出结果
preds = torch.tensor([[1.5,2.5,3.0]])
# 真实标签
target = torch.tensor([1])
cross_entropy_loss = torch.nn.CrossEntropyLoss()
log_softmax = torch.nn.LogSoftmax(dim=1)
nllloss = torch.nn.NLLLoss()
cs_loss = cross_entropy_loss(preds,target)
nls_loss = nllloss(log_softmax(preds),target)
print(f'交叉熵损失函数为:{
cs_loss}\n先经过log_softmax再经过nll损失函数为:{
nls_loss}')

版权声明
本文为[365JHWZGo]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_44833392/article/details/124342281
边栏推荐
猜你喜欢

Research on software security based on NLP (2)

Robust and Efficient Quadrotor Trajectory Generation for Fast Autonomous Flight

Houdini>流体,刚体导出学习过程笔记

Houdini terrain and fluid solution (simulated debris flow)

使用 Ingress 实现金丝雀发布

Online Safe Trajectory Generation For Quadrotors Using Fast Marching Method and Bernstein Basis Poly

【编程实践/嵌入式比赛】嵌入式比赛学习记录(一):TCP服务器和web界面的建立

Ctf-misc learning from start to give up

Redis--为什么字符串emstr的字符串长度是44字节上限?

国基北盛-openstack-容器云-环境搭建
随机推荐
String self generated code within a given range
Learning records of some shooting ranges: sqli labs, upload labs, XSS
VBA appelle SAP RFC pour réaliser la lecture et l'écriture des données
Ribbon启动流程
A series of articles, a summary of common vulnerabilities of Web penetration (continuously updated)
Buuctf misc brush questions
About USB flash drive data prompt raw, need to format, data recovery notes
Intranet penetration series: icmptunnel of Intranet tunnel (by master dhavalkapil)
第五章 投资性房地产
Houdini地形与流体解算(模拟泥石流)
Redis transaction implements optimistic locking principle
Online Safe Trajectory Generation For Quadrotors Using Fast Marching Method and Bernstein Basis Poly
学fpga(从verilog到hls)
Internal network security attack and defense: a practical guide to penetration testing (5): analysis and defense of horizontal movement in the domain
strcat()、strcpy()、strcmp()、strlen()
BUUCTF [ACTF2020 新生赛]Include1
读书笔记
Automatically fit single line text into the target rectangle
RAID0和RAID5的创建和模拟RAID5工作原理
Reading notes