当前位置:网站首页>PyTorch 12. hook的用法
PyTorch 12. hook的用法
2022-04-23 06:11:00 【DCGJ666】
PyTorch 12. hook的用法
hook
- 由于pytorch会自动舍弃图计算的中间结果,所以想要这些数值就需要使用钩子函数。钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
- 在使用hook函数时,不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度,即,使用该函数会返回一个梯度值
register_hook
针对Tensor变量的hook函数
import torch
grad_list = []
def print_grad(grad):
grad_list.append(grad)
x = torch.randn(2,1)
y = x+2
y.register_hook(print_grad)
y.backward()
在整个网络进行反向传递时,运行到钩子函数注册的变量后,会保存该变量的梯度,调用print_grad函数,将梯度添加到grad_list中
register_forward_hook
针对网络层的hook函数,具体的可视化博客,可以参考我的另一篇文档,可视化特定层
钩子函数不应该修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。
import torch
model = VGG()
features = torch.Tensor()
def hook(module, input, output):
# 把这一层的输出拷贝到features中
features.copy_(output.data)
handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()
register_backward_hook
首先介绍Container的概念:当Module的forward函数中只有一个Function的时候,称为Module,如果Module包含其他Module,称之为Container.
在module上注册一个backward hook。此方法目标只能用在Module上,不能用在Container上。
每次计算module的inputs的梯度时,这个hook会被调用
hook(module,grad_input,grad_output)->Tensor or None
求取模块的梯度,与register_grad类似
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/121638159
边栏推荐
- Component based learning (3) path and group annotations in arouter
- Mysql database installation and configuration details
- MySQL notes 5_ Operation data
- 【点云系列】SO-Net:Self-Organizing Network for Point Cloud Analysis
- 【2021年新书推荐】Kubernetes in Production Best Practices
- Component based learning (1) idea and Implementation
- 【 planification dynamique】 différentes voies 2
- PyTorch 模型剪枝实例教程三、多参数与全局剪枝
- Data class of kotlin journey
- Android exposed components - ignored component security
猜你喜欢

【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation

【点云系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
树莓派:双色LED灯实验

Component learning (2) arouter principle learning
![[3D shape reconstruction series] implicit functions in feature space for 3D shape reconstruction and completion](/img/4d/6d5821759766a6bf1d77ad51b69e24.png)
[3D shape reconstruction series] implicit functions in feature space for 3D shape reconstruction and completion

免费使用OriginPro学习版

【2021年新书推荐】Enterprise Application Development with C# 9 and .NET 5

MySQL的安装与配置——详细教程

Record WebView shows another empty pit

第5 章 机器学习基础
随机推荐
PyMySQL连接数据库
PaddleOCR 图片文字提取
第4章 Pytorch数据处理工具箱
GEE配置本地开发环境
ArcGIS license server administrator cannot start the workaround
Android暴露组件——被忽略的组件安全
Pytorch trains the basic process of a network in five steps
【3D形状重建系列】Implicit Functions in Feature Space for 3D Shape Reconstruction and Completion
机器学习 三: 基于逻辑回归的分类预测
WebRTC ICE candidate里面的raddr和rport表示什么?
红外传感器控制开关
Android exposed components - ignored component security
Handler进阶之sendMessage原理探索
【2021年新书推荐】Effortless App Development with Oracle Visual Builder
[2021 book recommendation] learn winui 3.0
免费使用OriginPro学习版
Machine learning II: logistic regression classification based on Iris data set
MySQL的安装与配置——详细教程
Minesweeping games
如何对多维矩阵进行标准化(基于numpy)