当前位置:网站首页>Pytorch学习记录(四):参数初始化
Pytorch学习记录(四):参数初始化
2022-04-23 05:43:00 【左小田^O^】
参数初始化(Weight Initialization)
只需要学会如何对模型的参数进行初始化的赋值即可。
使用 NumPy 来初始化
举例
import torch
import numpy as np
from torch import nn
net1 = nn.Sequentisl(
nn.Linear(30,40),
nn.ReLU(),
nn.Linear(40,50),
nn.ReLU(),
nn.Linear(50,10)
)
# 访问第一层的参数
w1 = net1[0].weight
b1 = net1[0].bias
print(w1)
Parameter containing:
0.1236 -0.1731 -0.0479 ... 0.0031 0.0784 0.1239
0.0713 0.1615 0.0500 ... -0.1757 -0.1274 -0.1625
0.0638 -0.1543 -0.0362 ... 0.0316 -0.1774 -0.1242
... ⋱ ...
0.1551 0.1772 0.1537 ... 0.0730 0.0950 0.0627
0.0495 0.0896 0.0243 ... -0.1302 -0.0256 -0.0326
-0.1193 -0.0989 -0.1795 ... 0.0939 0.0774 -0.0751
[torch.FloatTensor of size 40x30]
注意,这是一个 Parameter,也就是一个特殊的 Variable,我们可以访问其 .data属性得到其中的数据,然后直接定义一个新的 Tensor 对其进行替换。
我们可以使用 PyTorch 中的一些随机数据生成的方式,比如 torch.randn,如果要使用更多PyTorch 中没有的随机化方式,可以使用 numpy
使用torch.from_numpy(np.random.uniform(min,max,(size))
定义一个 Tensor 直接对其进行替换
# 定义一个 Tensor 直接对其进行替换
net1[0].weight.data = torch.from_numpy(np.random.uniform(3, 5, size=(40, 30)))
再次打印
print(net1[0].weight)
Parameter containing:
4.5768 3.6175 3.3098 ... 4.7374 4.0164 3.3037
4.1809 3.5624 3.1452 ... 3.0305 4.4444 4.1058
3.5277 4.3712 3.7859 ... 3.5760 4.8559 4.3252
... ⋱ ...
4.8983 3.9855 3.2842 ... 4.7683 4.7590 3.3498
4.9168 4.5723 3.5870 ... 3.2032 3.9842 3.2484
4.2532 4.6352 4.4857 ... 3.7543 3.9885 4.4211
[torch.DoubleTensor of size 40x30]
进行循环访问
for layer in net1:
if isinstance(layer, nn.Linear): # 判断是否是线性层
param_shape = layer.weight.shape
layer.weight.data = torch.from_numpy(np.random.normal(0, 0.5, size=param_shape))
# 定义为均值为 0,方差为 0.5 的正态分布
对于 Module 的参数初始化,可以直接像Sequential 一样对其 Tensor 进行重新定义,其唯一不同的地方在于,如果要用循环的方式访问,需要介绍两个属性,children 和 modules.
class sim_net(nn.Module):
def __init__(self):
super(sim_net, self).__init__()
self.l1 = nn.Sequential(
nn.Linear(30, 40),
nn.ReLU()
)
self.l1[0].weight.data = torch.randn(40, 30) # 直接对某一层初始化
self.l2 = nn.Sequential(
nn.Linear(40, 50),
nn.ReLU()
)
self.l3 = nn.Sequential(
nn.Linear(50, 10),
nn.ReLU()
)
def forward(self, x):
x = self.l1(x)
x =self.l2(x)
x = self.l3(x)
return x
net2 = sim_net()
访问 children
# 访问 children
for i in net2.children():
print(i)
Sequential(
(0): Linear(in_features=30, out_features=40)
(1): ReLU()
)
Sequential(
(0): Linear(in_features=40, out_features=50)
(1): ReLU()
)
Sequential(
(0): Linear(in_features=50, out_features=10)
(1): ReLU()
)
访问 modules
# 访问 modules
for i in net2.modules():
print(i)
# 访问 modules
for i in net2.modules():
print(i)
# 访问 modules
for i in net2.modules():
print(i)
sim_net(
(l1): Sequential(
(0): Linear(in_features=30, out_features=40)
(1): ReLU()
)
(l2): Sequential(
(0): Linear(in_features=40, out_features=50)
(1): ReLU()
)
(l3): Sequential(
(0): Linear(in_features=50, out_features=10)
(1): ReLU()
)
)
Sequential(
(0): Linear(in_features=30, out_features=40)
(1): ReLU()
)
Linear(in_features=30, out_features=40)
ReLU()
Sequential(
(0): Linear(in_features=40, out_features=50)
(1): ReLU()
)
Linear(in_features=40, out_features=50)
ReLU()
Sequential(
(0): Linear(in_features=50, out_features=10)
(1): ReLU()
)
Linear(in_features=50, out_features=10)
ReLU()
children 只会访问到模型定义中的第一层,因为上面的模型中定义了三个 Sequential,所以只会访问到三个 Sequential,而 modules 会访问到最后的结构,比如上面的例子,modules 不仅访问到了 Sequential,也访问到了 Sequential 里面.
循环访问进行初始化
for layer in net2.modules():
if isinstance(layer, nn.Linear):
param_shape = layer.weight.shape
layer.weight.data = torch.from_numpy(np.random.normal(0, 0.5, size=param_shape))
torch.nn.init
torch.nn.init:其操作层面仍然在 Tensor 上
init.xavier_uniform
from torch.nn import init
init.xavier_uniform(net1[0].weight) # 这就是上面我们讲过的 Xavier 初始化方法,PyTorch 直接内置了其实现
常见初始化方法
1.Xavier Initialization
Xavier初始化的基本思想是保持输入和输出的方差一致,这样就避免了所有输出值都趋向于0。这是通用的方法,适用于任何激活函数。
nn.init.xavier_uniform_(m.weight)
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
2.He et. al Initialization
He initialization的思想是:在ReLU网络中,假定每一层有一半的神经元被激活,另一半为0。推荐在ReLU网络中使用。
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
2.正交初始化(Orthogonal Initialization)
以解决深度网络下的梯度消失、梯度爆炸问题,在RNN中经常使用的参数初始化方法。
是nn.init.orthogonal(m.weight)
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.orthogonal(m.weight)
3.Batchnorm Initialization
在非线性激活函数之前,我们想让输出值有比较好的分布(例如高斯分布),以便于计算梯度和更新参数。Batch Normalization 将输出值强行做一次 Gaussian Normalization 和线性变换:
是nn.init.constant(m.weight, 1),nn.init.constant(m.bias, 0),
for m in model:
if isinstance(m, nn.BatchNorm2d):
nn.init.constant(m.weight, 1)
nn.init.constant(m.bias, 0)
单层初始化
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
nn.init.xavier_uniform(conv1.weight)
nn.init.constant(conv1.bias, 0.1)
模型初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
nn.init.xavier_normal_(m.weight.data)
nn.init.constant_(m.bias.data, 0.0)
elif classname.find('Linear') != -1:
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0.0)
net = Net()
net.apply(weights_init) #apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。
版权声明
本文为[左小田^O^]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45802081/article/details/119997522
边栏推荐
- Understand the current commonly used encryption technology system (symmetric, asymmetric, information abstract, digital signature, digital certificate, public key system)
- JDBC工具类封装
- 2-軟件設計原則
- Character recognition easyocr
- acwing854. Floyd finds the shortest path
- MySQL创建oracle练习表
- filebrowser实现私有网盘
- Strategy for improving the conversion rate of independent stations | recovering abandoned users
- DBCP使用
- 多个一维数组拆分合并为二维数组
猜你喜欢

