当前位置:网站首页>基于Keras的时装分类案例
基于Keras的时装分类案例
2022-04-23 05:50:00 【Stephen_Tao】
Keras介绍
Keras是一个用Python编写的开源神经网络库。它能够运行在TensorFlow,Microsoft Cognitive Toolkit,Theano或PlaidML之上。
时装分类数据集介绍
该数据集包含70000张灰度图像,一共有10个类别。
步骤分析及代码实现
读取数据集
from tensorflow.python.keras.datasets import fashion_mnist
class SingleNN(object):
def __init__(self):
(self.train,self.train_label),(self.test,self.test_label) = fashion_mnist.load_data()
self.train = self.train / 255.0
self.test = self.test / 255.0
训练数据,测试数据的形状如下:
train: (60000, 28, 28)
train_label: (60000,)
test: (10000, 28, 28)
test_label: (10000,)
模型构建
模型结构:双层神经网络
- 隐藏层有128个神经元,激活函数选择relu
- 全连接层有10个神经元,因为fashion_mnist具有10个类别,激活函数选择softmax
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Flatten,Dense
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
class SingleNN(object):
model = Sequential([
Flatten(input_shape=(28,28)),
Dense(128,activation=tf.nn.relu),
Dense(10,activation=tf.nn.softmax)
])
编译定义优化过程
优化器选择Adam,损失函数选择交叉熵损失**(标签数据是整型数据,需要先转换为one-hot编码)**
from tensorflow.python.keras.optimizer_v1 import Adam
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
class SingleNN(object):
def compile(self):
SingleNN.model.compile(optimizer=Adam(),
loss=sparse_categorical_crossentropy,
metrics=['accuracy'])
定义训练函数
epochs设置为3次,batch_size设置为32
class SingleNN(object):
def fit(self):
SingleNN.model.fit(self.train,self.train_label,epochs=3,batch_size=32)
return None
定义评估函数
记录测试集的损失函数值以及准确率
class SingleNN(object):
def evaluate(self):
test_loss,test_acc = SingleNN.model.evaluate(self.test,self.test_label)
print("test_loss:",test_loss)
print("test_acc:",test_acc)
return None
开启会话运行图
with tf.compat.v1.Session() as sess:
cnn = SingleNN()
cnn.compile()
cnn.fit()
cnn.evaluate()
运行结果
Train on 60000 samples
Epoch 1/3
60000/60000 [==============================] - 5s 81us/sample - loss: 0.4989 - accuracy: 0.8242
Epoch 2/3
60000/60000 [==============================] - 5s 77us/sample - loss: 0.3749 - accuracy: 0.8635
Epoch 3/3
60000/60000 [==============================] - 5s 78us/sample - loss: 0.3373 - accuracy: 0.8766
test_loss: 0.3732113081932068
test_acc: 0.8629
版权声明
本文为[Stephen_Tao]所创,转载请带上原文链接,感谢
https://blog.csdn.net/professor_tao/article/details/119481196
边栏推荐
猜你喜欢
Graduation project, curriculum link, student achievement evaluation system
File viewing commands and user management commands
PHP junior programmers, take orders and earn extra money
【UDS统一诊断服务】二、网络层协议(2)— 数据传输规则(单帧与多帧)
[UDS unified diagnosis service] i. diagnosis overview (3) - ISO 15765 architecture
Swagger2 generates API documents
[UDS unified diagnostic service] i. overview of diagnosis (4) - basic concepts and terms
jenkspy包安装
MySQL groups are sorted by a field, and the first value is taken
拷贝构造函数
随机推荐
Introduction to nonparametric camera distortion model
进程管理命令
爬取小米有品app商品数据
Generate random number
Matching between class template with default template argument and template parameter
利用文件保存数据(c语言)
Arcpy为矢量数据添加字段与循环赋值
[UDS unified diagnostic service] IV. typical diagnostic service (5) - function / component test function unit (routine function unit 0x31)
【踩坑】Win11 WSL2 中 meld 无法正常使用问题修复
Easy to use data set and open source network comparison website
gcc ,g++,gdb的安装
多线程爬取马可波罗网供应商数据
Figure guessing game
PHP junior programmers, take orders and earn extra money
[UDS unified diagnosis service] i. diagnosis overview (2) - main diagnosis protocols (K-line and can)
类和对象
【UDS统一诊断服务】二、网络层协议(1)— 网络层概述与功能
带默认模板实参的类模板与模板模板形参的匹配
Friend function, friend class, class template
Basemap库绘制地图