A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Overview

Torchmeta

PyPI Build Status Documentation

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning benchmarks, fully compatible with both torchvision and PyTorch's DataLoader.

Features

  • A unified interface for both few-shot classification and regression problems, to allow easy benchmarking on multiple problems and reproducibility.
  • Helper functions for some popular problems, with default arguments from the literature.
  • An thin extension of PyTorch's Module, called MetaModule, that simplifies the creation of certain meta-learning models (e.g. gradient based meta-learning methods). See the MAML example for an example using MetaModule.

Datasets available

Installation

You can install Torchmeta either using Python's package manager pip, or from source. To avoid any conflict with your existing Python setup, it is suggested to work in a virtual environment with virtualenv. To install virtualenv:

pip install --upgrade virtualenv
virtualenv venv
source venv/bin/activate

Requirements

  • Python 3.6 or above
  • PyTorch 1.4 or above
  • Torchvision 0.5 or above

Using pip

This is the recommended way to install Torchmeta:

pip install torchmeta

From source

You can also install Torchmeta from source. This is recommended if you want to contribute to Torchmeta.

git clone https://github.com/tristandeleu/pytorch-meta.git
cd pytorch-meta
python setup.py install

Example

Minimal example

This minimal example below shows how to create a dataloader for the 5-shot 5-way Omniglot dataset with Torchmeta. The dataloader loads a batch of randomly generated tasks, and all the samples are concatenated into a single tensor. For more examples, check the examples folder.

from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader

dataset = omniglot("data", ways=5, shots=5, test_shots=15, meta_train=True, download=True)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

for batch in dataloader:
    train_inputs, train_targets = batch["train"]
    print('Train inputs shape: {0}'.format(train_inputs.shape))    # (16, 25, 1, 28, 28)
    print('Train targets shape: {0}'.format(train_targets.shape))  # (16, 25)

    test_inputs, test_targets = batch["test"]
    print('Test inputs shape: {0}'.format(test_inputs.shape))      # (16, 75, 1, 28, 28)
    print('Test targets shape: {0}'.format(test_targets.shape))    # (16, 75)

Advanced example

Helper functions are only available for some of the datasets available. However, all of them are available through the unified interface provided by Torchmeta. The variable dataset defined above is equivalent to the following

from torchmeta.datasets import Omniglot
from torchmeta.transforms import Categorical, ClassSplitter, Rotation
from torchvision.transforms import Compose, Resize, ToTensor
from torchmeta.utils.data import BatchMetaDataLoader

dataset = Omniglot("data",
                   # Number of ways
                   num_classes_per_task=5,
                   # Resize the images to 28x28 and converts them to PyTorch tensors (from Torchvision)
                   transform=Compose([Resize(28), ToTensor()]),
                   # Transform the labels to integers (e.g. ("Glagolitic/character01", "Sanskrit/character14", ...) to (0, 1, ...))
                   target_transform=Categorical(num_classes=5),
                   # Creates new virtual classes with rotated versions of the images (from Santoro et al., 2016)
                   class_augmentations=[Rotation([90, 180, 270])],
                   meta_train=True,
                   download=True)
dataset = ClassSplitter(dataset, shuffle=True, num_train_per_class=5, num_test_per_class=15)
dataloader = BatchMetaDataLoader(dataset, batch_size=16, num_workers=4)

Note that the dataloader, receiving the dataset, remains the same.

Citation

Tristan Deleu, Tobias Würfl, Mandana Samiei, Joseph Paul Cohen, and Yoshua Bengio. Torchmeta: A Meta-Learning library for PyTorch, 2019 [ArXiv]

If you want to cite Torchmeta, use the following Bibtex entry:

@misc{deleu2019torchmeta,
  title={{Torchmeta: A Meta-Learning library for PyTorch}},
  author={Deleu, Tristan and W\"urfl, Tobias and Samiei, Mandana and Cohen, Joseph Paul and Bengio, Yoshua},
  year={2019},
  url={https://arxiv.org/abs/1909.06576},
  note={Available at: https://github.com/tristandeleu/pytorch-meta}
}
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 04, 2023
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A code copied from google-research which named motion-imitation was rewrited with PyTorch

motor-system Introduction A code copied from google-research which named motion-imitation was rewrited with PyTorch. More details can get from this pr

NewEra 6 Jan 08, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 03, 2023
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 01, 2023
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
High-fidelity performance metrics for generative models in PyTorch

High-fidelity performance metrics for generative models in PyTorch

Vikram Voleti 5 Oct 24, 2021
Learning Sparse Neural Networks through L0 regularization

Example implementation of the L0 regularization method described at Learning Sparse Neural Networks through L0 regularization, Christos Louizos, Max W

AMLAB 202 Nov 10, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
Training PyTorch models with differential privacy

Opacus is a library that enables training PyTorch models with differential privacy. It supports training with minimal code changes required on the cli

1.3k Dec 29, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023