当前位置:网站首页>Pytorch processes RNN input variable length sequence padding
Pytorch processes RNN input variable length sequence padding
2022-04-22 12:13:00 【Guanrunwei is getting fatter and fatter】
Why? RNN Need to handle variable length input
Suppose we have an example of Emotional Analysis , Make an emotional classification of each sentence , The main process is roughly as shown in the figure below :

The idea is simple , But when we do batch When the training data are calculated together , We will encounter multiple training samples with different lengths , In this way, we will proceed naturally padding, Short sentences padding To be the same as the longest sentence .
For example, as shown in the figure below :

But there will be a problem , What's the problem ? For example, above , The sentence “Yes” There is only one word , however padding 了 5 Of pad Symbol , And that leads to LSTM It is represented by a lot of useless characters , The resulting sentence representation will have errors , A more intuitive picture is as follows :

So what should we do correctly ?
This leads to pytorch in RNN You need to deal with the need for variable length input . In the above example , All we want to get is LSTM After the word "Yes" The following expression , Instead of passing through multiple useless “Pad” The resulting representation : Here's the picture :

pytorch in RNN How to deal with lengthening padding
Mainly using functions
- torch.nn.utils.rnn.pack_padded_sequence() as well as
- torch.nn.utils.rnn.pad_packed_sequence()
To carry out , Let's take a look at the usage of these two functions .
there pack, It's better to understand it as compression . Will a Filled variable length sequence Compress .( When filling , There will be redundancy , So press it down )
The input shape can be (T×B×* ).T Is the longest sequence length ,B yes batch size,* Represents any dimension ( It can be 0). If batch_first=True Words , So the corresponding input size Namely (B×T×*).
Variable Sequence saved in , It should be sorted by the length of the sequence , Long in front , Short after ( Special attention needs to be paid to sorting ). namely input[:,0] Represents the longest sequence ,input[:, B-1] Save the shortest sequence .
Parameter description :
- input (Variable) – Variable length sequences Filled batch
- lengths (list[int]) – Variable in The length of each sequence .( Knowing the length of each sequence , To know how long each sequence has been processed to stop )
- batch_first (bool, optional) – If it is True,input The shape of should be B*T*size.
Return value :
One PackedSequence object . One PackedSequence The expression is as follows :

The specific code is as follows :
embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)
encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)
here , Back to h_last and c_last It's to eliminate padding After the character hidden state and cell state, All are Variable Type of . The meaning of the representative is as follows ( The expression of each sentence ,lstm A sentence that only affects its actual length , Not through useless padding character , The following figure shows with a red tick ):

But back output yes PackedSequence Type of , have access to :
encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)
take encoderoutputs In the transformation of Variable type , Got _ Represents the length of each sentence .
3、 ... and 、 summary
To sum up ,RNN When dealing with sentence sequences that are similar to variable length , We can use it together
- torch.nn.utils.rnn.pack_padded_sequence()
- torch.nn.utils.rnn.pad_packed_sequence()
To avoid padding Influence on sentence expression
版权声明
本文为[Guanrunwei is getting fatter and fatter]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204221206458180.html
边栏推荐
- Case 4-1.7: file transfer (concurrent search)
- 量化系统开发,量化交易系统APP搭建
- The fluent application creates a list of expansion panels
- 百思买Best Buy 网站EDI 测试流程
- Walking in the clouds - all my life
- Go developer survey: 92% of developers are satisfied with go
- Oracle Data Guard和金仓KingbaseES集群的数据保护模式对比
- Flutter 应用程序创建一个扩展面板列表
- canvas系列教程01——直线、三角形、多边形、矩形、调色板
- 机器学习 训练模板,汇总多个分类器
猜你喜欢

Ali's internship offer successfully landed, which is very important

Oracle Data Guard和金仓KingbaseES集群的数据保护模式对比

Smart business card applet creates business card page function and realizes key code

【keras入门】MNIST数据集分类
![[in depth understanding of tcallusdb technology] data interface description for reading the specified location in the list - [list table]](/img/0d/1f70f1dba4e746d81eccb7ad3cea5d.png)
[in depth understanding of tcallusdb technology] data interface description for reading the specified location in the list - [list table]

"Open source summer" activity is hot. In the registration, rich bonuses are waiting for you to get!

UML总结

恭喜!您已关注公众号满1年, 诚邀您免费加入网易数据分析培训营!

How to write a valuable competitive product analysis report?

Kernel pwn 基础教程之 Heap Overflow
随机推荐
LeetCode202. Happy number
nt10.0系统(server2016/2019)RuntimeBroker异常关机,关联事件ID 142/143/226/227/228等
Cas 4 - 1.7: transfert de fichiers (et recherche d'ensembles)
Uniapp learning notes summary (I)
MySQL学习第四弹——多表查询分类以及案例练习源码详解
有研究显示,现在年轻人越来越不愿意换手机了。下一代智能手机在硬件上出现哪些更新,才会让你有换机的冲动?
How does software spread?
Low frequency (LF) RFID intelligent terminal
低频(LF)RFID 智能终端
模糊集合论
leetcode:508. Subtree elements with the most occurrences and [DFS records]
组件上用v-if好还是组件内最上级的div用v-if来控制好,优缺点是什么
Go developer survey: 92% of developers are satisfied with go
订单详情页面
ThreadLocal
电路实验——实验四 戴维南定理与诺顿定理
如何写出有价值的竞品分析报告?
开发者友好型公链Neo | 如何连接 Web2 开发者到 Web3 世界
【深入理解TcaplusDB技术】将数据插入到列表指定位置接口说明——[List表]
PyTorch处理RNN输入变长序列padding