当前位置:网站首页>Pytorch模型保存与加载(示例)
Pytorch模型保存与加载(示例)
2022-04-23 06:11:00 【sunshinecxm_BJTU】
0.为什么要保存和加载模型
用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。模型的本质是一堆用某种结构存储起来的参数,所以在保存的时候有两种方式,一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存;另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。
1.两种情况的实现方法
(1)只保存模型参数字典(推荐)
#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
(2)保存整个模型
#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)
3.只保存模型参数的情况(例子)
pytorch会把模型的参数放在一个字典里面,而我们所要做的就是将这个字典保存,然后再调用。
比如说设计一个单层LSTM的网络,然后进行训练,训练完之后将模型的参数字典进行保存,保存为同文件夹下面的rnn.pt文件:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# out: tensor of shape (batch_size, seq_length, hidden_size*2)
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
# the target label is not one-hotted
loss_func = nn.MSELoss()
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn.state_dict(), 'rnn.pt')
保存完之后利用这个训练完的模型对数据进行处理:
# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)
这里做一下说明,在保存模型的时候rnn.state_dict()表示rnn这个模型的参数字典,在测试所保存的模型时要先将这个参数字典加载一下m_state_dict = torch.load(‘rnn.pt’);
然后再实例化一个LSTM对像,这里要保证传入的参数跟实例化rnn是传入的对象时一样的,即结构相同new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device);
下面是给这个新的模型传入之前加载的参数new_m.load_state_dict(m_state_dict);
最后就可以利用这个模型处理数据了predict = new_m(test_tensor)
4.保存整个模型的情况(例子)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# print("output_in=", out.shape)
# print("fc_in_shape=", out[:, -1, :].shape)
# Decode the hidden state of the last time step
# out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
# out = self.fc(out[:, -1, :]) # 取最后一列为out
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # optimize all cnn parameters
loss_func = nn.MSELoss() # the target label is not one-hotted
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn, 'rnn1.pt')
保存完之后利用这个训练完的模型对数据进行处理:
new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)
版权声明
本文为[sunshinecxm_BJTU]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_36744449/article/details/124323478
边栏推荐
- adb shell 常用命令
- Component based learning (3) path and group annotations in arouter
- 三子棋小游戏
- Itop4412 HDMI display (4.0.3_r1)
- org.xml.sax.SAXParseException; lineNumber: 141; columnNumber: 252; cvc-complex-type.2.4.a: 发现了以元素 ‘b
- AVD Pixel_2_API_24 is already running.If that is not the case, delete the files at C:\Users\admi
- 机器学习 二:基于鸢尾花(iris)数据集的逻辑回归分类
- Itop4412 kernel restarts repeatedly
- iTOP4412 SurfaceFlinger(4.4.4_r1)
- 去掉状态栏
猜你喜欢

机器学习笔记 一:学习思路
![[recommendation of new books in 2021] enterprise application development with C 9 and NET 5](/img/1d/cc673ca857fff3c5c48a51883d96c4.png)
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5

PaddleOCR 图片文字提取

【2021年新书推荐】Learn WinUI 3.0

WebView displays a blank due to a certificate problem

【2021年新书推荐】Effortless App Development with Oracle Visual Builder
![[2021 book recommendation] learn winui 3.0](/img/1c/ca7e05946613e9eb2b8c24d121c2e1.png)
[2021 book recommendation] learn winui 3.0

一款png生成webp,gif, apng,同时支持webp,gif, apng转化的工具iSparta

Explore how @ modelandview can forward data and pages through the source code
树莓派:双色LED灯实验
随机推荐
Explore how @ modelandview can forward data and pages through the source code
MySQL notes 5_ Operation data
this.getOptions is not a function
MySQL notes 3_ Restraint_ Primary key constraint
C#新大陆物联网云平台的连接(简易理解版)
Itop4412 HDMI display (4.0.3_r1)
AVD Pixel_2_API_24 is already running.If that is not the case, delete the files at C:\Users\admi
机器学习笔记 一:学习思路
[2021 book recommendation] effortless app development with Oracle visual builder
项目,怎么打包
adb shell top 命令详解
Android-Room数据库快速上手
Fill the network gap
[SM8150][Pixel4]LCD驱动
组件化学习(1)思想及实现方式
【2021年新书推荐】Learn WinUI 3.0
5种方法获取Torch网络模型参数量计算量等信息
DCMTK(DCM4CHE)与DICOOGLE协同工作
常用UI控件简写名
常见的正则表达式