当前位置:网站首页>PyTorch训练一个网络的基本流程5步法
PyTorch训练一个网络的基本流程5步法
2022-04-23 06:11:00 【小风_】
- step1. 加载数据
- step2. 定义网络
- step3. 定义损失函数和优化器
- step4. 训练网络,循环4.1到4.6直到达到预定epoch数量
– step4.1 加载数据
– step4.2 初始化梯度
– step4.3 计算前馈
– step4.4 计算损失
– step4.5 计算梯度
– step4.6 更新权值 - step5. 保存权重
# 训练一个分类器
import torchvision.datasets
import torch.utils.data
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch import optim
def train():
'''训练'''
'''1.加载数据'''
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = (
'plane', 'car', 'bird', 'cat','deer',
'dog', 'frog', 'horse', 'ship', 'truck'
)
'''2.定义网络'''
Net = LeNet()
'''3.定义损失函数and优化器'''
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(Net.parameters(),lr=1e-3,momentum=0.9)
'''cuda加速'''
device = ['gpu' if torch.cuda.is_available() else 'cpu']
if device == 'gpu':
criterion.cuda()
Net.to(device)
# Net.cuda() #多GPU 请用 DataParallel方法
'''4.训练网络'''
print('开始训练')
for epoch in range(3):
runing_loss = 0.0
for i,data in enumerate(trainloader,0):
inputs,label = data #1.数据加载
if device == 'gpu':
inputs = inputs.cuda()
label = label.cuda()
optimizer.zero_grad() #2.初始化梯度
output = Net(inputs) #3.计算前馈
loss = criterion(output,label) #4.计算损失
loss.backward() #5.计算梯度
optimizer.step() #6.更新权值
runing_loss += loss.item()
if i % 20 == 19:
print('epoch:',epoch,'loss',runing_loss/20)
runing_loss = 0.0
print('训练完成')
'''4.保存模型参数'''
torch.save(Net.state_dict(),'cifar_AlexNet.pth')
if __name__=='__main__':
train()
版权声明
本文为[小风_]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_33952811/article/details/123301500
边栏推荐
- 扫雷小游戏
- Itop4412 LCD backlight drive (PWM)
- [2021 book recommendation] effortless app development with Oracle visual builder
- C connection of new world Internet of things cloud platform (simple understanding version)
- Binder机制原理
- [sm8150] [pixel4] LCD driver
- MySQL笔记2_数据表
- Fill the network gap
- error 403 In most cases, you or one of your dependencies are requesting解决
- 去掉状态栏
猜你喜欢
记录webView显示空白的又一坑
Component based learning (1) idea and Implementation
Ffmpeg common commands
[recommendation for new books in 2021] professional azure SQL managed database administration
Binder机制原理
【2021年新书推荐】Practical Node-RED Programming
取消远程依赖,用本地依赖
[2021 book recommendation] practical node red programming
Android interview Online Economic encyclopedia [constantly updating...]
Encapsulate a set of project network request framework from 0
随机推荐
补补网络缺口
Itop4412 kernel restarts repeatedly
HandlerThread原理和实际应用
Miscellaneous learning
读书小记——Activity
Cancel remote dependency and use local dependency
MySQL笔记5_操作数据
iTOP4412 SurfaceFlinger(4.0.3_r1)
[exynos4412] [itop4412] [android-k] add product options
基于BottomNavigationView实现底部导航栏
【2021年新书推荐】Practical IoT Hacking
ViewPager2实现画廊效果执行notifyDataSetChanged后PageTransformer显示异常 界面变形问题
接口幂等性问题
ARGB透明度换算
Using stack to realize queue out and in
[2021 book recommendation] learn winui 3.0
Cause: dx. jar is missing
[2021 book recommendation] kubernetes in production best practices
[2021 book recommendation] effortless app development with Oracle visual builder
[Andorid] 通过JNI实现kernel与app进行spi通讯