当前位置:网站首页>pytorch模型加载跑测试集和训练过程中跑测试集结果不一致的问题?
pytorch模型加载跑测试集和训练过程中跑测试集结果不一致的问题?
2022-04-22 09:05:00 【心之所向521】
前馈网络使用with torch.no_grad()和model.eval()比较
问题描述
将训练好的模型拿来做inference,发现显存被占满,无法进行后续操作,但按理说不应该出现这种情况。
RuntimeError: CUDA out of memory.
Tried to allocate 128.00 MiB (GPU 0; 7.93 GiB total capacity; 6.94 GiB already allocated; 10.56 MiB free; 7.28 GiB reserved in total by PyTorch)
解决方案 经过排查代码,发现做inference时,各模型虽然已经设置为eval()模式,但是并没有取消网络生成计算图这一操作,这就导致网络在单纯做前向传播时也生成了计算图,从而消耗了大量显存。
所以,将模型前向传播的代码放到with torch.no_grad()下,就能使pytorch不生成计算图,从而节省不少显存
with torch.no_grad():
# 代码块
outputs = model(inputs)
# 代码块
经过修改,再进行inference就没有遇到显存不够的情况了。此时显存占用显著降低,只占用5600MB左右(3卡)。
model.eval()和torch.no_grad()比较:
model.eval()
-
使用model.eval()切换到测试模式,不会更新模型的k,b参数
-
通知dropout层和batchnorm层在train和val中间进行切换在。train模式,dropout层会按照设定的参数p设置保留激活单元的概率(保留概率=p,比如keep_prob=0.8),batchnorm层会继续计算数据的mean和var并进行更新。在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值!
-
model.eval()不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播!(backprobagation),即只设置了model.eval()pytorch依旧会生成计算图,占用显存,只是不使用计算图来进行反向传播。
torch.no_grad()
首先从requires_grad讲起:
requires_grad
在pytorch中,tensor有一个requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导,并且保存在计算图中。tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么所有依赖它的节点requires_grad都为True(即使其他相依赖的tensor的requires_grad = False)
当requires_grad设置为False时,反向传播时就不会自动求导了,也就不会生成计算图,而GPU也不用再保存计算图,因此大大节约了显存或者说内存。
with torch.no_grad
在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。
即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。例子如下所示:
x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():
w = x + y + z
print(w.requires_grad)
print(w.grad_fn)
print(w.requires_grad)
False
None
False
也就是说,在with torch.no_grad结构中的所有tensor的requires_grad属性会被强行设置为false,如果前向传播过程在该结构中,那么inference过程中都不会产生计算图,从而节省不少显存。
reference:with torch.no_grad():显著减少测试时显存占用_落歌439的博客-CSDN博客
但是?又出问题pytorch模型加载跑测试集和训练过程中跑测试集结果不一致的问题?
虽然利用model.val()可以使得结果和训练时结果相似,但是误差相比训练时的测试还是有一定影响!(视情况而定),如何让他们的结果彻底相同呢?
解决方案:
保存训练完成的神经网络模型,来尝试跑了下测试集的结果,发现效果很差,和训练网络时跑测试集的结果不一样。查了些资料,发现是先eval()再测试数据的问题:
错误写法:
.....
model = torch.load('model.pkl')
model.eval() #先eval
x = model.forward(a) #然后传递数据进行测试
.....
改进:
.....
model = torch.load('model.pkl')
x = model.forward(a)
model.eval()
.....
上面的方法其实有问题,当初测单个数据的时候正确了,但是用for循环测大量数据的时候会出问题,例如:
错误写法:
model = torch.load('model.pkl')
for i in range(1,100)
a = load_data.. #导入数据
x = model.forward(a)
model.eval()
正确写法是不需要model.eval():
model = torch.load('model.pkl')
for i in range(1,100)
a = load_data.. #导入数据
x = model.forward(a)
这个小问题真的致命,让我一度怀疑我的数据有问题,但结果证明数据是没问题的!
版权声明
本文为[心之所向521]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_45564943/article/details/124334657
边栏推荐
- 如何在mysql数据表中添加邮箱email
- The bare metal machine developed by single chip microcomputer can also "multitask"?
- 玩转Kubernetes—基础概念篇
- About the fact that I was cheated of fifteen thousand when I wanted to borrow money
- How to add mailbox email in MySQL data table
- Stream API 优化代码
- Baby naming artifact applet source code_ Support multiple traffic master modes
- Construire manuellement le tissu hyperledger V2. X réseau de production (IV) Création de canaux, cycle de vie des codes de chaîne
- 智能手表的突破和新发展机遇
- Advanced view of MySQL
猜你喜欢

About the fact that I was cheated of fifteen thousand when I wanted to borrow money

精彩回顾|「源」来如此 第六期 - 开源经济与产业投资

宝宝起名神器小程序源码_支持多种流量主模式

Halo 开源项目学习(一):项目启动

ERP 集成对公司系统完善的重要性

VMware 虚拟机安装 OpenWrt 作旁路由 单臂路由 img 镜像转 vmdk

(CVPR-2014)通过预测 10,000 个类别的深度学习人脸表示

2022 R1 quick opening pressure vessel operation exercises and online simulation examination

oracle18c rac安装grid执行脚本root.sh报错,PRCR-1013 : 无法启动资源 ora.ons

分布式场景业务操作日志实现(基于redis轻量)
随机推荐
LeetCode 349. 两个数组的交集(简单、数组)day12
2022 R1 quick opening pressure vessel operation exercises and online simulation examination
【ValueError: math domain error】
onenet云平台数据推送到数据库
Axure 如何在页面加载时,设置文本框的内容为当前日期
RHCSA第二天作业
知识点的5W
【微信小程序】为小程序设置底部菜单(tabBar)
Section I: the first step of portrait refinement - reasonable transfer
新动态视频壁纸微信小程序源码_支持多种分类短视频-也有静态壁纸
Cmake uses the basic grammar of basic knowledge I
[path of system analyst] real topic of case analysis of system analyst in 2020
找工作、写简历到面试,这套资料就够了!
栈和队列的初步认识
基于麒麟SP10服务器版的Kubernetes集群安装
精彩回顾|「源」来如此 第六期 - 开源经济与产业投资
学习RHCSA的第三天
相交链表(Set、双指针)
智能手表的突破和新发展机遇
1215_SCons使用之hello world