当前位置:网站首页>PyTorch 21. PyTorch中nn.Embedding模块
PyTorch 21. PyTorch中nn.Embedding模块
2022-04-23 06:11:00 【DCGJ666】
PyTorch 21. PyTorch中nn.Embedding模块
torch.nn.Embedding
函数:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
参数解释:
- num_embeddings: 查询表的大小
- embedding_dim: 每个查询向量的维度
函数大概解释:一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索他们。模块的输入是一个下标的列表,输出是对应的词嵌入。相当于随机生成了一个tensor,可以把它看作一个查询表,其size为[num_embeddings, embedding_dim]。其中num_embeddings是查询表的大小,embedding_dim是每个查询向量的维度。
通俗来讲,这是一个矩阵类,里面初始化了一个随机矩阵,矩阵的长时字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。输入下标0,输出就是embeds矩阵中第0行。
具体实例:
- 创建查询矩阵并使用它做Embedding:
embedding = nn.Embedding(5,3) #定义一个具有5个单词,维度为3的查询矩阵
print(embedding.weight) #展示该矩阵的具体内容
test = torch.LongTensor([[0, 2, 0, 1],
[1, 3, 4, 4]]) #该test矩阵用于被embed,其size为[2,4]
# 其中的第一行为[0, 2, 0, 1], 表示获取查询矩阵中ID为0, 2, 0, 1的查询向量
# 可以在之后的test输出中与embed的输出进行比较
test = embedding(test)
print(test.size()) #输出embed后test的size, 为[2, 4, 3],增加的3,是因为查询向量的维度为3
print(test) # 输出embed后的test的内容
输出:
Parameter containing:
tensor([[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735],
[-1.3317, 0.8350, -1.7235],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459]], requires_grad=True)
torch.Size([2, 4, 3])
tensor([[[-1.8056, 0.1836, -1.4376],
[-1.3317, 0.8350, -1.7235],
[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735]],
[[ 0.8409, 0.1034, -1.3735],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459],
[ 0.4148, -0.0923, 0.2459]]], grad_fn=<EmbeddingBackward>)
可以看出,embedding相当于用一个矩阵的索引去查询原始嵌入矩阵中元素的特征。
- 寻找查询矩阵中特定ID(词)的查询向量(词向量):
# 访问某个ID,即第N个词的查询向量
print(embedding(torch.LongTensor([3]))) # 这里表示查询第3个词的词向量
输出:
tensor([[-1.6016, -0.8350, -0.7878]], grad_fn=<EmbeddingBackward>)
- 输出的hello这个词的word embedding
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
word_to_ix = {
'hello':0, 'world':1}
embeds = nn.Embedding(2,5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)
输出结果:
Variable containing:
0.4606 0.6847 -1.9592 0.9434 0.2316
[torch.FloatTensor of size 1x5]
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/122139828
边栏推荐
- [2021 book recommendation] practical node red programming
- cmder中文乱码问题
- [3D shape reconstruction series] implicit functions in feature space for 3D shape reconstruction and completion
- [recommendation of new books in 2021] enterprise application development with C 9 and NET 5
- 红外传感器控制开关
- 【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
- [point cloud series] a rotation invariant framework for deep point cloud analysis
- 素数求解的n种境界
- PyTorch 模型剪枝实例教程三、多参数与全局剪枝
- What did you do during the internship
猜你喜欢
C language, a number guessing game
[point cloud series] pnp-3d: a plug and play for 3D point clouds
ArcGIS license server administrator cannot start the workaround
[2021 book recommendation] artistic intelligence for IOT Cookbook
机器学习——PCA与LDA
第2章 Pytorch基础2
【2021年新书推荐】Practical IoT Hacking
Use originpro express for free
What did you do during the internship
1.1 PyTorch和神经网络
随机推荐
PyTorch最佳实践和代码编写风格指南
三子棋小游戏
Pytorch best practices and coding style guide
红外传感器控制开关
How to standardize multidimensional matrix (based on numpy)
Android exposed components - ignored component security
Component based learning (3) path and group annotations in arouter
PyTorch 12. hook的用法
Some common data type conversion methods in pytorch are similar to list and NP Conversion method of ndarray
1.2 preliminary pytorch neural network
Migrating your native/mobile application to Unified Plan/WebRTC 1.0 API
[dynamic programming] triangle minimum path sum
【点云系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
PyTorch 模型剪枝实例教程三、多参数与全局剪枝
常见的正则表达式
Chapter 5 fundamentals of machine learning
利用官方torch版GCN训练并测试cora数据集
机器学习——朴素贝叶斯
Miscellaneous learning
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation