当前位置:网站首页>Don't bother tensorflow learning notes (10-12) -- Constructing a simple neural network and its visualization
Don't bother tensorflow learning notes (10-12) -- Constructing a simple neural network and its visualization
2022-04-23 20:14:00 【Trabecular trying to write code】
This note is based on Mo fan python Of Tensorflow course
Personally, I don't think the video tutorial of don't bother God is suitable for Xiaobai with zero Foundation , If it's Xiaobai, you can watch the videos of Li Hongyi or Wu Enda first or read directly . Don't bother the great God's tutorial, which is suitable for in-depth learning. You already have some theoretical knowledge , But students who lack code foundation ( For example, I ), Basically, don't bother to write the code again , Xiaobian has been right tensorflow Have a certain understanding of the application of , Also mastered some training skills of neural network .
The following corresponds to... In the video Tensorflow course (10)(11)(12), Xiaobian will record according to Mo fan's explanation and my own understanding .
One 、 Understanding the structure of neural networks
The purpose of this tutorial is to build the simplest neural network . The simplest neural network structure is that it contains only one input layer( Input layer ), One hidden layer, One output layer( Output layer ). Each layer contains neurons , except hidden layer The number of neurons is a super parameter , The number of neurons in input layer and output layer is closely related to input data and output data . for instance , The input data defined in the program is :x_data, Then the number of neurons in the input layer is the same as that of the input data data The number of is related to . Be careful ! Not with data The number of samples in ! The same goes for the output layer .
The number of neurons in the hidden layer should not be too many , Because deep learning should be reflected in the number of hidden layers , If there are too many neurons in one layer , Then it should be breadth learning . Also important is the number of hidden layers , Blindly adding hidden layers will lead to over fitting( Over fitting ) problem . So the design of hidden layer is a knowledge based on experience .
Two 、 Code details
1、 Add layer creation def add_layer()
def add_layer(inputs,in_size,out_size,activation_function=None):
Weights=tf.Variable(tf.random_normal([in_size,out_size]))
biases=tf.Variable(tf.zeros([1,out_size])+0.1)
Wx_plus_b=tf.matmul(inputs,Weights)+biases
if activation_function is None:
outputs=Wx_plus_b
else:
outputs=activation_function(Wx_plus_b)
return outputs
First define the function as the input layer for subsequent creation , The basis of hidden layer and output layer . What needs special attention here , In many cases, we are defining bias Time will make bias=0, But I don't want to bias=0, So in the biases Add after 0.1, Guaranteed not to be zero .
Yes activation_function The judgment of the : In the connection between input layer and hidden layer , We need to use activation_function Convert the data into a non-linear form , But in the connection between the hidden layer and the output layer , We no longer need to activate the function . Because there is a judgment , When you don't need to activate a function , Direct output Wx_plus_b; When needed , will Wx_plus_b Output after activating the function .
2、 Data import and neural network establishment
x_data=np.linspace(-1,1,300)[:,np.newaxis]
noise=np.random.normal(0,0.05,x_data.shape)
y_data=np.square(x_data)-0.5+noise
xs=tf.placeholder(tf.float32,[None,1])
ys=tf.placeholder(tf.float32,[None,1])
l1=add_layer(xs,1,10,activation_function=tf.nn.relu)#l1 is an input layer
#activation function is a relu function
prediction=add_layer(l1,10,1,activation_function=None)#prediction is an hidden layer
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
#loss function is based on MSE(mean square error)
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init=tf.global_variables_initializer() #initialize all variables
sess=tf.Session()
sess.run(init) #very important
We use noise to import data (noise). use noise To get y_data Not exactly according to the direction of univariate quadratic function ,y_data The results must be random , So in bias Added... To the location noise, Got y-x The image is shown below :
Here we use tensorflow Medium placeholder function .placeholder() Functions are constructed in neural networks graph In the model , The data to be input is not passed into the model at this time , It just allocates the necessary memory . Etc session, In conversation , Run the model by feed_dict() Function passes data to the placeholder .
use placeholder Is that , You can customize the amount of incoming data . In the training of deep learning , Generally, the whole data set will not be trained directly , For example, random gradient descent (Stochastic Gradient Descent) Is to take mini batch Train the data in a way .
3、 Neural network training visualization
In this program, visualization uses matplotlib library , And dynamically display the training results . But don't bother with the code directly in jupyter notebook It will happen when running on. Only function images can be displayed , The problem of not displaying images dynamically . stay youtube In the comments under Mofan's video , Many small partners also reflect that whether it is IDLE still terminal Or other compilation environments , Add... To the front in time %matplotlib inline It can't achieve the expected effect . Small series search method , Several methods can be found in jupyter notebook Method of dynamic display on . But this method needs to install Qt,jupyter Will use Qt As drawing backend .
The code is as follows :
%matplotlib # Declare when importing the library
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()
for i in range(2000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i%50==0:
try:
plt.pause(0.5)
except Exception:
pass
try:
ax.lines.remove(lines[0])
plt.show()
except Exception as e:
pass
prediction_value=sess.run(prediction,feed_dict={xs:x_data})
lines=ax.plot(x_data,prediction_value,'r-',lw=10)
This is the training effect after adding dynamic images , The red curve is the curve trained by computer , The blue dot area is x、y The area of the true value of the point . You can see that after computer training , We can get a curve that roughly matches the regional trend .
That's the whole content of the tutorial , If you have any questions, please communicate in the comment area !
版权声明
本文为[Trabecular trying to write code]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210555381778.html
边栏推荐
- Scrum Patterns之理解各种团队模式
- R语言使用timeROC包计算存在竞争风险情况下的生存资料多时间AUC值、使用cox模型、并添加协变量、R语言使用timeROC包的plotAUCcurve函数可视化多时间生存资料的AUC曲线
- C6748 software simulation and hardware test - with detailed FFT hardware measurement time
- 基于pytorch搭建GoogleNet神经网络用于花类识别
- Record: call mapper to report null pointer Foreach > the usage of not removing repetition;
- Is the wechat CICC wealth high-end zone safe? How to open an account for securities
- SIGIR'22 "Microsoft" CTR estimation: using context information to promote feature representation learning
- Remote code execution in Win 11 using wpad / PAC and JScript 1
- A simple (redisson based) distributed synchronization tool class encapsulation
- PCL点云处理之直线与平面的交点计算(五十三)
猜你喜欢
考研英语唐叔的语法课笔记
山东大学软件学院项目实训-创新实训-网络安全靶场实验平台(五)
Software College of Shandong University Project Training - Innovation Training - network security shooting range experimental platform (8)
Shanda Wangan shooting range experimental platform project - personal record (V)
C6748 software simulation and hardware test - with detailed FFT hardware measurement time
antd dropdown + modal + textarea导致的textarea光标不可被键盘控制问题
Shanda Wangan shooting range experimental platform project - personal record (IV)
Leetcode dynamic planning training camp (1-5 days)
Physical meaning of FFT: 1024 point FFT is 1024 real numbers. The actual input to FFT is 1024 complex numbers (imaginary part is 0), and the output is also 1024 complex numbers. The effective data is
【文本分类案例】(4) RNN、LSTM 电影评价倾向分类,附TensorFlow完整代码
随机推荐
【目标跟踪】基于帧差法结合卡尔曼滤波实现行人姿态识别附matlab代码
Is the wechat CICC wealth high-end zone safe? How to open an account for securities
Design of warehouse management database system
R language uses timeroc package to calculate the multi time AUC value of survival data under competitive risk, uses Cox model and adds covariates, and R language uses the plotauccurve function of time
An error is reported when sqoop imports data from Mysql to HDFS: sqlexception in nextkeyvalue
【问题解决】‘ascii‘ codec can‘t encode characters in position xx-xx: ordinal not in range(128)
Unity general steps for creating a hyper realistic 3D scene
An error is reported in the initialization metadata of the dolphin scheduler -- it turns out that there is a special symbol in the password. "$“
R language survival package coxph function to build Cox regression model, ggrisk package ggrisk function and two_ Scatter function visualizes the risk score map of Cox regression, interprets the risk
[numerical prediction case] (3) LSTM time series electricity quantity prediction, with tensorflow complete code attached
The textarea cursor cannot be controlled by the keyboard due to antd dropdown + modal + textarea
Grafana shares links with variable parameters
antd dropdown + modal + textarea导致的textarea光标不可被键盘控制问题
Change the material of unity model as a whole
Devops integration - environment variables and building tools of Jenkins service
Cadence Orcad Capture 批量更改元件封装功能介绍图文教程及视频演示
C6748 software simulation and hardware test - with detailed FFT hardware measurement time
nc基础用法3
SRS 的部署
Record: call mapper to report null pointer Foreach > the usage of not removing repetition;