当前位置:网站首页>深度学习模型训练前的必做工作:总览模型信息
深度学习模型训练前的必做工作:总览模型信息
2022-08-10 05:29:00 【公众号学一点会一点】

在使用深度学习模型处理图像数据的时候,输入数据的大小在整个网络中是怎么变化的非常重要,但是如果只看代码的话,我们算起来比较麻烦,比如我们经过了各种上采样、下采样等,中间过程可能有几十个网络层,算过来算过去一来是麻烦,二来是不清晰明了。
今天介绍一个用来查看模型概览信息的包,对新手学习非常有帮助! 这就是torchsummary包(https://github.com/sksq96/pytorch-summary)。
直接上用法。
安装
安装没什么难的,直接用pip即可。
pip install torchsummary
搭建模型
根据自己的需求进行模型的搭建。
模型总体信息概览
比如我们搭建了下面一个网络:
class MyNet(nn.Module):
def __init__(self, inchannels):
super(SRCNN, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(inchannels, 64, kernel_size=9, stride=(1, 1), padding=(4, 4)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=1, stride=(1, 1), padding=(0, 0)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=5, stride=(1, 1), padding=(2, 2))
)
def forward(self, x):
y = self.main(x)
return y
现在我们想看一下输入数据在网络中的大小变化,以及可学习参数的总数量等信息,那么只需要使用summary函数即可:
model = MyNet(inchannels=1).to('cuda')
summary(model, input_size=(1, 222, 222))
然后便可以得到如下信息:

从中可以看到整个模型的如下信息:
所有的layer 每个layer的输出尺寸(-1代表是可变的,这里是batchsize的大小) 每一层的参数量 总的参数量 总得可学习参数量和无需学习的参数量 占用空间的大小
有什么用?
目前,个人觉得有很多:
最直接的就是可以清楚地了解自己的网络结构; 在用卷积神经网络的时候可以用这个来调整kernel size和padding的尺寸,这样子就不用用公式算了; 可以来测试自己的模型是不是能跑通。。如果错了的话,上面的结果是出不来的。
参考
【1】https://openbase.com/python/torch-summary
【2】https://clay-atlas.com/us/blog/2020/05/13/pytorch-en-note-torchsummary/
本文由 mdnice 多平台发布
边栏推荐
- Abstract problem methodology
- Advanced Feature Selection Techniques in Linear Models - Based on R
- MySql's json_extract function processes json fields
- OneFlow源码解析:算子指令在虚拟机中的执行
- 接口文档进化图鉴,有些古早接口文档工具,你可能都没用过
- 【格式转换】将JPEG图片批量处理为jpg格式
- 基于Qiskit——《量子计算编程实战》读书笔记(二)
- 树莓派入门(4)LED闪烁&呼吸灯
- pytorch框架学习(6)训练一个简单的自己的CNN (三)细节篇
- oracle rac 11g安装执行root.sh时报错
猜你喜欢

动手写prometheus的exporter-02-Counter(计数器)

MySQL simple tutorial

pytorch框架学习(1)网络的简单构建

R语言:修改chart.Correlation()函数绘制相关性图——完美出图

【写下自用】每次都忘记如何train?记录如何训练自己的yolov5

【静态代理】

strongest brain (1)

Order table delete, insert and search operations

SSM框架整合实例

EasyGBS connects to mysql database and prompts "can't connect to mysql server", how to solve it?
随机推荐
WSTP初体验
MySQL simple tutorial
Important transformation and upgrading
深度梳理:防止模型过拟合的方法汇总
k-近邻实现手写数字识别
论文精读 —— 2021 CVPR《Progressive Temporal Feature Alignment Network for Video Inpainting》
The time for flinkcdc to read pgsql is enlarged. Does anyone know what happened? gmt_create':1
看了几十篇轻量化目标检测论文扫盲做的摘抄笔记
Why are negative numbers in binary represented in two's complement form - binary addition and subtraction
FPGA engineer interview questions collection 1~10
Ask you guys.The FlinkCDC2.2.0 version in the CDC community has a description of the supported sqlserver version, please
Order table delete, insert and search operations
aliases节点分析
在vscode中屏蔽Alt热键
pytorch learning
Talk about API Management - Open Source Edition to SaaS Edition
Qiskit 学习笔记2
几种绘制时间线图的方法
Become a language that hackers have to learn. Do you think it's okay after reading it?
如何模拟后台API调用场景,很细!