当前位置:网站首页>pytorch-05. Implementing linear regression with pytorch
pytorch-05. Implementing linear regression with pytorch
2022-08-10 05:56:00 【Shengxin Research Ape】
import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel,self).__init__()
self.linear = torch.nn.Linear(1,1) #输入维度为1 输出维度为1
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=True) #损失函数
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #优化器
'''
求解过程:
求y^
loss
梯度清零
backward
更新
'''
for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print('epoch=',epoch," loss=",loss)
optimizer.zero_grad() #梯度归0
loss.backward() #反向传播
optimizer.step() #更新
# Output weight and bias
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())
# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
#两种输出格式
print('y_pred=',y_test.data)
print('y_pred=',y_test.data.item())
结果:
epoch= 0 loss= tensor(34.0210, grad_fn=<MseLossBackward0>)
epoch= 1 loss= tensor(26.9701, grad_fn=<MseLossBackward0>)
epoch= 2 loss= tensor(21.3965, grad_fn=<MseLossBackward0>)
epoch= 3 loss= tensor(16.9904, grad_fn=<MseLossBackward0>)
epoch= 4 loss= tensor(13.5073, grad_fn=<MseLossBackward0>)
epoch= 5 loss= tensor(10.7537, grad_fn=<MseLossBackward0>)
epoch= 6 loss= tensor(8.5767, grad_fn=<MseLossBackward0>)
epoch= 7 loss= tensor(6.8556, grad_fn=<MseLossBackward0>)
epoch= 8 loss= tensor(5.4948, grad_fn=<MseLossBackward0>)
epoch= 9 loss= tensor(4.4188, grad_fn=<MseLossBackward0>)
epoch= 10 loss= tensor(3.5679, grad_fn=<MseLossBackward0>)
epoch= 11 loss= tensor(2.8949, grad_fn=<MseLossBackward0>)
epoch= 12 loss= tensor(2.3626, grad_fn=<MseLossBackward0>)
epoch= 13 loss= tensor(1.9415, grad_fn=<MseLossBackward0>)
epoch= 14 loss= tensor(1.6083, grad_fn=<MseLossBackward0>)
epoch= 15 loss= tensor(1.3446, grad_fn=<MseLossBackward0>)
epoch= 16 loss= tensor(1.1357, grad_fn=<MseLossBackward0>)
epoch= 17 loss= tensor(0.9703, grad_fn=<MseLossBackward0>)
epoch= 18 loss= tensor(0.8392, grad_fn=<MseLossBackward0>)
epoch= 19 loss= tensor(0.7352, grad_fn=<MseLossBackward0>)
epoch= 20 loss= tensor(0.6527, grad_fn=<MseLossBackward0>)
epoch= 21 loss= tensor(0.5871, grad_fn=<MseLossBackward0>)
epoch= 22 loss= tensor(0.5350, grad_fn=<MseLossBackward0>)
epoch= 23 loss= tensor(0.4934, grad_fn=<MseLossBackward0>)
epoch= 24 loss= tensor(0.4602, grad_fn=<MseLossBackward0>)
epoch= 25 loss= tensor(0.4336, grad_fn=<MseLossBackward0>)
epoch= 26 loss= tensor(0.4122, grad_fn=<MseLossBackward0>)
epoch= 27 loss= tensor(0.3950, grad_fn=<MseLossBackward0>)
epoch= 28 loss= tensor(0.3811, grad_fn=<MseLossBackward0>)
epoch= 29 loss= tensor(0.3697, grad_fn=<MseLossBackward0>)
epoch= 30 loss= tensor(0.3604, grad_fn=<MseLossBackward0>)
epoch= 31 loss= tensor(0.3527, grad_fn=<MseLossBackward0>)
epoch= 32 loss= tensor(0.3464, grad_fn=<MseLossBackward0>)
epoch= 33 loss= tensor(0.3410, grad_fn=<MseLossBackward0>)
epoch= 34 loss= tensor(0.3364, grad_fn=<MseLossBackward0>)
epoch= 35 loss= tensor(0.3325, grad_fn=<MseLossBackward0>)
epoch= 36 loss= tensor(0.3290, grad_fn=<MseLossBackward0>)
epoch= 37 loss= tensor(0.3260, grad_fn=<MseLossBackward0>)
epoch= 38 loss= tensor(0.3233, grad_fn=<MseLossBackward0>)
epoch= 39 loss= tensor(0.3208, grad_fn=<MseLossBackward0>)
epoch= 40 loss= tensor(0.3186, grad_fn=<MseLossBackward0>)
epoch= 41 loss= tensor(0.3165, grad_fn=<MseLossBackward0>)
epoch= 42 loss= tensor(0.3145, grad_fn=<MseLossBackward0>)
epoch= 43 loss= tensor(0.3126, grad_fn=<MseLossBackward0>)
epoch= 44 loss= tensor(0.3109, grad_fn=<MseLossBackward0>)
epoch= 45 loss= tensor(0.3092, grad_fn=<MseLossBackward0>)
epoch= 46 loss= tensor(0.3075, grad_fn=<MseLossBackward0>)
epoch= 47 loss= tensor(0.3059, grad_fn=<MseLossBackward0>)
epoch= 48 loss= tensor(0.3043, grad_fn=<MseLossBackward0>)
epoch= 49 loss= tensor(0.3028, grad_fn=<MseLossBackward0>)
epoch= 50 loss= tensor(0.3012, grad_fn=<MseLossBackward0>)
epoch= 51 loss= tensor(0.2997, grad_fn=<MseLossBackward0>)
epoch= 52 loss= tensor(0.2983, grad_fn=<MseLossBackward0>)
epoch= 53 loss= tensor(0.2968, grad_fn=<MseLossBackward0>)
epoch= 54 loss= tensor(0.2953, grad_fn=<MseLossBackward0>)
epoch= 55 loss= tensor(0.2939, grad_fn=<MseLossBackward0>)
epoch= 56 loss= tensor(0.2925, grad_fn=<MseLossBackward0>)
epoch= 57 loss= tensor(0.2910, grad_fn=<MseLossBackward0>)
epoch= 58 loss= tensor(0.2896, grad_fn=<MseLossBackward0>)
epoch= 59 loss= tensor(0.2882, grad_fn=<MseLossBackward0>)
epoch= 60 loss= tensor(0.2869, grad_fn=<MseLossBackward0>)
epoch= 61 loss= tensor(0.2855, grad_fn=<MseLossBackward0>)
epoch= 62 loss= tensor(0.2841, grad_fn=<MseLossBackward0>)
epoch= 63 loss= tensor(0.2827, grad_fn=<MseLossBackward0>)
epoch= 64 loss= tensor(0.2814, grad_fn=<MseLossBackward0>)
epoch= 65 loss= tensor(0.2800, grad_fn=<MseLossBackward0>)
epoch= 66 loss= tensor(0.2787, grad_fn=<MseLossBackward0>)
epoch= 67 loss= tensor(0.2773, grad_fn=<MseLossBackward0>)
epoch= 68 loss= tensor(0.2760, grad_fn=<MseLossBackward0>)
epoch= 69 loss= tensor(0.2747, grad_fn=<MseLossBackward0>)
epoch= 70 loss= tensor(0.2733, grad_fn=<MseLossBackward0>)
epoch= 71 loss= tensor(0.2720, grad_fn=<MseLossBackward0>)
epoch= 72 loss= tensor(0.2707, grad_fn=<MseLossBackward0>)
epoch= 73 loss= tensor(0.2694, grad_fn=<MseLossBackward0>)
epoch= 74 loss= tensor(0.2681, grad_fn=<MseLossBackward0>)
epoch= 75 loss= tensor(0.2668, grad_fn=<MseLossBackward0>)
epoch= 76 loss= tensor(0.2656, grad_fn=<MseLossBackward0>)
epoch= 77 loss= tensor(0.2643, grad_fn=<MseLossBackward0>)
epoch= 78 loss= tensor(0.2630, grad_fn=<MseLossBackward0>)
epoch= 79 loss= tensor(0.2618, grad_fn=<MseLossBackward0>)
epoch= 80 loss= tensor(0.2605, grad_fn=<MseLossBackward0>)
epoch= 81 loss= tensor(0.2592, grad_fn=<MseLossBackward0>)
epoch= 82 loss= tensor(0.2580, grad_fn=<MseLossBackward0>)
epoch= 83 loss= tensor(0.2568, grad_fn=<MseLossBackward0>)
epoch= 84 loss= tensor(0.2555, grad_fn=<MseLossBackward0>)
epoch= 85 loss= tensor(0.2543, grad_fn=<MseLossBackward0>)
epoch= 86 loss= tensor(0.2531, grad_fn=<MseLossBackward0>)
epoch= 87 loss= tensor(0.2519, grad_fn=<MseLossBackward0>)
epoch= 88 loss= tensor(0.2507, grad_fn=<MseLossBackward0>)
epoch= 89 loss= tensor(0.2495, grad_fn=<MseLossBackward0>)
epoch= 90 loss= tensor(0.2483, grad_fn=<MseLossBackward0>)
epoch= 91 loss= tensor(0.2471, grad_fn=<MseLossBackward0>)
epoch= 92 loss= tensor(0.2459, grad_fn=<MseLossBackward0>)
epoch= 93 loss= tensor(0.2447, grad_fn=<MseLossBackward0>)
epoch= 94 loss= tensor(0.2435, grad_fn=<MseLossBackward0>)
epoch= 95 loss= tensor(0.2424, grad_fn=<MseLossBackward0>)
epoch= 96 loss= tensor(0.2412, grad_fn=<MseLossBackward0>)
epoch= 97 loss= tensor(0.2400, grad_fn=<MseLossBackward0>)
epoch= 98 loss= tensor(0.2389, grad_fn=<MseLossBackward0>)
epoch= 99 loss= tensor(0.2377, grad_fn=<MseLossBackward0>)
w= 1.4350532293319702
b= 1.2842093706130981
y_pred= tensor([[7.0244]])
y_pred= 7.0244221687316895
Process finished with exit code 0
边栏推荐
- [Notes] Collection Framework System Collection
- 测一测异性的你长什么样?
- 微信小程序--模板与设置WXML
- Timer (setInterval) on and off
- LeetCode 剑指offer 10-I.斐波那契数列(简单)
- The submenu of the el-cascader cascade selector is double-clicked to display the selected content
- LeetCode 94. Inorder Traversal of Binary Trees (Simple)
- Chain Reading|The latest and most complete digital collection sales calendar-07.29
- cesium add point, move point
- 数据库 笔记 创建数据库、表 备份
猜你喜欢
Notes for SVM
一个基于.Net Core 开源的物联网基础平台
开源免费WMS仓库管理系统【推荐】
LeetCode 2011.执行操作后的变量值(简单)
Ten years of sharpening a sword!The digital collection market software, Link Reading APP is officially open for internal testing!
【List练习】遍历集合并且按照价格从低到高排序,
Timer (setInterval) on and off
.NET操作Excel高效低内存的开源框架 - MiniExcel
LeetCode 剑指offer 21.调整数组顺序使奇数位于偶数前面(简单)
ORACLE system table space SYSTEM is full and cannot expand table space problem solving process
随机推荐
Link reading good article: What is the difference between hot encrypted storage and cold encrypted storage?
win12 modify dns script
菜谱小程序源码免费分享【推荐】
集合 set接口
pytorch-09. Multi-classification problem
opencv
LeetCode 剑指offer 21.调整数组顺序使奇数位于偶数前面(简单)
[Difference between el and template]
反射【笔记】
力扣——省份数量
shell脚本中利用sqlplus操作数据库
符号表
LeetCode 292.Nim 游戏(简单)
多表查询 笔记
离散数学的学习记录
Knowledge Distillation Thesis Learning
微信小程序--模板与设置WXML
Timer (setInterval) on and off
21天挑战杯MySQL-Day05
LeetCode 1894.找到需要补充粉笔的学生编号