当前位置:网站首页>Pytorch notes - complete code for linear regression & manual or automatic calculation of gradient code comparison
Pytorch notes - complete code for linear regression & manual or automatic calculation of gradient code comparison
2022-04-23 05:59:00 【umbrellalalalala】
Reference material :《 Deep learning framework PyTorch: Introduction and practice 》
This paper annotates and interprets the code of linear regression in this book , The formula for solving the gradient in the process of manual back propagation is supplemented .
Catalog
One 、 Generate data set complete code
use “ Fake data ”:
# Set random number seed , Ensure that the following output is consistent when running on different computers
t.manual_seed(1000)
def get_fake_data(batch_size=8):
''' Generate random data :y=x*2+3, Add some noise '''
x = t.rand(batch_size, 1) * 20
y = x * 2 + (1 + t.randn(batch_size, 1) * 3)
return x, y
* Two 、 Linear regression complete code
The code for automatically calculating the gradient is in the comment :
# If you fill in parentheses 1, False report :mat2 must be a matrix
w = t.rand(1, 1)
b = t.zeros(1, 1)
# # If the gradient is calculated automatically
# # Be careful requires_grad The default is False, Not set to True Will be in loss.backward() Report errors
# w = t.rand(1, 1, requires_grad=True)
# b = t.zeros(1, 1, requires_grad=True)
lr = 0.0001
losses = np.zeros(500)
for ii in range(500):
x, y = get_fake_data(batch_size=32)
# Forward propagation , Calculation loss, Using the mean square error
# torch.mul It's multiplication by elements ;torch.mm It's matrix multiplication
y_pred = t.mm(x, w) + b.expand_as(y)
loss = 0.5 * (y_pred - y) ** 2
loss = loss.sum()
losses[ii] = loss.item()
# Back propagation , Manual gradient calculation
dloss = 1
dy_pred = dloss * (y_pred - y)
dw = t.mm(x.t(), dy_pred)
db = dy_pred.sum() # Be careful b It's scalar , When used, it is extended to all elements b Vector
# Update parameters
w.sub_(lr * dw)
b.sub_(lr * db)
# # If the gradient is calculated automatically
# loss.backward()
# w.data.sub_(lr * w.grad.data)
# b.data.sub_(lr * b.grad.data)
# # Note that the gradient is cleared
# w.grad.data.zero_()
# b.grad.data.zero_()
# Every time 1000 Draw a picture for each training
if ii % 50 == 0:
display.clear_output(wait=True)
# predicted
x = t.arange(0, 20).view(-1, 1).float()
y = t.mm(x, w) + b.expand_as(x)
plt.plot(x.numpy(), y.numpy())
# true data
x2, y2 = get_fake_data(batch_size=20)
plt.scatter(x2.numpy(), y2.numpy())
plt.xlim(0, 5)
plt.ylim(0, 13)
plt.show()
plt.pause(0.5)
print(w.item(), b.item())
# print(w.data[0][0], b.data[0][0]) # It's equivalent to the above
Running results :

Observe loss The change of :
plt.plot(losses)
plt.ylim(50, 500)
Output results :

loss Is steadily getting smaller .
3、 ... and 、 Formula for manually calculating gradient
selected from 《 Deep learning 》( Flower Book ):

Remember the above formula G B T GB^T GBT perhaps A T G A^TG ATG that will do , According to this formula , According to this formula ,loss Yes w The gradient of x T d y _ p r e d x^Tdy\_pred xTdy_pred.
Four 、 About output as “nan nan” The situation of
print(w.item(), b.item()), If the last w and b The value output of is nan, Then just turn down the learning rate . I set the learning rate as 0.001 It's all going to happen , As the 0.0001 Just fine .
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230543474468.html
边栏推荐
- Rsync for file server backup
- MySql基础狂神说
- delete和truncate
- Postfix变成垃圾邮件中转站后的补救
- Pytorch Learning record (XIII): Recurrent Neural Network
- Get the value of state in effects in DVA
- Filebrowser realizes private network disk
- JVM family (4) -- memory overflow (OOM)
- PyEMD安装及简单使用
- Pytorch learning record (XI): data enhancement, torchvision Explanation of various functions of transforms
猜你喜欢
随机推荐
域内用户访问域外samba服务器用户名密码错误
Pytorch学习记录(十三):循环神经网络((Recurrent Neural Network)
框架解析2.源码-登录认证
PyQy5学习(四):QAbstractButton+QRadioButton+QCheckBox
The attendance client date of K / 3 wise system can only be selected to 2019
字符串(String)笔记
Implementation of displaying database pictures to browser tables based on thymeleaf
自定义异常类
ValueError: loaded state dict contains a parameter group that doesn‘t match the size of optimizer‘s
PreparedStatement防止SQL注入
ValueError: loaded state dict contains a parameter group that doesn‘t match the size of optimizer‘s
Development environment EAS login license modification
图解HashCode存在的意义
类的加载与ClassLoader的理解
Understand the current commonly used encryption technology system (symmetric, asymmetric, information abstract, digital signature, digital certificate, public key system)
Package mall system based on SSM
Latex快速入门
创建线程的三种方式
深入理解去噪论文——FFDNet和CBDNet中noise level与噪声方差之间的关系探索
Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)









