当前位置:网站首页>PyTorch 21. NN in pytorch Embedding module
PyTorch 21. NN in pytorch Embedding module
2022-04-23 13:16:00 【DCGJ666】
PyTorch 21. PyTorch in nn.Embedding modular
torch.nn.Embedding
function :
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
Parameter interpretation :
- num_embeddings: The size of the query table
- embedding_dim: The dimension of each query vector
Approximate explanation of function : A simple lookup table with a fixed dictionary and size . This module is often used to save word embedding and retrieve them with subscripts . The input to the module is a list of subscripts , The output is the corresponding word embedded . It's equivalent to randomly generating a tensor, Think of it as a query table , Its size by [num_embeddings, embedding_dim]. among num_embeddings Is the size of the query table ,embedding_dim Is the dimension of each query vector .
Popular speaking , This is a matrix class , It initializes a random matrix , The size of the long-term Dictionary of the matrix , Width is used to represent the attribute vector of each element in the dictionary , The dimension of the vector depends on the complexity of the element you want to represent . After class instantiation, you can find the vector corresponding to the element according to the subscript of the element in the dictionary . Input subscript 0, Output is embeds In matrix 0 That's ok .
Specific examples :
- Create a query matrix and use it to do Embedding:
embedding = nn.Embedding(5,3) # Define a with 5 Word , Dimension for 3 Query matrix
print(embedding.weight) # Show the details of the matrix
test = torch.LongTensor([[0, 2, 0, 1],
[1, 3, 4, 4]]) # The test The matrix is used to be embed, Its size by [2,4]
# The first act [0, 2, 0, 1], Indicates to get the information in the query matrix ID by 0, 2, 0, 1 The query vector of
# It can be later test Output medium and embed Compare the output of
test = embedding(test)
print(test.size()) # Output embed after test Of size, by [2, 4, 3], To increase the 3, Because the dimension of the query is 3
print(test) # Output embed After test The content of
Output :
Parameter containing:
tensor([[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735],
[-1.3317, 0.8350, -1.7235],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459]], requires_grad=True)
torch.Size([2, 4, 3])
tensor([[[-1.8056, 0.1836, -1.4376],
[-1.3317, 0.8350, -1.7235],
[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735]],
[[ 0.8409, 0.1034, -1.3735],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459],
[ 0.4148, -0.0923, 0.2459]]], grad_fn=<EmbeddingBackward>)
It can be seen that ,embedding It is equivalent to using the index of a matrix to query the characteristics of the elements in the original embedded matrix .
- Find a specific in the query matrix ID( word ) The query vector of ( The word vector ):
# Visit a ID, That is to say N A query vector of words
print(embedding(torch.LongTensor([3]))) # This means that the query No 3 The word vector of a word
Output :
tensor([[-1.6016, -0.8350, -0.7878]], grad_fn=<EmbeddingBackward>)
- Output hello The word word embedding
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
word_to_ix = {
'hello':0, 'world':1}
embeds = nn.Embedding(2,5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)
Output results :
Variable containing:
0.4606 0.6847 -1.9592 0.9434 0.2316
[torch.FloatTensor of size 1x5]
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230611343468.html
边栏推荐
- AUTOSAR from introduction to mastery 100 lectures (86) - 2F of UDS service foundation
- Learning notes of AMBA protocol
- Scons build embedded ARM compiler
- decast id.var measure. Var data splitting and merging
- The difference between string and character array in C language
- MySQL —— 16、索引的数据结构
- Loading and using image classification dataset fashion MNIST in pytorch
- Important knowledge of transport layer (interview, retest, final)
- async void 导致程序崩溃
- 8086 of x86 architecture
猜你喜欢
你和42W奖金池,就差一次“长沙银行杯”腾讯云启创新大赛!
Common interview questions and detailed analysis of the latest Android developers in 2020
Hbuilderx + uniapp packaging IPA submission app store stepping on the pit
melt reshape decast 长数据短数据 长短转化 数据清洗 行列转化
缘结西安 | CSDN与西安思源学院签约,全面开启IT人才培养新篇章
Super 40W bonus pool waiting for you to fight! The second "Changsha bank Cup" Tencent yunqi innovation competition is hot!
How do ordinary college students get offers from big factories? Ao Bing teaches you one move to win!
十万大学生都已成为猿粉,你还在等什么?
[51 single chip microcomputer traffic light simulation]
Learning notes of AMBA protocol
随机推荐
The first lesson is canvas, showing a small case
这几种 VSCode 扩展是我最喜欢的
MySQL —— 16、索引的数据结构
Super 40W bonus pool waiting for you to fight! The second "Changsha bank Cup" Tencent yunqi innovation competition is hot!
The project file '' has been renamed or is no longer in the solution, and the source control provider associated with the solution could not be found - two engineering problems
playwright控制本地谷歌浏览打开,并下载文件
Analysis of the latest Android high frequency interview questions in 2020 (BAT TMD JD Xiaomi)
The use of dcast and melt in R language is simple and easy to understand
缘结西安 | CSDN与西安思源学院签约,全面开启IT人才培养新篇章
Proteus 8.10 installation problem (personal test is stable and does not flash back!)
Learning notes of AMBA protocol
Three channel ultrasonic ranging system based on 51 single chip microcomputer (timer ranging)
基于uniapp异步封装接口请求简介
Uninstall MySQL database
普通大学生如何拿到大厂offer?敖丙教你一招致胜!
2020年最新字节跳动Android开发者常见面试题及详细解析
filter()遍历Array异常友好
AUTOSAR from introduction to mastery 100 lectures (81) - FIM of AUTOSAR Foundation
Is Hongmeng system plagiarism? Or the future? Professional explanation that can be understood after listening in 3 minutes
7_Addmodule和基因加和法add 得到的细胞类型打分在空间上空转对比