当前位置:网站首页>MNIST数据集的训练(内附完整代码及其注释)
MNIST数据集的训练(内附完整代码及其注释)
2022-08-09 14:57:00 【[email protected]】
#引入必要的包
import numpy as np
from keras.datasets import mnist #MNIST数据集
import cv2
from keras.models import Sequential #Sequential序贯模型
from keras.layers import Dense,Dropout,Activation #全连接层 丢弃节点 激活函数
from keras.optimizers import SGD #优化函数
import numpy
'''选择模型'''
model=Sequential() #序贯模型
'''构建模型'''
model.add(Dense(500,input_shape=(784,))) #第一个隐藏层 输出节点个数为500,输入节点个数为784
model.add(Activation('tanh')) #指定tanh为激活函数
model.add(Dropout(0.5)) #每次丢弃掉一半节点的信息
model.add(Dense(500)) #第二个隐藏层
model.add(Activation('tanh')) #指定tanh为激活函数
model.add(Dropout(0.5)) #每次丢弃掉一半节点的信息
model.add(Dense(500)) #第三个隐藏层
model.add(Activation('tanh')) #指定tanh为激活函数
model.add(Dropout(0.5)) #每次丢弃掉一半节点的信息
model.add(Dense(10)) #输出层
model.add(Activation('softmax')) #指定tanh为激活函数
'''训练设置和网络编译'''
sgd=SGD(lr=0.01,decay=1e-6) #使用SGD为优化参数 初始化学习率(0.01)和学习率衰减值(1e-6)
model.compile(loss='categorical_crossentropy',optimizer=sgd) #使用交叉熵作为loss函数
model.summary() #查看网络结构
'''数据准备'''
(x_train,y_train),(x_test,y_test)=mnist.load_data() #获取mnist数据集
print('原始训练样本的shape为:{}'.format(x_train.shape))
#将每个训练样本的输入变成一维
x_train=x_train.reshape(x_train.shape[0],x_train.shape[1]*x_train.shape[2]) #将每个训练样本的输入变成一维(由于mist的输入数据维度是(num,28,28),这里需要把后面的维度直接拼起来)
print('训练样本的输入变成一维后的shape为:{}'.format(x_train.shape))
#将每个测试样本的输入变成一维
x_test=x_test.reshape(x_test.shape[0],x_test.shape[1]*x_test.shape[2]) #将每个测试样本的输入变成一维
print('*'*100)
#创建one-hot向量(将每个样本的预期输出变为一个One-Hot的10维向量,真实标签对应的位置设为1,其余设为0)
print('第一个训练输出值为:{}'.format(y_train[0]))
y_train=(np.arange(10)==y_train[:,None]).astype(int) #对训练输出进行处理
print('对训练输出进行处理后的第一个训练输出值为:{}'.format(y_train[0]))
y_test=(np.arange(10)==y_test[:,None]).astype(int) #对训练输出进行处理
print('*'*100)
'''网络训练'''
model.fit(x_train,y_train,batch_size=128,epochs=20,shuffle=True,verbose=2,validation_split=0.3)
'''模型训练'''
scores=model.evaluate(x_test,y_test,batch_size=128,verbose=0)
print("The test loss is %f"%scores)
'''计算模型在测试集上的准确率'''
result=model.predict(x_test,batch_size=128,verbose=1)
print(result.shape)
print(result[0])
result_max=np.argmax(result,axis=1) #得到网络预测的最大概率对应的类别序号
print("得到网络预测的最大概率对应的类别序号:%f" % result_max[0])
test_max=np.argmax(y_test,axis=1) #得到真实类别的最大概率对应的类别序号
result_bool=np.equal(result_max,test_max) #得到预测值和真实值的样本
true_number=np.sum(result_bool) #正确结果的样本数
print("正确结果的样本数为:%f" % true_number)
print('The accuracy of the model is %f' % (true_number/len(result_bool))) #验证结果的准确率
以下是输出的训练过程:

以下是模型使用的代码
'''模型使用'''
test_image=x_test[0,:].reshape(28,28) #获取序号为0的测试图片
cv2.imshow("测试图片",test_image)
cv2.waitKey()
test_image=test_image.reshape(1,test_image.shape[0]*test_image.shape[1]) #转换为(1,784)
test_result=model.predict(test_image,batch_size=1) #预测
print(test_result)
test_max=np.argmax(test_result,axis=1) #这是结果的真实序号
print(test_max) #打印结果以上就是MNIST数据集训练的全过程啦
原网站
版权声明
本文为[[email protected]]所创,转载请带上原文链接,感谢
https://blog.csdn.net/m0_59405106/article/details/124470877
边栏推荐
- Qt control - QTextEdit usage record
- 【研究生工作周报】(第十二周)
- 【Postgraduate Work Weekly】(Week 12)
- 数据缺失对任务影响
- LNK1123: Failed during transition to COFF: invalid or corrupt file
- 流式布局总结
- 【 Leetcode 】 433. The smallest genetic changes
- [Deep Learning] Original Problem and Dual Problem (6)
- 【论文阅读】LIME:Low-light Image Enhancement via Illumination Map Estimation(笔记最全篇)
- 【研究生工作周报】(第八周)
猜你喜欢
随机推荐
人脸识别示例代码解析(一)——程序参数解析
链表翻转 全翻转 部分翻转
A shortcut method for writing menu commands in C
【Leetcode】433. 最小基因变化
Sort method (Hill, Quick, Heap)
分类任务系列学习——总述
【Postgraduate Work Weekly】(Week 9)
路由的懒加载与接口的封装
小型项目如何使用异步任务管理器实现不同业务间的解耦
Android面试题基础集锦《一》
浏览器指纹识别是什么意思?
Xgboost系列-XGB实际参数调优指南附源码
利用qrcode组件实现图片转二维码
Common compilation problems
AsyncTask 串行还是并行
NLP-Reading Comprehension Task Learning Summary Overview
将从后台获取到的数据 转换成 树形结构数据
响应式布局总结
WebGL:BabylonJS入门——初探:注入活力
Example of file operations - downloading and merging streaming video files









