当前位置:网站首页>Pytorch notes - complete code for linear regression & manual or automatic calculation of gradient code comparison
Pytorch notes - complete code for linear regression & manual or automatic calculation of gradient code comparison
2022-04-23 05:59:00 【umbrellalalalala】
Reference material :《 Deep learning framework PyTorch: Introduction and practice 》
This paper annotates and interprets the code of linear regression in this book , The formula for solving the gradient in the process of manual back propagation is supplemented .
Catalog
One 、 Generate data set complete code
use “ Fake data ”:
# Set random number seed , Ensure that the following output is consistent when running on different computers
t.manual_seed(1000)
def get_fake_data(batch_size=8):
''' Generate random data :y=x*2+3, Add some noise '''
x = t.rand(batch_size, 1) * 20
y = x * 2 + (1 + t.randn(batch_size, 1) * 3)
return x, y
* Two 、 Linear regression complete code
The code for automatically calculating the gradient is in the comment :
# If you fill in parentheses 1, False report :mat2 must be a matrix
w = t.rand(1, 1)
b = t.zeros(1, 1)
# # If the gradient is calculated automatically
# # Be careful requires_grad The default is False, Not set to True Will be in loss.backward() Report errors
# w = t.rand(1, 1, requires_grad=True)
# b = t.zeros(1, 1, requires_grad=True)
lr = 0.0001
losses = np.zeros(500)
for ii in range(500):
x, y = get_fake_data(batch_size=32)
# Forward propagation , Calculation loss, Using the mean square error
# torch.mul It's multiplication by elements ;torch.mm It's matrix multiplication
y_pred = t.mm(x, w) + b.expand_as(y)
loss = 0.5 * (y_pred - y) ** 2
loss = loss.sum()
losses[ii] = loss.item()
# Back propagation , Manual gradient calculation
dloss = 1
dy_pred = dloss * (y_pred - y)
dw = t.mm(x.t(), dy_pred)
db = dy_pred.sum() # Be careful b It's scalar , When used, it is extended to all elements b Vector
# Update parameters
w.sub_(lr * dw)
b.sub_(lr * db)
# # If the gradient is calculated automatically
# loss.backward()
# w.data.sub_(lr * w.grad.data)
# b.data.sub_(lr * b.grad.data)
# # Note that the gradient is cleared
# w.grad.data.zero_()
# b.grad.data.zero_()
# Every time 1000 Draw a picture for each training
if ii % 50 == 0:
display.clear_output(wait=True)
# predicted
x = t.arange(0, 20).view(-1, 1).float()
y = t.mm(x, w) + b.expand_as(x)
plt.plot(x.numpy(), y.numpy())
# true data
x2, y2 = get_fake_data(batch_size=20)
plt.scatter(x2.numpy(), y2.numpy())
plt.xlim(0, 5)
plt.ylim(0, 13)
plt.show()
plt.pause(0.5)
print(w.item(), b.item())
# print(w.data[0][0], b.data[0][0]) # It's equivalent to the above
Running results :

Observe loss The change of :
plt.plot(losses)
plt.ylim(50, 500)
Output results :

loss Is steadily getting smaller .
3、 ... and 、 Formula for manually calculating gradient
selected from 《 Deep learning 》( Flower Book ):

Remember the above formula G B T GB^T GBT perhaps A T G A^TG ATG that will do , According to this formula , According to this formula ,loss Yes w The gradient of x T d y _ p r e d x^Tdy\_pred xTdy_pred.
Four 、 About output as “nan nan” The situation of
print(w.item(), b.item()), If the last w and b The value output of is nan, Then just turn down the learning rate . I set the learning rate as 0.001 It's all going to happen , As the 0.0001 Just fine .
版权声明
本文为[umbrellalalalala]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/04/202204230543474468.html
边栏推荐
- 去噪论文阅读——[RIDNet, ICCV19]Real Image Denoising with Feature Attention
- In depth source code analysis servlet first program
- Pytorch learning record (XII): learning rate attenuation + regularization
- Pytorch學習記錄(十三):循環神經網絡((Recurrent Neural Network)
- Pytorch learning record (7): skills in processing data and training models
- Understand the current commonly used encryption technology system (symmetric, asymmetric, information abstract, digital signature, digital certificate, public key system)
- Record a project experience and technologies encountered in the project
- Fundamentals of digital image processing (Gonzalez) I
- Viewer: introduce MySQL date function
- 域内用户访问域外samba服务器用户名密码错误
猜你喜欢

Latex quick start

Opensips (1) -- detailed process of installing opensips

Ptorch learning record (XIII): recurrent neural network

LDCT图像重建论文——Eformer: Edge Enhancement based Transformer for Medical Image Denoising

图解HashCode存在的意义

创建二叉树

Get the value of state in effects in DVA

Dva中在effects中获取state的值

线性代数第三章-矩阵的初等变换与线性方程组
![无监督去噪——[TMI2022]ISCL: Interdependent Self-Cooperative Learning for Unpaired Image Denoising](/img/cd/10793445e6867eeee613b6ba4b85cf.png)
无监督去噪——[TMI2022]ISCL: Interdependent Self-Cooperative Learning for Unpaired Image Denoising
随机推荐
Pyqy5 learning (4): qabstractbutton + qradiobutton + qcheckbox
PyTorch笔记——观察DataLoader&用torch构建LeNet处理CIFAR-10完整代码
JDBC连接数据库
图解HashCode存在的意义
Shansi Valley P290 polymorphism exercise
The official website of UMI yarn create @ umijs / UMI app reports an error: the syntax of file name, directory name or volume label is incorrect
PyQy5学习(二):QMainWindow+QWidget+QLabel
RedHat6之smb服务访问速度慢解决办法记录
Font shape `OMX/cmex/m/n‘ in size <10.53937> not available (Font) size <10.95> substituted.
Pytorch学习记录(九):Pytorch中卷积神经网络
Complete example demonstration of creating table to page - joint table query
对比学习论文——[MoCo,CVPR2020]Momentum Contrast for Unsupervised Visual Representation Learning
In depth source code analysis servlet first program
MySql基础狂神说
编写一个自己的 RedisTemplate
Fundamentals of digital image processing (Gonzalez) I
The attendance client date of K / 3 wise system can only be selected to 2019
JVM family (4) -- memory overflow (OOM)
Anaconda
MySQL realizes master-slave replication / master-slave synchronization