当前位置:网站首页>治疗TensorFlow后遗症——简单例子记录torch.utils.data.dataset.Dataset重写时的图片维度问题
治疗TensorFlow后遗症——简单例子记录torch.utils.data.dataset.Dataset重写时的图片维度问题
2022-04-23 05:44:00 【umbrellalalalala】
torch大神请忽略此文。。。
1,一个简单例子回顾DataSet
from torch.utils.data import Dataset
class dataset(Dataset):
def __init__(self):
# 需要转化为array,不然运行结果会很奇怪
self.data = np.array([[1,1,1,1],
[2,2,2,2],
[3,3,3,3],
[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# return data[idx],这句少了self,报错...
return self.data[idx]
dataset = dataset()
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=False,
num_workers=1)
for i, data_ in enumerate(dataloader):
print(i)
print(data_)
观察运行结果:
0
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
[4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
这个例子足够理解DataSet了
2,维度
参考这篇文章:
https://blog.csdn.net/xddwz/article/details/108405817
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
import cv2
from PIL import Image
class MyDataset(Dataset):
def __init__(self, transform=None):
self.transform = transforms.Compose([
transforms.ToTensor() # 这里仅以最基本的为例
])
self.image_path = './image_data2/'
self.image_names = os.listdir(self.image_path)
def __len__(self):
return len(self.image_names)
def __getitem__(self, item):
image_name = self.image_names[item]
image = cv2.imread(os.path.join(self.image_path, image_name)) # 读到的是BGR数据
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转化为RGB,也可以用img = img[:, :, (2, 1, 0)]
# 这时的image是H,W,C的顺序,因此下面需要转化为C, H, W
image = torch.from_numpy(image).permute(2, 0, 1)
# image = Image.open(os.path.join(self.image_path, image_name))
# # print(image.shape)
# image = self.transform(image)
return image
这段代码是上述链接中的,搬运过来的原因是强调一个事情:我们知道torch的维度顺序是BCHW,而上述代码中的__getitem__()
,是返回一张图片,那么这个时候我们需要注意的是,单张图片本来的维度顺序是HWC,即维度是(height, width, channel),我们需要将它的维度调整为(channel, height, width),然后再返回。
同时,根据上述代码,也需要注意,返回的image的shape不是(1, channel, height, width),而是(channel, height, width),batch对应的那一维在__getitem__()
中不需要考虑。
对于一些channel为1的图片,如果需要增加channel维度,那么只需要squeeze(0)就行了。
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://blog.csdn.net/umbrellalalalala/article/details/124330845
边栏推荐
- 2 - principes de conception de logiciels
- Pytorch学习记录(十三):循环神经网络((Recurrent Neural Network)
- refused connection
- RedHat6之smb服务访问速度慢解决办法记录
- JDBC工具类封装
- umi官网yarn create @umijs/umi-app 报错:文件名、目录名或卷标语法不正确
- io. lettuce. core. RedisCommandExecutionException: ERR wrong number of arguments for ‘auth‘ command
- rsync实现文件服务器备份
- 对象转map
- filebrowser实现私有网盘
猜你喜欢
建表到页面完整实例演示—联表查询
Pilotage growth · ingenuity empowerment -- yonmaster developer training and pilotage plan is fully launched
Dva中在effects中获取state的值
PreparedStatement防止SQL注入
关于二叉树的遍历
Pytorch学习记录(十一):数据增强、torchvision.transforms各函数讲解
Getting started with JDBC \ getting a database connection \ using Preparedstatement
Pytorch学习记录(十三):循环神经网络((Recurrent Neural Network)
第36期《AtCoder Beginner Contest 248 打比赛总结》
JVM series (3) -- memory allocation and recycling strategy
随机推荐
PyEMD安装及简单使用
框架解析1.系统架构简介
Pytorch学习记录(三):神经网络的结构+使用Sequential、Module定义模型
数字图像处理基础(冈萨雷斯)二:灰度变换与空间滤波
Dva中在effects中获取state的值
多线程与高并发(3)——synchronized原理
Pilotage growth · ingenuity empowerment -- yonmaster developer training and pilotage plan is fully launched
interviewter:介绍一下MySQL日期函数
poi生成excel,插入图片
类的加载与ClassLoader的理解
字符串(String)笔记
Typescript interface & type rough understanding
excel获取两列数据的差异数据
Latex快速入门
SQL statement simple optimization
Shansi Valley P290 polymorphism exercise
RedHat6之smb服务访问速度慢解决办法记录
Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)
自定义异常类
EditorConfig