当前位置:网站首页>治療TensorFlow後遺症——簡單例子記錄torch.utils.data.dataset.Dataset重寫時的圖片維度問題
治療TensorFlow後遺症——簡單例子記錄torch.utils.data.dataset.Dataset重寫時的圖片維度問題
2022-04-23 05:53:00 【umbrellalalalala】
torch大神請忽略此文。。。
1,一個簡單例子回顧DataSet
from torch.utils.data import Dataset
class dataset(Dataset):
def __init__(self):
# 需要轉化為array,不然運行結果會很奇怪
self.data = np.array([[1,1,1,1],
[2,2,2,2],
[3,3,3,3],
[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# return data[idx],這句少了self,報錯...
return self.data[idx]
dataset = dataset()
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=2,
shuffle=False,
num_workers=1)
for i, data_ in enumerate(dataloader):
print(i)
print(data_)
觀察運行結果:
0
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
[4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
這個例子足够理解DataSet了
2,維度
參考這篇文章:
https://blog.csdn.net/xddwz/article/details/108405817
# -*- coding: utf-8 -*-
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import transforms
import os
import cv2
from PIL import Image
class MyDataset(Dataset):
def __init__(self, transform=None):
self.transform = transforms.Compose([
transforms.ToTensor() # 這裏僅以最基本的為例
])
self.image_path = './image_data2/'
self.image_names = os.listdir(self.image_path)
def __len__(self):
return len(self.image_names)
def __getitem__(self, item):
image_name = self.image_names[item]
image = cv2.imread(os.path.join(self.image_path, image_name)) # 讀到的是BGR數據
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 轉化為RGB,也可以用img = img[:, :, (2, 1, 0)]
# 這時的image是H,W,C的順序,因此下面需要轉化為C, H, W
image = torch.from_numpy(image).permute(2, 0, 1)
# image = Image.open(os.path.join(self.image_path, image_name))
# # print(image.shape)
# image = self.transform(image)
return image
這段代碼是上述鏈接中的,搬運過來的原因是强調一個事情:我們知道torch的維度順序是BCHW,而上述代碼中的__getitem__()
,是返回一張圖片,那麼這個時候我們需要注意的是,單張圖片本來的維度順序是HWC,即維度是(height, width, channel),我們需要將它的維度調整為(channel, height, width),然後再返回。
同時,根據上述代碼,也需要注意,返回的image的shape不是(1, channel, height, width),而是(channel, height, width),batch對應的那一維在__getitem__()
中不需要考慮。
對於一些channel為1的圖片,如果需要增加channel維度,那麼只需要squeeze(0)就行了。
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230543450851.html
边栏推荐
- 多线程与高并发(1)——线程的基本知识(实现,常用方法,状态)
- JDBC操作事务
- 无监督去噪——[TMI2022]ISCL: Interdependent Self-Cooperative Learning for Unpaired Image Denoising
- SQL基础:初识数据库与SQL-安装与基本介绍等—阿里云天池
- JDBC连接数据库
- 域内用户访问域外samba服务器用户名密码错误
- Typescript interface & type rough understanding
- PreparedStatement防止SQL注入
- DWSurvey是一个开源的调查问卷系统。解决无法运行问题,修改bug。
- Record a project experience and technologies encountered in the project
猜你喜欢
字符串(String)笔记
SQL注入
Dva中在effects中获取state的值
Navicate连接oracle(11g)时ORA:28547 Connection to server failed probable Oeacle Net admin error
图解HashCode存在的意义
Pytorch learning record (XI): data enhancement, torchvision Explanation of various functions of transforms
Pytorch学习记录(十一):数据增强、torchvision.transforms各函数讲解
JVM family (4) -- memory overflow (OOM)
Understand the current commonly used encryption technology system (symmetric, asymmetric, information abstract, digital signature, digital certificate, public key system)
filebrowser实现私有网盘
随机推荐
rsync实现文件服务器备份
线程的底部实现原理—静态代理模式
基于ssm 包包商城系统
delete和truncate
Write your own redistemplate
sklearn之 Gaussian Processes
编写一个自己的 RedisTemplate
The role of brackets' [] 'in the parameter writing method in MDN documents
jdbc入门\获取数据库连接\使用PreparedStatement
尚硅谷 p290 多态性练习
SQL基础:初识数据库与SQL-安装与基本介绍等—阿里云天池
mysql如何将存储的秒转换为日期
poi生成excel,插入图片
基于thymeleaf实现数据库图片展示到浏览器表格
异常的处理:抓抛模型
Get the value of state in effects in DVA
第36期《AtCoder Beginner Contest 248 打比赛总结》
Pyemd installation and simple use
Deconstruction function of ES6
Pyqy5 learning (4): qabstractbutton + qradiobutton + qcheckbox