当前位置:网站首页>pytorch实现数据集读取/下载
pytorch实现数据集读取/下载
2022-08-08 21:00:00 【MarkerTm】
#pic_center =400x
系列文章:
数据集读取本地
Dataset基类介绍
在 torch.utils.data.Dataset 提供了数据集的基类,我们只需要继承这个基类重写里面的方法即可完成数据集的加载与读取。
重写两个方法:
- __len__方法,能够实现通过全局的len()方法获取其中的元素个数
- __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据
数据加载案例
根据自己数据集的情况修改一下三个方法
- __init__方法可以用来设置读取数据集等初始化数据集的基本操作
- __getitem__方法通常用来根据索引来返回一条对应的数据内容
- __len__方法通常用来返回数据总数
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
print(my_dataset[0])
print(len(my_dataset))
本文使用的数据集为开源的本文分类数据集SMS Spam Collection Data Set,下载地址为https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection
数据集是从 Grumbletext 网站手动提取了 425 条 SMS 垃圾邮件的集合,由一个文本文件构成,其中每一行都是有一个类别和后面的原始消息构成。
DataLoader类的使用详解
DataLoader的主要作用是将Dataset处理后的数据集进行加载整合成batch用于后续训练,使用方法如下
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
if __name__ == '__main__':
data_loader = DataLoader(dataset=my_dataset, batch_size=5, shuffle=True, num_workers=1)
for i in data_loader:
print(i)
break
实际项目中通常使用enumerate方法 读取每一个batch内容的同时也返回 batch的索引
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self):
self.lines = open("datasets/smsspamcollection/SMSSpamCollection", encoding="utf-8").readlines()
def __len__(self):
return len(self.lines)
def __getitem__(self, index):
# strip取消换行
# cur_line = self.lines[index].strip()
# label = cur_line[:4].strip()
# content = cur_line[4:].strip()
cur_line = self.lines[index]
label = cur_line[:4]
content = cur_line[4:]
return label, content
my_dataset = MyDataset()
if __name__ == '__main__':
data_loader = DataLoader(dataset=my_dataset, batch_size=3, shuffle=True, num_workers=1)
for index, (label, content) in enumerate(data_loader):
print(index)
print(label)
print(content)
break
数据集在线下载
边栏推荐
猜你喜欢
【idea_取消自动import .*】
二分查找的坑
The new database is online | CnOpenData information transmission, software and information technology service industry basic information data of industrial and commercial registered enterprises
OneNote 教程,如何在 OneNote 中检查拼写?
GeoServer入门学习:01-开篇
神经网络论文Enhancing deep neural networks via multiple kernel learning
关于Mac终端自定义命令和Mysql命令问题
Everything原理探究以及C#实现
Flask 教程 第十二章:日期和时间
Swoole 健康检查
随机推荐
Kotlin delegate property knowledge points
关于KotlinAndroid遇到的小知识
Use fontforge to modify font, keep only numbers
Blazor PWA 单页应用身份认证机制示例
Bluu Seafood launches first lab-grown fish products
Kotlin's JSON format parsing
比较器? 如何使用比较器? 如何自定义比较器?
[highcharts application - double pie chart]
pm2安装配置与基本命令你知道吗?
去噪论文 Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising
内网渗透之代理转发
Centos下载安装redis- 使用yum
矩阵相乘
The first day of a solid foundation for Kotlin
charles简单使用
PHP传递任意数量的函数参数
Redis Bloom Filter
手机投影到deepin
The new database is online | CnOpenData information transmission, software and information technology service industry basic information data of industrial and commercial registered enterprises
澳洲ABM创新模式将销售代理权给到个体,让利消费者