当前位置:网站首页>pytorch使用Dataloader加载自己的数据集train_X和train_Y
pytorch使用Dataloader加载自己的数据集train_X和train_Y
2022-08-10 18:23:00 【王延凯的博客】
Pytorch使用Dataloader加载自己的数据集train_X和train_Y
1.重构一个新的dataloader函数
# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
# 初始化函数,得到数据
def __init__(self, data_root, data_label):
self.data = data_root
self.label = data_label
# index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data, labels
# 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
def __len__(self):
return len(self.data)
2.调用
train_data= GetLoader(train_X, train_Y)
train_loader=DataLoader(dataset=train_data,batch_size=50, shuffle=True, num_workers=0)
#这里只写了train_X和train_Y的,test_X和test_Y的类似
for data,labels in train_loader:
pass # 在这里就可以正常操作啦
边栏推荐
猜你喜欢
随机推荐
开发模式对测试的影响
三星Galaxy Watch5产品图片流出 非Pro表款亦有蓝宝石加持
CSV(Comma-Separate-Values)逗号分隔值文件
三坐标雷达显示软件 SPx Viewer-3D
钻石价格预测的ML全流程!从模型构建调优道部署应用!
Active users of mobile banking grew rapidly in June, hitting a half-year high
6-11 先序输出叶结点(15分)
去除富文本标签样式
Interview Question 04.12. Summation Path-dfs+Auxiliary Array Method
120Hz OLED拒绝“烧屏”!华硕无双全能轻薄本
【快应用】实现自定义导航栏组件
MySQL 原理与优化:Update 优化
2022-08-09 Study Notes day32-IO Stream
21天打卡挑战学习MySQL——《MySQL表管理》第二周 第五篇
MySql主要性能指标说明
flex使用align-content无效
FPGA工程师面试试题集锦81~90
Flexsim 发生器和暂存区设定临时实体流颜色和端口
eager模式和graph模式 Tensorflow
MySQL 查询出重复出现两次以上的数据 - having









