当前位置:网站首页>PyTorch 12. hook的用法
PyTorch 12. hook的用法
2022-04-23 06:11:00 【DCGJ666】
PyTorch 12. hook的用法
hook
- 由于pytorch会自动舍弃图计算的中间结果,所以想要这些数值就需要使用钩子函数。钩子函数包括Variable的钩子和nn.Module钩子,用法相似。
- 在使用hook函数时,不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度,即,使用该函数会返回一个梯度值
register_hook
针对Tensor变量的hook函数
import torch
grad_list = []
def print_grad(grad):
grad_list.append(grad)
x = torch.randn(2,1)
y = x+2
y.register_hook(print_grad)
y.backward()
在整个网络进行反向传递时,运行到钩子函数注册的变量后,会保存该变量的梯度,调用print_grad函数,将梯度添加到grad_list中
register_forward_hook
针对网络层的hook函数,具体的可视化博客,可以参考我的另一篇文档,可视化特定层
钩子函数不应该修改输入和输出,并且在使用后应及时删除,以避免每次都运行钩子增加运行负载。钩子函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。
import torch
model = VGG()
features = torch.Tensor()
def hook(module, input, output):
# 把这一层的输出拷贝到features中
features.copy_(output.data)
handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# 用完hook后删除
handle.remove()
register_backward_hook
首先介绍Container的概念:当Module的forward函数中只有一个Function的时候,称为Module,如果Module包含其他Module,称之为Container.
在module上注册一个backward hook。此方法目标只能用在Module上,不能用在Container上。
每次计算module的inputs的梯度时,这个hook会被调用
hook(module,grad_input,grad_output)->Tensor or None
求取模块的梯度,与register_grad类似
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/121638159
边栏推荐
- 【2021年新书推荐】Red Hat RHCSA 8 Cert Guide: EX200
- 第4章 Pytorch数据处理工具箱
- PyTorch 模型剪枝实例教程三、多参数与全局剪枝
- 红外传感器控制开关
- 【2021年新书推荐】Professional Azure SQL Managed Database Administration
- 素数求解的n种境界
- MySQL的安装与配置——详细教程
- DCMTK (dcm4che) works together with dicoogle
- Reading notes - activity
- Migrating your native/mobile application to Unified Plan/WebRTC 1.0 API
猜你喜欢
[2021 book recommendation] kubernetes in production best practices
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5
【点云系列】Multi-view Neural Human Rendering (NHR)
Machine learning III: classification prediction based on logistic regression
Google AdMob advertising learning
1.1 pytorch and neural network
ThreadLocal, just look at me!
【2021年新书推荐】Effortless App Development with Oracle Visual Builder
【2021年新书推荐】Artificial Intelligence for IoT Cookbook
随机推荐
【2021年新书推荐】Professional Azure SQL Managed Database Administration
【动态规划】三角形最小路径和
Reading notes - activity
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
[3D shape reconstruction series] implicit functions in feature space for 3D shape reconstruction and completion
[Point Cloud Series] SG - Gan: Adversarial Self - attachment GCN for Point Cloud Topological parts Generation
给女朋友写个微信双开小工具
Fill the network gap
【点云系列】Neural Opacity Point Cloud(NOPC)
1.2 preliminary pytorch neural network
【点云系列】Learning Representations and Generative Models for 3D pointclouds
微信小程序 使用wxml2canvas插件生成图片部分问题记录
【 planification dynamique】 différentes voies 2
Machine learning II: logistic regression classification based on Iris data set
Chapter 2 pytoch foundation 1
【Tensorflow】共享机制
Chapter 4 pytoch data processing toolbox
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
C connection of new world Internet of things cloud platform (simple understanding version)