Convert scikit-learn models to PyTorch modules

Related tags

Deep Learningsk2torch
Overview

sk2torch

sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript.

Problems solved by this project:

  1. scikit-learn cannot perform inference on a GPU. Models like SVMs have a lot to gain from fast GPU primitives, and converting the models to PyTorch gives immediate access to these primitives.
  2. While scikit-learn supports serialization through pickle, saved models are not reproducible across versions of the library. On the other hand, TorchScript provides a convenient, safe way to save a model with its corresponding implementation. The resulting models can be loaded anywhere that PyTorch is installed, even without importing sk2torch.
  3. While certain models like SVMs and linear classifiers are theoretically end-to-end differentiable, scikit-learn provides no mechanism to compute gradients through trained models. PyTorch provides this functionality mostly for free.

See Usage for a high-level example of using the library. See How it works to see which modules are supported.

For fun, here's a vector field produced by differentiating the probability predictions of a two-class SVM (produced by this script):

A vector field quiver plot with two modes

Usage

First, train a model with scikit-learn as usual:

from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

x, y = create_some_dataset()
model = Pipeline([
    ("center", StandardScaler(with_std=False)),
    ("classify", SGDClassifier()),
])
model.fit(x, y)

Then call sk2torch.wrap on the model to create a PyTorch equivalent:

import sk2torch
import torch

torch_model = sk2torch.wrap(model)
print(torch_model.predict(torch.tensor([[1., 2., 3.]]).double()))

You can save a model with TorchScript:

import torch.jit

torch.jit.script(torch_model).save("path.pt")

# ... sk2torch need not be installed to load the model.
loaded_model = torch.jit.load("path.pt")

For a full example of training a model and using its PyTorch translation, see examples/svm_vector_field.py.

How it works

sk2torch contains PyTorch re-implementations of supported scikit-learn models. For a supported estimator X, a class TorchX in sk2torch will be able to read the attributes of X and convert them to torch.Tensor or simple Python types. TorchX subclasses torch.nn.Module and has a method for each inference API of X (e.g. predict, decision_function, etc.).

Which modules are supported? The easiest way to get an up-to-date list is via the supported_classes() function, which returns all wrap()able scikit-learn classes:

>>> import sk2torch
>>> sk2torch.supported_classes()
[<class 'sklearn.tree._classes.DecisionTreeClassifier'>, <class 'sklearn.tree._classes.DecisionTreeRegressor'>, <class 'sklearn.dummy.DummyClassifier'>, <class 'sklearn.ensemble._gb.GradientBoostingClassifier'>, <class 'sklearn.preprocessing._label.LabelBinarizer'>, <class 'sklearn.svm._classes.LinearSVC'>, <class 'sklearn.svm._classes.LinearSVR'>, <class 'sklearn.neural_network._multilayer_perceptron.MLPClassifier'>, <class 'sklearn.kernel_approximation.Nystroem'>, <class 'sklearn.pipeline.Pipeline'>, <class 'sklearn.linear_model._stochastic_gradient.SGDClassifier'>, <class 'sklearn.preprocessing._data.StandardScaler'>, <class 'sklearn.svm._classes.SVC'>, <class 'sklearn.svm._classes.NuSVC'>, <class 'sklearn.svm._classes.SVR'>, <class 'sklearn.svm._classes.NuSVR'>, <class 'sklearn.compose._target.TransformedTargetRegressor'>]

Comparison to sklearn-onnx

sklearn-onnx is an open source package for converting trained scikit-learn models into ONNX. Like sk2torch, sklearn-onnx re-implements inference functions for various models, meaning that it can also provide serialization and GPU acceleration for supported modules.

Naturally, neither library will support modules that aren't manually ported. As a result, the two libraries support different subsets of all available models/methods. For example, sk2torch supports the SVC probability prediction methods predict_proba and predict_log_prob, whereas sklearn-onnx does not.

While sklearn-onnx exports models to ONNX, sk2torch exports models to Python objects with familiar method names that can be fine-tuned, backpropagated through, and serialized in a user-friendly way. PyTorch is strictly more general than ONNX, since PyTorch models can be converted to ONNX if desired.

Owner
Alex Nichol
Web developer, math geek, and AI enthusiast.
Alex Nichol
Chatbot in 200 lines of code using TensorLayer

