当前位置:网站首页>pytorch使用Dataloader加载自己的数据集train_X和train_Y

pytorch使用Dataloader加载自己的数据集train_X和train_Y

2022-08-10 18:23:00 王延凯的博客

Pytorch使用Dataloader加载自己的数据集train_X和train_Y

1.重构一个新的dataloader函数

# 定义GetLoader类,继承Dataset方法,并重写__getitem__()和__len__()方法
class GetLoader(torch.utils.data.Dataset):
	# 初始化函数,得到数据
    def __init__(self, data_root, data_label):
        self.data = data_root
        self.label = data_label
    # index是根据batchsize划分数据后得到的索引,最后将data和对应的labels进行一起返回
    def __getitem__(self, index):
        data = self.data[index]
        labels = self.label[index]
        return data, labels
    # 该函数返回数据大小长度,目的是DataLoader方便划分,如果不知道大小,DataLoader会一脸懵逼
    def __len__(self):
        return len(self.data)

2.调用

    train_data= GetLoader(train_X, train_Y)
    train_loader=DataLoader(dataset=train_data,batch_size=50, shuffle=True, num_workers=0)
    #这里只写了train_X和train_Y的,test_X和test_Y的类似
    
    for data,labels in train_loader:
        pass	#	在这里就可以正常操作啦
原网站

版权声明
本文为[王延凯的博客]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_38468077/article/details/126036159