A lightweight library designed to accelerate the process of training PyTorch models by providing a minimal

Overview

pytorch-accelerated

pytorch-accelerated is a lightweight library designed to accelerate the process of training PyTorch models by providing a minimal, but extensible training loop - encapsulated in a single Trainer object - which is flexible enough to handle the majority of use cases, and capable of utilizing different hardware options with no code changes required.

pytorch-accelerated offers a streamlined feature set, and places a huge emphasis on simplicity and transparency, to enable users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves!

The key features are:

  • A simple and contained, but easily customisable, training loop, which should work out of the box in straightforward cases; behaviour can be customised using inheritance and/or callbacks.
  • Handles device placement, mixed-precision, DeepSpeed integration, multi-GPU and distributed training with no code changes.
  • Uses pure PyTorch components, with no additional modifications or wrappers, and easily interoperates with other popular libraries such as timm, transformers and torchmetrics.
  • A small, streamlined API ensures that there is a minimal learning curve for existing PyTorch users.

Significant effort has been taken to ensure that every part of the library - both internal and external components - is as clear and simple as possible, making it easy to customise, debug and understand exactly what is going on behind the scenes at each step; most of the behaviour of the trainer is contained in a single class! In the spirit of Python, nothing is hidden and everything is accessible.

pytorch-accelerated is proudly and transparently built on top of Hugging Face Accelerate, which is responsible for the movement of data between devices and launching of training configurations. When customizing the trainer, or launching training, users are encouraged to consult the Accelerate documentation to understand all available options; Accelerate provides convenient functions for operations such gathering tensors and gradient clipping, usage of which can be seen in the pytorch-accelerated examples folder!

To learn more about the motivations behind this library, along with a detailed getting started guide, check out this blog post.

Installation

pytorch-accelerated can be installed from pip using the following command:

pip install pytorch-accelerated

To make the package as slim as possible, the packages required to run the examples are not included by default. To include these packages, you can use the following command:

pip install pytorch-accelerated[examples]

Quickstart

To get started, simply import and use the pytorch-accelerated Trainer ,as demonstrated in the following snippet, and then launch training using the accelerate CLI described below.

# examples/train_mnist.py
import os

from torch import nn, optim
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST

from pytorch_accelerated import Trainer

class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
        )

    def forward(self, input):
        return self.main(input.view(input.shape[0], -1))

def main():
    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    train_dataset, validation_dataset, test_dataset = random_split(dataset, [50000, 5000, 5000])
    model = MNISTModel()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    loss_func = nn.CrossEntropyLoss()

    trainer = Trainer(
            model,
            loss_func=loss_func,
            optimizer=optimizer,
    )

    trainer.train(
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        num_epochs=8,
        per_device_batch_size=32,
    )

    trainer.evaluate(
        dataset=test_dataset,
        per_device_batch_size=64,
    )
    
if __name__ == "__main__":
    main()

To launch training using the accelerate CLI , on your machine(s), run:

accelerate config --config_file accelerate_config.yaml

and answer the questions asked. This will generate a config file that will be used to properly set the default options when doing

accelerate launch --config_file accelerate_config.yaml train.py [--training-args]

Note: Using the accelerate CLI is completely optional, training can also be launched in the usual way using:

python train.py / python -m torch.distributed ...

depending on your infrastructure configuration, for users who would like to maintain a more fine-grained control over the launch command.

More complex training examples can be seen in the examples folder here.

Alternatively, if you would rather undertsand the core concepts first, this can be found in the documentation.

Usage

Who is pytorch-accelerated aimed at?

  • Users that are familiar with PyTorch but would like to avoid having to write the common training loop boilerplate to focus on the interesting parts of the training loop.
  • Users who like, and are comfortable with, selecting and creating their own models, loss functions, optimizers and datasets.
  • Users who value a simple and streamlined feature set, where the behaviour is easy to debug, understand, and reason about!

When shouldn't I use pytorch-accelerated?

  • If you are looking for an end-to-end solution, encompassing everything from loading data to inference, which helps you to select a model, optimizer or loss function, you would probably be better suited to fastai. pytorch-accelerated focuses only on the training process, with all other concerns being left to the responsibility of the user.
  • If you would like to write the entire training loop yourself, just without all of the device management headaches, you would probably be best suited to using Accelerate directly! Whilst it is possible to customize every part of the Trainer, the training loop is fundamentally broken up into a number of different methods that you would have to override. But, before you go, is writing those for loops really important enough to warrant starting from scratch again 😉 .
  • If you are working on a custom, highly complex, use case which does not fit the patterns of usual training loops and want to squeeze out every last bit of performance on your chosen hardware, you are probably best off sticking with vanilla PyTorch; any high-level API becomes an overhead in highly specialized cases!

Acknowledgements

Many aspects behind the design and features of pytorch-accelerated were greatly inspired by a number of excellent libraries and frameworks such as fastai, timm, PyTorch-lightning and Hugging Face Accelerate. Each of these tools have made an enormous impact on both this library and the machine learning community, and their influence can not be stated enough!

pytorch-accelerated has taken only inspiration from these tools, and all of the functionality contained has been implemented from scratch in a way that benefits this library. The only exceptions to this are some of the scripts in the examples folder in which existing resources were taken and modified in order to showcase the features of pytorch-accelerated; these cases are clearly marked, with acknowledgement being given to the original authors.

Comments
  • Do we need to set mixed-precision explicitly or is it handled if tensor cores available?

    Do we need to set mixed-precision explicitly or is it handled if tensor cores available?

    I following your awesome guide on timm: https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055.

    I am running training on an A100-based VM which should support mixed-precision training. Does Trainer from PyTorch Accelerated take care of that automatically?

    opened by sayakpaul 6
  • ERROR: No matching distribution found for pytorch-accelerated

    ERROR: No matching distribution found for pytorch-accelerated

    I'm just trying to install the package using the pip command and I get the following errors:

    ERROR: Could not find a version that satisfies the requirement pytorch-accelerated (from versions: none)
    ERROR: No matching distribution found for pytorch-accelerated
    

    Am I missing something?

    P.S. I've already installed the requirements including accelerate and tqdm

    opened by phosseini 4
  • Can pytorch-accelerated be used with pytorch-lightning callbacks and loggers?

    Can pytorch-accelerated be used with pytorch-lightning callbacks and loggers?

    I'm interested in this package for its support of methods like EMA that don't seem to have made it into Lightning yet, but don't want to lost my current experiment tracking setup etc.

    opened by GeorgePearse 3
  • Do you know about Lightning Lite ?

    Do you know about Lightning Lite ?

    Hey @Chris-hughes10,

    Awesome work there !

    Did you know about Lightning Lite in PyTorch Lightning ? Here are the docs : https://pytorch-lightning.readthedocs.io/en/latest/starter/lightning_lite.html

    lightning_lite

    opened by tchaton 1
  • Refactor loss tracking

    Refactor loss tracking

    • Create private methods in trainer to handle loss tracking, removing duplication
    • Move loss gathering to the end of each epoch, as opposed to after each batch
    • Add tests for loss tracker
    opened by Chris-hughes10 0
  • Add limit batches context manager

    Add limit batches context manager

    Add a context manager which can be used to limit the number of training and evaluation batches used without having to manually add the callback. This is done by setting an environment variable.

    opened by Chris-hughes10 0
  • Refactor batch unpacking

    Refactor batch unpacking

    • Refactor batch unpacking to explicitly assign the first two items as xb and yb. This will enable more flexibility in what is returned by a dataloader
    opened by Chris-hughes10 0
  • Enables distributed evaluation on uneven inputs

    Enables distributed evaluation on uneven inputs

    Adds functionality to enable distributed evaluation on uneven samples. Previously, this was handled by adding extra samples to the dataset, this behaviour is now disabled by default.

    opened by Chris-hughes10 0
