当前位置:网站首页>深度学习——怎样读取本地数据作为训练集和测试集
深度学习——怎样读取本地数据作为训练集和测试集
2022-08-06 08:01:00 【尘心平】
活动地址:CSDN21天学习挑战赛
目录
利用pathlib库检测数据集(如果数据集路径正确,这一步可以不用)
利用image_dataset_from_directory方法导入数据集
image_dataset_from_directory方法的介绍(参数、返回值、总结):
在机器学习的实践中,我们通常会遇到一些用到提前通过网络下载到本地计算机的数据集,而之前博客中的实例,都是利用tf.keras.datasets包中的现成的数据集直接load_data()
# 导入服装图像数据集(此外还有mnist,cifar10等数据集) (train_images, train_labels), ( test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data() # datasets内部集成了数据集,一些tensorflow内置数据集链接
常用数据集 Datasets - Keras 中文文档
本篇博客将介绍怎么导入这些下载到本地的数据集
图片识别类数据集导入
利用pathlib库检测数据集(如果数据集路径正确,这一步可以不用)
首先利用pathlib库,计算数据集总数,以及查看某数据集图片,确保数据集路径的正确
pathlib库将在下篇博客介绍,这里显示图片是用PIL库内的Image类
PS:这里数据集用的是 花朵数据集
右键url内容,转到链接,即可开始下载
本地数据集的目录结构
首先我们来看一下数据集的目录结构:
图像识别类数据集一般都有若按个子目录,子目录数即为CNN神经网络最后Dense层的输出维度,每个子目录下即有若干张图片,子目录名即为标签名。训练集,测试集混合
pathlib库检测数据集
# 设置数据集路径(这个路径是分类文件夹的上级目录) data_url = "E:/Download/flower_photos/flower_photos" # 数据集路径 注意\换成/,不然读不到数据 data_url = pathlib.Path(data_url) # 生成Path对象 image_count = len(list(data_url.glob("*/*.jpg"))) # 图片总数为: 3670 # glob在此路径表示的目录中下全局查找给定的相对pattern,生成所有匹配的文件列表 *是递归查找的意思 print("图片总数为:", image_count) # 获取数据集中的图片总数 roses = list(data_url.glob("daisy/*.jpg")) Image.open(roses[0]).show() # 调取默认的图片打开应用,打开open路径指定的图片调用Windows图片显示软件,打开roses[0]路径所表示的图片
十分重要的是:数据集路径data_url 中的 \ 要替换成 /
利用image_dataset_from_directory方法导入数据集
image_dataset_from_directory方法的介绍(参数、返回值、总结):
这里我使用的是 tensorflow 2.8.0 keras 2.8.0 (不同版本该方法所在的包有不同)
在我的使用版本中,该方法是在 keras.preprocessing,image 包中
即方法全名为:keras.preprocessing.image.image_dataset_from_directory
def image_dataset_from_directory(directory: Any, labels: str = 'inferred', label_mode: str = 'int', class_names: Any = None, color_mode: str = 'rgb', batch_size: int = 32, image_size: tuple[int, int] = (256, 256), shuffle: bool = True, seed: Any = None, validation_split: {__mul__} = None, subset: {__eq__} = None, interpolation: str = 'bilinear', follow_links: bool = False, crop_to_aspect_ratio: bool = False, **kwargs: Any) -> Any
1. 参数说明:
- directory :Any类型,这里directory可以传字符串数据集路径(注意 路径中是/ ),也可以传pathlib的Path对象
- labels:str类型,默认值是inferred,表示标签从目录结构中生成,子目录按照字母顺序从0开始编号
- label_mode:str类型,默认值是int,指label的索引值是0~class_names.length-1 label_mode的值类型还可以是 categorica,binary(索引要么是0要么是1)
- class_names:Any类型,仅当labels值为inferred时有效,存储实际标签名(按子目录字母顺序排序一一对应)的列表或元组
- color_mode:str类型,默认值是rgb(3通道),还可以是grayscale(1通道),rgba(4通道)
- batch_size:int类型,默认值是32,一次取样的大小
- image_size:int两元组,默认值是(256,256)
- shuffle:bool类型,默认值True,表示打乱数据
- seed:随机数种子,int型数即可(例如123,一般验证集训练集两个函数的seed要相同)。如果使用validation_split和shuffle,则必须提供一个seed参数,确保train和validation子集之间没有重叠。
- validation_split:数据集中验证集的比例,double型,0~1
- subset:str类型,有"training"和"validation"两种取值,表示函数返回的是训练集还是验证集
- interpolation:str类型,指定图像大小调整时的插值方法,默认值是bilinear(双线性插值),此外还有nearest(最邻近插值),bucubic(双三次插值),lanczos3,lanczos5。
nearest:这个算法的优点是计算简单方便,缺点是图像容易出现锯齿。
bucubic:效果最好,但耗时
bilinear与lanczos用时相仿,但lanczos效果更好
2. 返回值介绍:
函数返回 tensorflow.data.Dataset 类型的对象
Dataset对象可以看成一个元组列表,每个元组有batch_size个image和bacth_size个label,image和label一一对应,而列表有总大小/batch_size个
这里介绍我们使用该返回值的常用方法:
- 直接作为训练集和测试集使用,内部封装了(images,labels)
history = model.fit(train_data, epochs=6, validation_data=test_data) # 预测 pre = model.predict(test_data) # 对所有测试图片进行预测
- 可以通过take(数字n)按顺序获取Dataset中n个batch_size的元组
通过take得到的image数据是张量,显示图片时候需要先归一化 numpy()/255.0
for train_image, train_label in train_data.take(1): for i in range(len(train_image)): plt.subplot(5, 10, i + 1) plt.xticks([]) plt.yticks([]) plt.grid(False) plt.imshow(train_image[i].numpy() / 255., cmap=plt.cm.binary) plt.xlabel(class_names[train_label[i]]) plt.show()
- 可以通过as_numpy_iterator()返回Dataset中所有的元组数据
该方法得到的image数据就是numpy数字,直接 /255.0 归一化即可
for test_image, test_label in test_data.as_numpy_iterator(): for i in range(len(test_image)): plt.subplot(5, 10, i + 1) plt.xticks([]) plt.yticks([]) plt.grid(False) plt.imshow(test_image[i]/255.0, cmap=plt.cm.binary) plt.xlabel(class_names[test_label[i]]) plt.show() break
3. 总结:
一般使用该函数时,如果没有更高的要求,只需要指定 directory、class_names、color_mode、image_size、subset、validation_split、seed即可
(interpolation可以指定为lanczos3,batch_size也可以指定为其他2的幂)
代码示例:
batch_size = 32 image_width = 256 image_height = 256 class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] train_data = keras.preprocessing.image.image_dataset_from_directory( directory="E:/Download/flower_photos/flower_photos", class_names=class_names, color_mode='rgb', batch_size=batch_size, image_size=(image_width, image_height), subset='training', validation_split=0.2, seed=123 ) test_data = keras.preprocessing.image.image_dataset_from_directory( directory=data_url, # 之前的data_url class_names=class_names, color_mode='rgb', batch_size=batch_size, image_size=(image_width, image_height), subset='validation', validation_split=0.2, seed=123 ) # train_data中包含了train_image和train_label # test_data中包含了test_image和test_label
Dataset的额外处理
cache()函数
使用cache函数可以提高程序性能,主要是通过在第一次epoch迭代时,将数据集加入缓存中,后续epoch迭代从缓存中读取即可
使用注意事项:
- 必须完整的输入迭代的数据集,如果有缺少的数据集,那么该缺少的数据集不会产生缓存数据
- 由于cache()每次迭代生成的数据相同,所以后面必须接shuffle函数
shuffle(buffer_size=)函数
shuffle(buffer_size)的主要功能是扰乱数据
主要实现机理是:
prefetch(buffer_size=)函数
prefetch()函数用于生成一个数据集,该数据集从该给定数据集中预取指定数量buffer_size个元素。
通常buffer_size参数取 tensorflow.data.AUTOTUNE (值是-1)该参数是表示prefetch函数自动选择合适的buffer_size
代码:
train_data = train_data.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE) # 自动寻找合适的 buffer_size test_data = test_data.cache().prefetch(buffer_size=tf.data.AUTOTUNE) # train_data是训练集,训练需要保证数据集的无关性,所以需要扰乱 # 测试集不用于训练,且测试数据有没有顺序都不要紧,故不用扰乱
边栏推荐
- selenium4.0以上元素被定位
- Parameter ‘courseId’ not found. Available parameters are [arg1, arg0, param1, para
- 猴子都能上手的Unity插件Photon之重要部分(PUN)
- 按钮只能点击一次
- DescrTab2包,输出SCI级别的描述统计表
- How much is a code signing certificate?
- QianBase Operation and Maintenance Practical Commands
- Parameter ‘courseId‘ not found. Available parameters are [arg1, arg0, param1, para
- 【手机】手机选购指南
- 【云原生--Kubernetes】配置管理
猜你喜欢
![[Cloud Native--Kubernetes] Configuration Management](/img/ef/37732ff3ec1b3609ba53d876ad8178.png)
[Cloud Native--Kubernetes] Configuration Management

明日立秋 autumn begins,天气渐凉

Jetpack WorkManager is enough to read this article~

CSDN官方插件

使用Specification与Example方式实现动态条件查询案例

在Windows上安装Go语言开发包

Process finished with exit code -1073740791 (0xC0000409)

Day020 方法重写与多态

Jetpack WorkManager 看这一篇就够了~

I set the global mapping table prefix in yml, but the database does not recognize it
随机推荐
BuuWeb
Day020 Method Overriding and Polymorphism
QianBase Operation and Maintenance Practical Commands
[ CTF ]【天格】战队WriteUp-第六届“强网杯”全国安全挑战赛(初赛)
山石发声 | 做好安全运营,没有你想象的那么难
凹语言——名字的由来和寓意
动手学深度学习_Batch Normalization
亿纬首件大圆柱电池系统产品成功下线
How to improve the quality of articles without being "recommended and affected" by the post assistant
2022 Hailiang SC Travel Notes
如何提高文章质量,不被发文助手“推荐受影响”
pacman包 管理各种R包
剑指 Offer 15. 二进制中1的个数,位运算,与运算
模拟实现strcpy函数的实现(含多次优化思想)
【云原生--Kubernetes】配置管理
qwqのtechnology flag
DemographicTable 新的基线特征表绘制 R包
Cesium关于Entity中的parent、isShowing、entityCollection和监听事件的探讨
剑指 Offer 39. 数组中出现次数超过一半的数字
测试用例设计方法-场景法详解



