当前位置:网站首页>pytorch-05.用pytorch实现线性回归
pytorch-05.用pytorch实现线性回归
2022-08-10 05:32:00 【生信研究猿】
import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel,self).__init__()
self.linear = torch.nn.Linear(1,1) #输入维度为1 输出维度为1
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=True) #损失函数
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) #优化器
'''
求解过程:
求y^
loss
梯度清零
backward
更新
'''
for epoch in range(100):
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print('epoch=',epoch," loss=",loss)
optimizer.zero_grad() #梯度归0
loss.backward() #反向传播
optimizer.step() #更新
# Output weight and bias
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())
# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
#两种输出格式
print('y_pred=',y_test.data)
print('y_pred=',y_test.data.item())
结果:
epoch= 0 loss= tensor(34.0210, grad_fn=<MseLossBackward0>)
epoch= 1 loss= tensor(26.9701, grad_fn=<MseLossBackward0>)
epoch= 2 loss= tensor(21.3965, grad_fn=<MseLossBackward0>)
epoch= 3 loss= tensor(16.9904, grad_fn=<MseLossBackward0>)
epoch= 4 loss= tensor(13.5073, grad_fn=<MseLossBackward0>)
epoch= 5 loss= tensor(10.7537, grad_fn=<MseLossBackward0>)
epoch= 6 loss= tensor(8.5767, grad_fn=<MseLossBackward0>)
epoch= 7 loss= tensor(6.8556, grad_fn=<MseLossBackward0>)
epoch= 8 loss= tensor(5.4948, grad_fn=<MseLossBackward0>)
epoch= 9 loss= tensor(4.4188, grad_fn=<MseLossBackward0>)
epoch= 10 loss= tensor(3.5679, grad_fn=<MseLossBackward0>)
epoch= 11 loss= tensor(2.8949, grad_fn=<MseLossBackward0>)
epoch= 12 loss= tensor(2.3626, grad_fn=<MseLossBackward0>)
epoch= 13 loss= tensor(1.9415, grad_fn=<MseLossBackward0>)
epoch= 14 loss= tensor(1.6083, grad_fn=<MseLossBackward0>)
epoch= 15 loss= tensor(1.3446, grad_fn=<MseLossBackward0>)
epoch= 16 loss= tensor(1.1357, grad_fn=<MseLossBackward0>)
epoch= 17 loss= tensor(0.9703, grad_fn=<MseLossBackward0>)
epoch= 18 loss= tensor(0.8392, grad_fn=<MseLossBackward0>)
epoch= 19 loss= tensor(0.7352, grad_fn=<MseLossBackward0>)
epoch= 20 loss= tensor(0.6527, grad_fn=<MseLossBackward0>)
epoch= 21 loss= tensor(0.5871, grad_fn=<MseLossBackward0>)
epoch= 22 loss= tensor(0.5350, grad_fn=<MseLossBackward0>)
epoch= 23 loss= tensor(0.4934, grad_fn=<MseLossBackward0>)
epoch= 24 loss= tensor(0.4602, grad_fn=<MseLossBackward0>)
epoch= 25 loss= tensor(0.4336, grad_fn=<MseLossBackward0>)
epoch= 26 loss= tensor(0.4122, grad_fn=<MseLossBackward0>)
epoch= 27 loss= tensor(0.3950, grad_fn=<MseLossBackward0>)
epoch= 28 loss= tensor(0.3811, grad_fn=<MseLossBackward0>)
epoch= 29 loss= tensor(0.3697, grad_fn=<MseLossBackward0>)
epoch= 30 loss= tensor(0.3604, grad_fn=<MseLossBackward0>)
epoch= 31 loss= tensor(0.3527, grad_fn=<MseLossBackward0>)
epoch= 32 loss= tensor(0.3464, grad_fn=<MseLossBackward0>)
epoch= 33 loss= tensor(0.3410, grad_fn=<MseLossBackward0>)
epoch= 34 loss= tensor(0.3364, grad_fn=<MseLossBackward0>)
epoch= 35 loss= tensor(0.3325, grad_fn=<MseLossBackward0>)
epoch= 36 loss= tensor(0.3290, grad_fn=<MseLossBackward0>)
epoch= 37 loss= tensor(0.3260, grad_fn=<MseLossBackward0>)
epoch= 38 loss= tensor(0.3233, grad_fn=<MseLossBackward0>)
epoch= 39 loss= tensor(0.3208, grad_fn=<MseLossBackward0>)
epoch= 40 loss= tensor(0.3186, grad_fn=<MseLossBackward0>)
epoch= 41 loss= tensor(0.3165, grad_fn=<MseLossBackward0>)
epoch= 42 loss= tensor(0.3145, grad_fn=<MseLossBackward0>)
epoch= 43 loss= tensor(0.3126, grad_fn=<MseLossBackward0>)
epoch= 44 loss= tensor(0.3109, grad_fn=<MseLossBackward0>)
epoch= 45 loss= tensor(0.3092, grad_fn=<MseLossBackward0>)
epoch= 46 loss= tensor(0.3075, grad_fn=<MseLossBackward0>)
epoch= 47 loss= tensor(0.3059, grad_fn=<MseLossBackward0>)
epoch= 48 loss= tensor(0.3043, grad_fn=<MseLossBackward0>)
epoch= 49 loss= tensor(0.3028, grad_fn=<MseLossBackward0>)
epoch= 50 loss= tensor(0.3012, grad_fn=<MseLossBackward0>)
epoch= 51 loss= tensor(0.2997, grad_fn=<MseLossBackward0>)
epoch= 52 loss= tensor(0.2983, grad_fn=<MseLossBackward0>)
epoch= 53 loss= tensor(0.2968, grad_fn=<MseLossBackward0>)
epoch= 54 loss= tensor(0.2953, grad_fn=<MseLossBackward0>)
epoch= 55 loss= tensor(0.2939, grad_fn=<MseLossBackward0>)
epoch= 56 loss= tensor(0.2925, grad_fn=<MseLossBackward0>)
epoch= 57 loss= tensor(0.2910, grad_fn=<MseLossBackward0>)
epoch= 58 loss= tensor(0.2896, grad_fn=<MseLossBackward0>)
epoch= 59 loss= tensor(0.2882, grad_fn=<MseLossBackward0>)
epoch= 60 loss= tensor(0.2869, grad_fn=<MseLossBackward0>)
epoch= 61 loss= tensor(0.2855, grad_fn=<MseLossBackward0>)
epoch= 62 loss= tensor(0.2841, grad_fn=<MseLossBackward0>)
epoch= 63 loss= tensor(0.2827, grad_fn=<MseLossBackward0>)
epoch= 64 loss= tensor(0.2814, grad_fn=<MseLossBackward0>)
epoch= 65 loss= tensor(0.2800, grad_fn=<MseLossBackward0>)
epoch= 66 loss= tensor(0.2787, grad_fn=<MseLossBackward0>)
epoch= 67 loss= tensor(0.2773, grad_fn=<MseLossBackward0>)
epoch= 68 loss= tensor(0.2760, grad_fn=<MseLossBackward0>)
epoch= 69 loss= tensor(0.2747, grad_fn=<MseLossBackward0>)
epoch= 70 loss= tensor(0.2733, grad_fn=<MseLossBackward0>)
epoch= 71 loss= tensor(0.2720, grad_fn=<MseLossBackward0>)
epoch= 72 loss= tensor(0.2707, grad_fn=<MseLossBackward0>)
epoch= 73 loss= tensor(0.2694, grad_fn=<MseLossBackward0>)
epoch= 74 loss= tensor(0.2681, grad_fn=<MseLossBackward0>)
epoch= 75 loss= tensor(0.2668, grad_fn=<MseLossBackward0>)
epoch= 76 loss= tensor(0.2656, grad_fn=<MseLossBackward0>)
epoch= 77 loss= tensor(0.2643, grad_fn=<MseLossBackward0>)
epoch= 78 loss= tensor(0.2630, grad_fn=<MseLossBackward0>)
epoch= 79 loss= tensor(0.2618, grad_fn=<MseLossBackward0>)
epoch= 80 loss= tensor(0.2605, grad_fn=<MseLossBackward0>)
epoch= 81 loss= tensor(0.2592, grad_fn=<MseLossBackward0>)
epoch= 82 loss= tensor(0.2580, grad_fn=<MseLossBackward0>)
epoch= 83 loss= tensor(0.2568, grad_fn=<MseLossBackward0>)
epoch= 84 loss= tensor(0.2555, grad_fn=<MseLossBackward0>)
epoch= 85 loss= tensor(0.2543, grad_fn=<MseLossBackward0>)
epoch= 86 loss= tensor(0.2531, grad_fn=<MseLossBackward0>)
epoch= 87 loss= tensor(0.2519, grad_fn=<MseLossBackward0>)
epoch= 88 loss= tensor(0.2507, grad_fn=<MseLossBackward0>)
epoch= 89 loss= tensor(0.2495, grad_fn=<MseLossBackward0>)
epoch= 90 loss= tensor(0.2483, grad_fn=<MseLossBackward0>)
epoch= 91 loss= tensor(0.2471, grad_fn=<MseLossBackward0>)
epoch= 92 loss= tensor(0.2459, grad_fn=<MseLossBackward0>)
epoch= 93 loss= tensor(0.2447, grad_fn=<MseLossBackward0>)
epoch= 94 loss= tensor(0.2435, grad_fn=<MseLossBackward0>)
epoch= 95 loss= tensor(0.2424, grad_fn=<MseLossBackward0>)
epoch= 96 loss= tensor(0.2412, grad_fn=<MseLossBackward0>)
epoch= 97 loss= tensor(0.2400, grad_fn=<MseLossBackward0>)
epoch= 98 loss= tensor(0.2389, grad_fn=<MseLossBackward0>)
epoch= 99 loss= tensor(0.2377, grad_fn=<MseLossBackward0>)
w= 1.4350532293319702
b= 1.2842093706130981
y_pred= tensor([[7.0244]])
y_pred= 7.0244221687316895
Process finished with exit code 0
边栏推荐
- 各个架构指令集对应的机型
- 符号表
- 先人一步,不再错过,链读APP即将上线!
- opencv
- 微信小程序--模板与设置WXML
- A little knowledge point every day
- A timeout error is reported when connecting to Nacos
- sqlplus 显示上一条命令及可用退格键
- Decentralized and p2p networks and traditional communications with centralization at the core
- [Notes] Collection Framework System Collection
猜你喜欢
随机推荐
Using sqlplus to operate database in shell script
21天挑战杯MySQL-Day05
Reprint fstream, detailed usage of ifstream
大端以及小端以及读寄存器习惯
LeetCode 剑指offer 10-I.斐波那契数列(简单)
Day1 微信小程序-小程序代码的构成
国内数字藏品投资价值分析
菜谱小程序源码免费分享【推荐】
力扣——统计只差一个字符的子串数目
pytorch-07.处理多维特征的输入
Linux database Oracle client installation, used for shell scripts to connect to the database with sqlplus
Chain Reading|The latest and most complete digital collection sales calendar-08.02
网络安全之防火墙
generic notes()()()
IO流【】【】【】
链读|最新最全的数字藏品发售日历-07.29
Batch add watermark to pictures batch add background zoom batch merge tool picUnionV4.0
Common class BigDecimal
Canal reports Could not find first log file name in binary log index file
连接 Nacos 报超时错误









