当前位置:网站首页>PyTorch 21. PyTorch中nn.Embedding模块
PyTorch 21. PyTorch中nn.Embedding模块
2022-04-23 06:11:00 【DCGJ666】
PyTorch 21. PyTorch中nn.Embedding模块
torch.nn.Embedding
函数:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)
参数解释:
- num_embeddings: 查询表的大小
- embedding_dim: 每个查询向量的维度
函数大概解释:一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索他们。模块的输入是一个下标的列表,输出是对应的词嵌入。相当于随机生成了一个tensor,可以把它看作一个查询表,其size为[num_embeddings, embedding_dim]。其中num_embeddings是查询表的大小,embedding_dim是每个查询向量的维度。
通俗来讲,这是一个矩阵类,里面初始化了一个随机矩阵,矩阵的长时字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。输入下标0,输出就是embeds矩阵中第0行。
具体实例:
- 创建查询矩阵并使用它做Embedding:
embedding = nn.Embedding(5,3) #定义一个具有5个单词,维度为3的查询矩阵
print(embedding.weight) #展示该矩阵的具体内容
test = torch.LongTensor([[0, 2, 0, 1],
[1, 3, 4, 4]]) #该test矩阵用于被embed,其size为[2,4]
# 其中的第一行为[0, 2, 0, 1], 表示获取查询矩阵中ID为0, 2, 0, 1的查询向量
# 可以在之后的test输出中与embed的输出进行比较
test = embedding(test)
print(test.size()) #输出embed后test的size, 为[2, 4, 3],增加的3,是因为查询向量的维度为3
print(test) # 输出embed后的test的内容
输出:
Parameter containing:
tensor([[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735],
[-1.3317, 0.8350, -1.7235],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459]], requires_grad=True)
torch.Size([2, 4, 3])
tensor([[[-1.8056, 0.1836, -1.4376],
[-1.3317, 0.8350, -1.7235],
[-1.8056, 0.1836, -1.4376],
[ 0.8409, 0.1034, -1.3735]],
[[ 0.8409, 0.1034, -1.3735],
[ 1.5251, -0.2363, -0.6729],
[ 0.4148, -0.0923, 0.2459],
[ 0.4148, -0.0923, 0.2459]]], grad_fn=<EmbeddingBackward>)
可以看出,embedding相当于用一个矩阵的索引去查询原始嵌入矩阵中元素的特征。
- 寻找查询矩阵中特定ID(词)的查询向量(词向量):
# 访问某个ID,即第N个词的查询向量
print(embedding(torch.LongTensor([3]))) # 这里表示查询第3个词的词向量
输出:
tensor([[-1.6016, -0.8350, -0.7878]], grad_fn=<EmbeddingBackward>)
- 输出的hello这个词的word embedding
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
word_to_ix = {
'hello':0, 'world':1}
embeds = nn.Embedding(2,5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)
输出结果:
Variable containing:
0.4606 0.6847 -1.9592 0.9434 0.2316
[torch.FloatTensor of size 1x5]
版权声明
本文为[DCGJ666]所创,转载请带上原文链接,感谢
https://blog.csdn.net/DCGJ666/article/details/122139828
边栏推荐
- Migrating your native/mobile application to Unified Plan/WebRTC 1.0 API
- PyTorch 10. 学习率
- [2021 book recommendation] learn winui 3.0
- [recommendation for new books in 2021] professional azure SQL managed database administration
- 【点云系列】SG-GAN: Adversarial Self-Attention GCN for Point Cloud Topological Parts Generation
- MySQL的安装与配置——详细教程
- 面试总结之特征工程
- What did you do during the internship
- [dynamic programming] longest increasing subsequence
- [recommendation of new books in 2021] enterprise application development with C 9 and NET 5
猜你喜欢
Gephi tutorial [1] installation
Visual Studio 2019安装与使用
第4章 Pytorch数据处理工具箱
C language, a number guessing game
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
【点云系列】Multi-view Neural Human Rendering (NHR)
Chapter 8 generative deep learning
Chapter 1 numpy Foundation
【2021年新书推荐】Red Hat RHCSA 8 Cert Guide: EX200
【点云系列】Neural Opacity Point Cloud(NOPC)
随机推荐
PyTorch最佳实践和代码编写风格指南
cmder中文乱码问题
Handlerthread principle and practical application
【点云系列】FoldingNet:Point Cloud Auto encoder via Deep Grid Deformation
机器学习 三: 基于逻辑回归的分类预测
Google AdMob advertising learning
GEE配置本地开发环境
常见的正则表达式
Android exposed components - ignored component security
MySQL5. 7 insert Chinese data and report an error: ` incorrect string value: '\ xb8 \ XDF \ AE \ xf9 \ X80 at row 1`
.net加载字体时遇到 Failed to decode downloaded font:
Pymysql connection database
Mysql database installation and configuration details
Component learning (2) arouter principle learning
MySQL数据库安装与配置详解
[recommendation of new books in 2021] enterprise application development with C 9 and NET 5
Machine learning III: classification prediction based on logistic regression
Pytorch模型保存与加载(示例)
Pytorch best practices and coding style guide
1.1 PyTorch和神经网络