当前位置:网站首页>用Pytorch从0到1实现逻辑回归
用Pytorch从0到1实现逻辑回归
2022-08-10 05:30:00 【公众号学一点会一点】
在机器学习/深度学习中,二分类也是经常遇到的任务,逻辑回归就是二分类中常用的模型。
本文简单回顾下逻辑回归,并且用Pytorch实现。
逻辑回归Logistic Regression
逻辑回归是机器学习中常用的一种二分类算法,常用于疾病预测等“非黑即白”的分类,简单说就是在使用逻辑回归的任务中,标签数据的Y值要么是0要么是1。
Sigmoid函数
逻辑回归,不管怎么着,还是一个回归,而我们是用它来进行分类的。回归一般的得到的是一个连续值,二分类需要的是0或者1(类别),那么怎么建立起连续值到类别的映射关系呢?这时候Sigmoid函数就发挥作用了。 Sigmoid函数的公式如下:
其示意图为:

可以看出其值域为0到1之间,以0为分界点:
当x小于0的时候,f(x)值小于0.5; 当x大于0的时候,f(x)值大于0.5;
正是由于上面的特性,sigmoid函数可以被用来进行二分类,比如我们以0.5为界,将大于0.5的值归为类别1,小于0.5的归为类别0。
逻辑回归
有了Sigmoid函数,才有了逻辑回归。 逻辑回归逻辑回归又叫做对数几率回归,也就是说我们是对概率进行建模,而不同于线性回归对Y直接进行建模。
线性回归:
而逻辑回归:
对上面的公式进行转换,可以得到下面的公式:
也就是逻辑回归常见的公式。这样子我们就完成了从输入数据的线性组合到概率的映射,然后根据一定的阈值将概率映射到类别,就完成了分类的过程。
Pytorch实现逻辑回归
本系列上篇文章讲了用Pytorch实现一个模型应用的过程主要包括:
数据 模型 损失函数 优化器 迭代训练
上代码。
导入包:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim
构造数据:
n = 1000
mean_value = 2.5 # mean value of the distribution
bias = 1.3
n_data = torch.ones(n, 2)
x0 = torch.normal(mean_value*n_data,1.0) + bias
y0 = torch.zeros(n)
x1 = torch.normal(-mean_value*n_data, 1.0) + bias
y1 = torch.ones(n)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)
模型:
class LogReg(nn.Module):
def __init__(self):
super(LogReg, self).__init__()
self.features1 = nn.Linear(2,5)
self.features2 = nn.Linear(5,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features1(x)
x = self.features2(x)
x = self.sigmoid(x)
return x
model = LogReg()
损失函数和优化器:
criterion = nn.BCELoss()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
迭代训练:
for epoch in range(1000):
y_pred = model(train_x)
loss = criterion(y_pred.squeeze(), train_y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
# 计算准确率
mask = y_pred.ge(0.5).float().squeeze()
correct = (mask == train_y).sum()
acc = correct.item()/train_y.shape[0]
print('epoch: ', epoch, 'loss: ', loss.item(), 'acc: ', acc)
训练结果:

可以看到第10个epoch的时候准确率就差不多98%了。

参考
【1】很多机器学习可以用到的数据集:http://archive.ics.uci.edu/ml/index.php
【2】Latex怎么打分式:https://zhuanlan.zhihu.com/p/262715401
【3】二分类交叉熵损失:https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
【4】https://www.jianshu.com/p/63e255e3232f
本文由 mdnice 多平台发布
边栏推荐
猜你喜欢
随机推荐
pytorch框架学习(3)torch.nn.functional模块和nn.Module模块
Pony语言学习(一):环境配置(续)
kaggle小白必看:小白常见的2个错误解决方案
ThreadPoolExecutor thread pool principle
scikit-learn机器学习 读书笔记(二)
基于Servlet的验证码登陆demo
Big guys, mysql cdc (2.2.1 and previous versions) sometimes has this situation since savepoint, is there anything wrong?
How to improve product quality from the code layer
实战小技巧19:List转Map List的几种姿势
OAuth2的使用场景、常见误区、使用案例
Qiskit官方文档选译之量子傅里叶变换(Quantum Fourier Transform, QFT)
基于Qiskit——《量子计算编程实战》读书笔记(六)
pytorch框架学习(9)torchvision.transform
细数国产接口协作平台的六把武器!
基本比例尺标准分幅编号流程
Depth of carding: prevent model fitting method
FPGA engineer interview questions collection 41~50
Buu Web
Why are negative numbers in binary represented in two's complement form - binary addition and subtraction
summer preschool assignments