当前位置:网站首页>Timing model: gated cyclic unit network (Gru)
Timing model: gated cyclic unit network (Gru)
2022-04-23 15:39:00 【HadesZ~】
1. Model definition
Gated loop unit network (Gated Recurrent Unit,GRU)1 Is in LSTM A simplified variant developed on the basis of , It can usually achieve the same speed as LSTM The effect of the model is similar 2.
2. Model structure and forward propagation formula
GRU The hidden state calculation module of the model does not introduce additional memory units , The logic gate is simplified to Reset door (reset gate) and Update door (update gate), Its structural diagram and forward propagation formula are as follows :
{ transport Enter into : X t ∈ R m × d , H t − 1 ∈ R m × h heavy Set up door : R t = σ ( X t W x r + H t − 1 W h r + b r ) , W x r ∈ R d × h , W h r ∈ R h × h Hou choose implicit hidden shape state : H ~ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) , W x h ∈ R d × h , W h h ∈ R h × h more new door : Z t = σ ( X t W x z + H t − 1 W h z + b z ) , W x z ∈ R d × h , W h z ∈ R h × h implicit hidden shape state : H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t transport Out : Y ^ t = H t W h y + b y , W h y ∈ R h × q damage loss Letter Count : L = ∑ t = 1 T l ( ( ^ Y ) t , Y t ) (2.1) \begin{cases} Input : & X_t \in R^{m \times d}, \ \ \ \ H_{t-1} \in R^{m \times h} \\ \\ Reset door : & R_t = \sigma(X_tW_{xr} + H_{t-1}W_{hr} + b_r), & W_{xr} \in R^{d \times h},\ \ W_{hr} \in R^{h \times h} \\ \\ Candidate hidden status : & \tilde{H}_t = tanh(X_tW_{xh} + (R_t \odot H_{t-1})W_{hh} + b_h),\ \ & W_{xh} \in R^{d \times h},\ \ W_{hh} \in R^{h \times h} \\ \\ Update door : & Z_t = \sigma(X_tW_{xz} + H_{t-1}W_{hz} + b_z), & W_{xz} \in R^{d \times h},\ \ W_{hz} \in R^{h \times h} \\ \\ Hidden state : & H_t = Z_t \odot H_{t-1} + (1-Z_t) \odot \tilde{H}_t \\ \\ Output : & \hat{Y}_t = H_tW_{hy} + b_y, & W_{hy} \in R^{h \times q} \\ \\ Loss function : & L = \sum_{t=1}^{T} l(\hat(Y)_t, Y_t) \end{cases} \tag{2.1} ⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧ transport Enter into : heavy Set up door : Hou choose implicit hidden shape state : more new door : implicit hidden shape state : transport Out : damage loss Letter Count :Xt∈Rm×d, Ht−1∈Rm×hRt=σ(XtWxr+Ht−1Whr+br),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh), Zt=σ(XtWxz+Ht−1Whz+bz),Ht=Zt⊙Ht−1+(1−Zt)⊙H~tY^t=HtWhy+by,L=∑t=1Tl((^Y)t,Yt)Wxr∈Rd×h, Whr∈Rh×hWxh∈Rd×h, Whh∈Rh×hWxz∈Rd×h, Whz∈Rh×hWhy∈Rh×q(2.1)
3. GRU Back propagation process
Because no additional memory units are introduced , therefore GRU Back propagation calculation diagram and RNN Agreement ( Such as the author's article : Time series model : Cyclic neural network (RNN) Chinese 3 Shown ),GRU The back propagation formula is as follows :
∂ L ∂ Y ^ t = ∂ l ( Y ^ t , Y t ) T ⋅ ∂ Y ^ t (3.1) \frac{\partial L}{\partial \hat{Y}_t} = \frac{\partial l(\hat{Y}_t, Y_t)}{T \cdot\partial \hat{Y}_t} \tag {3.1} ∂Y^t∂L=T⋅∂Y^t∂l(Y^t,Yt)(3.1)
∂ L ∂ Y ^ t ⇒ { ∂ L ∂ W h y = ∂ L ∂ Y ^ t ∂ Y ^ t ∂ W h y ∂ L ∂ H t = { ∂ L ∂ Y ^ t ∂ Y ^ t ∂ H t , t = T ∂ L ∂ Y ^ t ∂ Y ^ t ∂ H t + ∂ L ∂ H t + 1 ∂ H t + 1 ∂ H t , t < T (3.2) \frac{\partial L}{\partial \hat{Y}_t} \Rightarrow \begin{cases} \frac{\partial L}{\partial W_{hy}} = \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial W_{hy}} \\ \\ \frac{\partial L}{\partial H_t} = \begin{cases} \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial H_t}, & t=T \\ \\ \frac{\partial L}{\partial \hat{Y}_t} \frac{\partial \hat{Y}_t}{\partial H_t} + \frac{\partial L}{\partial H_{t+1}}\frac{\partial H_{t+1}}{\partial H_{t}}, & t<T \end{cases} \end{cases} \tag {3.2} ∂Y^t∂L⇒⎩⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎧∂Why∂L=∂Y^t∂L∂Why∂Y^t∂Ht∂L=⎩⎪⎪⎨⎪⎪⎧∂Y^t∂L∂Ht∂Y^t,∂Y^t∂L∂Ht∂Y^t+∂Ht+1∂L∂Ht∂Ht+1,t=Tt<T(3.2)
{ ∂ L ∂ W x z = ∂ L ∂ Z t ∂ Z t ∂ W x z ∂ L ∂ W h z = ∂ L ∂ Z t ∂ Z t ∂ W h z ∂ L ∂ b z = ∂ L ∂ Z t ∂ Z t ∂ b z { ∂ L ∂ W x h = ∂ L ∂ H ~ t ∂ H ~ t ∂ W x h ∂ L ∂ W h h = ∂ L ∂ H ~ t ∂ H ~ t ∂ W h h ∂ L ∂ b h = ∂ L ∂ H ~ t ∂ H ~ t ∂ b h (3.3) \begin{matrix} \begin{cases} \frac{\partial L}{\partial W_{xz}} = \frac{\partial L}{\partial Z_{t}}\frac{\partial Z_{t}}{\partial W_{xz}} \\ \\ \frac{\partial L}{\partial W_{hz}} = \frac{\partial L}{\partial Z_{t}}\frac{\partial Z_{t}}{\partial W_{hz}} \\ \\ \frac{\partial L}{\partial b_{z}} = \frac{\partial L}{\partial Z_{t}}\frac{\partial Z_{t}}{\partial b_{z}} \end{cases} & & & & \begin{cases} \frac{\partial L}{\partial W_{xh}} = \frac{\partial L}{\partial \tilde{H}_t}\frac{\partial \tilde{H}_t}{\partial W_{xh}} \\ \\ \frac{\partial L}{\partial W_{hh}} = \frac{\partial L}{\partial \tilde{H}_t}\frac{\partial \tilde{H}_t}{\partial W_{hh}} \\ \\ \frac{\partial L}{\partial b_{h}} = \frac{\partial L}{\partial \tilde{H}_t}\frac{\partial \tilde{H}_t}{\partial b_{h}} \end{cases} \end{matrix} \tag {3.3} ⎩⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎧∂Wxz∂L=∂Zt∂L∂Wxz∂Zt∂Whz∂L=∂Zt∂L∂Whz∂Zt∂bz∂L=∂Zt∂L∂bz∂Zt⎩⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎧∂Wxh∂L=∂H~t∂L∂Wxh∂H~t∂Whh∂L=∂H~t∂L∂Whh∂H~t∂bh∂L=∂H~t∂L∂bh∂H~t(3.3)
{ ∂ L ∂ W x r = ∂ L ∂ R t ∂ R t ∂ W x r ∂ L ∂ W h r = ∂ L ∂ R t ∂ R t ∂ W h r ∂ L ∂ b r = ∂ L ∂ R t ∂ R t ∂ b r (3.4) \begin{cases} \frac{\partial L}{\partial W_{xr}} = \frac{\partial L}{\partial R_{t}}\frac{\partial R_{t}}{\partial W_{xr}} \\ \\ \frac{\partial L}{\partial W_{hr}} = \frac{\partial L}{\partial R_{t}}\frac{\partial R_{t}}{\partial W_{hr}} \\ \\ \frac{\partial L}{\partial b_{r}} = \frac{\partial L}{\partial R_{t}}\frac{\partial R_{t}}{\partial b_{r}} \end{cases} \tag {3.4} ⎩⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎧∂Wxr∂L=∂Rt∂L∂Wxr∂Rt∂Whr∂L=∂Rt∂L∂Whr∂Rt∂br∂L=∂Rt∂L∂br∂Rt(3.4)
And LSTM Empathy ,GRU The key to the solution of the back propagation formula is also the calculation of different time steps ( Pass on ) gradient The solution of , The method is similar to LSTM Consistent, this article will not repeat . And we can also draw qualitative conclusions ,GRU Principles and methods of alleviating long-term dependence LSTM similar , It is realized by adjusting the multiplier of high-order power term and adding low-order power term . among , Resetting the gate helps capture short-term dependencies in the sequence , Update gates help capture long-term dependencies in sequences .( Please refer to the author's article for details : Time series model : Long and short term memory network (LSTM) The proof process in )
4. Code implementation of the model
4.1 TensorFlow Framework implementations
4.2 Pytorch Framework implementations
import torch
from torch import nn
from torch.nn import functional as F
#
class Dense(nn.Module):
""" Args: outputs_dim: Positive integer, dimensionality of the output space. activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix. bias_initializer: Initializer for the bias vector. Input shape: N-D tensor with shape: `(batch_size, ..., input_dim)`. The most common situation would be a 2D input with shape `(batch_size, input_dim)`. Output shape: N-D tensor with shape: `(batch_size, ..., units)`. For instance, for a 2D input with shape `(batch_size, input_dim)`, the output would have shape `(batch_size, units)`. """
def __init__(self, input_dim, output_dim, **kwargs):
super().__init__()
# Hyperparametric definition
self.use_bias = kwargs.get('use_bias', True)
self.kernel_initializer = kwargs.get('kernel_initializer', nn.init.xavier_uniform_)
self.bias_initializer = kwargs.get('bias_initializer', nn.init.zeros_)
#
self.Linear = nn.Linear(input_dim, output_dim, bias=self.use_bias)
self.Activation = kwargs.get('activation', nn.ReLU())
# Parameter initialization
self.kernel_initializer(self.Linear.weight)
self.bias_initializer(self.Linear.bias)
def forward(self, inputs):
outputs = self.Activation(
self.Linear(inputs)
)
return outputs
#
class GRU_Cell(nn.Module):
def __init__(self, token_dim, hidden_dim
, reset_act=nn.ReLU()
, update_act=nn.ReLU()
, hathid_act=nn.Tanh()
, **kwargs):
super().__init__()
#
self.hidden_dim = hidden_dim
#
self.ResetG = Dense(
token_dim + self.hidden_dim, self.hidden_dim
, activation=reset_act, **kwargs
)
self.UpdateG = Dense(
token_dim + self.hidden_dim, self.hidden_dim
, activation=update_act, **kwargs
)
self.HatHidden = Dense(
token_dim + self.hidden_dim, self.hidden_dim
, activation=hathid_act, **kwargs
)
def forward(self, inputs, last_state):
last_hidden = last_state[-1]
#
Rg = self.ResetG(
torch.concat([inputs, last_hidden], dim=1)
)
Zg = self.UpdateG(
torch.concat([inputs, last_hidden], dim=1)
)
hat_hidden = self.HatHidden(
torch.concat([inputs, Rg * last_hidden], dim=1)
)
hidden = Zg*last_hidden + (1-Zg)*hat_hidden
#
return [hidden]
def zero_initialization(self, batch_size):
return torch.zeros([batch_size, self.hidden_dim])
#
class RNN_Layer(nn.Module):
def __init__(self, rnn_cell, bidirectional=False):
super().__init__()
self.RNNCell = rnn_cell
self.bidirectional = bidirectional
def forward(self, inputs, mask=None, initial_state=None):
""" inputs: it's shape is [batch_size, time_steps, token_dim] mask: it's shape is [batch_size, time_steps] :return hidden_state_seqence: its' shape is [batch_size, time_steps, hidden_dim] last_state: it is the hidden state of input sequences at last time step, but, attentively, the last token wouble be a padding token, so this last state is not the real last state of input sequences; if you want to get the real last state of input sequences, please use utils.get_rnn_last_state(hidden_state_seqence). """
batch_size, time_steps, token_dim = inputs.shape
#
if initial_state is None:
initial_state = self.RNNCell.zero_initialization(batch_size)
if mask is None:
if batch_size == 1:
mask = torch.ones([1, time_steps])
else:
raise ValueError(' Please give a mask matrix (mask)')
# Forward time step cycle
hidden_list = []
hidden_state = initial_state
for i in range(time_steps):
hidden_state = self.RNNCell(inputs[:, i], hidden_state)
hidden_list.append(hidden_state[-1])
hidden_state = [hidden_state[j] * mask[:, i:i+1] + initial_state[j] * (1-mask[:, i:i+1])
for j in range(len(hidden_state))] # Reinitialize ( Addend function )
#
seqences = torch.reshape(
torch.unsqueeze(
torch.concat(hidden_list, dim=1), dim=1
)
, [batch_size, time_steps, -1]
)
last_state = hidden_list[-1]
# Reverse time step cycle
if self.bidirectional is True:
hidden_list = []
hidden_state = initial_state
for i in range(time_steps, 0, -1):
hidden_state = self.RNNCell(inputs[:, i-1], hidden_state)
hidden_list.append(hidden_state[-1])
hidden_state = [hidden_state[j] * mask[:, i-1:i] + initial_state[j] * (1 - mask[:, i-1:i])
for j in range(len(hidden_state))] # Reinitialize ( Addend function )
#
seqences = torch.concat([
seqences,
torch.reshape(
torch.unsqueeze(
torch.concat(hidden_list, dim=1), dim=1)
, [batch_size, time_steps, -1])
]
, dim=-1
)
last_state = torch.concat([
last_state
, hidden_list[-1]
]
, dim=-1
)
return {
'hidden_state_seqences': seqences
, 'last_state': last_state
}
Cho, K., Van Merriënboer, B., Bahdanau, D., & Bengio, Y. (2014). On the properties of neural machine translation: encoder-decoder approaches. arXiv preprint arXiv:1409.1259. ︎
Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555. ︎
版权声明
本文为[HadesZ~]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204231536010859.html
边栏推荐
- php函数
- 导入地址表分析(根据库文件名求出:导入函数数量、函数序号、函数名称)
- Single architecture system re architecture
- s16.基于镜像仓库一键安装containerd脚本
- Explanation of redis database (I)
- Crawling fragment of a button style on a website
- PHP PDO ODBC loads files from one folder into the blob column of MySQL database and downloads the blob column to another folder
- Elk installation
- 电脑怎么重装系统后显示器没有信号了
- Codejock Suite Pro v20.3.0
猜你喜欢
[leetcode daily question] install fence
移动金融(自用)
Multi level cache usage
现在做自媒体能赚钱吗?看完这篇文章你就明白了
ICE -- 源码分析
Explanation 2 of redis database (redis high availability, persistence and performance management)
Openstack command operation
Openstack theoretical knowledge
Mysql database explanation (8)
Independent operation smart farm Innovation Forum
随机推荐
What role does the software performance test report play? How much is the third-party test report charged?
Sorting and replying to questions related to transformer
推荐搜索 常用评价指标
基于 TiDB 的 Apache APISIX 高可用配置中心的最佳实践
北京某信护网蓝队面试题目
Leetcode学习计划之动态规划入门day3(198,213,740)
Detailed explanation of MySQL connection query
Functions (Part I)
一刷313-剑指 Offer 06. 从尾到头打印链表(e)
How to test mobile app?
Explanation of redis database (III) redis data type
Go并发和通道
After time judgment of date
Codejock Suite Pro v20.3.0
Knn,Kmeans和GMM
开源项目推荐:3D点云处理软件ParaView,基于Qt和VTK
负载均衡器
Rsync + inotify remote synchronization
Openstack theoretical knowledge
Do keyword search, duplicate keyword search, or do not match