Releases(v0.1.40)
  • v0.1.40(Nov 17, 2022)

    What's Changed

    • Add option to execute callbacks during ModelEma evaluation loop by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/41

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.39...v0.1.40

    Source code(tar.gz)
    Source code(zip)
  • v0.1.39(Oct 14, 2022)

    What's Changed

    • Improve gathering to automatically pad tensors across processes
    • Add get_model method in Trainer by @bepuca in https://github.com/Chris-hughes10/pytorch-accelerated/pull/39

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.38...v0.1.39

    Source code(tar.gz)
    Source code(zip)
  • v0.1.38(Sep 7, 2022)

    What's Changed

    • update worker init function by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/37
    • Separate out decay function in model EMA for easier override by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/38

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.37...v0.1.38

    Source code(tar.gz)
    Source code(zip)
  • v0.1.37(Aug 24, 2022)

  • v0.1.36(Aug 22, 2022)

    What's Changed

    • Improve logging for SaveBestModelCallback by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/35
    • Add sync batchnorm callback by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/34
    • Add Ema model callback by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/36

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.35...v0.1.36

    Source code(tar.gz)
    Source code(zip)
  • v0.1.35(Jul 9, 2022)

    What's Changed

    • Update Custom sampler handling by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/33

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.34...v0.1.35

    Source code(tar.gz)
    Source code(zip)
  • v0.1.34(Jun 29, 2022)

  • v0.1.33(Jun 29, 2022)

  • v0.1.32(Jun 29, 2022)

    What's Changed

    • Add limit batches context manager by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/32

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.31...v0.1.32

    Source code(tar.gz)
    Source code(zip)
  • v0.1.31(Jun 22, 2022)

  • v0.1.30(Jun 22, 2022)

    What's Changed

    • Add Limit batches callback (beta version) by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/31

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.29...v0.1.30

    Source code(tar.gz)
    Source code(zip)
  • v0.1.29(Jun 17, 2022)

    What's Changed

    • Improve grad accumulation by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/30 Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.28...v0.1.29
    Source code(tar.gz)
    Source code(zip)
  • v0.1.28(May 31, 2022)

    What's Changed

    • Add local process first decorator by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/27
    • Update fp16 arg to mixed precision by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/28
    • Update accelerate version to 0.8.0 by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/29

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.27...v0.1.28

    Source code(tar.gz)
    Source code(zip)
  • v0.1.27(May 24, 2022)

  • v0.1.26(Apr 28, 2022)

    What's Changed

    • Add process decorators for distributed training by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/25
    • Update accelerate version by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/26

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.25...v0.1.26

    Source code(tar.gz)
    Source code(zip)
  • v0.1.25(Apr 25, 2022)

    What's Changed

    • Add handling for multi boolean tensors by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/23
    • Refactor batch unpacking by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/24

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.24...v0.1.25

    Source code(tar.gz)
    Source code(zip)
  • v0.1.24(Apr 20, 2022)

  • v0.1.23(Apr 17, 2022)

    What's Changed

    • Add schedulers by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/22
    • Add a better way of getting default callbacks
    • Update project license to Apache-2.0

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.22...v0.1.23

    Source code(tar.gz)
    Source code(zip)
  • v0.1.22(Feb 23, 2022)

    What's Changed

    • Add operations to placeholders by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/17
    • Add clarification for LR schedulers in the docs by @bepuca in https://github.com/Chris-hughes10/pytorch-accelerated/pull/16
    • Add specialised trainer to work with timm schedulers

    New Contributors

    • @bepuca made their first contribution in https://github.com/Chris-hughes10/pytorch-accelerated/pull/16

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.21...v0.1.22

    Source code(tar.gz)
    Source code(zip)
  • v0.1.21(Jan 27, 2022)

  • v0.1.20(Jan 19, 2022)

    What's Changed

    • Added an example to the docs for a callback that saves predictions during evaluation by @alexhock10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/13
    • Create run config for standalone evaluation runs by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/14

    New Contributors

    • @alexhock10 made their first contribution in https://github.com/Chris-hughes10/pytorch-accelerated/pull/13

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.9...v0.1.20

    Source code(tar.gz)
    Source code(zip)
  • v0.1.9(Dec 31, 2021)

    What's Changed

    • Freezing exploration by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/11
    • Add gather method by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/12

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.8...v0.1.9

    Source code(tar.gz)
    Source code(zip)
  • v0.1.8(Dec 11, 2021)

    What's Changed

    • Changes to facilitate AzureML example by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/10

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.7...v0.1.8

    Source code(tar.gz)
    Source code(zip)
  • v0.1.7(Nov 30, 2021)

    What's Changed

    • Update early stopping by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/9
    • Remove torch dependency (covered by accelerate)

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.6...v0.1.7

    Source code(tar.gz)
    Source code(zip)
  • v0.1.6(Nov 26, 2021)

  • v0.1.5(Nov 24, 2021)

    What's Changed

    • Add intersphinx to docs by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/7
    • Update device handling by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/8

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.4...v0.1.5

    Source code(tar.gz)
    Source code(zip)
  • v0.1.4(Nov 17, 2021)

    Update the package documentation

    What's Changed

    • Get docs to build properly by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/5
    • Get docs to build properly by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/6

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.3...v0.1.4

    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Nov 13, 2021)

    Initial release

    What's Changed

    • Add gradient clipping to trainer by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/2
    • Prepare pypi workflow by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/3

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/commits/v0.1.0

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.0...v0.1.2

    What's Changed

    • Add sphinx docs by @Chris-hughes10 in https://github.com/Chris-hughes10/pytorch-accelerated/pull/4

    Full Changelog: https://github.com/Chris-hughes10/pytorch-accelerated/compare/v0.1.2...v0.1.3

    Source code(tar.gz)
    Source code(zip)
