当前位置:网站首页>用Pytorch从0到1实现逻辑回归
用Pytorch从0到1实现逻辑回归
2022-08-10 05:30:00 【公众号学一点会一点】
在机器学习/深度学习中,二分类也是经常遇到的任务,逻辑回归就是二分类中常用的模型。
本文简单回顾下逻辑回归,并且用Pytorch实现。
逻辑回归Logistic Regression
逻辑回归是机器学习中常用的一种二分类算法,常用于疾病预测等“非黑即白”的分类,简单说就是在使用逻辑回归的任务中,标签数据的Y值要么是0要么是1。
Sigmoid函数
逻辑回归,不管怎么着,还是一个回归,而我们是用它来进行分类的。回归一般的得到的是一个连续值,二分类需要的是0或者1(类别),那么怎么建立起连续值到类别的映射关系呢?这时候Sigmoid函数就发挥作用了。 Sigmoid函数的公式如下:
其示意图为:

可以看出其值域为0到1之间,以0为分界点:
当x小于0的时候,f(x)值小于0.5; 当x大于0的时候,f(x)值大于0.5;
正是由于上面的特性,sigmoid函数可以被用来进行二分类,比如我们以0.5为界,将大于0.5的值归为类别1,小于0.5的归为类别0。
逻辑回归
有了Sigmoid函数,才有了逻辑回归。 逻辑回归逻辑回归又叫做对数几率回归,也就是说我们是对概率进行建模,而不同于线性回归对Y直接进行建模。
线性回归:
而逻辑回归:
对上面的公式进行转换,可以得到下面的公式:
也就是逻辑回归常见的公式。这样子我们就完成了从输入数据的线性组合到概率的映射,然后根据一定的阈值将概率映射到类别,就完成了分类的过程。
Pytorch实现逻辑回归
本系列上篇文章讲了用Pytorch实现一个模型应用的过程主要包括:
数据 模型 损失函数 优化器 迭代训练
上代码。
导入包:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim
构造数据:
n = 1000
mean_value = 2.5 # mean value of the distribution
bias = 1.3
n_data = torch.ones(n, 2)
x0 = torch.normal(mean_value*n_data,1.0) + bias
y0 = torch.zeros(n)
x1 = torch.normal(-mean_value*n_data, 1.0) + bias
y1 = torch.ones(n)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)
模型:
class LogReg(nn.Module):
def __init__(self):
super(LogReg, self).__init__()
self.features1 = nn.Linear(2,5)
self.features2 = nn.Linear(5,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features1(x)
x = self.features2(x)
x = self.sigmoid(x)
return x
model = LogReg()
损失函数和优化器:
criterion = nn.BCELoss()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
迭代训练:
for epoch in range(1000):
y_pred = model(train_x)
loss = criterion(y_pred.squeeze(), train_y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
# 计算准确率
mask = y_pred.ge(0.5).float().squeeze()
correct = (mask == train_y).sum()
acc = correct.item()/train_y.shape[0]
print('epoch: ', epoch, 'loss: ', loss.item(), 'acc: ', acc)
训练结果:

可以看到第10个epoch的时候准确率就差不多98%了。

参考
【1】很多机器学习可以用到的数据集:http://archive.ics.uci.edu/ml/index.php
【2】Latex怎么打分式:https://zhuanlan.zhihu.com/p/262715401
【3】二分类交叉熵损失:https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
【4】https://www.jianshu.com/p/63e255e3232f
本文由 mdnice 多平台发布
边栏推荐
- What are the common commands of mysql
- Flutter development: error The following assertion was thrown resolving an image codec: Solution for Unable to...
- Arduino框架下合宙ESP32C3 +1.8“tft 网络时钟
- 基于Qiskit——《量子计算编程实战》读书笔记(一)
- 【格式转换】将JPEG图片批量处理为jpg格式
- [Thesis Notes] Prototypical Contrast Adaptation for Domain Adaptive Semantic Segmentation
- 看了几十篇轻量化目标检测论文扫盲做的摘抄笔记
- 聊聊 API 管理-开源版 到 SaaS 版
- oracle rac 11g安装执行root.sh时报错
- 如何在报表控件FastReport.NET中连接XLSX 文件作为数据源?
猜你喜欢
随机推荐
WSTP初体验
How cursors work in Pulsar
Tkinter 入门之旅
Why are negative numbers in binary represented in two's complement form - binary addition and subtraction
基于Qiskit——《量子计算编程实战》读书笔记(六)
Get started with the OAuth protocol easily with a case
SQLSERVER 2008 parses data in Json format
FPGA engineer interview questions collection 41~50
Shield Alt hotkey in vscode
SEO搜索引擎优化
通过一个案例轻松入门OAuth协议
OpenGauss source code, is it maintained with VSCode in the window system?
基本比例尺标准分幅编号流程
论文精度 —— 2017 ACM《Globally and Locally Consistent Image Completion》
summer preschool assignments
MySql之json_extract函数处理json字段
How to simulate the background API call scene, very detailed!
FPGA engineer interview questions collection 1~10
清览题库--C语言程序设计第五版编程题解析(1)
【论文笔记1】小样本分类









