当前位置:网站首页>Don't bother tensorflow learning notes (10-12) -- Constructing a simple neural network and its visualization
Don't bother tensorflow learning notes (10-12) -- Constructing a simple neural network and its visualization
2022-04-23 20:14:00 【Trabecular trying to write code】
This note is based on Mo fan python Of Tensorflow course
Personally, I don't think the video tutorial of don't bother God is suitable for Xiaobai with zero Foundation , If it's Xiaobai, you can watch the videos of Li Hongyi or Wu Enda first or read directly . Don't bother the great God's tutorial, which is suitable for in-depth learning. You already have some theoretical knowledge , But students who lack code foundation ( For example, I ), Basically, don't bother to write the code again , Xiaobian has been right tensorflow Have a certain understanding of the application of , Also mastered some training skills of neural network .
The following corresponds to... In the video Tensorflow course (10)(11)(12), Xiaobian will record according to Mo fan's explanation and my own understanding .
One 、 Understanding the structure of neural networks
The purpose of this tutorial is to build the simplest neural network . The simplest neural network structure is that it contains only one input layer( Input layer ), One hidden layer, One output layer( Output layer ). Each layer contains neurons , except hidden layer The number of neurons is a super parameter , The number of neurons in input layer and output layer is closely related to input data and output data . for instance , The input data defined in the program is :x_data, Then the number of neurons in the input layer is the same as that of the input data data The number of is related to . Be careful ! Not with data The number of samples in ! The same goes for the output layer .
The number of neurons in the hidden layer should not be too many , Because deep learning should be reflected in the number of hidden layers , If there are too many neurons in one layer , Then it should be breadth learning . Also important is the number of hidden layers , Blindly adding hidden layers will lead to over fitting( Over fitting ) problem . So the design of hidden layer is a knowledge based on experience .
Two 、 Code details
1、 Add layer creation def add_layer()
def add_layer(inputs,in_size,out_size,activation_function=None):
Weights=tf.Variable(tf.random_normal([in_size,out_size]))
biases=tf.Variable(tf.zeros([1,out_size])+0.1)
Wx_plus_b=tf.matmul(inputs,Weights)+biases
if activation_function is None:
outputs=Wx_plus_b
else:
outputs=activation_function(Wx_plus_b)
return outputs
First define the function as the input layer for subsequent creation , The basis of hidden layer and output layer . What needs special attention here , In many cases, we are defining bias Time will make bias=0, But I don't want to bias=0, So in the biases Add after 0.1, Guaranteed not to be zero .
Yes activation_function The judgment of the : In the connection between input layer and hidden layer , We need to use activation_function Convert the data into a non-linear form , But in the connection between the hidden layer and the output layer , We no longer need to activate the function . Because there is a judgment , When you don't need to activate a function , Direct output Wx_plus_b; When needed , will Wx_plus_b Output after activating the function .
2、 Data import and neural network establishment
x_data=np.linspace(-1,1,300)[:,np.newaxis]
noise=np.random.normal(0,0.05,x_data.shape)
y_data=np.square(x_data)-0.5+noise
xs=tf.placeholder(tf.float32,[None,1])
ys=tf.placeholder(tf.float32,[None,1])
l1=add_layer(xs,1,10,activation_function=tf.nn.relu)#l1 is an input layer
#activation function is a relu function
prediction=add_layer(l1,10,1,activation_function=None)#prediction is an hidden layer
loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))
#loss function is based on MSE(mean square error)
train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
init=tf.global_variables_initializer() #initialize all variables
sess=tf.Session()
sess.run(init) #very important
We use noise to import data (noise). use noise To get y_data Not exactly according to the direction of univariate quadratic function ,y_data The results must be random , So in bias Added... To the location noise, Got y-x The image is shown below :
Here we use tensorflow Medium placeholder function .placeholder() Functions are constructed in neural networks graph In the model , The data to be input is not passed into the model at this time , It just allocates the necessary memory . Etc session, In conversation , Run the model by feed_dict() Function passes data to the placeholder .
use placeholder Is that , You can customize the amount of incoming data . In the training of deep learning , Generally, the whole data set will not be trained directly , For example, random gradient descent (Stochastic Gradient Descent) Is to take mini batch Train the data in a way .
3、 Neural network training visualization
In this program, visualization uses matplotlib library , And dynamically display the training results . But don't bother with the code directly in jupyter notebook It will happen when running on. Only function images can be displayed , The problem of not displaying images dynamically . stay youtube In the comments under Mofan's video , Many small partners also reflect that whether it is IDLE still terminal Or other compilation environments , Add... To the front in time %matplotlib inline It can't achieve the expected effect . Small series search method , Several methods can be found in jupyter notebook Method of dynamic display on . But this method needs to install Qt,jupyter Will use Qt As drawing backend .
The code is as follows :
%matplotlib # Declare when importing the library
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()
for i in range(2000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i%50==0:
try:
plt.pause(0.5)
except Exception:
pass
try:
ax.lines.remove(lines[0])
plt.show()
except Exception as e:
pass
prediction_value=sess.run(prediction,feed_dict={xs:x_data})
lines=ax.plot(x_data,prediction_value,'r-',lw=10)
This is the training effect after adding dynamic images , The red curve is the curve trained by computer , The blue dot area is x、y The area of the true value of the point . You can see that after computer training , We can get a curve that roughly matches the regional trend .
That's the whole content of the tutorial , If you have any questions, please communicate in the comment area !
版权声明
本文为[Trabecular trying to write code]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204210555381778.html
边栏推荐
- 【2022】将3D目标检测看作序列预测-Point2Seq: Detecting 3D Objects as Sequences
- SIGIR'22 "Microsoft" CTR estimation: using context information to promote feature representation learning
- Understanding various team patterns in scrum patterns
- Rédaction de thèses 19: différences entre les thèses de conférence et les thèses périodiques
- WordPress plug-in: WP CHINA Yes solution to slow domestic access to the official website
- VeraCrypt文件硬盘加密使用教程
- Redis的安装(CentOS7命令行安装)
- Change the material of unity model as a whole
- Error reported by Azkaban: Azkaban jobExecutor. utils. process. ProcessFailureException: Process exited with code 64
- Unity general steps for creating a hyper realistic 3D scene
猜你喜欢
山东大学软件学院项目实训-创新实训-网络安全靶场实验平台(五)
Project training of Software College of Shandong University - Innovation Training - network security shooting range experimental platform (V)
基于pytorch搭建GoogleNet神经网络用于花类识别
DTMF dual tone multi frequency signal simulation demonstration system
【目标跟踪】基于帧差法结合卡尔曼滤波实现行人姿态识别附matlab代码
Azkaban recompile, solve: could not connect to SMTP host: SMTP 163.com, port: 465 [January 10, 2022]
山东大学软件学院项目实训-创新实训-网络安全靶场实验平台(六)
Mysql database backup scheme
Kubernetes introduction to mastery - ktconnect (full name: kubernetes toolkit connect) is a small tool based on kubernetes environment to improve the efficiency of local test joint debugging.
Unity创建超写实三维场景的一般步骤
随机推荐
A simple (redisson based) distributed synchronization tool class encapsulation
selenium.common.exceptions.WebDriverException: Message: ‘chromedriver‘ executable needs to be in PAT
Electron入门教程4 —— 切换应用的主题
Inject Autowired fields into ordinary beans
R语言ggplot2可视化:ggplot2可视化散点图并使用geom_mark_ellipse函数在数据簇或数据分组的数据点周围添加椭圆进行注释
Mysql database - connection query
基于pytorch搭建GoogleNet神经网络用于花类识别
网络通信基础(局域网、广域网、IP地址、端口号、协议、封装、分用)
MFC obtains local IP (used more in network communication)
IIS data conversion problem: 16bit to 24bit
Possible root causes include a too low setting for -Xss and illegal cyclic inheritance dependencies
redis 分布式锁
波场DAO新物种下场,USDD如何破局稳定币市场?
NC basic usage 2
Understanding various team patterns in scrum patterns
R语言使用timeROC包计算无竞争风险情况下的生存资料多时间AUC值、使用confint函数计算无竞争风险情况下的生存资料多时间AUC指标的置信区间值
Intersection calculation of straight line and plane in PCL point cloud processing (53)
如何在BNB鏈上創建BEP-20通證
Error reported by Azkaban: Azkaban jobExecutor. utils. process. ProcessFailureException: Process exited with code 127
如何在BNB链上创建BEP-20通证