Owner
Chris Hughes
Chris Hughes
PHOTONAI is a high level python API for designing and optimizing machine learning pipelines.

PHOTONAI is a high level python API for designing and optimizing machine learning pipelines. We've created a system in which you can easily select and

Medical Machine Learning Lab - University of Münster 57 Nov 12, 2022
ERISHA is a mulitilingual multispeaker expressive speech synthesis framework. It can transfer the expressivity to the speaker's voice for which no expressive speech corpus is available.

ERISHA: Multilingual Multispeaker Expressive Text-to-Speech Library ERISHA is a multilingual multispeaker expressive speech synthesis framework. It ca

Ajinkya Kulkarni 43 Nov 27, 2022
Direct design of biquad filter cascades with deep learning by sampling random polynomials.

IIRNet Direct design of biquad filter cascades with deep learning by sampling random polynomials. Usage git clone https://github.com/csteinmetz1/IIRNe

Christian J. Steinmetz 55 Nov 02, 2022
A Free and Open Source Python Library for Multiobjective Optimization

Platypus What is Platypus? Platypus is a framework for evolutionary computing in Python with a focus on multiobjective evolutionary algorithms (MOEAs)

Project Platypus 424 Dec 18, 2022
RLMeta is a light-weight flexible framework for Distributed Reinforcement Learning Research.

RLMeta rlmeta - a flexible lightweight research framework for Distributed Reinforcement Learning based on PyTorch and moolib Installation To build fro

Meta Research 281 Dec 22, 2022
RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x flax community week

RoBERTa base model for Marathi Language (मराठी भाषा) Pretrained model on Marathi language using a masked language modeling (MLM) objective. RoBERTa wa

Nipun Sadvilkar 23 Oct 19, 2022
Probabilistic Programming and Statistical Inference in PyTorch

PtStat Probabilistic Programming and Statistical Inference in PyTorch. Introduction This project is being developed during my time at Cogent Labs. The

