当前位置:网站首页>torch.gather() 用法解读
torch.gather() 用法解读
2022-08-08 05:41:00 【00000cj】
torch.gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor
沿\(dim\)指定的轴和\(index\)指定的索引从\(input\)中提取对应的值。
对于一个三维张量
\(output[i][j][k]=input[index[i][j][k]][j][k] \quad \#\enspace if \enspace dim==0\)
\(output[i][j][k]=input[i][index[i][j][k]][k] \quad \#\enspace if \enspace dim==1\)
\(output[i][j][k]=input[i][j][index[i][j][k]] \quad \#\enspace if \enspace dim==2\)
\(input\)和\(index\)的\(dimensions\)数目必须相同。 \(out\)和\(index\)的\(shape\)是相同的。(注意\(dimensions\)和\(shape\)的区别)
示例
下面用两个例子来解释一下具体的用法
例1
import torch
dim = 0
_input = torch.tensor([[10, 11, 12],
[13, 14, 15],
[16, 17, 18]])
index = torch.tensor([[0, 1, 2],
[1, 2, 0]])
output = torch.gather(_input, dim, index)
print(output)
# tensor([[10, 14, 18],
# [13, 17, 12]])
该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(2, 3)。
因为dim=0,index中的每个数其值代表dim=0即"行"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1, 2]分别表示第0、1、2行,而这三个数本身在dim=1维度的索引为0、1、2即第0、1、2列。因此第一个数0定位到_input中的第0行,而0本身在index中的第0列,因此又定位到_input的第0列,这样就找到了10这个数,同理找到14和18。
index中的第1行[1, 2, 0]分别表示_input中的第1、2、0行和第0、1、2列,因此找到_input中对应的数[13, 17, 12]。
例2
import torch
dim = 1
_input = torch.tensor([[10, 11, 12],
[13, 14, 15],
[16, 17, 18]])
index = torch.tensor([[0, 1],
[1, 2],
[2, 0]])
output = torch.gather(_input, dim, index)
print(output)
# tensor([[10, 11],
# [14, 15],
# [18, 16]])
该例中 _input.shape=(3, 3),dimensions=2,其中_input和index的dimensions相同都为2,output和index的shape相同都为(3, 2)。
因为dim=1,index中的每个数其值代表dim=1即"列"这个维度的索引,而每个数本身所在位置的索引指定了其它维度的索引。比如index中第0行的[0, 1]分别表示第0、1列,而这三个数本身在dim=0维度的索引为0即第0行。因此第一个数0定位到_input中的第0列,而0本身在index中的第0行,因此又定位到_input的第0行,这样就找到了10这个数,同理找到11。
index中的第1行[1, 2]分别表示_input中的第1、2列和第1行,因此找到_input中对应的数[14, 15]。
index中的第2行[2, 0]分别表示_input中的第2、0列和第2行,因此找到_input中对应的数[18, 16]。
总结
上面的示例是二维的情况,同理也可以推广到三维甚至更多维。总结来说,index中每个数其本身的值表示参数dim指定维度的索引,而其它的每个维度都由每个数在index中的对应维度的索引指定。
参考
torch.gather — PyTorch 1.12 documentation
python - What does the gather function do in pytorch in layman terms? - Stack Overflow
边栏推荐
猜你喜欢

28. Anomaly detection

How to batch import files and rename them all to the same file name

TCP/IP基本实现

Eighteen, OIDC OAuth2 】 【 the understanding of the application

Servlet---ServletConfig类使用介绍

automation tool

stack-queue

文件操作 - IO

apifox使用文档之环境变量 / 全局变量 / 临时变量附apifox学习路线图

TSF Microservice Governance Combat Series (2) - Service Routing
随机推荐
VSCode已经设置过为中文但变成英文的解决办法
什么是 DevOps?看这一篇就够了!
Rust开发——Struct使用示例
Personal Summary of OLTP and OLAP Issues
Use of Filter
The difference between classification, object detection, semantic segmentation, and instance segmentation
cs软件ui构建办法
C语言力扣第58题之最后一个单词的长度。从后往前遍历
[Redis] Redis Learning - Transaction
Several postman features worth collecting will help you do more with less!
The big and small end problem caused by union union
121道分布式面试题和答案
Week 9 10 Neural Networks
Day7:面试必考选择题
Sequence table (below)
神经网络解决哪些问题,神经网络结果不稳定
【js基础】闭包的几种情况(代码)
leetcode-isomorphic string judgment
IP核之RAM实验
关于如何做选择