当前位置:网站首页>PyTorch训练一个网络的基本流程5步法
PyTorch训练一个网络的基本流程5步法
2022-04-23 06:11:00 【小风_】
- step1. 加载数据
- step2. 定义网络
- step3. 定义损失函数和优化器
- step4. 训练网络,循环4.1到4.6直到达到预定epoch数量
– step4.1 加载数据
– step4.2 初始化梯度
– step4.3 计算前馈
– step4.4 计算损失
– step4.5 计算梯度
– step4.6 更新权值 - step5. 保存权重
# 训练一个分类器
import torchvision.datasets
import torch.utils.data
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch import optim
def train():
'''训练'''
'''1.加载数据'''
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = (
'plane', 'car', 'bird', 'cat','deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)
'''2.定义网络'''
Net = LeNet()
'''3.定义损失函数and优化器'''
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Net.parameters(),lr=1e-3,momentum=0.9)
'''cuda加速'''
device = ['gpu' if torch.cuda.is_available() else 'cpu']
if device == 'gpu':
criterion.cuda()
Net.to(device)
# Net.cuda() #多GPU 请用 DataParallel方法
'''4.训练网络'''
print('开始训练')
for epoch in range(3):
runing_loss = 0.0
for i,data in enumerate(trainloader,0):
inputs,label = data #1.数据加载
if device == 'gpu':
inputs = inputs.cuda()
label = label.cuda()
optimizer.zero_grad() #2.初始化梯度
output = Net(inputs) #3.计算前馈
loss = criterion(output,label) #4.计算损失
loss.backward() #5.计算梯度
optimizer.step() #6.更新权值
runing_loss += loss.item()
if i % 20 == 19:
print('epoch:',epoch,'loss',runing_loss/20)
runing_loss = 0.0
print('训练完成')
'''4.保存模型参数'''
torch.save(Net.state_dict(),'cifar_AlexNet.pth')
if __name__=='__main__':
train()
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/123301500
边栏推荐
猜你喜欢

机器学习笔记 一:学习思路

【2021年新书推荐】Professional Azure SQL Managed Database Administration
![Android interview Online Economic encyclopedia [constantly updating...]](/img/48/dd1abec83ec0db7d68812f5fa9dcfc.png)
Android interview Online Economic encyclopedia [constantly updating...]

Itop4412 HDMI display (4.4.4_r1)

取消远程依赖,用本地依赖

Project, how to package

【2021年新书推荐】Effortless App Development with Oracle Visual Builder
![[2021 book recommendation] artistic intelligence for IOT Cookbook](/img/8a/3ff45a911becb895e6dd9e061ac252.png)
[2021 book recommendation] artistic intelligence for IOT Cookbook

红外传感器控制开关

WebView displays a blank due to a certificate problem
随机推荐
Using queue to realize stack
adb shell top 命令详解
电脑关机程序
npm ERR code 500解决
Cause: dx.jar is missing
C#新大陆物联网云平台的连接(简易理解版)
Tiny4412 HDMI display
AVD Pixel_ 2_ API_ 24 is already running. If that is not the case, delete the files at C:\Users\admi
iTOP4412 HDMI显示(4.0.3_r1)
iTOP4412 FramebufferNativeWindow(4.0.3_r1)
三种实现ImageView以自身中心为原点旋转的方法
org.xml.sax.SAXParseException; lineNumber: 141; columnNumber: 252; cvc-complex-type.2.4.a: 发现了以元素 ‘b
利用队列实现栈
Itop4412 cannot display boot animation (4.0.3_r1)
[2021 book recommendation] red hat rhcsa 8 cert Guide: ex200
PaddleOCR 图片文字提取
iTOP4412内核反复重启
C# EF mysql更新datetime字段报错Modifying a column with the ‘Identity‘ pattern is not supported
JS 比较2个数组中不同的元素
[Andorid] 通过JNI实现kernel与app进行spi通讯