当前位置:网站首页>torch. autograd. Function customization
torch. autograd. Function customization
2022-04-21 10:28:00 【hxxjxw】
although pytorch It can be derived automatically , But sometimes some operations are not differentiable , At this time, you need to customize the derivation method . It's called "Extending torch.autograd".
If you want to pass Function Customize an action , need
① Inherit torch.autograd.Function This class
from torch.autograd import Function class LinearFunction(Function):② Realization forward() and backward()
attribute ( Member variables ) saved_tensors: Pass to forward() Parameters of , stay backward() Will be used in . needs_input_grad: The length is :attr:num_inputs Of bool Tuples , The gradient indicates whether the output is required . It can be used to optimize the cache of the reverse process . num_inputs: Pass to function :func:forward The number of parameters . num_outputs: function :func:forward Number of returned values . requires_grad: Boolean value , According to the function :func:backward Whether it will never be called . Member functions forward() forward() There can be any number of inputs 、 Any number of outputs , But the input and output must be Variable.( In the official examples, only tensor As an example of parameters ) backward() backward() The number of inputs and outputs is forward() The number of outputs and inputs of the function . among ,backward() Input indicates about forward() The gradient of the output ( Calculate the gradient of the previous node in the graph ),backward() The output of represents about forward() The gradient of the input . When the input does not require a gradient ( By looking at needs_input_grad Parameters ) Or when it's not differentiable , Can return None.ctx is a context object that can be used to stash information for backward computation
ctx You can use
save_for_backwardTo preserve tensors, stay backward The stage can be acquiredexample 1
z
import torch from torch import nn from torch.autograd import Function import torch class Exp(Function): @staticmethod def forward(ctx, input): result = torch.exp(input) ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): result, = ctx.saved_tensors return grad_output * result x = torch.rand(4,3,5,5) exp = Exp.apply # Use it by calling the apply method: output = exp(x) print(output.shape)Self defined forward and backward Use static methods , Others on the Internet have written def forward(self, input_): This form , But this way of writing is about to be Pytorch Eliminated
example 2
import torch from torch import nn from torch.autograd import Function import torch class MyReLU(Function): @staticmethod def forward(ctx, input_): # stay forward in , Need to define MyReLU This operation is forward The calculation process # At the same time, any variable value that needs to be used in backward propagation can be saved ctx.save_for_backward(input_) # Save the input , stay backward When using output = input_.clamp(min=0) # relu Is to truncate negative numbers , Let all negative numbers equal 0 return output @staticmethod def backward(ctx, grad_output): # according to BP Algorithm derivation ( The chain rule ),dloss / dx = (dloss / doutput) * (doutput / dx) # dloss / doutput It's the input parameter grad_output、 # So just need relu The derivative of , Times grad_output input_, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input_ < 0] = 0 # The result of the appeal calculation is left . namely ReLU In back propagation, it can be regarded as a channel selection function , All do not reach the threshold ( Activation value <0) The gradients of all elements are 0 return grad_input x = torch.rand(4,3,5,5) myrelu = MyReLU.apply # Use it by calling the apply method: output = myrelu(x) print(output.shape)example 3
import torch from torch.autograd import Function from torch.autograd import gradcheck class LinearFunction(Function): # establish torch.autograd.Function A subclass of class # Must be staticmethod @staticmethod # The first is ctx, The second is input, Other parameters are optional . # ctx It's similar here self,ctx The properties of can be found in backward Call in . # Self defined Function Medium forward() Method , be-all Variable The parameter will be converted to tensor! So here's input It's also tensor. In the incoming forward front ,autograd engine Will automatically Variable unpack become Tensor. def forward(ctx, input, weight, bias=None): ctx.save_for_backward(input, weight, bias) # take Tensor Turn into Variable Save to ctx in output = input @ weight.t() if bias is not None: output += bias.unsqueeze(0).expand_as(output) #unsqueeze(0) Expansion section 0 dimension # expand_as(tensor) Equivalent to expand(tensor.size()), The original tensor According to the new size Expand return output @staticmethod def backward(ctx, grad_output): # grad_output The gradient value calculated for the upper stage of back propagation input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None # Each represents the input , A weight , Bias the gradient of the three # Judge the corresponding Variable Whether it is necessary to calculate the gradient by reverse derivation if ctx.needs_input_grad[0]: grad_input = grad_output @ weight # Derivative of compound function , The chain rule if ctx.needs_input_grad[1]: grad_weight = grad_output.t() @ input # Derivative of compound function , The chain rule if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0).squeeze(0) return grad_input, grad_weight, grad_bias linear = LinearFunction.apply # gradchek takes a tuple of tensor as input, check if your gradient # evaluated with these tensors are close enough to numerical # approximations and returns True if they all verify this condition. input = torch.randn(20,20,requires_grad=True).double() weight = torch.randn(20,20,requires_grad=True).double() bias = torch.randn(20,requires_grad=True).double() test = gradcheck(LinearFunction.apply, (input,weight,bias), eps=1e-6, atol=1e-4) print(test) # If there is no problem, output Truectx.needs_input_grad As a boolean The representation of type can also be used to control each input Is it necessary to calculate the gradient ,e.g., ctx.needs_input_grad[0] = False, Express forward The first one in input You don't need gradients , If we return The gradient value of this position , by None that will do
Function And Module Differences and application scenarios
Function And Module All right pytorch Do custom expansion , Make it meet the needs of the network , But there are important differences between the two :
- Function Generally, only one operation is defined , Because it can't save parameters , So it applies to activation functions 、pooling Wait for the operation ;Module Yes, the parameters are saved , So it's appropriate to define a layer of , Such as linear layer , Convolution layer , It also applies to defining a network
- Function Three methods need to be defined :__init__, forward, backward( You need to write your own derivation formula );Module: Just define __init__ and forward, and backward The calculation of is made up of an automatic derivation mechanism
版权声明
本文为[hxxjxw]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204211023332441.html
边栏推荐
- "Air washing" meets the iteration again, and the imitator has a new goal
- DFS of vigorously flying brick (creation of tree)
- 看完这篇 教你玩转渗透测试靶机vulnhub——DC9
- Zsh: segmentation fault solution
- 两两数之和
- Ant a-table data synchronization
- The prospectus of quwan group is "invalid", and its TT voice has been taken off the shelf. How to achieve stable growth?
- shell脚本修改文件名和权限
- 搭建个人主页保姆级教程(一)
- (SIP-1-话机注册)关于IP话机通过SIP协议注册到PBX电话交换机的全过程解析-如何看wireshark中的报文
猜你喜欢