Stefano Peluchetti 109 Nov 26, 2022
Benchmarks for the Optimal Power Flow Problem

Power Grid Lib - Optimal Power Flow This benchmark library is curated and maintained by the IEEE PES Task Force on Benchmarks for Validation of Emergi

A Library of IEEE PES Power Grid Benchmarks 207 Dec 08, 2022
MarcoPolo is a clustering-free approach to the exploration of bimodally expressed genes along with group information in single-cell RNA-seq data

MarcoPolo is a method to discover differentially expressed genes in single-cell RNA-seq data without depending on prior clustering Overview MarcoPolo

Chanwoo Kim 13 Dec 18, 2022
[Machine Learning Engineer Basic Guide] 부스트캠프 AI Tech - Product Serving 자료

Boostcamp-AI-Tech-Product-Serving 부스트캠프 AI Tech - Product Serving 자료 Repository 구조 part1(MLOps 개론, Model Serving, 머신러닝 프로젝트 라이프 사이클은 별도의 코드가 없으며, part

Sung Yun Byeon 269 Dec 21, 2022
A really easy-to-use and powerful sudoku solver.

SodukuSolver This is a really useful sudoku solver with a Qt gui. USAGE Enter the numbers in and click "RUN"! If you don't want to wait, simply press

Ujhhgtg Teams 11 Jun 02, 2022
Establishing Strong Baselines for TripClick Health Retrieval; ECIR 2022

TripClick Baselines with Improved Training Data Welcome 🙌 to the hub-repo of our paper: Establishing Strong Baselines for TripClick Health Retrieval

Sebastian Hofstätter 3 Nov 03, 2022
"SOLQ: Segmenting Objects by Learning Queries", SOLQ is an end-to-end instance segmentation framework with Transformer.

SOLQ: Segmenting Objects by Learning Queries This repository is an official implementation of the paper SOLQ: Segmenting Objects by Learning Queries.

MEGVII Research 179 Jan 02, 2023
Learning Synthetic Environments and Reward Networks for Reinforcement Learning

Learning Synthetic Environments and Reward Networks for Reinforcement Learning We explore meta-learning agent-agnostic neural Synthetic Environments (

AutoML-Freiburg-Hannover 16 Sep 02, 2022
PyTorch code for our paper "Image Super-Resolution with Non-Local Sparse Attention" (CVPR2021).

Image Super-Resolution with Non-Local Sparse Attention This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-

143 Dec 28, 2022
Automatically Build Multiple ML Models with a Single Line of Code. Created by Ram Seshadri. Collaborators Welcome. Permission Granted upon Request.

Auto-ViML Automatically Build Variant Interpretable ML models fast! Auto_ViML is pronounced "auto vimal" (autovimal logo created by Sanket Ghanmare) N

AutoViz and Auto_ViML 397 Dec 30, 2022
Spherical CNNs

Spherical CNNs Equivariant CNNs for the sphere and SO(3) implemented in PyTorch Overview This library contains a PyTorch implementation of the rotatio

Jonas Köhler 893 Dec 28, 2022
Original code for "Zero-Shot Domain Adaptation with a Physics Prior"

Zero-Shot Domain Adaptation with a Physics Prior [arXiv] [sup. material] - ICCV 2021 Oral paper, by Attila Lengyel, Sourav Garg, Michael Milford and J

Attila Lengyel 40 Dec 21, 2022
A compendium of useful, interesting, inspirational usage of pandas functions, each example will be an ipynb file

Pandas_by_examples A compendium of useful/interesting/inspirational usage of pandas functions, each example will be an ipynb file What is this reposit

Guangyuan(Frank) Li 32 Nov 20, 2022
BEGAN in PyTorch

BEGAN in PyTorch This project is still in progress. If you are looking for the working code, use BEGAN-tensorflow. Requirements Python 2.7 Pillow tqdm

Taehoon Kim 260 Dec 07, 2022