当前位置:网站首页>深度学习模型训练前的必做工作:总览模型信息
深度学习模型训练前的必做工作:总览模型信息
2022-08-10 05:29:00 【公众号学一点会一点】

在使用深度学习模型处理图像数据的时候,输入数据的大小在整个网络中是怎么变化的非常重要,但是如果只看代码的话,我们算起来比较麻烦,比如我们经过了各种上采样、下采样等,中间过程可能有几十个网络层,算过来算过去一来是麻烦,二来是不清晰明了。
今天介绍一个用来查看模型概览信息的包,对新手学习非常有帮助! 这就是torchsummary包(https://github.com/sksq96/pytorch-summary)。
直接上用法。
安装
安装没什么难的,直接用pip即可。
pip install torchsummary
搭建模型
根据自己的需求进行模型的搭建。
模型总体信息概览
比如我们搭建了下面一个网络:
class MyNet(nn.Module):
def __init__(self, inchannels):
super(SRCNN, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(inchannels, 64, kernel_size=9, stride=(1, 1), padding=(4, 4)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=1, stride=(1, 1), padding=(0, 0)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=5, stride=(1, 1), padding=(2, 2))
)
def forward(self, x):
y = self.main(x)
return y
现在我们想看一下输入数据在网络中的大小变化,以及可学习参数的总数量等信息,那么只需要使用summary函数即可:
model = MyNet(inchannels=1).to('cuda')
summary(model, input_size=(1, 222, 222))
然后便可以得到如下信息:

从中可以看到整个模型的如下信息:
所有的layer 每个layer的输出尺寸(-1代表是可变的,这里是batchsize的大小) 每一层的参数量 总的参数量 总得可学习参数量和无需学习的参数量 占用空间的大小
有什么用?
目前,个人觉得有很多:
最直接的就是可以清楚地了解自己的网络结构; 在用卷积神经网络的时候可以用这个来调整kernel size和padding的尺寸,这样子就不用用公式算了; 可以来测试自己的模型是不是能跑通。。如果错了的话,上面的结果是出不来的。
参考
【1】https://openbase.com/python/torch-summary
【2】https://clay-atlas.com/us/blog/2020/05/13/pytorch-en-note-torchsummary/
本文由 mdnice 多平台发布
边栏推荐
- MongoDB 基础了解(一)
- How does flinksql write that the value of redis has only the last field?
- Nexus_Warehouse Type
- pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)
- Become a language that hackers have to learn. Do you think it's okay after reading it?
- Pony语言学习(九)——泛型与模式匹配(终章)
- 基于Qiskit——《量子计算编程实战》读书笔记(三)
- Order table delete, insert and search operations
- AVL tree insertion--rotation notes
- FPGA engineer interview questions collection 11~20
猜你喜欢

AVL树的插入--旋转笔记

Advanced Feature Selection Techniques in Linear Models - Based on R

MySql之json_extract函数处理json字段

ThreadPoolExecutor thread pool principle

CORS跨域资源共享漏洞的原理与挖掘方法

Kubernetes:(十六)Ingress的概念和原理

strongest brain (1)

MongoDB 基础了解(一)

基于Servlet的验证码登陆demo

Conda creates a virtual environment method and pqi uses a domestic mirror source to install a third-party library method tutorial
随机推荐
Shield Alt hotkey in vscode
OAuth2 usage scenarios, common misunderstandings, use cases
Practical skills 19: Several postures of List to Map List
利用PyQt5制作YOLOv5的GUI界面
自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略
You can‘t specify target table ‘kms_report_reportinfo‘ for update in FROM clause
25张炫酷交互图表,一文入门Plotly
An article will help you understand what is idempotency?How to solve the idempotency problem?
pytorch框架学习(3)torch.nn.functional模块和nn.Module模块
树莓派入门(3)树莓派GPIO学习
基于BP神经网络的多因素房屋价格预测matlab仿真
pytorch框架学习(2)使用GPU训练
Order table delete, insert and search operations
Kubernetes:(十六)Ingress的概念和原理
通过一个案例轻松入门OAuth协议
Matlab simulation of multi-factor house price prediction based on BP neural network
暑期学前作业
pytorch框架学习(5)torchvision模块&训练一个简单的自己的CNN (二)
How to improve product quality from the code layer
虚拟土地价格暴跌85% 房地产泡沫破裂?依托炒作的暴富游戏需谨慎参与