当前位置:网站首页>治疗TensorFlow后遗症——简单例子记录torch.utils.data.dataset.Dataset重写时的图片维度问题
治疗TensorFlow后遗症——简单例子记录torch.utils.data.dataset.Dataset重写时的图片维度问题
2022-04-23 05:44:00 【umbrellalalalala】
torch大神请忽略此文。。。
1,一个简单例子回顾DataSet
from torch.utils.data import Dataset
class dataset(Dataset):
def __init__(self):
# 需要转化为array,不然运行结果会很奇怪
self.data = np.array([[1,1,1,1],
[2,2,2,2],
[3,3,3,3],
[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# return data[idx],这句少了self,报错...
return self.data[idx]
dataset = dataset()
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=False,
num_workers=1)
for i, data_ in enumerate(dataloader):
print(i)
print(data_)
观察运行结果:
0
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
[4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
这个例子足够理解DataSet了
2,维度
参考这篇文章:
https://blog.csdn.net/xddwz/article/details/108405817
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
import cv2
from PIL import Image
class MyDataset(Dataset):
def __init__(self, transform=None):
self.transform = transforms.Compose([
transforms.ToTensor() # 这里仅以最基本的为例
])
self.image_path = './image_data2/'
self.image_names = os.listdir(self.image_path)
def __len__(self):
return len(self.image_names)
def __getitem__(self, item):
image_name = self.image_names[item]
image = cv2.imread(os.path.join(self.image_path, image_name)) # 读到的是BGR数据
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 转化为RGB,也可以用img = img[:, :, (2, 1, 0)]
# 这时的image是H,W,C的顺序,因此下面需要转化为C, H, W
image = torch.from_numpy(image).permute(2, 0, 1)
# image = Image.open(os.path.join(self.image_path, image_name))
# # print(image.shape)
# image = self.transform(image)
return image
这段代码是上述链接中的,搬运过来的原因是强调一个事情:我们知道torch的维度顺序是BCHW,而上述代码中的__getitem__(),是返回一张图片,那么这个时候我们需要注意的是,单张图片本来的维度顺序是HWC,即维度是(height, width, channel),我们需要将它的维度调整为(channel, height, width),然后再返回。
同时,根据上述代码,也需要注意,返回的image的shape不是(1, channel, height, width),而是(channel, height, width),batch对应的那一维在__getitem__()中不需要考虑。
对于一些channel为1的图片,如果需要增加channel维度,那么只需要squeeze(0)就行了。
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://blog.csdn.net/umbrellalalalala/article/details/124330845
边栏推荐
- Pytorch learning record (IX): convolutional neural network in pytorch
- 对象转map
- POI generates excel and inserts pictures
- Package mall system based on SSM
- MySQL triggers, stored procedures, stored functions
- XXL job pit guide XXL RPC remoting error (connect timed out)
- rsync实现文件服务器备份
- 给yarn配置国内镜像加速器
- 多线程与高并发(3)——synchronized原理
- PHP处理json_decode()解析JSON.stringify
猜你喜欢
随机推荐
Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)
SQL注入
Idea plug-in --- playing songs in the background
mysql实现主从复制/主从同步
实体中list属性为空或者null,设置为空数组
MySQL query uses \ g, column to row
Hotkeys, interface visualization configuration (interface interaction)
opensips(1)——安装opensips详细流程
Pytorch学习记录(十三):循环神经网络((Recurrent Neural Network)
容器
MySQL triggers, stored procedures, stored functions
JDBC连接数据库
JDBC工具类封装
Excel obtains the difference data of two columns of data
refused connection
ES6之解构函数
Record a project experience and technologies encountered in the project
SQL statement simple optimization
Total score of [Huawei machine test] (how to deal with the wrong answer? Go back once to represent one wrong answer)
框架解析1.系统架构简介









