当前位置:网站首页>Pytorch preserves different forms of pre training models
Pytorch preserves different forms of pre training models
2022-04-23 20:47:00 【NuerNuer】
Be careful , suffix .pt and .pth It doesn't seem to make any difference
When saving, you can save the whole model or only parameters , You can also build a new dictionary and save it again , This is also true. This corresponds to the need to do different processing when reading , When we load load_state_dict The argument to the function is zero OrderedDict Parameters of type , Here are Four different storage methods and their reading get OrderedDict The way .
1. preservation
# coding=gbk
import torch
import torch.nn as nn
class MLP_(nn.Module):
def __init__(self):
super(MLP_, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP_()
# Save the entire model
torch.save(net, 'a1.pt')
all_model = {'model':net} # Add key values to the model part , So if you want to save the optimizer parameters , You can add new values to the dictionary
torch.save(all_model, 'a2.pt')
# Save only parameters
torch.save(net.state_dict(),'a3.pt')
all_states = {'state_dict': net.state_dict()} # Add key values to the model parameters section , So if you want to save the optimizer parameters , You can add new values to the dictionary
torch.save(all_states, 'a4.pt')
2. load
# coding=gbk
import torch
from save import MLP_
if __name__ == "__main__":
with torch.no_grad():
a1 = 'a1.pt'
a2 = 'a2.pt'
a3 = 'a3.pt'
a4 = 'a4.pt'
a1_ = torch.load(a1)
print(a1_.state_dict())
a2_ = torch.load(a2)['model'] # Select the corresponding value through the key value
print(a2_.state_dict())
a3_ = torch.load(a3)
print(a3_)
a4_ = torch.load(a4)['state_dict'] # Select the corresponding value through the key value
print(a4_)
Reference resources :https://zhuanlan.zhihu.com/p/94971100
版权声明
本文为[NuerNuer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210545523057.html
边栏推荐
- Devaxpress report replay: complete the drawing of conventional two-dimensional report + histogram + pie chart
- What about laptop Caton? Teach you to reinstall the system with one click to "revive" the computer
- Go限制深度遍历目录下文件
- MySQL基础之写表(创建表)
- Pikachuxss how to get cookie shooting range, always fail to return to the home page
- wait、waitpid
- C knowledge
- pikachuxss如何获取cookie靶场,返回首页总是失败
- "Meta function" of tidb 6.0: what is placement rules in SQL?
- LeetCode 20、有效的括号
猜你喜欢
Fastdfs mind map
Express③(使用Express编写接口、跨域有关问题)
Plato farm is one of the four largest online IEOS in metauniverse, and the transaction on the chain is quite high
Unity solves Z-fighting
Elastic box model
MySQL基础合集
Gsi-ecm digital platform for engineering construction management
A login and exit component based on token
缓存淘汰算法初步认识(LRU和LFU)
Leetcode 994, rotten orange
随机推荐
GO语言开发天天生鲜项目第三天 案例-新闻发布系统二
MySQL数据库常识之储存引擎
打新债中签以后怎么办,网上开户安全吗
41. 缺失的第一个正数
Reentrant function
go reflect
The more you use the computer, the slower it will be? Recovery method of file accidental deletion
Unity solves Z-fighting
An error occurs when the addressable assets system project is packaged. Runtimedata is null
高薪程序员&面试题精讲系列91之Limit 20000加载很慢怎么解决?如何定位慢SQL?
Mysql database common sense storage engine
go slice
mmap、munmap
Plato farm is one of the four largest online IEOS in metauniverse, and the transaction on the chain is quite high
Win 11K in 100 days, super complete learning guide for job transfer test
Send email to laravel
6-5 string - 2 String copy (assignment) (10 points) the C language standard function library includes the strcpy function for string copy (assignment). As an exercise, we write a function with the sam
一些接地气的话儿
Zhongchuang storage | how to choose a useful distributed storage cloud disk
Unity ECS dots notes