当前位置:网站首页>PyTorch框架的 torch.cat()函数

PyTorch框架的 torch.cat()函数

2022-08-09 18:52:00 哈哈哈哈哈嗝哈哈哈

前言

搭建深度神经网络模型时,难免会遇到 torch.cat()函数,来进行tensor的拼接。

错误信息

代码案例如下:

import torch
import numpy as np

array1 = np.zeros((4, 1, 28, 28))
array2 = np.zeros((4, 1, 28, 28))
print("array1.shape:", array1.shape)      # (4, 1, 28, 28)

tensor1 = torch.tensor(array1)
tensor2 = torch.tensor(array2)
print("tensor1.shape", tensor1.shape)    # torch.Size([4, 1, 28, 28])

c = torch.cat(tensor1, tensor2, dim=0)    # 报错

这样执行会报错,错误信息如下:

TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:
 * (tuple of Tensors tensors, name dim, Tensor out)
      didn't match because some of the keywords were incorrect: dim
 * (tuple of Tensors tensors, int dim, Tensor out)
      didn't match because some of the keywords were incorrect: dim

解决方法

torch.cat() 函数进行tensor的拼接,将要拼接的tensor组合成元组,即可解决该报错1

代码修改如下:

import torch
import numpy as np

array1 = np.zeros((4, 1, 28, 28))
array2 = np.zeros((4, 1, 28, 28))
print("array1.shape:", array1.shape)      # (4, 1, 28, 28)

tensor1 = torch.tensor(array1)
tensor2 = torch.tensor(array2)
print("tensor1.shape", tensor1.shape)    # torch.Size([4, 1, 28, 28])

# c = torch.cat(tensor1, tensor2, dim=0) # 报错
c = torch.cat((tensor1, tensor2), dim=0)
print("c.shape", c.shape)
c1 = torch.cat((tensor1, tensor2), dim=1)
print("c1.shape", c1.shape)

此时输出为:

array1.shape: (4, 1, 28, 28)
tensor1.shape torch.Size([4, 1, 28, 28])
c.shape torch.Size([8, 1, 28, 28])
c1.shape torch.Size([4, 2, 28, 28])

从该案例中,不仅可以学到dubug信息,还可以详细了解到 torch.cat()函数的具体用法,比如参数dim的含义等等。


如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行


参考


  1. 使用torch.cat()和torch.stack()时出现的小错误

原网站

版权声明
本文为[哈哈哈哈哈嗝哈哈哈]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_43051346/article/details/126124060