当前位置:网站首页>【简易笔记】PyTorch官方教程简易笔记 EP4
【简易笔记】PyTorch官方教程简易笔记 EP4
2022-08-10 05:34:00 【Mioulo】
已完结…
记录PyTorch 官方入门教程中的大部分代码和对代码的解释注释
暂时内容包括:优化模型参数,保存并加载模型参数
参考网站:https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
优化模型参数
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
与之前介绍的相似,前面是设置训练集和测试集(引入的是现成的给的datasets数据集)
之后设置线性全连接神经网络,打包为model
learning_rate = 1e-1
batch_size = 64
epochs = 10
这部分是超参数,分别控制着学习率,批处理的数据大小,训练集中数据被训练的次数
(在本文的代码中实际上没有传入batch_size,我现在还不知道到底在哪里设置这个batch_size,但是猜测应该是在生成数据集时进行设置)
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
#该函数有些传入的参数的定义如下
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
#loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(pred, y.to(device))
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
#这里即是每批64个,训练100批共6400个后将loss和准确率进行打印
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
在这里设置循环训练函数,传入dataloader(来自于导入的库)model,loss_fn和optimizer
之后用enumerate迭代器进行迭代,把数据集一个个传入写的神经网络中
之后用优化器torch.optim.SGD存储参数然后根据梯度更新优化参数
再接着在下面三行进行反向传播,具体实现和背后原理参考下文链接
(后面的print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]"),具体用法参考python里的输出格式)
函数参考:enumerate
torch.optim.SGD
Backpropagation
输出格式
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y.to(device)).item()
correct += (pred.argmax(1) == y.to(device)).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
这里是测试循环函数的编写
y对应着准确值,pred是我们之前的神经网络对这个X的预测值
然后计算交叉熵损失loss和准确预测的次数correct,得出平均损失和准确率
tips:这里设置不用反向传播是因为在这里并不需要求梯度,所以为了省时间,设置torch.no_grad
epochs = 2
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
然后循环给定次数
保存并加载模型参数
本章节比较简单,直接给参考
参考网站:保存并加载模型参数
下面是在上文训练模型参数的基础上,将其保存,再在下次加载时进行读取
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
if torch.cuda.is_available:
device = "cuda"
print("cuda加速已开启")
else:
device = "cpu"
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512,256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,32),
nn.ReLU(),
nn.Linear(32,20),
)
def forward(self, x):
x = self.flatten(x).to(device)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load('data_1'))
model.eval()
#
learning_rate = 1e-1
batch_size = 64
epochs = 10
#
#
#
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y.to(device))
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 500 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {
loss:>7f} [{
current:>5d}/{
size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y.to(device)).item()
correct += (pred.argmax(1) == y.to(device)).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {
(100*correct):>0.1f}%, Avg loss: {
test_loss:>8f} \n")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 2
for t in range(epochs):
print(f"Epoch {
t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
torch.save(model.state_dict(),'data_1')
print("Done!")
在这段代码中
保存模型参数:torch.save(model.state_dict(),‘data_1’)
注意后面存的是文件路径
读取模型参数:
要加载模型权重,需要先创建同一模型的实例,然后使用方法加载参数
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load(‘data_1’))
model.eval()
其他方法:
torch.save(model, ‘data_1’)
model = torch.load(‘data_1’)
边栏推荐
- Notes for SVM
- cesium listens to map zoom or zoom to control whether the content added on the map is displayed
- Reprint fstream, detailed usage of ifstream
- Batch add watermark to pictures batch add background zoom batch merge tool picUnionV4.0
- pytorch-10.卷积神经网络(作业)
- opencv
- 常用类 String概述
- Index Notes【】【】
- Chain Reading|The latest and most complete digital collection sales calendar-08.02
- pytorch-10.卷积神经网络
猜你喜欢
随机推荐
21天挑战杯MySQL-Day05
pytorch-06.逻辑斯蒂回归
PyTorch之CV
Notes for SVM
LeetCode 292. Nim Game (Simple)
win12 modify dns script
Link reading good article: What is the difference between hot encrypted storage and cold encrypted storage?
操作表 函数的使用
LeetCode 100.相同的树(简单)
wiki confluence installation
pytorch-08.加载数据集
开源免费WMS仓库管理系统【推荐】
cesium rotate image
LeetCode 938. Range Sum of Binary Search Trees (Simple)
国内数字藏品投资价值分析
深度学习阶段性报告(一)
我不喜欢我的代码
PyTorch之模型定义
LeetCode 2011.执行操作后的变量值(简单)
常用类 String概述