当前位置:网站首页>Official explanation, detailed explanation and example of torch.cat() function

Official explanation, detailed explanation and example of torch.cat() function

2022-08-09 10:45:00 Fuzzy Pack

You can directly see the example below,Look back at the previous explanation,就很明白了.

pytorch中,常见的拼接函数主要是两个,分别是:

  1. stack()
  2. cat()

一般torch.cat()是为了把多个tensorexist for splicing.实际使用中,和torch.stack()使用场景不同:参考链接torch.stack(),But this article mainly sayscat().

torch.cat()python中的内置函数cat(), in use and purpose,是没有区别的,The difference is that the former operation object istensor.

1. cat()

函数目的: 在给定维度上对输入的张量序列seq 进行连接操作.

outputs = torch.cat(inputs, dim=?) → Tensor

参数

  • inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
  • dim : 选择的扩维, 必须在0len(inputs[0])之间,沿着此维连接张量序列.

2. 重点

  1. 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
  2. 维度不可以超过输入数据的任一个张量的维度

3.举例子

  1. 准备数据,每个的shape都是[2,3]
# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])
  1. 合成inputs
'inputs为2个形状为[2 , 3]的矩阵 '
inputs = [x1, x2]
print(inputs)
'打印查看'
[tensor([[11, 21, 31],
         [21, 31, 41]], dtype=torch.int32),
 tensor([[12, 22, 32],
         [22, 32, 42]], dtype=torch.int32)]

3.查看结果, 测试不同的dim拼接结果

In    [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4,  3])

In    [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])

In    [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

You can copy the code and run it and you will find the rules.

总结

通常用来,把torch.stack得到tensorexist for splicing.

原网站

版权声明
本文为[Fuzzy Pack]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/221/202208091041529964.html