PyTorch Personal Trainer: My framework for deep learning experiments

Related tags

Deep Learningptpt
Overview

Alex's PyTorch Personal Trainer (ptpt)

(name subject to change)

This repository contains my personal lightweight framework for deep learning projects in PyTorch.

Disclaimer: this project is very much work-in-progress. Although technically useable, it is missing many features. Nonetheless, you may find some of the design patterns and code snippets to be useful in the meantime.

Installation

Simply run python -m build in the root of the repo, then run pip install on the resulting .whl file.

No pip package yet..

Usage

Import the library as with any other python library:

from ptpt.trainer import Trainer, TrainerConfig
from ptpt.log import debug, info, warning, error, critical

The core of the library is the trainer.Trainer class. In the simplest case, it takes the following as input:

net:            a `nn.Module` that is the model we wish to train.
loss_fn:        a function that takes a `nn.Module` and a batch as input.
                it returns the loss and optionally other metrics.
train_dataset:  the training dataset.
test_dataset:   the test dataset.
cfg:            a `TrainerConfig` instance that holds all
                hyperparameters.

Once this is instantiated, starting the training loop is as simple as calling trainer.train() where trainer is an instance of Trainer.

cfg stores most of the configuration options for Trainer. See the class definition of TrainerConfig for details on all options.

Examples

An example workflow would go like this:

Define your training and test datasets:

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=transform)

Define your model:

# in this case, we have imported `Net` from another file
net = Net()

Define your loss function that calls net, taking the full batch as input:

# minimising classification error
def loss_fn(net, batch):
    X, y = batch
    logits = net(X)
    loss = F.nll_loss(logits, y)

    pred = logits.argmax(dim=-1, keepdim=True)
    accuracy = 100. * pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
    return loss, accuracy

Optionally create a configuration object:

# see class definition for full list of parameters
cfg = TrainerConfig(
    exp_name = 'mnist-conv',
    batch_size = 64,
    learning_rate = 4e-4,
    nb_workers = 4,
    save_outputs = False,
    metric_names = ['accuracy']
)

Initialise the Trainer class:

trainer = Trainer(
    net=net,
    loss_fn=loss_fn,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    cfg=cfg
)

Call trainer.train() to begin the training loop

trainer.train() # Go!

See more examples here.

Motivation

I found myself repeating a lot of same structure in many of my deep learning projects. This project is the culmination of my efforts refining the typical structure of my projects into (what I hope to be) a wholly reusable and general-purpose library.

Additionally, there are many nice theoretical and engineering tricks that are available to deep learning researchers. Unfortunately, a lot of them are forgotten because they fall outside the typical workflow, despite them being very beneficial to include. Another goal of this project is to transparently include these tricks so they can be added and removed with minimal code change. Where it is sane to do so, some of these could be on by default.

Finally, I am guilty of forgetting to implement decent logging: both of standard output and of metrics. Logging of standard output is not hard, and is implemented using other libraries such as rich. However, metric logging is less obvious. I'd like to avoid larger dependencies such as tensorboard being an integral part of the project, so metrics will be logged to simple numpy arrays. The library will then provide functions to produce plots from these, or they can be used in another library.

TODO:

  • Make a todo.

References

Citations

Owner
Alex McKinney
Student at Durham University. I do a variety of things. I use Arch btw
Alex McKinney
PyTorch implementation of PSPNet segmentation network

pspnet-pytorch PyTorch implementation of PSPNet segmentation network Original paper Pyramid Scene Parsing Network Details This is a slightly different

Roman Trusov 532 Dec 29, 2022
Fast RFC3339 compliant Python date-time library

udatetime: Fast RFC3339 compliant date-time library Handling date-times is a painful act because of the sheer endless amount of formats used by people

Simon Pirschel 235 Oct 25, 2022
Artificial intelligence technology inferring issues and logically supporting facts from raw text

개요 비정형 텍스트를 학습하여 쟁점별 사실과 논리적 근거 추론이 가능한 인공지능 원천기술 Artificial intelligence techno

6 Dec 29, 2021
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 04, 2023
Open-sourcing the Slates Dataset for recommender systems research

