当前位置:网站首页>用Pytorch从0到1实现逻辑回归
用Pytorch从0到1实现逻辑回归
2022-08-10 05:30:00 【公众号学一点会一点】
在机器学习/深度学习中,二分类也是经常遇到的任务,逻辑回归就是二分类中常用的模型。
本文简单回顾下逻辑回归,并且用Pytorch实现。
逻辑回归Logistic Regression
逻辑回归是机器学习中常用的一种二分类算法,常用于疾病预测等“非黑即白”的分类,简单说就是在使用逻辑回归的任务中,标签数据的Y值要么是0要么是1。
Sigmoid函数
逻辑回归,不管怎么着,还是一个回归,而我们是用它来进行分类的。回归一般的得到的是一个连续值,二分类需要的是0或者1(类别),那么怎么建立起连续值到类别的映射关系呢?这时候Sigmoid函数就发挥作用了。 Sigmoid函数的公式如下:
其示意图为:

可以看出其值域为0到1之间,以0为分界点:
当x小于0的时候,f(x)值小于0.5; 当x大于0的时候,f(x)值大于0.5;
正是由于上面的特性,sigmoid函数可以被用来进行二分类,比如我们以0.5为界,将大于0.5的值归为类别1,小于0.5的归为类别0。
逻辑回归
有了Sigmoid函数,才有了逻辑回归。 逻辑回归逻辑回归又叫做对数几率回归,也就是说我们是对概率进行建模,而不同于线性回归对Y直接进行建模。
线性回归:
而逻辑回归:
对上面的公式进行转换,可以得到下面的公式:
也就是逻辑回归常见的公式。这样子我们就完成了从输入数据的线性组合到概率的映射,然后根据一定的阈值将概率映射到类别,就完成了分类的过程。
Pytorch实现逻辑回归
本系列上篇文章讲了用Pytorch实现一个模型应用的过程主要包括:
数据 模型 损失函数 优化器 迭代训练
上代码。
导入包:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim
构造数据:
n = 1000
mean_value = 2.5 # mean value of the distribution
bias = 1.3
n_data = torch.ones(n, 2)
x0 = torch.normal(mean_value*n_data,1.0) + bias
y0 = torch.zeros(n)
x1 = torch.normal(-mean_value*n_data, 1.0) + bias
y1 = torch.ones(n)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)
模型:
class LogReg(nn.Module):
def __init__(self):
super(LogReg, self).__init__()
self.features1 = nn.Linear(2,5)
self.features2 = nn.Linear(5,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features1(x)
x = self.features2(x)
x = self.sigmoid(x)
return x
model = LogReg()
损失函数和优化器:
criterion = nn.BCELoss()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
迭代训练:
for epoch in range(1000):
y_pred = model(train_x)
loss = criterion(y_pred.squeeze(), train_y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
# 计算准确率
mask = y_pred.ge(0.5).float().squeeze()
correct = (mask == train_y).sum()
acc = correct.item()/train_y.shape[0]
print('epoch: ', epoch, 'loss: ', loss.item(), 'acc: ', acc)
训练结果:

可以看到第10个epoch的时候准确率就差不多98%了。

参考
【1】很多机器学习可以用到的数据集:http://archive.ics.uci.edu/ml/index.php
【2】Latex怎么打分式:https://zhuanlan.zhihu.com/p/262715401
【3】二分类交叉熵损失:https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
【4】https://www.jianshu.com/p/63e255e3232f
本文由 mdnice 多平台发布
边栏推荐
猜你喜欢
在yolov5的网络结构中添加注意力机制模块

深度梳理:防止模型过拟合的方法汇总

论文精度 —— 2016 CVPR 《Context Encoders: Feature Learning by Inpainting》

Kubernetes:(十七)Helm概述、安装及配置

strongest brain (1)

自适应空间特征融合( adaptively spatial feature fusion)一种基于数据驱动的金字塔特征融合策略

Talk about API Management - Open Source Edition to SaaS Edition

Arduino框架下合宙ESP32C3 +1.8“tft 网络时钟

k-近邻实现手写数字识别

GtkD开发之路
随机推荐
Joomla vulnerability reproduced
An article will help you understand what is idempotency?How to solve the idempotency problem?
An article to master the entire JVM, JVM ultra-detailed analysis!!!
You can‘t specify target table ‘kms_report_reportinfo‘ for update in FROM clause
【Static proxy】
Arduino框架下合宙ESP32C3 +1.8“tft 网络时钟
scikit-learn机器学习 读书笔记(一)
接口文档进化图鉴,有些古早接口文档工具,你可能都没用过
openGauss源码,在window系统用VSCode维护吗?
Transforming into a product, is it reliable to take the NPDP test?
基本比例尺标准分幅编号流程
FPGA工程师面试试题集锦21~30
Qiskit 学习笔记1
WSTP初体验
Qiskit学习笔记(三)
【格式转换】将JPEG图片批量处理为jpg格式
基于Qiskit——《量子计算编程实战》读书笔记(三)
pytorch框架学习(1)网络的简单构建
How does Jenkins play with interface automation testing?
树莓派入门(3)树莓派GPIO学习