当前位置:网站首页>深度学习模型训练前的必做工作:总览模型信息
深度学习模型训练前的必做工作:总览模型信息
2022-08-10 05:29:00 【公众号学一点会一点】

在使用深度学习模型处理图像数据的时候,输入数据的大小在整个网络中是怎么变化的非常重要,但是如果只看代码的话,我们算起来比较麻烦,比如我们经过了各种上采样、下采样等,中间过程可能有几十个网络层,算过来算过去一来是麻烦,二来是不清晰明了。
今天介绍一个用来查看模型概览信息的包,对新手学习非常有帮助! 这就是torchsummary
包(https://github.com/sksq96/pytorch-summary)。
直接上用法。
安装
安装没什么难的,直接用pip即可。
pip install torchsummary
搭建模型
根据自己的需求进行模型的搭建。
模型总体信息概览
比如我们搭建了下面一个网络:
class MyNet(nn.Module):
def __init__(self, inchannels):
super(SRCNN, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(inchannels, 64, kernel_size=9, stride=(1, 1), padding=(4, 4)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=1, stride=(1, 1), padding=(0, 0)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=5, stride=(1, 1), padding=(2, 2))
)
def forward(self, x):
y = self.main(x)
return y
现在我们想看一下输入数据在网络中的大小变化,以及可学习参数的总数量等信息,那么只需要使用summary函数即可:
model = MyNet(inchannels=1).to('cuda')
summary(model, input_size=(1, 222, 222))
然后便可以得到如下信息:

从中可以看到整个模型的如下信息:
所有的layer 每个layer的输出尺寸(-1代表是可变的,这里是batchsize的大小) 每一层的参数量 总的参数量 总得可学习参数量和无需学习的参数量 占用空间的大小
有什么用?
目前,个人觉得有很多:
最直接的就是可以清楚地了解自己的网络结构; 在用卷积神经网络的时候可以用这个来调整kernel size和padding的尺寸,这样子就不用用公式算了; 可以来测试自己的模型是不是能跑通。。如果错了的话,上面的结果是出不来的。
参考
【1】https://openbase.com/python/torch-summary
【2】https://clay-atlas.com/us/blog/2020/05/13/pytorch-en-note-torchsummary/
本文由 mdnice 多平台发布
边栏推荐
- You can‘t specify target table ‘kms_report_reportinfo‘ for update in FROM clause
- 【格式转换】将JPEG图片批量处理为jpg格式
- Introduction to curl command
- aliases节点分析
- 深度梳理:防止模型过拟合的方法汇总
- Jenkins 如何玩转接口自动化测试?
- ThreadPoolExecutor线程池原理
- SQLSERVER 2008 parses data in Json format
- 自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略
- Qiskit 学习笔记2
猜你喜欢
Kubernetes:(十六)Ingress的概念和原理
Rpc interface stress test
Flutter development: error The following assertion was thrown resolving an image codec: Solution for Unable to...
从GET切换为POST提交数据的方法
Interface documentation evolution illustration, some ancient interface documentation tools, you may not have used it
strongest brain (1)
An article to master the entire JVM, JVM ultra-detailed analysis!!!
Talk about API Management - Open Source Edition to SaaS Edition
PyTorch 入门之旅
Qiskit学习笔记(三)
随机推荐
How to simulate the background API call scene, very detailed!
flex related
YOLOv5 PyQt5(一起制作YOLOv5的GUI界面)
【yolov5训练错误】WARNING: Ignoring corrupted image
反转链表中的第m至第n个节点---leetcode
Introduction to curl command
实战小技巧19:List转Map List的几种姿势
FPGA工程师面试试题集锦41~50
conda创建虚拟环境方法和pqi使用国内镜像源安装第三方库的方法教程
【静态代理】
基于Qiskit——《量子计算编程实战》读书笔记(三)
FPGA工程师面试试题集锦1~10
AVL树的插入--旋转笔记
Order table delete, insert and search operations
FPGA engineer interview questions collection 21~30
CORS跨域资源共享漏洞的原理与挖掘方法
如何用Apifox 的智能Mock功能?
Advanced Feature Selection Techniques in Linear Models - Based on R
手把手带你写嵌入式物联网的第一个项目
大咖说·对话生态|当Confluent遇见云:实时流动的数据更有价值