当前位置:网站首页>The torch. The stack () official explanation, explanation and example

The torch. The stack () official explanation, explanation and example

2022-08-09 10:46:00 Fuzzy Pack

可以直接看最下面的【3.例子】,再回头看前面的解释

pytorch中,常见的拼接函数主要是两个,分别是:

  1. stack()
  2. cat()

实际使用中,These two functions complement each other,使用场景不同:关于cat()参考torch.cat(),但是本文主要说stack().

函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数.

形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度.

This function often appears in natural language processing(NLP)和图像卷积神经网络(CV)中.

1. stack()

官方解释:沿着一个新维度对输入张量序列进行连接. 序列中所有的张量都应该为相同形状.

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠.

outputs = torch.stack(inputs, dim=?) → Tensor

参数

  • inputs : 待连接的张量序列.
    注:python的序列数据只有listtuple.

  • dim : 新的维度, 必须在0len(outputs)之间.
    注:len(outputs)是生成数据的维度大小,也就是outputs的维度值.

2. 重点

  1. 函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等

----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

  1. dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小

See the example if you don't understand,Looking back, you will understand.

3. 例子

1.准备2个tensor数据,每个的shape都是[3,3]

# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])

2.测试stack函数

print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'选择的dim>len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

You can copy the code and try it out:拼接后的tensor形状,会根据不同的dim发生变化.

dimshape
0[2, 3, 3]
1[3, 2, 3]
2[3, 3, 2]
3溢出报错

4. 总结

  1. 函数作用:
    函数stack()序列数据Internal tensors are performed扩维拼接,The specified dimension is chosen by the programmer、Size is the dimension interval of the generated data.

  2. 存在意义:
    In Natural Language Processing and Convolutional and Neural Networks, Usually for reservations–[序列(先后)信息] 和 [Matrix information for the tensor] 才会使用stack.

函数存在意义?》》》

手写过RNN的同学,Know that the output data in a recurrent neural network is:一个list,The list is insertedseq_len个形状是[batch_size, output_size]tensor,不利于计算,需要使用stack进行拼接,保留–[1.seq_lenthis time step]和–[2.张量属性[batch_size, output_size]].

原网站

版权声明
本文为[Fuzzy Pack]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/221/202208091041530035.html