当前位置:网站首页>Pytorch preserves different forms of pre training models

Pytorch preserves different forms of pre training models

2022-04-23 20:47:00 NuerNuer

Be careful , suffix .pt and .pth It doesn't seem to make any difference

When saving, you can save the whole model or only parameters , You can also build a new dictionary and save it again , This is also true. This corresponds to the need to do different processing when reading , When we load load_state_dict The argument to the function is zero OrderedDict Parameters of type , Here are Four different storage methods and their reading get OrderedDict The way .
1. preservation

# coding=gbk
import torch 
import torch.nn as nn
class MLP_(nn.Module):
    def __init__(self):
        super(MLP_, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP_()

# Save the entire model 

torch.save(net, 'a1.pt')

all_model = {'model':net} # Add key values to the model part , So if you want to save the optimizer parameters , You can add new values to the dictionary 
torch.save(all_model, 'a2.pt')
# Save only parameters 

torch.save(net.state_dict(),'a3.pt')
all_states = {'state_dict': net.state_dict()} # Add key values to the model parameters section , So if you want to save the optimizer parameters , You can add new values to the dictionary 
torch.save(all_states, 'a4.pt')

2. load

# coding=gbk
import torch
from save import MLP_
if __name__ == "__main__":
    with torch.no_grad():
   
        a1 = 'a1.pt'
        a2 = 'a2.pt'
        a3 = 'a3.pt'
        a4 = 'a4.pt'
        a1_ = torch.load(a1)
        print(a1_.state_dict())
        a2_ = torch.load(a2)['model'] # Select the corresponding value through the key value 
        print(a2_.state_dict())
        a3_ = torch.load(a3)
        print(a3_)
        a4_ = torch.load(a4)['state_dict'] # Select the corresponding value through the key value 
        print(a4_)

Reference resources :https://zhuanlan.zhihu.com/p/94971100

版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545523057.html