当前位置:网站首页>PyTorch框架的 torch.cat()函数
PyTorch框架的 torch.cat()函数
2022-08-09 18:52:00 【哈哈哈哈哈嗝哈哈哈】
前言
搭建深度神经网络模型时,难免会遇到 torch.cat()函数,来进行tensor的拼接。
错误信息
代码案例如下:
import torch
import numpy as np
array1 = np.zeros((4, 1, 28, 28))
array2 = np.zeros((4, 1, 28, 28))
print("array1.shape:", array1.shape) # (4, 1, 28, 28)
tensor1 = torch.tensor(array1)
tensor2 = torch.tensor(array2)
print("tensor1.shape", tensor1.shape) # torch.Size([4, 1, 28, 28])
c = torch.cat(tensor1, tensor2, dim=0) # 报错
这样执行会报错,错误信息如下:
TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:
* (tuple of Tensors tensors, name dim, Tensor out)
didn't match because some of the keywords were incorrect: dim
* (tuple of Tensors tensors, int dim, Tensor out)
didn't match because some of the keywords were incorrect: dim
解决方法
torch.cat() 函数进行tensor的拼接,将要拼接的tensor组合成元组,即可解决该报错1。
代码修改如下:
import torch
import numpy as np
array1 = np.zeros((4, 1, 28, 28))
array2 = np.zeros((4, 1, 28, 28))
print("array1.shape:", array1.shape) # (4, 1, 28, 28)
tensor1 = torch.tensor(array1)
tensor2 = torch.tensor(array2)
print("tensor1.shape", tensor1.shape) # torch.Size([4, 1, 28, 28])
# c = torch.cat(tensor1, tensor2, dim=0) # 报错
c = torch.cat((tensor1, tensor2), dim=0)
print("c.shape", c.shape)
c1 = torch.cat((tensor1, tensor2), dim=1)
print("c1.shape", c1.shape)
此时输出为:
array1.shape: (4, 1, 28, 28)
tensor1.shape torch.Size([4, 1, 28, 28])
c.shape torch.Size([8, 1, 28, 28])
c1.shape torch.Size([4, 2, 28, 28])
从该案例中,不仅可以学到dubug信息,还可以详细了解到 torch.cat()函数的具体用法,比如参数dim的含义等等。
如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行。
参考
边栏推荐
猜你喜欢

数据分散情况的统计图-盒须图

看完这波 Android 面试题;助你斩获心中 offer

加工制造业智慧采购系统解决方案:助力企业实现全流程采购一体化协同

Openharmony Lightweight System Experiment--GPIO Lighting

Toronto Research Chemicals单羟基舒更葡糖钠说明书

《评估、创建和使用知识图谱的限制》2022最新230页博士论文,根特大学
![[免费专栏] Android安全之APK动态方式逆向应用【三种Smali注入方法】](/img/11/39a25d86c9486bb5201659bbbeaa36.png)
[免费专栏] Android安全之APK动态方式逆向应用【三种Smali注入方法】

基于CC2530 E18-MS1-PCB Zigbee DIY作品(三)

uniapp离线推送华为厂商申请流程

真香|持一建证书央企可破格录取
随机推荐
真香|持一建证书央企可破格录取
为什么maxcompute的数据导入到mysql会乱码?mysql的表是udf8mb4的编码
[免费专栏] Android安全之Android Fragment注入
Fully automated machine learning modeling!The effect hangs the primary alchemist!
shell之变量详解,让你秒懂!
[免费专栏] Android安全之数据存储与数据安全【大集合】
Office 365 Group概述以及创建方法
DP-Differential Privacy概念介绍
Bi Sheng Compiler Optimization: Lazy Code Motion
Toronto Research Chemicals盐酸乙环胺应用说明
基于设计稿识别的可视化低代码系统实践
加工制造业智慧采购系统解决方案:助力企业实现全流程采购一体化协同
competed中访问ref为undefined
AWS CodePipeLine deploys ECS across accounts
以技术创新加速国家“碳中和”建设进程,华为云创新中心助力欣冠精密实现云智控“气”
小满nestjs(第三章 前置知识装饰器)
2022深圳(软考中级)系统集成项目管理工程师报名
IDEA快捷代码实时模板
切绳子【洛谷P1577】【二分】
基于Web的疫情隔离区订餐系统