当前位置:网站首页>Pytorch trains the basic process of a network in five steps
Pytorch trains the basic process of a network in five steps
2022-04-23 07:17:00 【Breeze_】
- step1. Load data
- step2. Defining network
- step3. Define the loss function and optimizer
- step4. Training network , loop 4.1 To 4.6 Until the scheduled time is reached epoch Number
– step4.1 Load data
– step4.2 Initialization gradient
– step4.3 Computational feedforward
– step4.4 Calculate the loss
– step4.5 Calculate the gradient
– step4.6 Update the weights - step5. Save weights
# Training a classifier
import torchvision.datasets
import torch.utils.data
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch import optim
def train():
''' Training '''
'''1. Load data '''
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = (
'plane', 'car', 'bird', 'cat','deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)
'''2. Defining network '''
Net = LeNet()
'''3. Define the loss function and Optimizer '''
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Net.parameters(),lr=1e-3,momentum=0.9)
'''cuda Speed up '''
device = ['gpu' if torch.cuda.is_available() else 'cpu']
if device == 'gpu':
criterion.cuda()
Net.to(device)
# Net.cuda() # many GPU Please use DataParallel Method
'''4. Training network '''
print(' Start training ')
for epoch in range(3):
runing_loss = 0.0
for i,data in enumerate(trainloader,0):
inputs,label = data #1. Data loading
if device == 'gpu':
inputs = inputs.cuda()
label = label.cuda()
optimizer.zero_grad() #2. Initialization gradient
output = Net(inputs) #3. Computational feedforward
loss = criterion(output,label) #4. Calculate the loss
loss.backward() #5. Calculate the gradient
optimizer.step() #6. Update the weights
runing_loss += loss.item()
if i % 20 == 19:
print('epoch:',epoch,'loss',runing_loss/20)
runing_loss = 0.0
print(' Training done ')
'''4. Save model parameters '''
torch.save(Net.state_dict(),'cifar_AlexNet.pth')
if __name__=='__main__':
train()
版权声明
本文为[Breeze_]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610323060.html
边栏推荐
- Tiny4412 HDMI display
- Keras如何保存、加载Keras模型
- 【动态规划】杨辉三角
- 读书小记——Activity
- Bottom navigation bar based on bottomnavigationview
- 微信小程序 使用wxml2canvas插件生成图片部分问题记录
- 【2021年新书推荐】Practical IoT Hacking
- [2021 book recommendation] Red Hat Certified Engineer (RHCE) Study Guide
- 取消远程依赖,用本地依赖
- [recommendation for new books in 2021] professional azure SQL managed database administration
猜你喜欢

Fill the network gap

补补网络缺口
![[recommendation of new books in 2021] practical IOT hacking](/img/9a/13ea1e7df14a53088d4777d21ab1f6.png)
[recommendation of new books in 2021] practical IOT hacking

C#新大陆物联网云平台的连接(简易理解版)

【2021年新书推荐】Red Hat Certified Engineer (RHCE) Study Guide

Component based learning (3) path and group annotations in arouter

Easyui combobox 判断输入项是否存在于下拉列表中

Itop4412 HDMI display (4.0.3_r1)

机器学习 三: 基于逻辑回归的分类预测
树莓派:双色LED灯实验
随机推荐
【动态规划】三角形最小路径和
Android暴露组件——被忽略的组件安全
红外传感器控制开关
C#新大陆物联网云平台的连接(简易理解版)
1.1 PyTorch和神经网络
【2021年新书推荐】Practical Node-RED Programming
MySQL笔记5_操作数据
BottomSheetDialogFragment + ViewPager+Fragment+RecyclerView 滑动问题
基于BottomNavigationView实现底部导航栏
【2021年新书推荐】Red Hat RHCSA 8 Cert Guide: EX200
Component based learning (3) path and group annotations in arouter
素数求解的n种境界
电脑关机程序
JNI中使用open打开文件是返回-1问题
【2021年新书推荐】Learn WinUI 3.0
机器学习 二:基于鸢尾花(iris)数据集的逻辑回归分类
第8章 生成式深度学习
【2021年新书推荐】Artificial Intelligence for IoT Cookbook
第3章 Pytorch神经网络工具箱
Encapsulate a set of project network request framework from 0