当前位置:网站首页>Pytorch框架 || torch.nn.modules.Module(nn.Module)
Pytorch框架 || torch.nn.modules.Module(nn.Module)
2022-04-21 20:33:00 【研究生不迟到】
1 一个简单的网络
- 一个Pytorch模型应该以类的形式出现
- Pytorch训练模型应该是nn.Module的子类
- 一个训练模型最少包含init和forward(初始化和前向传播)两个过程。
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
2 nn.Module.init_weight()
- 这个代码是
SeNet的代码,放在这里学习init_weight
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局均值池化 输出的是c×1×1
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # channel // reduction代表通道压缩
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False), # 还原
nn.Sigmoid()
)
def init_weights(self):
for m in self.modules():
print(m) # 没运行到这儿
if isinstance(m, nn.Conv2d): # 判断类型函数——:m是nn.Conv2d类吗?
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, _, _ = x.size() # 50×512×7×7
y = self.avg_pool(x).view(b, c) # ① maxpool之后得:50×512×1×1 ② view形状得到50×512
y = self.fc(y).view(b, c, 1, 1) # 50×512×1×1
return x * y.expand_as(x) # 根据x.size来扩展y
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
se = SEAttention(channel=512, reduction=8) # 实例化模型se
output = se(input)
print(output.shape)
2.1 kaiming 高斯初始化
- 使得每一个卷积层的输出方差都为1,权值的初始化方法如下:



torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
版权声明
本文为[研究生不迟到]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_42521185/article/details/124329781
边栏推荐
- Tracup|使用项目管理软件帮助战胜拖延症
- 如何利用 xUnit 框架对测试用例进行维护?
- 3D 沙盒游戏之人物的点击行走移动
- The difference and relationship between glew, glee and GL Glu glut GLX glext
- glew, glee与 gl glu glut glx glext的区别和关系
- getchar,putchar,EOF
- Specific methods of configuring Profibus and PROFINET communication in two TIA botu projects
- 实战 | JMeter 典型电商场景(下单/支付)的性能压测
- Actual combat | performance pressure test of JMeter typical e-commerce scenario (order / payment)
- Multi factor strategy
猜你喜欢

实战 | 电商业务性能测试(二): Jmeter 参数化功能实现注册登录的数据驱动

IoT平台如何实现业务配置中心

在两个TIA博途项目中组态PROFIBUS和PROFINET通信的具体方法

LeetCode_509 斐波那契数

Click, walk and move of characters in 3D sandbox game

LeetCode_70 爬楼梯

深度剖析TCP三次握手,面试官拍案叫绝

The whole process of callback registration and callback of openharmony sensor module

Actual combat | performance pressure test of JMeter typical e-commerce scenario (order / payment)

ROS knowledge: how to realize camera access
随机推荐
IaaS,PaaS,SaaS 的区别
基于C的电子通讯录管理系统
TCP example of grpc implemented by golang
La classe Timer de la version C # conserve une décimale pour supporter la sortie de l'unit é mm / JJ / MM / MIN / sec après avoir été exacte à microsecondes
glew, glee與 gl glu glut glx glext的區別和關系
Use of register keyword
Shell: Variables
shell:变量
上午面了个腾讯拿 38K 出来的,让我见识到了基础的天花板
C # cannot be used for characters of file name
VS2019配置opencv4
分布式秒杀系统构建
Mysql 基础命令大全
[original] BigInteger. Large number multiplication. Large number operation. "Infinite number" multiplication. Comparison of two methods of large number multiplication
Status code encapsulation -- reprint
In depth analysis of TCP three handshakes, the interviewer applauded
C# 版本的 計時器類 精確到微秒 秒後保留一比特小數 支持年月日時分秒帶單比特的輸出
实战 | JMeter 典型电商场景(下单/支付)的性能压测
One click installation of ROS and rosdep (no wall)
Actual combat | e-commerce business performance test (II): JMeter parameterization function realizes data-driven registration and login