当前位置:网站首页>【机器学习】门控循环单元
【机器学习】门控循环单元
2022-04-21 17:07:00 【CHH3213】
1. GRU
在循环神经⽹络中的梯度计算⽅法中,我们发现,当时间步数较⼤或者时间步较小时,循环神经⽹络的梯度较容易出现衰减或爆炸。虽然裁剪梯度可以应对梯度爆炸,但⽆法解决梯度衰减的问题。通常由于这个原因,循环神经⽹络在实际中较难捕捉时间序列中时间步距离较⼤的依赖关系。
门控循环神经⽹络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较⼤的依赖关系。它通过可以学习的⻔来控制信息的流动。其中,门控循环单元(gatedrecurrent unit,GRU)是⼀种常⽤的门控循环神经⽹络。
2. ⻔控循环单元
2.1 重置门和更新门
GRU它引⼊了**重置⻔(reset gate)和更新⻔(update gate)**的概念,从而修改了循环神经⽹络中隐藏状态的计算⽅式。
门控循环单元中的重置⻔和更新⻔的输⼊均为当前时间步输⼊ 与上⼀时间步隐藏状态,输出由激活函数为sigmoid函数的全连接层计算得到。 如下图所示:

具体来说,假设隐藏单元个数为 h,给定时间步 t 的小批量输⼊ X t ∈ R n ∗ d X_t\in_{}\mathbb{R}^{n*d} Xt∈Rn∗d(样本数为n,输⼊个数为d)和上⼀时间步隐藏状态 H t − 1 ∈ R n ∗ h H_{t-1}\in_{}\mathbb{R}^{n*h} Ht−1∈Rn∗h。重置⻔ H t ∈ R n ∗ h H_t\in_{}\mathbb{R}^{n*h} Ht∈Rn∗h和更新⻔ Z t ∈ R n ∗ h Z_t\in_{}\mathbb{R}^{n*h} Zt∈Rn∗h的计算如下:
R t = σ ( X t W x r + H t − 1 W h r + b r ) R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r) Rt=σ(XtWxr+Ht−1Whr+br)
Z t = σ ( X t W x z + H t − 1 W h z + b z ) Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z) Zt=σ(XtWxz+Ht−1Whz+bz)
sigmoid函数可以将元素的值变换到0和1之间。因此,重置⻔ R t R_t Rt和更新⻔ Z t Z_t Zt中每个元素的值域都是[0,1]。
2.2 候选隐藏状态
接下来,⻔控循环单元将计算候选隐藏状态来辅助稍后的隐藏状态计算。我们将当前时间步重置⻔的输出与上⼀时间步隐藏状态做按元素乘法(符号为⊙)。如果重置⻔中元素值接近0,那么意味着重置对应隐藏状态元素为0,即丢弃上⼀时间步的隐藏状态。如果元素值接近1,那么表⽰保留上⼀时间步的隐藏状态。然后,将按元素乘法的结果与当前时间步的输⼊连结,再通过含激活函数tanh的全连接层计算出候选隐藏状态,其所有元素的值域为[-1,1]。

具体来说,时间步 t 的候选隐藏状态 H ~ ∈ R n ∗ h \tilde{H}\in_{}\mathbb{R}^{n*h} H~∈Rn∗h的计算为:
H ~ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h \tilde{H}_t=tanh(X_tW_{xh}+(R_t⊙H_{t-1})W_{hh}+b_h H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh
2.3 隐藏状态
时间步t的隐藏状态 H t ∈ R n ∗ h H_t\in_{}\mathbb{R}^{n*h} Ht∈Rn∗h的计算使⽤当前时间步的更新⻔ Z t Z_t Zt来对上⼀时间步的隐藏状态 H t − 1 H_{t-1} Ht−1和当前时间步的候选隐藏状态 H ~ t \tilde{H}_t H~t做组合:

值得注意的是,更新⻔可以控制隐藏状态应该如何被包含当前时间步信息的候选隐藏状态所更新,如上图所⽰。假设更新⻔在时间步 t ′ t' t′到 t ( t ′ < t ) t(t'<t) t(t′<t)之间⼀直近似1。那么,在时间步 t ′ t' t′到 t t t间的输⼊信息⼏乎没有流⼊时间步 t 的隐藏状态 H t H_t Ht,这可以看作是较早时刻的隐藏状态 H t ′ − 1 H_{t^{′}-1} Ht′−1通过时间保存并传递⾄当前时间步 t。这个设计可以应对循环神经⽹络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较⼤的依赖关系。
我们对⻔控循环单元的设计稍作总结:
- 重置⻔有助于捕捉时间序列⾥短期的依赖关系;
- 更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。
版权声明
本文为[CHH3213]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_42301220/article/details/124314368
边栏推荐
- How to hide important data in Excel table
- [free] download 19860 yuan programming course materials on a platform, only once
- 一、数据库系列之数据库系统概述
- MySQL查询语句关键字执行的优先级问题
- 建议提供大纲 - 一键展开功能
- Summary of Wu Enda's course of machine learning (I)
- 使用epoll时需要将socket设为非阻塞吗?
- sqli-labs 23-25a关闯关心得与思路
- Vitis HLS build project and generate IP core (vivado HLS)
- 824. 山羊拉丁文
猜你喜欢

解读论文记录 指出经典的RMS证明过程小错误的一个论文的解读

Win10 bridging network card enables QEMU virtual machine to access the network normally
![[newcode] cattle team competition](/img/1f/e4bc0a246c4e6631a9201b067d07cd.png)
[newcode] cattle team competition

机器学习吴恩达课程总结(四)

Summary of Wu Enda's course of machine learning (5)

Database Principle -- library management system

前五章内容思维导图

Summary of Wu Enda's course of machine learning (4)

. net treasure API: ihostedservice, background task execution

域内信息查询工具AdFind
随机推荐
R语言使用grepl函数检查子字符串是否存在于指定的字符串中、字符串匹配,负责搜索给定字符串对象中是否包含特定表达式
Image Manipulation Detection by Multi-View Multi-Scale Supervision
建议提供大纲 - 一键展开功能
338-Leetcode 同构字符串
mysql为数据库表起别名的注意事项
Shell case series 5 Oracle detects whether there are invalid objects
幹貨 | 實戰演練基於加密接口測試測試用例設計
CSP Darknet53
Download the tutorial of chrome plug-in CRX
Buuctf Web [WANGDING Cup 2018] Fakebook
Redis三种模式——主从复制,哨兵模式,集群
剑指 Offer II 011. 0 和 1 个数相同的子数组
338 leetcode isomorphic string
Conception d'un cas d'essai basé sur l'interface de chiffrement pour le forage pratique de marchandises sèches
flutter dart .. Addall
SSD_ RESNET aircraft and oil drum data set actual combat
Win10 bridging network card enables QEMU virtual machine to access the network normally
How to judge whether a binary differential equation is a total differential
一个简单易用的文件上传方案
pytorch index_ add_ Usage introduction