Seq2Seq Chatbot This is a 200 lines implementation of Twitter/Cornell-Movie Chatbot, please read the following references before you read the code: Pr

TensorLayer Community 820 Dec 17, 2022
Transformer in Vision

Transformer-in-Vision Recent Transformer-based CV and related works. Welcome to comment/contribute! Keep updated. Resource SCENIC: A JAX Library for C

Yong-Lu Li 1.1k Dec 30, 2022
Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

SSRL-for-image-classification Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

Feng 2 Nov 19, 2021
Code release for Local Light Field Fusion at SIGGRAPH 2019

Local Light Field Fusion Project | Video | Paper Tensorflow implementation for novel view synthesis from sparse input images. Local Light Field Fusion

1.1k Dec 27, 2022
LabelImg is a graphical image annotation tool.

LabelImgPlus LabelImg is a graphical image annotation tool. This project is not updated with new functions now. More functions are supported with Labe

lzx1413 200 Dec 20, 2022
UltraPose: Synthesizing Dense Pose with 1 Billion Points by Human-body Decoupling 3D Model

UltraPose: Synthesizing Dense Pose with 1 Billion Points by Human-body Decoupling 3D Model Official repository for the ICCV 2021 paper: UltraPose: Syn

MomoAILab 92 Dec 21, 2022
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

19 Sep 29, 2022
Official repository for the paper "Self-Supervised Models are Continual Learners" (CVPR 2022)

Self-Supervised Models are Continual Learners This is the official repository for the paper: Self-Supervised Models are Continual Learners Enrico Fini

Enrico Fini 73 Dec 18, 2022
🛰️ List of earth observation companies and job sites

Earth Observation Companies & Jobs source Portals & Jobs Geospatial Geospatial jobs newsletter: ~biweekly newsletter with geospatial jobs by Ali Ahmad

Dahn 64 Dec 27, 2022
Unofficial implementation of HiFi-GAN+ from the paper "Bandwidth Extension is All You Need" by Su, et al.

HiFi-GAN+ This project is an unoffical implementation of the HiFi-GAN+ model for audio bandwidth extension, from the paper Bandwidth Extension is All

Brent M. Spell 134 Dec 30, 2022
Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks

Hidden-Fold Networks (HFN): Random Recurrent Residuals Using Sparse Supermasks by Ángel López García-Arias, Masanori Hashimoto, Masato Motomura, and J

Ángel López García-Arias 4 May 19, 2022
PyTorch implementation of ECCV 2020 paper "Foley Music: Learning to Generate Music from Videos "

Foley Music: Learning to Generate Music from Videos This repo holds the code for the framework presented on ECCV 2020. Foley Music: Learning to Genera

Chuang Gan 30 Nov 03, 2022
PyTorch CZSL framework containing GQA, the open-world setting, and the CGE and CompCos methods.

Compositional Zero-Shot Learning This is the official PyTorch code of the CVPR 2021 works Learning Graph Embeddings for Compositional Zero-shot Learni

EML TĂźbingen 70 Dec 27, 2022
Code release for General Greedy De-bias Learning

General Greedy De-bias for Dataset Biases This is an extention of "Greedy Gradient Ensemble for Robust Visual Question Answering" (ICCV 2021, Oral). T

4 Mar 15, 2022
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

770 Jan 02, 2023
Web mining module for Python, with tools for scraping, natural language processing, machine learning, network analysis and visualization.

Pattern Pattern is a web mining module for Python. It has tools for: Data Mining: web services (Google, Twitter, Wikipedia), web crawler, HTML DOM par

Computational Linguistics Research Group 8.4k Jan 03, 2023
Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order of magnitude using coresets and data selection.

COResets and Data Subset selection Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order

decile-team 244 Jan 09, 2023
PyTorch DepthNet Training on Still Box dataset

DepthNet training on Still Box Project page This code can replicate the results of our paper that was published in UAVg-17. If you use this repo in yo

ClĂŠment Pinard 115 Nov 21, 2022
A Model for Natural Language Attack on Text Classification and Inference

TextFooler A Model for Natural Language Attack on Text Classification and Inference This is the source code for the paper: Jin, Di, et al. "Is BERT Re

Di Jin 418 Dec 16, 2022
Algorithmic Trading using RNN

Deep-Trading This an implementation adapted from Rachnog Neural networks for algorithmic trading. Part One — Simple time series forecasting and this c

Hazem Nomer 29 Sep 04, 2022