当前位置:网站首页>PyTorch 21. NN in pytorch Embedding module
PyTorch 21. NN in pytorch Embedding module
2022-04-23 13:16:00 【DCGJ666】
PyTorch 21. PyTorch in nn.Embedding modular
torch.nn.Embedding
function :
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
Parameter interpretation :
- num_embeddings: The size of the query table
- embedding_dim: The dimension of each query vector
Approximate explanation of function : A simple lookup table with a fixed dictionary and size . This module is often used to save word embedding and retrieve them with subscripts . The input to the module is a list of subscripts , The output is the corresponding word embedded . It's equivalent to randomly generating a tensor, Think of it as a query table , Its size by [num_embeddings, embedding_dim]. among num_embeddings Is the size of the query table ,embedding_dim Is the dimension of each query vector .
Popular speaking , This is a matrix class , It initializes a random matrix , The size of the long-term Dictionary of the matrix , Width is used to represent the attribute vector of each element in the dictionary , The dimension of the vector depends on the complexity of the element you want to represent . After class instantiation, you can find the vector corresponding to the element according to the subscript of the element in the dictionary . Input subscript 0, Output is embeds In matrix 0 That's ok .
Specific examples :
- Create a query matrix and use it to do Embedding:
embedding = nn.Embedding(5,3) # Define a with 5 Word , Dimension for 3 Query matrix
print(embedding.weight) # Show the details of the matrix
test = torch.LongTensor([[0, 2, 0, 1],
[1, 3, 4, 4]]) # The test The matrix is used to be embed, Its size by [2,4]
# The first act [0, 2, 0, 1], Indicates to get the information in the query matrix ID by 0, 2, 0, 1 The query vector of
# It can be later test Output medium and embed Compare the output of
test = embedding(test)
print(test.size()) # Output embed after test Of size, by [2, 4, 3], To increase the 3, Because the dimension of the query is 3
print(test) # Output embed After test The content of
Output :
Parameter containing:
tensor([[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735],
[-1.3317, 0.8350, -1.7235],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459]], requires_grad=True)
torch.Size([2, 4, 3])
tensor([[[-1.8056, 0.1836, -1.4376],
[-1.3317, 0.8350, -1.7235],
[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735]],
[[ 0.8409, 0.1034, -1.3735],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459],
[ 0.4148, -0.0923, 0.2459]]], grad_fn=<EmbeddingBackward>)
It can be seen that ,embedding It is equivalent to using the index of a matrix to query the characteristics of the elements in the original embedded matrix .
- Find a specific in the query matrix ID( word ) The query vector of ( The word vector ):
# Visit a ID, That is to say N A query vector of words
print(embedding(torch.LongTensor([3]))) # This means that the query No 3 The word vector of a word
Output :
tensor([[-1.6016, -0.8350, -0.7878]], grad_fn=<EmbeddingBackward>)
- Output hello The word word embedding
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
word_to_ix = {
'hello':0, 'world':1}
embeds = nn.Embedding(2,5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)
Output results :
Variable containing:
0.4606 0.6847 -1.9592 0.9434 0.2316
[torch.FloatTensor of size 1x5]
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611343468.html
边栏推荐
- Filter and listener of three web components
- Translation of multi modal visual tracking: review and empirical comparison
- According to the salary statistics of programmers in June 2021, the average salary is 15052 yuan. Are you holding back?
- (personal) sorting out system vulnerabilities after recent project development
- 【官宣】长沙软件人才实训基地成立!
- The first lesson is canvas, showing a small case
- web三大组件之Filter、Listener
- melt reshape decast 长数据短数据 长短转化 数据清洗 行列转化
- 9419 page analysis of the latest first-line Internet Android interview questions
- "Play with Lighthouse" lightweight application server self built DNS resolution server
猜你喜欢

According to the salary statistics of programmers in June 2021, the average salary is 15052 yuan. Are you holding back?

【动态规划】221. 最大正方形

AUTOSAR from introduction to mastery 100 lectures (52) - diagnosis and communication management function unit

AUTOSAR from introduction to mastery 100 lectures (51) - AUTOSAR network management

Design and manufacture of 51 single chip microcomputer solar charging treasure with low voltage alarm (complete code data)

R语言中dcast 和 melt的使用 简单易懂

MySQL5. 5 installation tutorial

The filter() traverses the array, which is extremely friendly

mui + hbuilder + h5api模拟弹出支付样式

"Xiangjian" Technology Salon | programmer & CSDN's advanced road
随机推荐
mui 微信支付 排坑
FFmpeg常用命令
Nodejs + Mysql realize simple registration function (small demo)
[dynamic programming] 221 Largest Square
AUTOSAR from introduction to mastery 100 lectures (87) - key weapon of advanced EEA - AUTOSAR and DDS
SSM整合之pom.xml
8086 of x86 architecture
Office 2021 installation package download and activation tutorial
初鉴canvas,展示个小小的小案例
2020最新Android大厂高频面试题解析大全(BAT TMD JD 小米)
three. JS text ambiguity problem
mui 关闭其他页面,只保留首页面
榜样专访 | 孙光浩:高校俱乐部伴我成长并创业
Wu Enda's programming assignment - logistic regression with a neural network mindset
JMeter operation redis
Complete project data of UAV apriltag dynamic tracking landing based on openmv (LabVIEW + openmv + apriltag + punctual atom four axes)
decast id.var measure. Var data splitting and merging
SQL exercise question 1
nodejs + mysql 实现简单注册功能(小demo)
4.22 study record (you only did water problems in one day, didn't you)