当前位置:网站首页>Pytorch introduction notes - use a simple example to observe the output size of each layer of forward propagation
Pytorch introduction notes - use a simple example to observe the output size of each layer of forward propagation
2022-04-23 05:58:00 【umbrellalalalala】
Bloggers are learning 《 Deep learning framework PyTorch: Introduction and practice 》, Record a simple example , To deepen the torch Understanding of forward propagation parameters .
This is an example of defining a network in Chapter 2 of the book , Looking directly at the code may not be intuitive , especially x = x.view(x.size()[0], -1) This sentence , Beginners want to feel size The change of , as well as fc1 Medium 16 ∗ 5 ∗ 5 16*5*5 16∗5∗5 The source of the :
import torch.nn as nn
import torch.nn.functional as F # Activation and pooling are here
class Net(nn.Module):
def __init__(self):
# nn.Module The function of the subclass must execute the constructor of the parent class in the constructor
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5) # Input channel 1, Output channel 6, Convolution kernel 5*5, The same below
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
For bloggers jupyter The output of each layer is recorded size:
# The size of the input picture is 32*32,channel by 1. namely batch*channel*height*width
input = Variable(t.randn(1, 1, 32, 32))
out = net(input)
# Run layer by layer and output the data of each layer size
out1 = net.conv1(input)
pool1 = F.max_pool2d(out1, 2)
out2 = net.conv2(pool1)
pool2 = F.max_pool2d(out2, 2)
stretch = pool2.view(pool2.size()[0], -1)
out3 = net.fc1(strech)
out4 = net.fc2(out3)
out5 = net.fc3(out4)
print('output of conv1:', out1.shape)
print('output of pool1:', pool1.shape)
print('output of conv2:', out2.shape)
print('output of pool2:', pool2.shape)
print('output of stretch:', stretch.shape)
print('output of fc1:', out3.shape)
print('output of fc2:', out4.shape)
print('output of fc3:', out5.shape)
Running results :
output of conv1: torch.Size([1, 6, 28, 28])
output of pool1: torch.Size([1, 6, 14, 14])
output of conv2: torch.Size([1, 16, 10, 10])
output of pool2: torch.Size([1, 16, 5, 5])
output of strech: torch.Size([1, 400])
output of fc1: torch.Size([1, 120])
output of fc2: torch.Size([1, 84])
output of fc3: torch.Size([1, 10])
This is very intuitive ! among x = x.view(x.size()[0], -1) Yes, it will [1, 16, 5, 5] Turned into [1, 400]
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230543474540.html
边栏推荐
猜你喜欢

关于二叉树的遍历

深入源码分析Servlet第一个程序

You cannot access this shared folder because your organization's security policy prevents unauthenticated guests from accessing it

Pyqy5 learning (2): qmainwindow + QWidget + qlabel

JDBC连接数据库

Anaconda安装PyQt5 和 pyqt5-tools后没有出现designer.exe的问题解决

Pytorch learning record (V): back propagation + gradient based optimizer (SGD, adagrad, rmsporp, Adam)

类的加载与ClassLoader的理解

Fundamentals of digital image processing (Gonzalez) II: gray transformation and spatial filtering

Software architecture design - software architecture style
随机推荐
Postfix变成垃圾邮件中转站后的补救
关于二叉树的遍历
在Jupyter notebook中用matplotlib.pyplot出现服务器挂掉、崩溃的问题
Pyqy5 learning (III): qlineedit + qtextedit
You cannot access this shared folder because your organization's security policy prevents unauthenticated guests from accessing it
Graphic numpy array matrix
图像恢复论文简记——Uformer: A General U-Shaped Transformer for Image Restoration
深入理解去噪论文——FFDNet和CBDNet中noise level与噪声方差之间的关系探索
线性代数第一章-行列式
container
Use Matplotlib. In Jupiter notebook Pyplot server hangs up and crashes
PyQy5学习(四):QAbstractButton+QRadioButton+QCheckBox
PyQy5学习(二):QMainWindow+QWidget+QLabel
Software architecture design - software architecture style
Package mall system based on SSM
filebrowser实现私有网盘
What is JSON? First acquaintance with JSON
Ptorch learning record (XIII): recurrent neural network
MySQL lock mechanism
rsync实现文件服务器备份