当前位置:网站首页>Pytorch模型保存与加载(示例)
Pytorch模型保存与加载(示例)
2022-04-23 06:11:00 【sunshinecxm_BJTU】
0.为什么要保存和加载模型
用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。模型的本质是一堆用某种结构存储起来的参数,所以在保存的时候有两种方式,一种方式是直接将整个模型保存下来,之后直接加载整个模型,但这样会比较耗内存;另一种是只保存模型的参数,之后用到的时候再创建一个同样结构的新模型,然后把所保存的参数导入新模型。
1.两种情况的实现方法
(1)只保存模型参数字典(推荐)
#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
(2)保存整个模型
#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)
3.只保存模型参数的情况(例子)
pytorch会把模型的参数放在一个字典里面,而我们所要做的就是将这个字典保存,然后再调用。
比如说设计一个单层LSTM的网络,然后进行训练,训练完之后将模型的参数字典进行保存,保存为同文件夹下面的rnn.pt文件:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# out: tensor of shape (batch_size, seq_length, hidden_size*2)
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
# the target label is not one-hotted
loss_func = nn.MSELoss()
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn.state_dict(), 'rnn.pt')
保存完之后利用这个训练完的模型对数据进行处理:
# 测试所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)
这里做一下说明,在保存模型的时候rnn.state_dict()表示rnn这个模型的参数字典,在测试所保存的模型时要先将这个参数字典加载一下m_state_dict = torch.load(‘rnn.pt’);
然后再实例化一个LSTM对像,这里要保证传入的参数跟实例化rnn是传入的对象时一样的,即结构相同new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device);
下面是给这个新的模型传入之前加载的参数new_m.load_state_dict(m_state_dict);
最后就可以利用这个模型处理数据了predict = new_m(test_tensor)
4.保存整个模型的情况(例子)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# print("output_in=", out.shape)
# print("fc_in_shape=", out[:, -1, :].shape)
# Decode the hidden state of the last time step
# out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
# out = self.fc(out[:, -1, :]) # 取最后一列为out
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # optimize all cnn parameters
loss_func = nn.MSELoss() # the target label is not one-hotted
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# 保存模型
torch.save(rnn, 'rnn1.pt')
保存完之后利用这个训练完的模型对数据进行处理:
new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)
版权声明
本文为[sunshinecxm_BJTU]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_36744449/article/details/124323478
边栏推荐
- [2021 book recommendation] effortless app development with Oracle visual builder
- 取消远程依赖,用本地依赖
- 5种方法获取Torch网络模型参数量计算量等信息
- Itop4412 HDMI display (4.4.4_r1)
- Itop4412 LCD backlight drive (PWM)
- Android room database quick start
- Android暴露组件——被忽略的组件安全
- PyTorch 模型剪枝实例教程三、多参数与全局剪枝
- this. getOptions is not a function
- Record WebView shows another empty pit
猜你喜欢

常用UI控件简写名
![[2021 book recommendation] kubernetes in production best practices](/img/78/2b5bf03adad5da9a109ea5d4e56b18.png)
[2021 book recommendation] kubernetes in production best practices
![[2021 book recommendation] practical node red programming](/img/f4/e397c01f1551cd6c59ea4f54c197e6.png)
[2021 book recommendation] practical node red programming

【2021年新书推荐】Artificial Intelligence for IoT Cookbook

组件化学习(1)思想及实现方式

【2021年新书推荐】Learn WinUI 3.0

图像分类白盒对抗攻击技术总结
![[recommendation of new books in 2021] practical IOT hacking](/img/9a/13ea1e7df14a53088d4777d21ab1f6.png)
[recommendation of new books in 2021] practical IOT hacking

机器学习笔记 一:学习思路

Record WebView shows another empty pit
随机推荐
Itop4412 HDMI display (4.0.3_r1)
机器学习 三: 基于逻辑回归的分类预测
iTOP4412 HDMI显示(4.4.4_r1)
MySQL notes 3_ Restraint_ Primary key constraint
【2021年新书推荐】Enterprise Application Development with C# 9 and .NET 5
Cause: dx. jar is missing
Binder mechanism principle
Project, how to package
MySQL notes 1_ database
个人博客网站搭建
[2021 book recommendation] kubernetes in production best practices
[2021 book recommendation] Red Hat Certified Engineer (RHCE) Study Guide
Kotlin征途之data class [数据类]
深度学习模型压缩与加速技术(一):参数剪枝
Itop4412 cannot display boot animation (4.0.3_r1)
Using queue to realize stack
Exploration of SendMessage principle of advanced handler
winform滚动条美化
Three methods to realize the rotation of ImageView with its own center as the origin
Recyclerview batch update view: notifyitemrangeinserted, notifyitemrangeremoved, notifyitemrangechanged