当前位置:网站首页>Pytoch model saving and loading (example)
Pytoch model saving and loading (example)
2022-04-23 07:21:00 【sunshinecxm_ BJTU】
0. Why save and load models
After training the model with data, an ideal model is obtained , But in practical application, it is impossible to train first and then use , So you have to save the trained model first , Then load it when you need it and use it directly . The essence of a model is a pile of parameters stored in some structure , So there are two ways to save , One way is to save the whole model directly , Then directly load the whole model , But this will consume more memory ; The other is to save only the parameters of the model , When used later, create a new model with the same structure , Then import the saved parameters into the new model .
1. Implementation methods of two cases
(1) Save only the model parameter dictionary ( recommend )
# preservation
torch.save(the_model.state_dict(), PATH)
# Read
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
(2) Save the entire model
# preservation
torch.save(the_model, PATH)
# Read
the_model = torch.load(PATH)
3. Save only model parameters ( Example )
pytorch Will put the parameters of the model in a dictionary , And all we have to do is save this dictionary , Then call .
For example, design a single layer LSTM Network of , And then training , After training, save the parameter Dictionary of the model , Save as... Under the same folder rnn.pt file :
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# out: tensor of shape (batch_size, seq_length, hidden_size*2)
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001)
# the target label is not one-hotted
loss_func = nn.MSELoss()
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# Save the model
torch.save(rnn.state_dict(), 'rnn.pt')
After saving, use the trained model to process the data :
# Test the saved model
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)
Here's an explanation , When you save the model rnn.state_dict() Express rnn The parameter Dictionary of this model , When testing the saved model, first load the parameter Dictionary m_state_dict = torch.load(‘rnn.pt’);
Then instantiate one LSTM Antithetic image , Here, we need to ensure that the parameters passed in are consistent with the instantiation rnn Is the same as when the object is passed in , That is, the structure is the same new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device);
Here are the parameters loaded before passing in the new model new_m.load_state_dict(m_state_dict);
Finally, we can use this model to process the data predict = new_m(test_tensor)
4. Save the whole model ( Example )
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# print("output_in=", out.shape)
# print("fc_in_shape=", out[:, -1, :].shape)
# Decode the hidden state of the last time step
# out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
# out = self.fc(out[:, -1, :]) # Take the last column as out
out = self.fc(out)
return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) # optimize all cnn parameters
loss_func = nn.MSELoss() # the target label is not one-hotted
for epoch in range(1000):
output = rnn(train_tensor) # cnn output`
loss = loss_func(output, train_labels_tensor) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
output_sum = output
# Save the model
torch.save(rnn, 'rnn1.pt')
After saving, use the trained model to process the data :
new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)
版权声明
本文为[sunshinecxm_ BJTU]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230610529717.html
边栏推荐
- 深度学习模型压缩与加速技术(一):参数剪枝
- 【2021年新书推荐】Practical IoT Hacking
- [recommendation for new books in 2021] professional azure SQL managed database administration
- 1.1 PyTorch和神经网络
- N states of prime number solution
- JNI中使用open打开文件是返回-1问题
- 读书小记——Activity
- 【2021年新书推荐】Practical Node-RED Programming
- 给女朋友写个微信双开小工具
- face_recognition人脸检测
猜你喜欢
Summary of image classification white box anti attack technology
[2021 book recommendation] practical node red programming
Write a wechat double open gadget to your girlfriend
【2021年新书推荐】Professional Azure SQL Managed Database Administration
【点云系列】 A Rotation-Invariant Framework for Deep Point Cloud Analysis
Chapter 1 numpy Foundation
Component learning (2) arouter principle learning
Miscellaneous learning
[2021 book recommendation] red hat rhcsa 8 cert Guide: ex200
Gee configuring local development environment
随机推荐
Gee configuring local development environment
[2021 book recommendation] learn winui 3.0
【动态规划】不同路径2
去掉状态栏
机器学习笔记 一:学习思路
第2章 Pytorch基础2
adb shell top 命令详解
Android暴露组件——被忽略的组件安全
【2021年新书推荐】Enterprise Application Development with C# 9 and .NET 5
【点云系列】Unsupervised Multi-Task Feature Learning on Point Clouds
MySQL notes 4_ Primary key auto_increment
【点云系列】PnP-3D: A Plug-and-Play for 3D Point Clouds
winform滚动条美化
Bottom navigation bar based on bottomnavigationview
【2021年新书推荐】Kubernetes in Production Best Practices
Thanos.sh灭霸脚本,轻松随机删除系统一半的文件
Pytorch best practices and coding style guide
Visual studio 2019 installation and use
Gephi教程【1】安装
MySQL数据库安装与配置详解