当前位置:网站首页>torch.utils.data.DataLoader
torch.utils.data.DataLoader
2022-08-09 00:27:00 【沙小菜】
#设置数据增强方法
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#加载数据集的数据,返回所有样本的img和label
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
#对数据进行batch采样
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)1.加载数据集的是数据,返回所有样本的img和label
通过数据加载类完成这一操作
数据加载类包括三个函数:__init__()、__getitem__()、__len()__()
(1)__init__()
__init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
)返回所有样本的img和label
(2)__getitem__()
这个函数在进行epoch训练时才会运行,根据给出的index确定样本,并进行数据增强操作。
返回数据增强后的样本。
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target(3)__len()__()
返回数据的数量
def __len__(self) -> int:
return len(self.data)2.确定训练时的数据加载方式
torch.utils.data.DataLoader,结合了数据集和取样器,并且可以提供多个线程处理数据集。用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。
参数:
dataset:包含所有数据的数据集
batch_size :每一小组所包含数据的数量
Shuffle : 是否打乱数据位置,当为Ture时打乱数据,全部抛出数据后再次dataloader时重新打乱。
sampler : 自定义从数据集中采样的策略,如果制定了采样策略,shuffle则必须为False.
Batch_sampler:和sampler一样,但是每次返回一组的索引,和batch_size, shuffle, sampler, drop_last 互斥。
num_workers : 使用线程的数量,当为0时数据直接加载到主程序,默认为0。
collate_fn:不太了解
pin_memory:s 是一个布尔类型,为T时将会把数据在返回前先复制到CUDA的固定内存中
drop_last:布尔类型,为T时将会把最后不足batch_size的数据丢掉,为F将会把剩余的数据作为最后一小组。
timeout:默认为0。当为正数的时候,这个数值为时间上限,每次取一个batch超过这个值的时候会报错。此参数必须为正数。
worker_init_fn:和进程有关系,暂时用不到
torch.utils.data.DataLoader中有采样器、迭代器、__len__()。
边栏推荐
猜你喜欢

笔记&代码 | 统计学——基于R(第四版) 第九章一元线性回归

"Replay" interview BAMT came back to sort out 398 high-frequency interview questions to help you get a high salary offer

笔记&代码 | 统计学——基于R(第四版) 第二章数据可视化

整流七 - 三相PWM整流器—公式推导篇

《MySQL入门很轻松》第3章:数据库的创建与操作

控件限制总结

mysql 批量修改表及字段字符集

非科班毕业生,五面阿里:四轮技术面 +HR 一面已拿 offer

整流十四---直接功率控制策略

无代码平台邮箱入门教程
随机推荐
AutoX安途杯中山大学程序设计校赛(同步赛)
【学习-目标检测】目标检测之——YOLO v3
JSON基础,传递JSON数据,介绍jackson、gson、fastjson、json-lib四种主流框架!
Mysql Workbench uses .sql file to import data into database
[Deep Learning] TensorFlow Learning Road 2: Introduction to ANN and TensorFlow Implementation
GaN图腾柱无桥 Boost PFC(单相)二 (公式推到理解篇)
centos7 yum 安装最新版redis
Using MySQL in Ubuntu/Linux environment: Modify the database sql_mode to solve the "this is incompatible with sql_mode=only_full_group_by" problem
怎么重置mysql的自增列AUTO_INCREMENT初时值
数学模型建立常用方法
插值拟合——数据处理或预测
【 StoneDB Class 】 introductory lesson 3: StoneDB installation of compilation
A - A + B Problem II
全新Swagger3.0教程,OAS3快速配置指南,实现API接口文档自动化!
supervisor 安装、配置、常用命令
【C语言刷题】链表中快慢指针的应用
光照衰减-Lights
NodeJs连接mysql数据库
2021.10.7 2020 CCPC重现赛
「复盘」面试 BAMT 回来整理 398 道高频面试题,助你拿高薪 offer