当前位置:网站首页>Pytorch (V) -- Notes
Pytorch (V) -- Notes
2022-04-21 23:16:00 【Deer holding grass】
Catalog
1. Classical convolutional network
1.1 ImageNet

1.2 VGG
VGG The network layer

1.3 GoogLeNet

1,4 Stack more layers?

2. Deep residual network
2.1 The residual module

2.2 Deeper residual module

2.3 Code implementation
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
# from torchvision.models import resnet18
class ResBlk(nn.Module):
""" resnet block """
def __init__(self, ch_in, ch_out):
""" :param ch_in: :param ch_out: """
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
""" :param x: [b, ch, h, w] :return: """
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h ,w]
self.blk1 = ResBlk(16, 16)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(16, 32)
# # [b, 256, h, w] => [b, 512, h, w]
# self.blk3 = ResBlk(128, 256)
# # [b, 512, h, w] => [b, 1024, h, w]
# self.blk4 = ResBlk(256, 512)
self.outlayer = nn.Linear(32*32*32, 10)
def forward(self, x):
""" :param x: :return: """
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
# x = self.blk3(x)
# x = self.blk4(x)
# print(x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')
# model = Lenet5().to(device)
model = ResNet18().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
#
print(epoch, 'loss:', loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
# print(correct)
acc = total_correct / total_num
print(epoch, 'acc:', acc)
if __name__ == '__main__':
main()
版权声明
本文为[Deer holding grass]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204212315084992.html
边栏推荐
- P1027 [noip2001 improvement group] car's travel route (the shortest path in the diagram)
- 【H.264】简单编码器及SPS
- Bit by bit concentrated and clean, the way to break the situation in the detergent industry
- Kubernetes - Secret configuration management
- Kubernetes---ConfigMap配置管理
- SWOOLE高性能内存数据库的使用和配置教程
- Teach you to easily solve CSRF Cross Site Request Forgery Attack
- Go111module = on go mod init error (go: cannot determine module path for source directory)
- APM industry awareness series - 12 - 13
- Concept and working principle of image acquisition card
猜你喜欢

TensorFlow 2.8 安装

6. Example of QT using MySQL

手机APP游戏/软件/资源下载站/软件盒子源码

Ruffian Heng embedded: talk about the application and influence of system watchdog wdog1 in the startup of i.mxrt1xxx system

Ros2 robot modeling URDF 8.2rviz2 visual mobile robot model

8.3 create a mobile robot by hand in rodf robot modeling

87 R k-means,层次聚类,EM聚类的实现

【ACM】46. Full Permutation (1. Here, the previous elements need to be used for permutation, so StartIndex is not used (only for combination and division); 2. Pay attention to whether the elements in t

Tensorflow 2.8 installation

Ruixin microchip AI part development record section 1 "PC side environment construction 1"
随机推荐
Go language self-study series | golang defer statements
golang力扣leetcode 第 289 场周赛
MySQL Chapter 5 addition, deletion, modification and query of MySQL table data
P1027 [noip2001 improvement group] car's travel route (the shortest path in the diagram)
2022-04-22日报:基于Transformer的新型人脸属性编辑框架TransEditor
Basic concepts of audio and video and a simple introduction to ffmpeg
手机APP游戏/软件/资源下载站/软件盒子源码
雲原生架構下的微服務選型和演進
1. MySQL workbench 8.0 installation
点滴浓缩洁净,洗衣液行业的破局之路
go-map
Basic concepts of audio and video and a simple introduction to ffmpeg
idea 解决项目包出现[wrapper(1)]
golang力扣leetcode 386.字典序排数
文件操作和IO
Chapter 2 installation of MySQL database
[MQ] starting from scratch to realize mq-01-start of producers and consumers
瑞芯微芯片AI部分开发记录 第一节 《PC端环境搭建2》
Bit by bit concentrated and clean, the way to break the situation in the detergent industry
Custom template problem help, automatically add time and date