当前位置:网站首页>用Pytorch从0到1实现逻辑回归
用Pytorch从0到1实现逻辑回归
2022-08-10 05:30:00 【公众号学一点会一点】
在机器学习/深度学习中,二分类也是经常遇到的任务,逻辑回归就是二分类中常用的模型。
本文简单回顾下逻辑回归,并且用Pytorch实现。
逻辑回归Logistic Regression
逻辑回归是机器学习中常用的一种二分类算法,常用于疾病预测等“非黑即白”的分类,简单说就是在使用逻辑回归的任务中,标签数据的Y值要么是0要么是1。
Sigmoid函数
逻辑回归,不管怎么着,还是一个回归,而我们是用它来进行分类的。回归一般的得到的是一个连续值,二分类需要的是0或者1(类别),那么怎么建立起连续值到类别的映射关系呢?这时候Sigmoid函数就发挥作用了。 Sigmoid函数的公式如下:
其示意图为:

可以看出其值域为0到1之间,以0为分界点:
当x小于0的时候,f(x)值小于0.5; 当x大于0的时候,f(x)值大于0.5;
正是由于上面的特性,sigmoid函数可以被用来进行二分类,比如我们以0.5为界,将大于0.5的值归为类别1,小于0.5的归为类别0。
逻辑回归
有了Sigmoid函数,才有了逻辑回归。 逻辑回归逻辑回归又叫做对数几率回归,也就是说我们是对概率进行建模,而不同于线性回归对Y直接进行建模。
线性回归:
而逻辑回归:
对上面的公式进行转换,可以得到下面的公式:
也就是逻辑回归常见的公式。这样子我们就完成了从输入数据的线性组合到概率的映射,然后根据一定的阈值将概率映射到类别,就完成了分类的过程。
Pytorch实现逻辑回归
本系列上篇文章讲了用Pytorch实现一个模型应用的过程主要包括:
数据 模型 损失函数 优化器 迭代训练
上代码。
导入包:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim
构造数据:
n = 1000
mean_value = 2.5 # mean value of the distribution
bias = 1.3
n_data = torch.ones(n, 2)
x0 = torch.normal(mean_value*n_data,1.0) + bias
y0 = torch.zeros(n)
x1 = torch.normal(-mean_value*n_data, 1.0) + bias
y1 = torch.ones(n)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)
模型:
class LogReg(nn.Module):
def __init__(self):
super(LogReg, self).__init__()
self.features1 = nn.Linear(2,5)
self.features2 = nn.Linear(5,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features1(x)
x = self.features2(x)
x = self.sigmoid(x)
return x
model = LogReg()
损失函数和优化器:
criterion = nn.BCELoss()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
迭代训练:
for epoch in range(1000):
y_pred = model(train_x)
loss = criterion(y_pred.squeeze(), train_y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
# 计算准确率
mask = y_pred.ge(0.5).float().squeeze()
correct = (mask == train_y).sum()
acc = correct.item()/train_y.shape[0]
print('epoch: ', epoch, 'loss: ', loss.item(), 'acc: ', acc)
训练结果:

可以看到第10个epoch的时候准确率就差不多98%了。

参考
【1】很多机器学习可以用到的数据集:http://archive.ics.uci.edu/ml/index.php
【2】Latex怎么打分式:https://zhuanlan.zhihu.com/p/262715401
【3】二分类交叉熵损失:https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
【4】https://www.jianshu.com/p/63e255e3232f
本文由 mdnice 多平台发布
边栏推荐
- FPGA工程师面试试题集锦31~40
- FPGA engineer interview questions collection 41~50
- Jenkins 如何玩转接口自动化测试?
- Qiskit学习笔记(三)
- oracle rac 11g安装执行root.sh时报错
- pytorch框架学习(4)torchvision模块&训练一个简单的自己的CNN (一)
- Arduino框架下合宙ESP32C3 +1.8“tft 网络时钟
- Practical skills 19: Several postures of List to Map List
- awk of the Three Musketeers of Shell Programming
- 我用这一招让团队的开发效率提升了 100%!
猜你喜欢

awk of the Three Musketeers of Shell Programming

pytorch框架学习(3)torch.nn.functional模块和nn.Module模块

How to improve product quality from the code layer

Practical skills 19: Several postures of List to Map List

【静态代理】

Kubernetes:(十六)Ingress的概念和原理

自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略

【LeetCode】41. The first missing positive number

基于Servlet的验证码登陆demo

Zhongang Mining: Strong downstream demand for fluorite
随机推荐
SSM框架整合实例
summer preschool assignments
Linear Algebra (4)
CORS跨域资源共享漏洞的原理与挖掘方法
pytorch框架学习(1)网络的简单构建
Ask you guys.The FlinkCDC2.2.0 version in the CDC community has a description of the supported sqlserver version, please
YOLOv5 PyQt5(一起制作YOLOv5的GUI界面)
MySql之json_extract函数处理json字段
I have a dream for Career .
CSDN Markdown 之我见代码块 | CSDN编辑器测评
How to get the last day of a month
Pony语言学习(八):引用能力(Reference Capabilities)
基于Qiskit——《量子计算编程实战》读书笔记(三)
canvas canvas drawing clock
Transforming into a product, is it reliable to take the NPDP test?
Pony语言学习(六):Struct, Type Alias, Type Expressions
The sword refers to Offer 033. Variation array
How to improve product quality from the code layer
基于Qiskit——《量子计算编程实战》读书笔记(一)
Conda creates a virtual environment method and pqi uses a domestic mirror source to install a third-party library method tutorial