解决报错:ImportError: IProgress not found. Please update jupyter and ipywidgets

Differences between sea level anatomy and sea surface height anatomy

SQL statement simple optimization

Batch import of orange single micro service

引航成长·匠心赋能——YonMaster开发者培训领航计划全面开启

JVM系列(4)——内存溢出(OOM)

框架解析1.系统架构简介

容器

opensips(1)——安装opensips详细流程

delete和truncate
随机推荐
基于ssm 包包商城系统
Record a project experience and technologies encountered in the project
MySQL triggers, stored procedures, stored functions
一文读懂当前常用的加密技术体系(对称、非对称、信息摘要、数字签名、数字证书、公钥体系)
What is JSON? First acquaintance with JSON
xxl-job采坑指南xxl-rpc remoting error(connect timed out)
建表到页面完整实例演示—联表查询
Pilotage growth · ingenuity empowerment -- yonmaster developer training and pilotage plan is fully launched
JVM系列(3)——内存分配与回收策略
AcWing 1096. Detailed notes of Dungeon Master (3D BFS) code
Flutter nouvelle génération de rendu graphique Impeller
Duplicate key update in MySQL
Multithreading and high concurrency (1) -- basic knowledge of threads (implementation, common methods, state)
redhat实现目录下特定文本类型内关键字查找及vim模式下关键字查找
jdbc入门\获取数据库连接\使用PreparedStatement
filebrowser实现私有网盘
protected( 被 protected 修饰的成员对于本包和其子类可见)
线性规划问题中可行解,基本解和基本可行解有什么区别?
事实最终变量与最终变量
Ora: 28547 connection to server failed probable Oracle net admin error