当前位置:网站首页>PyTorch 21. NN in pytorch Embedding module
PyTorch 21. NN in pytorch Embedding module
2022-04-23 13:16:00 【DCGJ666】
PyTorch 21. PyTorch in nn.Embedding modular
torch.nn.Embedding
function :
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
Parameter interpretation :
- num_embeddings: The size of the query table
- embedding_dim: The dimension of each query vector
Approximate explanation of function : A simple lookup table with a fixed dictionary and size . This module is often used to save word embedding and retrieve them with subscripts . The input to the module is a list of subscripts , The output is the corresponding word embedded . It's equivalent to randomly generating a tensor, Think of it as a query table , Its size by [num_embeddings, embedding_dim]. among num_embeddings Is the size of the query table ,embedding_dim Is the dimension of each query vector .
Popular speaking , This is a matrix class , It initializes a random matrix , The size of the long-term Dictionary of the matrix , Width is used to represent the attribute vector of each element in the dictionary , The dimension of the vector depends on the complexity of the element you want to represent . After class instantiation, you can find the vector corresponding to the element according to the subscript of the element in the dictionary . Input subscript 0, Output is embeds In matrix 0 That's ok .
Specific examples :
- Create a query matrix and use it to do Embedding:
embedding = nn.Embedding(5,3) # Define a with 5 Word , Dimension for 3 Query matrix
print(embedding.weight) # Show the details of the matrix
test = torch.LongTensor([[0, 2, 0, 1],
[1, 3, 4, 4]]) # The test The matrix is used to be embed, Its size by [2,4]
# The first act [0, 2, 0, 1], Indicates to get the information in the query matrix ID by 0, 2, 0, 1 The query vector of
# It can be later test Output medium and embed Compare the output of
test = embedding(test)
print(test.size()) # Output embed after test Of size, by [2, 4, 3], To increase the 3, Because the dimension of the query is 3
print(test) # Output embed After test The content of
Output :
Parameter containing:
tensor([[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735],
[-1.3317, 0.8350, -1.7235],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459]], requires_grad=True)
torch.Size([2, 4, 3])
tensor([[[-1.8056, 0.1836, -1.4376],
[-1.3317, 0.8350, -1.7235],
[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735]],
[[ 0.8409, 0.1034, -1.3735],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459],
[ 0.4148, -0.0923, 0.2459]]], grad_fn=<EmbeddingBackward>)
It can be seen that ,embedding It is equivalent to using the index of a matrix to query the characteristics of the elements in the original embedded matrix .
- Find a specific in the query matrix ID( word ) The query vector of ( The word vector ):
# Visit a ID, That is to say N A query vector of words
print(embedding(torch.LongTensor([3]))) # This means that the query No 3 The word vector of a word
Output :
tensor([[-1.6016, -0.8350, -0.7878]], grad_fn=<EmbeddingBackward>)
- Output hello The word word embedding
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
word_to_ix = {
'hello':0, 'world':1}
embeds = nn.Embedding(2,5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)
Output results :
Variable containing:
0.4606 0.6847 -1.9592 0.9434 0.2316
[torch.FloatTensor of size 1x5]
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611343468.html
边栏推荐
- [walking notes]
- Is Hongmeng system plagiarism? Or the future? Professional explanation that can be understood after listening in 3 minutes
- async void 導致程序崩潰
- EMMC / SD learning notes
- 100 GIS practical application cases (52) - how to keep the number of rows and columns consistent and aligned when cutting grids with grids in ArcGIS?
- Solve the problem of Oracle Chinese garbled code
- 普通大学生如何拿到大厂offer?敖丙教你一招致胜!
- mui + hbuilder + h5api模拟弹出支付样式
- MySQL5.5安装教程
- ECDSA signature verification principle and C language implementation
猜你喜欢
MySQL -- 16. Data structure of index
Solve the problem that Oracle needs to set IP every time in the virtual machine
The difference between string and character array in C language
X509 parsing
9419页最新一线互联网Android面试题解析大全
8086 of x86 architecture
叮~ 你的奖学金已到账!C认证企业奖学金名单出炉
2020年最新字节跳动Android开发者常见面试题及详细解析
在 pytorch 中加载和使用图像分类数据集 Fashion-MNIST
Important knowledge of network layer (interview, reexamination, term end)
随机推荐
Nodejs + Mysql realize simple registration function (small demo)
【动态规划】221. 最大正方形
Async void caused the program to crash
ESP32 VHCI架构传统蓝牙设置scan mode,让设备能被搜索到
Uninstall MySQL database
web三大组件之Servlet
Nodejs + websocket cycle small case
uniapp image 引入本地图片不显示
超40W奖金池等你来战!第二届“长沙银行杯”腾讯云启创新大赛火热来袭!
@优秀的你!CSDN高校俱乐部主席招募!
LeetCode_DFS_中等_695.岛屿的最大面积
榜样专访 | 孙光浩:高校俱乐部伴我成长并创业
How to build a line of code with M4 qprotex
mui 关闭其他页面,只保留首页面
AUTOSAR from introduction to mastery lecture 100 (84) - Summary of UDS time parameters
Design and manufacture of 51 single chip microcomputer solar charging treasure with low voltage alarm (complete code data)
【行走的笔记】
Melt reshape decast long data short data length conversion data cleaning row column conversion
Install nngraph
Data warehouse - what is OLAP