当前位置:网站首页>基于Keras的时装分类案例
基于Keras的时装分类案例
2022-04-23 05:50:00 【Stephen_Tao】
Keras介绍
Keras是一个用Python编写的开源神经网络库。它能够运行在TensorFlow,Microsoft Cognitive Toolkit,Theano或PlaidML之上。
时装分类数据集介绍

该数据集包含70000张灰度图像,一共有10个类别。
步骤分析及代码实现
读取数据集
from tensorflow.python.keras.datasets import fashion_mnist
class SingleNN(object):
def __init__(self):
(self.train,self.train_label),(self.test,self.test_label) = fashion_mnist.load_data()
self.train = self.train / 255.0
self.test = self.test / 255.0
训练数据,测试数据的形状如下:
train: (60000, 28, 28)
train_label: (60000,)
test: (10000, 28, 28)
test_label: (10000,)
模型构建
模型结构:双层神经网络
- 隐藏层有128个神经元,激活函数选择relu
- 全连接层有10个神经元,因为fashion_mnist具有10个类别,激活函数选择softmax
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Flatten,Dense
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
class SingleNN(object):
model = Sequential([
Flatten(input_shape=(28,28)),
Dense(128,activation=tf.nn.relu),
Dense(10,activation=tf.nn.softmax)
])
编译定义优化过程
优化器选择Adam,损失函数选择交叉熵损失**(标签数据是整型数据,需要先转换为one-hot编码)**
from tensorflow.python.keras.optimizer_v1 import Adam
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
class SingleNN(object):
def compile(self):
SingleNN.model.compile(optimizer=Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
定义训练函数
epochs设置为3次,batch_size设置为32
class SingleNN(object):
def fit(self):
SingleNN.model.fit(self.train,self.train_label,epochs=3,batch_size=32)
return None
定义评估函数
记录测试集的损失函数值以及准确率
class SingleNN(object):
def evaluate(self):
test_loss,test_acc = SingleNN.model.evaluate(self.test,self.test_label)
print("test_loss:",test_loss)
print("test_acc:",test_acc)
return None
开启会话运行图
with tf.compat.v1.Session() as sess:
cnn = SingleNN()
cnn.compile()
cnn.fit()
cnn.evaluate()
运行结果
Train on 60000 samples
Epoch 1/3
60000/60000 [==============================] - 5s 81us/sample - loss: 0.4989 - accuracy: 0.8242
Epoch 2/3
60000/60000 [==============================] - 5s 77us/sample - loss: 0.3749 - accuracy: 0.8635
Epoch 3/3
60000/60000 [==============================] - 5s 78us/sample - loss: 0.3373 - accuracy: 0.8766
test_loss: 0.3732113081932068
test_acc: 0.8629
版权声明
本文为[Stephen_Tao]所创,转载请带上原文链接,感谢
https://blog.csdn.net/professor_tao/article/details/119481196
边栏推荐
- For() loop parameter call order
- copy constructor
- SSH 公钥 私钥的理解
- 日志
- Opencv uses genericindex for KNN search
- [UDS unified diagnosis service] i. diagnosis overview (3) - ISO 15765 architecture
- 相机标定:关键点法 vs 直接法
- Rust:单元测试(cargo test )的时候显示 println 的输出信息
- Initialization of classes and objects (constructors and destructors)
- 静态成员
猜你喜欢
![[UDS unified diagnostic service] i. overview of diagnosis (4) - basic concepts and terms](/img/fb/3d3cf54dc5b67ce42d60e0fe63baa6.png)
[UDS unified diagnostic service] i. overview of diagnosis (4) - basic concepts and terms

【UDS统一诊断服务】四、诊断典型服务(2)— 数据传输功能单元

C#【文件操作篇】PDF文件和图片互相转换

Graduation project, viewing screenshots of epidemic psychological counseling system

拷贝构造函数
![[UDS unified diagnostic service] IV. typical diagnostic service (2) - data transmission function unit](/img/22/c501c79176a93345dc72ff150c53c3.png)
[UDS unified diagnostic service] IV. typical diagnostic service (2) - data transmission function unit

C语言的浪漫

PM2 deploy nuxt project

安装pyshp库

Call procedure of function
随机推荐
Basemap库绘制地图
Tabbar implementation of dynamic bottom navigation bar in uniapp, authority management
ArcGIS表转EXCEL超出上限转换失败
Class inheritance and derivation
使用TransmittableThreadLocal实现参数跨线程传递
Robocode教程8——AdvancedRobot
[UDS unified diagnostic service] III. application layer protocol (1)
[untitled]
爬西瓜视频url
客户端软件增量更新
C语言进阶要点笔记3
爬虫效率提升方法
大学概率论与数理统计知识点详细整理
grub boot. S code analysis
数组旋转
Jeu de devinettes
CUDA环境安装
Initialization of classes and objects (constructors and destructors)
C#【文件操作篇】按行读取txt文本
PM2 deploy nuxt project