当前位置:网站首页>PyTorch 12. Hook usage
PyTorch 12. Hook usage
2022-04-23 07:28:00 【DCGJ666】
PyTorch 12. hook Usage of
hook
- because pytorch The intermediate result of graph calculation will be discarded automatically , So if you want these values, you need to use hook functions . Hook functions include Variable Hook and nn.Module hook , Similar usage .
- In the use of hook Function time , Its input should not be modified , But it can return a new gradient that replaces the current gradient , namely , Using this function will return a gradient value
register_hook
in the light of Tensor Variable hook function
import torch
grad_list = []
def print_grad(grad):
grad_list.append(grad)
x = torch.randn(2,1)
y = x+2
y.register_hook(print_grad)
y.backward()
When the whole network carries out reverse transmission , After running to the variable registered by the hook function , The gradient of the variable is saved , call print_grad function , Add gradient to grad_list in
register_forward_hook
For the network layer hook function , Specific visual blog , You can refer to my other document , Visualize specific layers
Hook functions should not modify input and output , And it should be deleted in time after use , To avoid increasing the running load by running the hook every time . Hook function is mainly used to obtain some intermediate results , Such as the output of an intermediate layer or the gradient of a layer .
import torch
model = VGG()
features = torch.Tensor()
def hook(module, input, output):
# Copy the output of this layer to features in
features.copy_(output.data)
handle = model.layer8.register_forward_hook(hook)
_ = model(input)
# run out hook Delete after
handle.remove()
register_backward_hook
First introduced Container The concept of : When Module Of forward There is only one of the functions Function When , be called Module, If Module Include others Module, be called Container.
stay module Register one on backward hook. The target of this method can only be used in Module On , Cannot be used in Container On .
Every time calculation module Of inputs The gradient of , This hook Will be called
hook(module,grad_input,grad_output)->Tensor or None
Find the gradient of the module , And register_grad similar
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611343899.html
边栏推荐
- onnxruntime-gpu 1.7 出现的警告“Force fallback to CPU execution for node: Gather_191”等
- PyTorch 9. 优化器
- Infrared sensor control switch
- 《Attention in Natural Language Processing》翻译
- WinForm scroll bar beautification
- 南方投资大厦SDC智能通信巡更管理系统
- Wechat applet uses wxml2canvas plug-in to generate some problem records of pictures
- 以智能生产引领行业风潮!美摄智能视频生产平台亮相2021世界超高清视频产业发展大会
- PyTorch 14. Module class
- 美摄科技推出桌面端专业视频编辑解决方案——美映PC版
猜你喜欢
项目文件“ ”已被重命名或已不在解决方案中、未能找到与解决方案关联的源代码管理提供程序——两个工程问题
公专融合对讲机是如何实现多模式通信下的协同工作?
PyTorch 10. Learning rate
LPDDR4笔记
以智能生产引领行业风潮!美摄智能视频生产平台亮相2021世界超高清视频产业发展大会
【点云系列】Relationship-based Point Cloud Completion
AUTOSAR从入门到精通100讲(八十六)-UDS服务基础篇之2F
电力行业巡检对讲通信系统
UEFI学习01-ARM AARCH64编译、ArmPlatformPriPeiCore(SEC)
1.1 pytorch and neural network
随机推荐
关于短视频平台框架搭建与技术选型探讨
Chapter 2 pytoch foundation 1
【点云系列】Neural Opacity Point Cloud(NOPC)
x86架构初探之8086
excel实战应用案例100讲(八)-Excel的报表连接功能
Draw margin curve in arcface
海南凤凰机场智能通信解决方案
基于51单片机的三路超声波测距系统(定时器方式测距)
地铁无线对讲系统
UEFI学习01-ARM AARCH64编译、ArmPlatformPriPeiCore(SEC)
无盲区、长续航|公专融合对讲机如何提升酒店服务效率?
美摄科技推出桌面端专业视频编辑解决方案——美映PC版
画 ArcFace 中的 margin 曲线
Chapter 2 pytoch foundation 2
自组网灵活补盲|北峰油气田勘测解决方案
AUTOSAR从入门到精通100讲(八十三)-BootLoader自我刷新
torch.where能否传递梯度
《Attention in Natural Language Processing》翻译
onnxruntime-gpu 1.7 出现的警告“Force fallback to CPU execution for node: Gather_191”等
“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggregated