当前位置:网站首页>深度学习模型训练前的必做工作:总览模型信息
深度学习模型训练前的必做工作:总览模型信息
2022-08-10 05:29:00 【公众号学一点会一点】

在使用深度学习模型处理图像数据的时候,输入数据的大小在整个网络中是怎么变化的非常重要,但是如果只看代码的话,我们算起来比较麻烦,比如我们经过了各种上采样、下采样等,中间过程可能有几十个网络层,算过来算过去一来是麻烦,二来是不清晰明了。
今天介绍一个用来查看模型概览信息的包,对新手学习非常有帮助! 这就是torchsummary包(https://github.com/sksq96/pytorch-summary)。
直接上用法。
安装
安装没什么难的,直接用pip即可。
pip install torchsummary
搭建模型
根据自己的需求进行模型的搭建。
模型总体信息概览
比如我们搭建了下面一个网络:
class MyNet(nn.Module):
def __init__(self, inchannels):
super(SRCNN, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(inchannels, 64, kernel_size=9, stride=(1, 1), padding=(4, 4)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=1, stride=(1, 1), padding=(0, 0)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=5, stride=(1, 1), padding=(2, 2))
)
def forward(self, x):
y = self.main(x)
return y
现在我们想看一下输入数据在网络中的大小变化,以及可学习参数的总数量等信息,那么只需要使用summary函数即可:
model = MyNet(inchannels=1).to('cuda')
summary(model, input_size=(1, 222, 222))
然后便可以得到如下信息:

从中可以看到整个模型的如下信息:
所有的layer 每个layer的输出尺寸(-1代表是可变的,这里是batchsize的大小) 每一层的参数量 总的参数量 总得可学习参数量和无需学习的参数量 占用空间的大小
有什么用?
目前,个人觉得有很多:
最直接的就是可以清楚地了解自己的网络结构; 在用卷积神经网络的时候可以用这个来调整kernel size和padding的尺寸,这样子就不用用公式算了; 可以来测试自己的模型是不是能跑通。。如果错了的话,上面的结果是出不来的。
参考
【1】https://openbase.com/python/torch-summary
【2】https://clay-atlas.com/us/blog/2020/05/13/pytorch-en-note-torchsummary/
本文由 mdnice 多平台发布
边栏推荐
- Jenkins 如何玩转接口自动化测试?
- Important transformation and upgrading
- SQL Server query optimization
- 看了几十篇轻量化目标检测论文扫盲做的摘抄笔记
- Introduction to curl command
- 并发工具类——CountDownLatch、CyclicBarrier、Semaphore、Exchanger的介绍与使用
- 【论文笔记1】小样本分类
- How does flinksql write that the value of redis has only the last field?
- strongest brain (1)
- 基于Qiskit——《量子计算编程实战》读书笔记(六)
猜你喜欢

How to improve product quality from the code layer

R语言:修改chart.Correlation()函数绘制相关性图——完美出图

聊聊 API 管理-开源版 到 SaaS 版

Interface debugging also can play this?

基于Qiskit——《量子计算编程实战》读书笔记(四)

PyTorch 入门之旅

【Pei Shu Theorem】CF1055C Lucky Days

CORS跨域资源共享漏洞的原理与挖掘方法

Qiskit官方文档选译之量子傅里叶变换(Quantum Fourier Transform, QFT)

awk of the Three Musketeers of Shell Programming
随机推荐
Big guys, mysql cdc (2.2.1 and previous versions) sometimes has this situation since savepoint, is there anything wrong?
每周推荐短视频:探索AI的应用边界
再肝3天,整理了90个 NumPy 例子,不能不收藏!
Linear Algebra (4)
SQL database field to append to main table
论文精读 —— 2021 CVPR《Progressive Temporal Feature Alignment Network for Video Inpainting》
Matlab simulation of multi-factor house price prediction based on BP neural network
WSTP初体验
几种绘制时间线图的方法
【Static proxy】
MySql's json_extract function processes json fields
awk of the Three Musketeers of Shell Programming
Concurrency tool class - introduction and use of CountDownLatch, CyclicBarrier, Semaphore, Exchanger
基本比例尺标准分幅编号流程
MySql之json_extract函数处理json字段
【LeetCode】41. The first missing positive number
k-近邻实现手写数字识别
Transforming into a product, is it reliable to take the NPDP test?
Order table delete, insert and search operations
Rpc interface stress test