当前位置:网站首页>pytorch实现数据集读取/下载
pytorch实现数据集读取/下载
2022-08-08 21:00:00 【MarkerTm】
#pic_center =400x
系列文章:
数据集读取本地
Dataset基类介绍
在 torch.utils.data.Dataset 提供了数据集的基类,我们只需要继承这个基类重写里面的方法即可完成数据集的加载与读取。
重写两个方法:
- __len__方法,能够实现通过全局的len()方法获取其中的元素个数
- __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据
数据加载案例
根据自己数据集的情况修改一下三个方法
- __init__方法可以用来设置读取数据集等初始化数据集的基本操作
- __getitem__方法通常用来根据索引来返回一条对应的数据内容
- __len__方法通常用来返回数据总数
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
print(my_dataset[0])
print(len(my_dataset))


本文使用的数据集为开源的本文分类数据集SMS Spam Collection Data Set,下载地址为https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据集是从 Grumbletext 网站手动提取了 425 条 SMS 垃圾邮件的集合,由一个文本文件构成,其中每一行都是有一个类别和后面的原始消息构成。
DataLoader类的使用详解
DataLoader的主要作用是将Dataset处理后的数据集进行加载整合成batch用于后续训练,使用方法如下
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
if __name__ == '__main__':
data_loader = DataLoader(dataset=my_dataset, batch_size=5, shuffle=True, num_workers=1)
for i in data_loader:
print(i)
break

实际项目中通常使用enumerate方法 读取每一个batch内容的同时也返回 batch的索引
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
if __name__ == '__main__':
data_loader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True, num_workers=1)
for index, (label, content) in enumerate(data_loader):
print(index)
print(label)
print(content)
break
数据集在线下载
边栏推荐
猜你喜欢
随机推荐
【读代码重构有感】
Kotlin-学习的第五天之Handler
【highcharts应用-双饼图】
rancher坑记录
编译原理——逆波兰式分析程序(C#)
语义分割FCN FPN UNet DeepLab HRNet SETR TransFuse...
amd和Intel的CPU到底哪个好?
究竟什么才是“云计算” | 科普好文
go基于泛型实现继承
使用fontforge修改字体,只保留数字
解决gradle导包速度慢问题
numpy
关于Mac终端自定义命令和Mysql命令问题
fastapi-实战-综述
The new library online | CnOpenDataA shares of the listed company basic information data
手机投影到deepin
Kotlin注解
Kotlin reflection
Flask 教程 第四章:数据库
360杜跃进ISC演讲:保障信创软件的可信性和安全性是信创安全体系的基础







