当前位置:网站首页>第二讲 Linear Model 线性模型
第二讲 Linear Model 线性模型
2022-08-05 05:13:00 【长路漫漫 大佬为伴】
参考资料
- 一句话解释numpy.meshgrid()
- matplotlib教程之——自定义配置文件和绘图风格(rcParams和style)
- python中zip()函数的用法
- matplotlib之plot()详解
- matplotlib 3D绘图警告
课堂练习
实现线性模型y=wx的平面图
import numpy as np
import matplotlib.pyplot as plt
#保存数据集,相同的索引为一个样本
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
#模型的前馈
def forward(x):
return x * w
#损失函数
def loss(x, y):
y_pred = forward(x) #根据前馈求y_hat
return (y_pred - y) ** 2 #计算损失
# 穷举法
w_list = [] #权重
mse_list = [] #权重对应的损失值
for w in np.arange(0.0, 4.1, 0.1):
print("w=", w)
l_sum = 0
#从x_data, y_data取出x_val, y_val
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
l_sum += loss_val
print('x_val==', x_val, 'y_val==',y_val, 'y_pred_val==',y_pred_val,'loss_val==', loss_val)
print('MSE=', l_sum / 3)
w_list.append(w)
mse_list.append(l_sum / 3)
#调用画图
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()
图案轨迹
课后练习
实现线性模型(y=wx+b)并输出loss的3D图像
这里存在几个问题需要解决
1.w,b的取值
之前课堂练习中,只需要取一个w,因此可以用for循环取值。课后练习中需要对w,b两个值进行取值操作,因此要使用meshgrid函数
2.图像无法显示中文
在前方加上
from pylab import * mpl.rcParams[‘font.sans-serif’] = [‘SimHei’]
3.matplotlib 3D绘图警告
matplotlib 3D绘图警告
课后习题代码:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']
#这里设函数为y=3x+2
x_data = [1.0,2.0,3.0]
y_data = [5.0,8.0,11.0]
def forward(x):
return x * w + b
def loss(x,y):
y_pred = forward(x)
return (y_pred-y)*(y_pred-y)
mse_list = []
W=np.arange(0.0,4.1,0.1)
B=np.arange(0.0,4.1,0.1)
w,b=np.meshgrid(W,B)
# print("w==",w)
# print('b==',b)
l_sum = 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val = forward(x_val)
loss_val = loss(x_val, y_val)
print('x_val==', x_val,'\ny_val==', y_val,'\ny_pred_val==', y_pred_val, '\nloss_val==',loss_val)
l_sum += loss_val
fig = plt.figure()
ax = Axes3D(fig,auto_add_to_figure=False)
fig.add_axes(ax)
ax.plot_surface(w, b, l_sum/3)
ax.set_xlabel("权重 W")
ax.set_ylabel("偏置项 B")
ax.set_zlabel("损失值")
plt.show()
3D图:
边栏推荐
猜你喜欢

RL强化学习总结(一)

Flutter真机运行及模拟器运行

Flutter learning 5-integration-packaging-publish

jvm three heap and stack

Excel画图

Flutter学习2-dart学习

Develop a highly fault-tolerant distributed system

MySQL Foundation (1) - Basic Cognition and Operation

Structured light 3D reconstruction (1) Striped structured light 3D reconstruction
![[Study Notes Dish Dog Learning C] Classic Written Exam Questions of Dynamic Memory Management](/img/0b/f7d9205c616f7785519cf94853d37d.png)
[Study Notes Dish Dog Learning C] Classic Written Exam Questions of Dynamic Memory Management
随机推荐
Redis - 13、开发规范
Flutter学习-开篇
u-boot中的u-boot,dm-pre-reloc
RDD和DataFrame和Dataset
【Untitled】
MySQL中控制导出文件后变为了\N有什么解决方案吗?
HQL语句执行过程
Flutter学习2-dart学习
2022牛客多校第四场C.Easy Counting Problem(EGF+NTT)
数字_获取指定位数的小数
数据库 单表查询
"PHP8 Beginner's Guide" A brief introduction to PHP
How can Flutter parent and child components receive click events
Distributed systems revisited: there will never be a perfect consistency scheme...
mutillidae download and installation
number_gets the specified number of decimals
Reverse theory knowledge 4
【cesium】Load and locate 3D Tileset
"Recursion" recursion concept and typical examples
Mesos学习