当前位置:网站首页>PyTorch入门小笔记——利用简单例子观察前向传播各个层输出的size
PyTorch入门小笔记——利用简单例子观察前向传播各个层输出的size
2022-04-23 05:44:00 【umbrellalalalala】
博主正在学习《深度学习框架PyTorch:入门与实践》,记录一个简单的例子,加深对torch前向传播参数的理解。
这是书籍第二章的一个定义网络的例子,直接看代码可能会不太直观,特别是x = x.view(x.size()[0], -1)这一句,初学者希望能直观感受size的变化,以及fc1中的 16 ∗ 5 ∗ 5 16*5*5 16∗5∗5的来源:
import torch.nn as nn
import torch.nn.functional as F # 激活和池化都在这里
class Net(nn.Module):
def __init__(self):
# nn.Module子类的函数必须在构造函数中执行父类的构造函数
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5) # 输入通道1,输出通道6,卷积核5*5,下同
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
博主用jupyter记录了每一层输出的size:
# 输入图片的尺寸是32*32,channel为1。即batch*channel*height*width
input = Variable(t.randn(1, 1, 32, 32))
out = net(input)
# 逐层运行并输出每层的size
out1 = net.conv1(input)
pool1 = F.max_pool2d(out1, 2)
out2 = net.conv2(pool1)
pool2 = F.max_pool2d(out2, 2)
stretch = pool2.view(pool2.size()[0], -1)
out3 = net.fc1(strech)
out4 = net.fc2(out3)
out5 = net.fc3(out4)
print('output of conv1:', out1.shape)
print('output of pool1:', pool1.shape)
print('output of conv2:', out2.shape)
print('output of pool2:', pool2.shape)
print('output of stretch:', stretch.shape)
print('output of fc1:', out3.shape)
print('output of fc2:', out4.shape)
print('output of fc3:', out5.shape)
运行结果:
output of conv1: torch.Size([1, 6, 28, 28])
output of pool1: torch.Size([1, 6, 14, 14])
output of conv2: torch.Size([1, 16, 10, 10])
output of pool2: torch.Size([1, 16, 5, 5])
output of strech: torch.Size([1, 400])
output of fc1: torch.Size([1, 120])
output of fc2: torch.Size([1, 84])
output of fc3: torch.Size([1, 10])
这样就非常直观了!其中x = x.view(x.size()[0], -1)是将[1, 16, 5, 5]变成了[1, 400]
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://blog.csdn.net/umbrellalalalala/article/details/119891921
边栏推荐
- Kingdee EAS "general ledger" system calls "de posting" button
- Dwsurvey is an open source questionnaire system. Solve the problem that cannot be run and modify the bug.
- MySQL事务
- Package mall system based on SSM
- sklearn之 Gaussian Processes
- 治療TensorFlow後遺症——簡單例子記錄torch.utils.data.dataset.Dataset重寫時的圖片維度問題
- opensips(1)——安装opensips详细流程
- 自定义异常类
- 手动删除eureka上已经注册的服务
- Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)
猜你喜欢

JVM系列(4)——内存溢出(OOM)

Conda 虚拟环境管理(创建、删除、克隆、重命名、导出和导入)

Software architecture design - software architecture style

Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)

Package mall system based on SSM

Pyemd installation and simple use

2 - principes de conception de logiciels

jdbc入门\获取数据库连接\使用PreparedStatement

Configure domestic image accelerator for yarn

Anaconda
随机推荐
Pytorch Learning record (XIII): Recurrent Neural Network
EditorConfig
多线程与高并发(2)——synchronized用法详解
图像恢复论文简记——Uformer: A General U-Shaped Transformer for Image Restoration
Font shape `OMX/cmex/m/n‘ in size <10.53937> not available (Font) size <10.95> substituted.
MySQL lock mechanism
去噪论文阅读——[CVPR2022]Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots
容器
如何利用对比学习做无监督——[CVPR22]Deraining&[ECCV20]Image Translation
金蝶EAS“总账”系统召唤“反过账”按钮
Manually delete registered services on Eureka
PyQy5学习(二):QMainWindow+QWidget+QLabel
Package mall system based on SSM
RedHat6之smb服务访问速度慢解决办法记录
Latex快速入门
MySql基础狂神说
JDBC连接数据库
Multithreading and high concurrency (2) -- detailed explanation of synchronized usage
编程记录——图片旋转函数scipy.ndimage.rotate()的简单使用和效果观察
2.devops-sonar安装