FINN.no Recommender Systems Slate Dataset This repository accompany the paper "Dynamic Slate Recommendation with Gated Recurrent Units and Thompson Sa

FINN.no 48 Nov 28, 2022
Revisiting Weakly Supervised Pre-Training of Visual Perception Models

SWAG: Supervised Weakly from hashtAGs This repository contains SWAG models from the paper Revisiting Weakly Supervised Pre-Training of Visual Percepti

Meta Research 134 Jan 05, 2023
A Protein-RNA Interface Predictor Based on Semantics of Sequences

PRIP PRIP:A Protein-RNA Interface Predictor Based on Semantics of Sequences installation gensim==3.8.3 matplotlib==3.1.3 xgboost==1.3.3 prettytable==2

李优 0 Mar 25, 2022
[ICLR 2021] Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization

Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization Kaidi Cao, Yining Chen, Junwei Lu, Nikos Arechiga, Adrien Gaidon, Tengyu Ma

Kaidi Cao 29 Oct 20, 2022
Subdivision-based Mesh Convolutional Networks

Subdivision-based Mesh Convolutional Networks The official implementation of SubdivNet in our paper, Subdivion-based Mesh Convolutional Networks Requi

Zheng-Ning Liu 181 Dec 28, 2022
[NeurIPS 2019] Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss

Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss Kaidi Cao, Colin Wei, Adrien Gaidon, Nikos Arechiga, Tengyu Ma This is the offi

Kaidi Cao 528 Jan 01, 2023
Code for DeepCurrents: Learning Implicit Representations of Shapes with Boundaries

DeepCurrents | Webpage | Paper DeepCurrents: Learning Implicit Representations of Shapes with Boundaries David Palmer*, Dmitriy Smirnov*, Stephanie Wa

Dima Smirnov 36 Dec 08, 2022
Kernel Point Convolutions

Created by Hugues THOMAS Introduction Update 27/04/2020: New PyTorch implementation available. With SemanticKitti, and Windows supported. This reposit

Hugues THOMAS 584 Jan 07, 2023
HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty

HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty Giorgio Cantarini, Francesca Odone, Nicoletta Noceti, Federi

18 Aug 02, 2022
Release of the ConditionalQA dataset

ConditionalQA Datasets accompanying the paper ConditionalQA: A Complex Reading Comprehension Dataset with Conditional Answers. Disclaimer This dataset

14 Oct 17, 2022
This repo is a PyTorch implementation for Paper "Unsupervised Learning for Cuboid Shape Abstraction via Joint Segmentation from Point Clouds"

Unsupervised Learning for Cuboid Shape Abstraction via Joint Segmentation from Point Clouds This repository is a PyTorch implementation for paper: Uns

Kaizhi Yang 42 Dec 09, 2022
This repository includes code of my study about Asynchronous in Frequency domain of GAN images.

Exploring the Asynchronous of the Frequency Spectra of GAN-generated Facial Images Binh M. Le & Simon S. Woo, "Exploring the Asynchronous of the Frequ

4 Aug 06, 2022
The implementation of the paper "A Deep Feature Aggregation Network for Accurate Indoor Camera Localization".

A Deep Feature Aggregation Network for Accurate Indoor Camera Localization This is the PyTorch implementation of our paper "A Deep Feature Aggregation

9 Dec 09, 2022
This is the official repository for our paper: ''Pruning Self-attentions into Convolutional Layers in Single Path''.

Pruning Self-attentions into Convolutional Layers in Single Path This is the official repository for our paper: Pruning Self-attentions into Convoluti

Zhuang AI Group 77 Dec 26, 2022
[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding

[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding Official Pytorch implementation of Negative Sample Matter

Multimedia Computing Group, Nanjing University 69 Dec 26, 2022
ICRA 2021 "Towards Precise and Efficient Image Guided Depth Completion"

PENet: Precise and Efficient Depth Completion This repo is the PyTorch implementation of our paper to appear in ICRA2021 on "Towards Precise and Effic

232 Dec 25, 2022