当前位置:网站首页>PyTorch笔记——实现线性回归完整代码&手动或自动计算梯度代码对比
PyTorch笔记——实现线性回归完整代码&手动或自动计算梯度代码对比
2022-04-23 05:44:00 【umbrellalalalala】
参考资料:《深度学习框架PyTorch:入门与实践》
本文对此书中线性回归部分的代码进行注释解读,并补充手动反向传播过程中求解梯度的公式。
一、生成数据集完整代码
采用“假数据”:
# 设置随机数种子,保证在不同计算机上运行时下面的输出一致
t.manual_seed(1000)
def get_fake_data(batch_size=8):
''' 产生随机数据:y=x*2+3,加上了一些噪声 '''
x = t.rand(batch_size, 1) * 20
y = x * 2 + (1 + t.randn(batch_size, 1) * 3)
return x, y
*二、线性回归完整代码
自动计算梯度的代码在注释中:
# 如果括号内填一个1,则报错:mat2 must be a matrix
w = t.rand(1, 1)
b = t.zeros(1, 1)
# # 如果自动计算梯度
# # 注意requires_grad默认是False,不设置为True会在loss.backward()报错
# w = t.rand(1, 1, requires_grad=True)
# b = t.zeros(1, 1, requires_grad=True)
lr = 0.0001
losses = np.zeros(500)
for ii in range(500):
x, y = get_fake_data(batch_size=32)
# 前向传播,计算loss,采用均方误差
# torch.mul是逐元素相乘;torch.mm是矩阵相乘
y_pred = t.mm(x, w) + b.expand_as(y)
loss = 0.5 * (y_pred - y) ** 2
loss = loss.sum()
losses[ii] = loss.item()
# 反向传播,手动计算梯度
dloss = 1
dy_pred = dloss * (y_pred - y)
dw = t.mm(x.t(), dy_pred)
db = dy_pred.sum() # 注意b是标量,使用的时候扩展为元素全为b的向量
# 更新参数
w.sub_(lr * dw)
b.sub_(lr * db)
# # 如果自动计算梯度
# loss.backward()
# w.data.sub_(lr * w.grad.data)
# b.data.sub_(lr * b.grad.data)
# # 注意梯度清零
# w.grad.data.zero_()
# b.grad.data.zero_()
# 每1000次训练画一次图
if ii % 50 == 0:
display.clear_output(wait=True)
# predicted
x = t.arange(0, 20).view(-1, 1).float()
y = t.mm(x, w) + b.expand_as(x)
plt.plot(x.numpy(), y.numpy())
# true data
x2, y2 = get_fake_data(batch_size=20)
plt.scatter(x2.numpy(), y2.numpy())
plt.xlim(0, 5)
plt.ylim(0, 13)
plt.show()
plt.pause(0.5)
print(w.item(), b.item())
# print(w.data[0][0], b.data[0][0]) # 和上面等价
运行结果:

观察loss的变化:
plt.plot(losses)
plt.ylim(50, 500)
输出结果:

loss是在稳步变小。
三、手动计算梯度的公式
选自《深度学习》(花书):

记住上述公式 G B T GB^T GBT或者 A T G A^TG ATG即可,根据这个公式,根据这个公式,loss对w的梯度为 x T d y _ p r e d x^Tdy\_pred xTdy_pred.
四、关于输出为“nan nan”的情况
print(w.item(), b.item()),如果最后w和b的值输出都为nan,那么调小学习率就行了。我将学习率定为0.001都会遇到这个情况,定为0.0001就好了。
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://blog.csdn.net/umbrellalalalala/article/details/119945805
边栏推荐
- 容器
- Getting started with JDBC \ getting a database connection \ using Preparedstatement
- The attendance client date of K / 3 wise system can only be selected to 2019
- Create enterprise mailbox account command
- JVM系列(4)——内存溢出(OOM)
- 创建二叉树
- 关于二叉树的遍历
- CONDA virtual environment management (create, delete, clone, rename, export and import)
- 线程的底部实现原理—静态代理模式
- interviewter:介绍一下MySQL日期函数
猜你喜欢

手动删除eureka上已经注册的服务
![去噪论文——[Noise2Void,CVPR19]Noise2Void-Learning Denoising from Single Noisy Images](/img/9d/487c77b5d25d3e37fb629164c804e2.png)
去噪论文——[Noise2Void,CVPR19]Noise2Void-Learning Denoising from Single Noisy Images

Pytorch学习记录(十二):学习率衰减+正则化

Opensips (1) -- detailed process of installing opensips
![去噪论文阅读——[CVPR2022]Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots](/img/fd/84df62c88fe90a73294886642036a0.png)
去噪论文阅读——[CVPR2022]Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots

JVM系列(4)——内存溢出(OOM)

Pytorch learning record (XII): learning rate attenuation + regularization

框架解析1.系统架构简介

The user name and password of users in the domain accessing the samba server outside the domain are wrong

Latex快速入门
随机推荐
数据处理之Numpy常用函数表格整理
Pyemd installation and simple use
JSP语法及JSTL标签
Multithreading and high concurrency (3) -- synchronized principle
Treatment of tensorflow sequelae - simple example record torch utils. data. dataset. Picture dimension problem when rewriting dataset
数字图像处理基础(冈萨雷斯)一
Pyqy5 learning (III): qlineedit + qtextedit
JVM系列(4)——内存溢出(OOM)
建表到页面完整实例演示—联表查询
io.lettuce.core.RedisCommandExecutionException: ERR wrong number of arguments for ‘auth‘ command
Understand the current commonly used encryption technology system (symmetric, asymmetric, information abstract, digital signature, digital certificate, public key system)
解决报错:ImportError: IProgress not found. Please update jupyter and ipywidgets
Pytorch学习记录(九):Pytorch中卷积神经网络
框架解析1.系统架构简介
创建二叉树
数字图像处理基础(冈萨雷斯)二:灰度变换与空间滤波
实操—Nacos安装与配置
2 - principes de conception de logiciels
Fundamentals of digital image processing (Gonzalez) I
filebrowser实现私有网盘