当前位置:网站首页>用Pytorch从0到1实现逻辑回归
用Pytorch从0到1实现逻辑回归
2022-08-10 05:30:00 【公众号学一点会一点】
在机器学习/深度学习中,二分类也是经常遇到的任务,逻辑回归就是二分类中常用的模型。
本文简单回顾下逻辑回归,并且用Pytorch实现。
逻辑回归Logistic Regression
逻辑回归是机器学习中常用的一种二分类算法,常用于疾病预测等“非黑即白”的分类,简单说就是在使用逻辑回归的任务中,标签数据的Y值要么是0要么是1。
Sigmoid函数
逻辑回归,不管怎么着,还是一个回归,而我们是用它来进行分类的。回归一般的得到的是一个连续值,二分类需要的是0或者1(类别),那么怎么建立起连续值到类别的映射关系呢?这时候Sigmoid函数就发挥作用了。 Sigmoid函数的公式如下:
其示意图为:

可以看出其值域为0到1之间,以0为分界点:
当x小于0的时候,f(x)值小于0.5; 当x大于0的时候,f(x)值大于0.5;
正是由于上面的特性,sigmoid函数可以被用来进行二分类,比如我们以0.5为界,将大于0.5的值归为类别1,小于0.5的归为类别0。
逻辑回归
有了Sigmoid函数,才有了逻辑回归。 逻辑回归逻辑回归又叫做对数几率回归,也就是说我们是对概率进行建模,而不同于线性回归对Y直接进行建模。
线性回归:
而逻辑回归:
对上面的公式进行转换,可以得到下面的公式:
也就是逻辑回归常见的公式。这样子我们就完成了从输入数据的线性组合到概率的映射,然后根据一定的阈值将概率映射到类别,就完成了分类的过程。
Pytorch实现逻辑回归
本系列上篇文章讲了用Pytorch实现一个模型应用的过程主要包括:
数据 模型 损失函数 优化器 迭代训练
上代码。
导入包:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim
构造数据:
n = 1000
mean_value = 2.5 # mean value of the distribution
bias = 1.3
n_data = torch.ones(n, 2)
x0 = torch.normal(mean_value*n_data,1.0) + bias
y0 = torch.zeros(n)
x1 = torch.normal(-mean_value*n_data, 1.0) + bias
y1 = torch.ones(n)
train_x = torch.cat((x0, x1), 0)
train_y = torch.cat((y0, y1), 0)
模型:
class LogReg(nn.Module):
def __init__(self):
super(LogReg, self).__init__()
self.features1 = nn.Linear(2,5)
self.features2 = nn.Linear(5,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features1(x)
x = self.features2(x)
x = self.sigmoid(x)
return x
model = LogReg()
损失函数和优化器:
criterion = nn.BCELoss()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
迭代训练:
for epoch in range(1000):
y_pred = model(train_x)
loss = criterion(y_pred.squeeze(), train_y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
# 计算准确率
mask = y_pred.ge(0.5).float().squeeze()
correct = (mask == train_y).sum()
acc = correct.item()/train_y.shape[0]
print('epoch: ', epoch, 'loss: ', loss.item(), 'acc: ', acc)
训练结果:

可以看到第10个epoch的时候准确率就差不多98%了。

参考
【1】很多机器学习可以用到的数据集:http://archive.ics.uci.edu/ml/index.php
【2】Latex怎么打分式:https://zhuanlan.zhihu.com/p/262715401
【3】二分类交叉熵损失:https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
【4】https://www.jianshu.com/p/63e255e3232f
本文由 mdnice 多平台发布
边栏推荐
- 细数国产接口协作平台的六把武器!
- FPGA工程师面试试题集锦1~10
- scikit-learn机器学习 读书笔记(二)
- Interface documentation evolution illustration, some ancient interface documentation tools, you may not have used it
- AVL树的插入--旋转笔记
- Consulting cdc 2.0 for mysql does not execute flush with read lock. How to ensure bin
- scikit-learn机器学习 读书笔记(一)
- R中设置图形参数--函数par()详解
- pytorch learning
- 论文精读 —— 2021 CVPR《Progressive Temporal Feature Alignment Network for Video Inpainting》
猜你喜欢
【Pei Shu Theorem】CF1055C Lucky Days
Linear Algebra (4)
Interface debugging also can play this?
基于Qiskit——《量子计算编程实战》读书笔记(四)
基本比例尺标准分幅编号流程
Pony语言学习(六):Struct, Type Alias, Type Expressions
Pony语言学习(一):环境配置(续)
基于Qiskit——《量子计算编程实战》读书笔记(二)
Order table delete, insert and search operations
Stacks and Queues | Valid parentheses, delete all adjacent elements in a string, reverse Polish expression evaluation, maximum sliding window, top K high frequency elements | leecode brush questions
随机推荐
【LeetCode】41. The first missing positive number
在yolov5的网络结构中添加注意力机制模块
R语言:修改chart.Correlation()函数绘制相关性图——完美出图
【Static proxy】
FPGA engineer interview questions collection 1~10
Rpc interface stress test
awk of the Three Musketeers of Shell Programming
【yolov5训练错误】WARNING: Ignoring corrupted image
基本比例尺标准分幅编号流程
每周推荐短视频:探索AI的应用边界
Interface debugging also can play this?
conda创建虚拟环境方法和pqi使用国内镜像源安装第三方库的方法教程
Mysql CDC (2.1.1) inital snapshot database set up five concurrent degree, se
FPGA工程师面试试题集锦1~10
pytorch框架学习(2)使用GPU训练
【论文笔记1】小样本分类
WSTP初体验
scikit-learn机器学习 读书笔记(二)
聊聊 API 管理-开源版 到 SaaS 版
Linear Algebra (4)