当前位置:网站首页>Fashion MNIST 数据集分类训练
Fashion MNIST 数据集分类训练
2022-04-23 02:36:00 【彩色海绵】
首先之前的文章记录了MNIST简单的分类玩法。目前这个以及没有挑战性,作为它的代替者Fashion MNIST 。我们当然要玩一玩了
它长这样,同样有10类。和MNIST分类图片大小张数都一样
实现代码:本人用jupyter。上篇博文发的安装的tf2.6.0环境下写的
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
#首先导入所需包
#加载数据集,这里可以自己引入,不需要下载,为了速度快,可以下载好几个文件,放到对应jupyter存放文件位置,本人放的看下面图,文件评论里有分享
(train_image, train_lable), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
#下面不是必须的,这个是看看数据集里内容
train_image.shape,train_lable.shape,test_image.shape, test_label.shape
plt.imshow(train_image[1])
#接下来归一化
train_image = train_image/255
test_image = test_image/255 #归一化
#创建模型,此部分可以根据自己增加层数等搞复杂点
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
# 28*28 拉平
model.add(tf.keras.layers.Dense(128, activation='relu'))
# 128个隐藏层单元个数,太大容易过拟合,太小容易遗漏信息
model.add(tf.keras.layers.Dense(10, activation='softmax'))
#输出层,10个概率
model.summary()
'''
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28))) # 28*28
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
'''
#最后选择优化器adam,损失函数
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['acc']
)
#开始训练,这个我用epochs=5。50的时候可以达到96%多
model.fit(train_image, train_lable, epochs=5)
数据集有60000张28*28的训练集,和10000张测试集
取出训练集的第二张看看是这个样子
模型内容这样:
加了隐藏层模型内容这样:
5个epochs时准去率acc=89.18%
50个epochs时准去率acc=96.76%
附加:完整干净的百分之96.75的简单的代码
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
(train_image, train_lable), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
train_image = train_image/255
test_image = test_image/255 #归一化
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28))) # 28*28 拉平
model.add(tf.keras.layers.Dense(128, activation='relu')) # 128个隐藏层单元个数
model.add(tf.keras.layers.Dense(10, activation='softmax'))#输出层,10个概率
model.summary()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['acc']
)
model.fit(train_image, train_lable, epochs=50)
版权声明
本文为[彩色海绵]所创,转载请带上原文链接,感谢
https://blog.csdn.net/m0_63172128/article/details/124340008
边栏推荐
- wordpress 调用指定页面内容详解2 get_children()
- Suggestion: block reference sorting is in the order of keywords
- 智能辅助功能丰富,思皓X6安全配置曝光:将于4月23日预售
- [xjtu Computer Network Security and Management] session 2 Cryptographic Technology
- Execute external SQL script in MySQL workbench and report error
- 【ValueError: math domain error】
- 高效音乐格式转换工具Music Converter Pro
- 都是做全屋智能的,Aqara和HomeKit到底有什么不同?
- Lane cross domain problem
- 想体验HomeKit智能家居?不如来看看这款智能生态
猜你喜欢
高效音乐格式转换工具Music Converter Pro
定了,今日起,本号粉丝可免费参与网易数据分析培训营!
JVM类加载器
基于Torchserve部署SBERT模型<语义相似度任务>
How to solve the complexity of project document management?
16、 Anomaly detection
If 404 page is like this | daily anecdotes
Efficient music format conversion tool Music Converter Pro
双亲委派模型【理解】
[XJTU computer network security and management] Lecture 2 password technology
随机推荐
全局、独享、局部路由守卫
JDBC JDBC
从0开始开发一个chrome插件(2)
How many steps are there from open source enthusiasts to Apache directors?
A domestic image segmentation project is heavy and open source!
Tp6 Alibaba Cloud SMS Window message Curl Error 60: SSL Certificate Problem: Unable to get local issuer Certificate
本地远程访问云服务器的jupyter
[nk]牛客月赛48 D
Suggestion: block reference sorting is in the order of keywords
leetcode 烹饪料理
IAR嵌入式開發STM32f103c8t6之點亮LED燈
They are all intelligent in the whole house. What's the difference between aqara and homekit?
LeetCode 349. Intersection of two arrays (simple, array) Day12
电源电路设计原来是这么回事
双亲委派模型【理解】
PTA: Romantic reflection [binary tree reconstruction] [depth first traversal]
使用Go语言构建Web服务器
PIP install shutil reports an error
Flink real-time data warehouse project - Design and implementation of DWS layer
Flink stream processing engine system learning (III)