当前位置:网站首页>记录贴:pytorch学习Part4
记录贴:pytorch学习Part4
2022-08-08 17:33:00 【安联之夜】
记录贴:pytorch学习Part4
一、卷积
import torch
import torch.nn as nn
from torch.nn import functional as F
#方式一
layer=nn.Conv2d(1,3,kernel_size=3,stride=1,padding=0)
x = torch.rand(1,1,28,28)
out = layer.forward(x)
layer=nn.Conv2d(1,3,kernel_size=3,stride=1,padding=1)
out = layer.forward(x)
layer=nn.Conv2d(1,3,kernel_size=3,stride=2,padding=1)
out = layer.forward(x)
out.shape
layer.weight#权重的维度为3,1,3,3,代表三个卷积,一个图像通道,3*3的卷积核
layer.bias#一层卷积一个偏置
#方式二
w = torch.rand(16,3,5,5)#16个卷积层,3个图像通道,5*5的卷积核
b = torch.rand(16)
x = torch.randn(1,3,28,28)
out = F.conv2d(x,w,b,stride=1,padding=1)
二、下采样和上采样
#Pooling
x = torch.randn(1,16,14,14)
layer = nn.MaxPool2d(2,stride=2)#最大
out = layer(x)
out = F.avg_pool2d(x,2,stride=2)#平均
#upsample
x = out
out = F.interpolate(x,scale_factor=2,mode='nearest')
out = F.interpolate(x,scale_factor=3,mode='nearest')
三、标准化
#Image Normalization
normalize = transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
#Batch Normalization
x = torch.rand(100,16,784)
layer = nn.BatchNorm1d(16)
out = layer(x)
layer.running_mean
layer.running_var
x = torch.rand(1,16,7,7)
layer = nn.BatchNorm2d(16)
四、Resnet
class ResBlk(nn.Module):
def __init__(self,ch_in,ch_out):
self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
self.extra = nn.Sequential(
nn.Conv2d(ch_in,ch_out,kernel_size=4,stride=1),
nn.BatchNorm2d(ch_out))
def forward(self,x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.extra(x) + out
return out
五、类
# 网络结构
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.model = nn.Sequential(
#nn.Linear(784, 200),
Mylinear(784,200),
nn.BatchNorm1d(200, eps=1e-8),
nn.LeakyReLU(inplace=True),
#nn.Linear(200, 200),
Mylinear(200, 200),
nn.BatchNorm1d(200, eps=1e-8),
nn.LeakyReLU(inplace=True),
#nn.Linear(200, 10),
Mylinear(200,10),
nn.LeakyReLU(inplace=True)
)
#Container
self.net = nn.Sequential(
nn.Conv2d(1,32,5,1,1),
nn.MaxPool2d(2,2),
nn.ReLU(True),
nn.BatchNorm2d(32),
nn.Conv2d(32,64,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.Conv2d(64,64,3,1,1),
nn.MaxPool2d(2,2),
nn.ReLU(True),
nn.BatchNorm2d(64),
nn.Conv2d(64,128,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(128))
边栏推荐
猜你喜欢
随机推荐
咸阳广发证券股票开户安全吗,需要准备什么证件
无需精子卵子子宫体外培育胚胎,Cell论文作者这番话让网友们炸了
Reprinted, the fragment speaks very well, the big guy
【CC3200AI 实验教程4】疯壳·AI语音人脸识别(会议记录仪/人脸打卡机)-GPIO
【历史上的今天】8 月 8 日:中国第一个校园 BBS 成立;网景通信上市;EarthLink 创始人出生
Cy5反式环辛烯,TCO-Cy5,Cy5 trans-cyclooctene标记生物分子
正则在js中的使用
leetcode:313. 超级丑数
Debug和Release的区别
win10如何设置定时联网断网辅助自律
How to set timed network disconnection to assist self-discipline in win10
socket concept
LeetCode_Backtrack_Medium_491. Incrementing Subsequence
Qt——获取文件夹下所有子文件名称
R文件找不到问题
Cyanine5 tetrazine,Cy5 tetrazineCY5四嗪,1427705-31-4
什么是服务网格?在微服务体系中又是如何使用的?
Prometheus+Grafana监控系统
MySQL中怎么对varchar类型排序问题
为什么MySQL的主键查询这么快