当前位置:网站首页>基于TensorFlow的线性回归实例
基于TensorFlow的线性回归实例
2022-04-23 05:50:00 【Stephen_Tao】
文章目录
线性回归原理
- 根据训练数据建立回归模型 y = w 1 x 1 + w 2 x 2 + . . . + w n x n + b y=w_1x_1+w_2x_2+...+w_nx_n+b y=w1x1+w2x2+...+wnxn+b
- 建立预测值与真实值间的误差损失函数
- 采用梯度下降法优化误差损失函数,对最优的权重和偏置进行预测
实例分析
1. 训练数据生成
- 随机生成200个服从正态分布的数据点
- 数据本身分布为 y = 0.6 x + 0.9 y=0.6x+0.9 y=0.6x+0.9
随机数据点可视化代码如下:
import tensorflow as tf
import matplotlib.pyplot as plt
X = tf.random.normal(shape=(200, 1), mean=0, stddev=1)
y_true = tf.matmul(X, [[0.6]]) + 0.9
plt.scatter(X,y_true)
plt.show()
2. 建立线性回归模型
采用TensorFlow建立线性回归模型,代码如下:
def Linear_regression():
with tf.compat.v1.Session() as sess:
# 生成正态分布的随机数据
X = tf.random.normal(shape=(200, 1), mean=0, stddev=1)
y_true = tf.matmul(X, [[0.6]]) + 0.9
# 初始化线性回归的权重和偏置
weight = tf.Variable(initial_value=tf.random.normal(shape=(1, 1)))
bias = tf.Variable(initial_value=tf.random.normal(shape=(1, 1)))
# 采用初始化参数建立线性回归模型
y_predict = tf.matmul(X, weight) + bias
# 建立误差损失函数
error = tf.reduce_mean(tf.square(y_predict - y_true))
# 采用随机梯度下降法进行模型训练
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.05).minimize(error)
# 初始化会话中的变量
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
# 记录每次训练得到的损失值
error_set = []
# 训练200次,打印每次训练的权重、偏置及损失
for i in range(200):
sess.run(optimizer)
error_set.append(error.eval())
print("第%d步的误差为%f,权重为%f,偏置为%f" %(i,error.eval(),weight.eval(),bias.eval()))
运行结果如下:
第0步的误差为0.869870,权重为1.532328,偏置为0.675902
第1步的误差为0.663277,权重为1.428975,偏置为0.691342
第2步的误差为0.637269,权重为1.346699,偏置为0.713041
第3步的误差为0.434829,权重为1.280713,偏置为0.727712
第4步的误差为0.456109,权重为1.209046,偏置为0.745070
第5步的误差为0.276588,权重为1.146608,偏置为0.758473
第6步的误差为0.268131,权重为1.080588,偏置为0.770594
第7步的误差为0.201195,权重为1.024212,偏置为0.794897
第8步的误差为0.136823,权重为0.988268,偏置为0.804994
第9步的误差为0.123242,权重为0.950725,偏置为0.811171
第10步的误差为0.131416,权重为0.922447,偏置为0.819484
第11步的误差为0.084628,权重为0.897494,偏置为0.826534
第12步的误差为0.072482,权重为0.873246,偏置为0.833084
第13步的误差为0.073851,权重为0.841141,偏置为0.842617
第14步的误差为0.046367,权重为0.811977,偏置为0.849992
第15步的误差为0.034317,权重为0.789964,偏置为0.855224
第16步的误差为0.026889,权重为0.767914,偏置为0.861630
第17步的误差为0.020377,权重为0.751666,偏置为0.866115
第18步的误差为0.020895,权重为0.735054,偏置为0.871480
第19步的误差为0.016006,权重为0.718687,偏置为0.872991
第20步的误差为0.012947,权重为0.707634,偏置为0.875760
第21步的误差为0.011159,权重为0.698776,偏置为0.878277
第22步的误差为0.008542,权重为0.688199,偏置为0.881770
第23步的误差为0.006672,权重为0.679353,偏置为0.884191
第24步的误差为0.005545,权重为0.671653,偏置为0.885150
第25步的误差为0.003817,权重为0.664658,偏置为0.885894
第26步的误差为0.003786,权重为0.658720,偏置为0.886094
第27步的误差为0.002937,权重为0.652663,偏置为0.887146
第28步的误差为0.002135,权重为0.646956,偏置为0.888744
第29步的误差为0.001883,权重为0.641135,偏置为0.890537
第30步的误差为0.001530,权重为0.637612,偏置为0.891546
第31步的误差为0.001205,权重为0.633786,偏置为0.892773
第32步的误差为0.000896,权重为0.630734,偏置为0.892999
第33步的误差为0.000742,权重为0.627161,偏置为0.893859
第34步的误差为0.000600,权重为0.624422,偏置为0.894859
第35步的误差为0.000556,权重为0.622399,偏置为0.895277
第36步的误差为0.000471,权重为0.620324,偏置为0.895607
第37步的误差为0.000341,权重为0.618261,偏置为0.895853
第38步的误差为0.000309,权重为0.616247,偏置为0.896142
第39步的误差为0.000233,权重为0.614667,偏置为0.896652
第40步的误差为0.000219,权重为0.613210,偏置为0.896935
...
第194步的误差为0.000000,权重为0.600000,偏置为0.900000
第195步的误差为0.000000,权重为0.600000,偏置为0.900000
第196步的误差为0.000000,权重为0.600000,偏置为0.900000
第197步的误差为0.000000,权重为0.600000,偏置为0.900000
第198步的误差为0.000000,权重为0.600000,偏置为0.900000
第199步的误差为0.000000,权重为0.600000,偏置为0.900000
训练过程中损失值变化折线图如下:
可以看到随着训练次数的增加,损失函数的值逐渐减小,最后变为0。训练过程中所预测的最优参数也与实际的数据分布参数相符。
小结
本文介绍了线性回归的基本原理与步骤,并基于TensorFlow实现了简单的线性回归任务,取得了良好的效果。
版权声明
本文为[Stephen_Tao]所创,转载请带上原文链接,感谢
https://blog.csdn.net/professor_tao/article/details/119292134
边栏推荐
猜你喜欢
文件查看命令和用户管理命令
【UDS统一诊断服务】四、诊断典型服务(5)— 功能/元件测试功能单元(例行程序功能单元0x31)
【UDS统一诊断服务】四、诊断典型服务(4)— 在线编程功能单元(0x34-0x38)
【UDS统一诊断服务】二、网络层协议(2)— 数据传输规则(单帧与多帧)
MySQL groups are sorted by a field, and the first value is taken
Class inheritance and derivation
如何安装jsonpath包
【UDS统一诊断服务】三、应用层协议(1)
ArcGIS表转EXCEL超出上限转换失败
[UDS unified diagnosis service] i. diagnosis overview (3) - ISO 15765 architecture
随机推荐
File viewing commands and user management commands
数组旋转
在MFC中使用printf
进程管理命令
爬取彩票数据
Tabbar implementation of dynamic bottom navigation bar in uniapp, authority management
Robocode教程4——Robocode的游戏物理
Installation of GCC, G + +, GDB
【学习一下】HF-Net 训练
猜数字游戏
C语言实用小技巧合集(持续更新)
[UDS unified diagnostic service] III. application layer protocol (1)
Flask - 中间件
[UDS unified diagnostic service] II. Network layer protocol (2) - data transmission rules (single frame and multi frame)
ArcGIS license错误-15解决方法
Common shortcut keys of IDE
【UDS统一诊断服务】三、应用层协议(2)
Introduction to nonparametric camera distortion model
[ThreadX] h743zi + lan8720 + ThreadX + netx duo transplantation
[UDS unified diagnosis service] IV. typical diagnosis service (3) - read fault information function unit (storage data transmission function unit)