Pytorch梯度检查 torch.autograd.gradcheck

【并发编程043】CAS存在的问题,ABA问题,如何解决的?

HMS Core 6.4.0版本发布公告

ConvNeXt

openCV——模板匹配

Pytorch学习笔记(3)torch.autograd,逻辑回归模型训练

zsh: segmentation fault 解决方法

CommDGI: Community detection oriented deep graph infomax 2020 CIKM

Using pycharm to load the QRC resource file in pyside2

Filebeat收集日志数据传输到Redis,通过Logstash来根据日志字段创建不同的ES索引
随机推荐
SQL:树形三层职业分类表的SQL文件
说一说期货网络开户的安全系数
数字经济-新经济指数(2017-2022)&31省数字经济测算(2013-2020)两大维度指标
JS initial practice -- an example of dealing with the collision between a pinball and a wall
有关gethostbyname()的不可重入
WinPcap获取设备列表
L1-048 矩阵A乘以B (15 分)
Alibaba cloud mobile R & D platform EMAS, product dynamics in March
【并发编程044】CAS循环时间太长,会有什么问题?
VS 2019中使用qt
JVM调优笔记
Mapbox 创建多个可拖动的标记点
Release announcement of HMS core version 6.4.0
Microsoft updates the verifier application for Android / IOS to support the generation of more secure strong passwords
Canoe: what is the vector tool platform
UVM和SystemVerilog中的返回值使用方法
Construction of mobile communication platform (voice visual screen, sending and receiving SMS)
我的创作纪念日
Using pycharm to load the QRC resource file in pyside2
页面导航-声明式/编程式导航
z