当前位置:网站首页>31 - Gru principle and source code line by line implementation
31 - Gru principle and source code line by line implementation
2022-04-21 19:05:00 【It's hard to choose a name】
List of articles
1. Principle structure

2. Code details
import torch
from torch import nn
batch_size,sequence_length,H_in,H_out = 2,3,4,5
gru_layer = nn.GRU(input_size=4,hidden_size=5,batch_first=True)
input = torch.randn((batch_size,sequence_length,H_in))
h0 = torch.randn((batch_size,H_out))
output,h_n = gru_layer(input,h0.unsqueeze(0))
# weight_ih_l0 torch.Size([15, 4]) -> (3*H_out,H_in)
# weight_hh_l0 torch.Size([15, 5]) -> (3*H_out,H_out)
# bias_ih_l0 torch.Size([15]) -> (3*H_out)
# bias_hh_l0 torch.Size([15]) -> (3*H_out)
def custom_gru(input,h0,w_ih,w_hh,b_ih,b_hh):
# define the w_times_x
bs,T,h_in=input.shape
h_out = w_ih.shape[0]//3
# w_ih.shape=torch.Size([3*h_out,h_in])
# batch_w_ih.shape = torch.Size([bs,3*h_out,h_in])
batch_w_ih = w_ih.unsqueeze(0).tile([bs,1,1])
# h0.shape=prev_h.shape=torch.Size([bs,h_out])
prev_h = h0
# w_hh.shape=torch.Size([3*h_out,h_out])
# batch_w_hh=torch.Size([bs,3*h_out,h_out])
batch_w_hh = w_hh.unsqueeze(0).tile([bs,1,1])
output = torch.zeros([bs,T,h_out])
for t in range(T):
# input.shape=torch.Size([bs,T,h_in])
# x.shape=torch.Size([bs,h_in])->([bs,h_in,1])
x = input[:,t,:].unsqueeze(-1)
# batch_w_ih.shape=torch.Size([bs,3*h_out,h_in])
# w_ih_times_x.shape=torch.Size([bs,3*h_out,1])->([bs,3*h_out])
w_ih_times_x = torch.bmm(batch_w_ih,x).squeeze(-1)
# batch_w_hh.shape=torch.Size([bs,3*h_out,h_out])
# prev_h.shape=torch.Size([bs,h_out])->([bs,h_out,1])
# w_hh_times_x.shape=torch.Size([bs,3*h_out,1]) ->([bs,3*h_out])
w_hh_times_x = torch.bmm(batch_w_hh,prev_h.unsqueeze(-1)).squeeze(-1)
r_t = torch.sigmoid(w_ih_times_x[:,:h_out]+b_ih[:h_out]+w_hh_times_x[:,:h_out]+b_hh[:h_out])
z_t = torch.sigmoid(w_ih_times_x[:,h_out:2*h_out]+b_ih[h_out:2*h_out]+w_hh_times_x[:,h_out:2*h_out]+b_hh[h_out:2*h_out])
n_t = torch.tanh(w_ih_times_x[:,2*h_out:3*h_out]+b_ih[2*h_out:3*h_out]+r_t*(w_hh_times_x[:,2*h_out:3*h_out]+b_hh[2*h_out:3*h_out]))
prev_h = (1-z_t)*n_t+z_t*prev_h
output[:,t,:] = prev_h
# prev_h.shape=torch.Size([bs,h_out])
# h_n.shape=torch.Size([1,bs,h_out])
h_n = prev_h.unsqueeze(0)
return output,prev_h
# def custom_gru(input,h0,w_ih,w_hh,b_ih,b_hh):
cu_input = input
cu_h0 = h0
cu_w_ih = gru_layer.weight_ih_l0
cu_w_hh = gru_layer.weight_hh_l0
cu_b_ih = gru_layer.bias_ih_l0
cu_b_hh = gru_layer.bias_hh_l0
cu_output,cu_hn = custom_gru(cu_input,cu_h0,cu_w_ih,cu_w_hh,cu_b_ih,cu_b_hh)
torch.isclose(output,cu_output)
(tensor([[[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]],
[[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]]])
torch.isclose(cu_hn,h_n)
tensor([[[True, True, True, True, True],
[True, True, True, True, True]]]))
3. Summary
The parameter is LSTM Quantitative 3/4. Implementation principle LSTM It's as simple as .
版权声明
本文为[It's hard to choose a name]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204211900501037.html
边栏推荐
- 国标GB28181平台EasyGBS如何开启语音对讲功能?
- Can the stadium system be connected to other intelligent devices
- Database advanced learning: index classification and creation syntax
- Wide application of medical robot in AI field
- leetcode:423. 从英文中重建数字
- Qianxin monitoring equipment --- Jiaotu
- How much can I get a month with a PMP certificate
- Crystal Chem小鼠葡萄糖检测试剂盒说明书
- APM industry awareness series - III
- Analytic robot intelligent reasoning planning
猜你喜欢
随机推荐
APM industry awareness series - IV
编程中的Context(上下文)
西电信号与系统
PAT (Advanced Level) 1096——Consecutive连续的
mysql (三) 索引优化以及案例分析
Is it useful for newly graduated college students to take the PMP test?
第五章 使用 matplotlib 绘制饼图
Crystal Chem小鼠葡萄糖检测试剂盒说明书
MySQL cannot use MySQL - U root - P to start error reporting and solve it
2022.04.21(LC_56_合并区间)
"Actual combat" realizes linear regression with tensorflow
[untitled]
NMI paper by payel DAS and others of IBM Research Institute: general machine learning framework for optimizing molecules
看机器人教育二十一世纪之变
An important trend in the development of children's programming training
APM industry awareness series - VII - 17 Ways to define Devops
使用MCUXpresso开发RT1060(1)——驱动RGB接口LCD
EasyGBS关闭了录像计划,为何还有录像文件生成?
[talkative cloud native] load balancing - the passenger flow of small restaurants has increased
Apply El tooltip (bubble text prompt box) in El tabs




![[talkative cloud native] load balancing - the passenger flow of small restaurants has increased](/img/ba/4ccf0c2181572fed16bbc9c797d557.png)



