当前位置:网站首页>Handwritten numeral recognition in deep learning environment
Handwritten numeral recognition in deep learning environment
2022-04-22 07:02:00 【PRML_ MAN】
In their own windows The environment of deep learning has been configured under the environment , This paper mainly records the training and use of a simple handwritten numeral recognition model in the environment of deep learning .
1、 stay pycharm Middle configuration conda Environmental Science :

After the environment is configured , You can start the code of handwritten numeral recognition
2、 load tensorflow and keras The library of
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
3、 If the local environment is GPU Environment , Need configuration GPU.
If you don't know if the local branch doesn't support GPU, You can print out the current... Through the following code GPU list
gpuList = tf.config.list_physical_devices("GPU")
print(gpuList)
The printed information is as follows :
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
To configure GPU For model training
if gpuList:
usedGPU = gpuList[0]
tf.config.experimental.set_memory_growth(usedGPU, True)
tf.config.set_visible_devices([usedGPU], "GPU")
4、 Import from the network MNIST Handwritten digit recognition data set
This dataset contains 70000 A picture with numbers written by hand , The size of each picture is 28281( Grayscale picture ),MNIST Dataset Links , Local choice 60000 Zhang as a training set ,10000 As validation set , The dataset has been deleted Keras Integrate , therefore , You can go directly through Keras Of datasets.mnist.load_data() Function import data , The code is as follows :
# trainImageSet Training picture collection
# trainLabelSet Training data tag set
# testImageSet Test image collection
# testLabelSet Test data tag set
(trainImageSet, trainLabelSet), (testImageSet, testLabelSet) = datasets.mnist.load_data()
5、 Preprocess the picture
trainImageSet = trainImageSet.reshape((60000, 28, 28, 1)) # Change the data into a tensor
testImageSet = testImageSet.reshape( (10000, 28, 28, 1)) # Change the data into a tensor
trainImageSet = trainImageSet / 255.0 # Change the pixel value of the picture from [0,255] Mapping to [0,1] within
testImageSet = testImageSet / 255.0 # Change the pixel value of the picture from [0,255] Mapping to [0,1] within
6、 structure Keras Model
Next, build the following depth model :
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 26, 26, 32) 320
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 11, 11, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 1600) 0
_________________________________________________________________
dense (Dense) (None, 64) 102464
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
_________________________________________________________________
The implementation code is as follows :
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), # Convolution layer 1, Convolution kernel 3*3, common 32 Convolution kernels
layers.MaxPooling2D((2, 2)), # Pooling layer 1,2*2 sampling , Maximum pooling
layers.Conv2D(64, (3, 3), activation='relu'), # Convolution layer 2, Convolution kernel 3*3, common 64 Convolution kernels
layers.MaxPooling2D((2, 2)), # Pooling layer 2,2*2 sampling , Maximum pooling
layers.Flatten(), # Flatten layer , Connect the convolution layer to the full connection layer , Looking up the data, it is found that all the data are converted into one-dimensional vector data , Prepare for the full connection later .
layers.Dense(64, activation='relu'), # Fully connected layer , The output dimension is 64
layers.Dense(10) # Output layer , The output dimension is 10, representative 0-9 The value of the corresponding tag
])
You can print the entire network through the following code :
model.summary()
7、 Start training model
Set up the optimizer 、 Loss function 、 Evaluation function
model.compile(optimizer='adam', # Optimizer
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # Loss function
metrics=['accuracy']) # Evaluation function
Model training , Save the trained model
history = model.fit(trainImageSet, trainLabelSet, epochs=10, # iteration 10 Time
validation_data=(testImageSet, testLabelSet))
model.save('recogNumber.h5') # Save the model to recogNumber.h5 In file
Only this and nothing more , The whole process of model training has been completed
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training / Identify split lines ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Next, use the trained model
8、 Read local pictures
adopt matplotlib.image Read local pictures , A value selected locally is 3 Test with pictures of .
import tensorflow as tf
from tensorflow.keras.models import load_model
import scipy.misc as misc
from tensorflow.keras.preprocessing import image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
pil_im = mpimg.imread('333.JPG')
9、 Image preprocessing
Preprocess the picture , Be careful : Modify the local picture to 28 * 28 * 1 A picture of Victoria
im_data2 = pil_im / 255.0 # Data normalization
print(im_data2)
plt.imshow(im_data2)
plt.show() # display picture , After closing the picture , The program will continue to move on
im_data2 = image.img_to_array(im_data2)
im_data2 = np.expand_dims(im_data2, axis=0) # Convert data to np Array

10、 Load the model and predict
model = load_model('C://Users//Administrator//Desktop//recogNumber.h5') # Load model
pre = model.predict(im_data2) # The model predicts
print(pre) # Print the confidence of all current labels
print(np.argmax(pre)) # Select the subscript corresponding to the maximum confidence , The label
After closing the picture , The output result vector and prediction result are as follows :
# Output the confidence vector of different labels , Subscript corresponds to label 0-9
[[-11.468813 -1.7106687 -6.7820425 15.381593 -2.7402487 0.973804 -8.244136 -2.8609006 -1.9001967 -0.34003535]]
# Predicted results :
3
From the output results , The recognition rate of model classifier is still very good , except 3 Outside this category , The values of other categories are basically less than 0.
11、 summary
The example process of deep learning is summarized as follows :
1、 Get and load the dataset ( Training data sets and test data sets )
2、 Set the model manually
3、 The training model and the saved model
4、 Use the trained model to use
I thought it was hard , It's much easier to take the first step to success , If you see here , Explain that your environment has been set up , When you're happy, just like it ^ _ ^, Please leave a message if you have any questions .
版权声明
本文为[PRML_ MAN]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204220603058627.html
边栏推荐
- ROS系列(三):ROS架构简介
- C form design: the mouse is close to the control to display prompt remarks
- 遇到数学公式中不认识的符号怎么办
- Anaconda configures a deep learning environment (Yolo as an example)
- STM32 时钟分割TIM_ClockDivision配置及使用详细说明
- 替代RTD2171U|CS5266设计电路|TYPEC转HDMI方案|CS5266AN
- 很难相信,这对高速信号换了那么多次过孔!!!
- Using keras framework to write three-layer neural network to solve the problem of house price prediction
- STM32 定时器同步 触发 代码 实验 验证分享
- 超定方程的求解
猜你喜欢
随机推荐
Clark transform of PMSM FOC control MATLAB / Simulink simulation
The last day of 2018, welcome 2019.
基于mediapipe的人手关键点检测
cmake qmake简单知识
图片合成视频
机器学习基本名词介绍
Google Colab的基本使用方法(一)
Anaconda configures a deep learning environment (Yolo as an example)
Lachnospira, the core genus of intestinal bacteria
Conflict between glide 4.0 and fillet clipping centercrop
(亲测有效)paddledetection在Jetson上的编译指南
如果有一种设计不增加成本又能改善信号质量
树莓派3B通过mentohust登录锐捷校园网有线端,并创建WIFI(开热点)供其他设备使用,同时实现开机自启动
一阶数字低通滤波器-C语言/matlab实现
c#窗体设计 鼠标靠近 控件显示 提示 备注 信息
What kind of design is sure?
PolarMask is not in the models registry
遇到数学公式中不认识的符号怎么办
[蓝桥杯复习] 生命之树
Do you believe that one day BGA can't take the difference line?









