当前位置:网站首页>Pytorch trains the basic process of a network in five steps
Pytorch trains the basic process of a network in five steps
2022-04-23 07:17:00 【Breeze_】
- step1. Load data
- step2. Defining network
- step3. Define the loss function and optimizer
- step4. Training network , loop 4.1 To 4.6 Until the scheduled time is reached epoch Number
– step4.1 Load data
– step4.2 Initialization gradient
– step4.3 Computational feedforward
– step4.4 Calculate the loss
– step4.5 Calculate the gradient
– step4.6 Update the weights - step5. Save weights
# Training a classifier
import torchvision.datasets
import torch.utils.data
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch import optim
def train():
''' Training '''
'''1. Load data '''
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = (
'plane', 'car', 'bird', 'cat','deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)
'''2. Defining network '''
Net = LeNet()
'''3. Define the loss function and Optimizer '''
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Net.parameters(),lr=1e-3,momentum=0.9)
'''cuda Speed up '''
device = ['gpu' if torch.cuda.is_available() else 'cpu']
if device == 'gpu':
criterion.cuda()
Net.to(device)
# Net.cuda() # many GPU Please use DataParallel Method
'''4. Training network '''
print(' Start training ')
for epoch in range(3):
runing_loss = 0.0
for i,data in enumerate(trainloader,0):
inputs,label = data #1. Data loading
if device == 'gpu':
inputs = inputs.cuda()
label = label.cuda()
optimizer.zero_grad() #2. Initialization gradient
output = Net(inputs) #3. Computational feedforward
loss = criterion(output,label) #4. Calculate the loss
loss.backward() #5. Calculate the gradient
optimizer.step() #6. Update the weights
runing_loss += loss.item()
if i % 20 == 19:
print('epoch:',epoch,'loss',runing_loss/20)
runing_loss = 0.0
print(' Training done ')
'''4. Save model parameters '''
torch.save(Net.state_dict(),'cifar_AlexNet.pth')
if __name__=='__main__':
train()
版权声明
本文为[Breeze_]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610323060.html
边栏推荐
- HandlerThread原理和实际应用
- BottomSheetDialogFragment + ViewPager+Fragment+RecyclerView 滑动问题
- torch_geometric学习一,MessagePassing
- this.getOptions is not a function
- PyTorch最佳实践和代码编写风格指南
- 第5 章 机器学习基础
- [recommendation of new books in 2021] practical IOT hacking
- 第8章 生成式深度学习
- 记录webView显示空白的又一坑
- C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported
猜你喜欢
随机推荐
Exploration of SendMessage principle of advanced handler
MySQL notes 5_ Operation data
Handlerthread principle and practical application
JVM basics you should know
MySQL笔记4_主键自增长(auto_increment)
PaddleOCR 图片文字提取
ViewPager2实现画廊效果执行notifyDataSetChanged后PageTransformer显示异常 界面变形问题
launcher隐藏不需要显示的app icon
[recommendation for new books in 2021] professional azure SQL managed database administration
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
Record WebView shows another empty pit
ArcGIS License Server Administrator 无法启动解决方法
MySQL笔记1_数据库
一款png生成webp,gif, apng,同时支持webp,gif, apng转化的工具iSparta
项目,怎么打包
Recyclerview 批量更新View:notifyItemRangeInserted、notifyItemRangeRemoved、notifyItemRangeChanged
利用官方torch版GCN训练并测试cora数据集
素数求解的n种境界
Markdown basic grammar notes
红外传感器控制开关