PyTorch META-DATASET (Few-shot classification benchmark)

Overview

PyTorch META-DATASET (Few-shot classification benchmark)

This repo contains a PyTorch implementation of meta-dataset and a unified implementation of some few-shot methods. This repo may be useful to you if you:

  • want some pre-trained ImageNet models in PyTorch for META-DATASET;
  • want to benchmark your method on META-DATASET (but do not want to mix your PyTorch code with the original TensorFlow implementation);
  • are looking for a codebase to visualize few-shot episodes.

Benefits over original code:

  1. This repo can be properly seeded, allowing to repeat the same random series of episodes if needed;
  2. Data shuffling is performed without using a buffer, hence reducing the memory consumption;
  3. Better results can be obtained using this repo thanks to an enhanced way of resizing images. More details in the paper.

Note that this code also includes the original implementation for comparison (using the PyTorch workaround proposed by the authors). If you wish to use the original implementation, set the option loader_version: 'tf' in base.yaml (by default set to pytorch).

Yet to do:

  1. Add more methods
  2. Test for the multi-source setting

Table of contents

1. Setting up

Please carefully follow the instructions below to get started.

1.1 Requirements

The present code was developped and tested in Python 3.8. The list of requirements is provided in requirements.txt:

pip install -r requirements.txt

1.2 Data

To download the META-DATASET, please follow the details instructions provided at meta-dataset to obtain the .tfrecords converted data. Once done, make sure all converted dataset are in a single folder, and execute the following script to produce index files:

bash scripts/make_records/make_index_files.sh <path_to_converted_data>

This may take a few minutes. Once all this is done, set the path variable in config/base.yaml to your data folder.

1.3 Download pre-trained models

We provide trained Resnet-18 and WRN-2810 models on the training split of ILSVRC_2012 at checkpoints. All non-episodic baselines use the same checkpoint, stored in the standard folder. The results (averaged over 600 episodes) obtained with the provided Resnet-18 are summarized below:

Inductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
Finetune Resnet-18 59.8 60.5 63.5 80.6 80.9 61.5 45.2 91.1 55.1 41.8 64.0
ProtoNet Resnet-18 48.2 46.7 44.6 53.8 70.3 45.1 38.5 82.4 42.2 38.0 51.0
SimpleShot Resnet-18 60.0 54.2 55.9 78.6 77.8 57.4 49.2 90.3 49.6 44.2 61.7
Transductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
BD-CSPN Resnet-18 60.5 54.4 55.2 80.9 77.9 57.3 50.0 91.7 47.8 43.9 62.0
TIM-GD Resnet-18 63.6 65.6 66.4 85.6 84.7 65.8 57.5 95.6 65.2 50.9 70.1

See Sect. 1.4 and 1.5 to reproduce these results.

1.4 Train models from scratch (optional)

In order to train you model from scratch, execute scripts/train.sh script:

bash scripts/train.sh <method> <architecture> <dataset>

method is to be chosen among all method specific config files in config/, architecture in ['resnet18', 'wideres2810'] and dataset among all datasets (as named by the META-DATASET converted folders). Note that the hierarchy of arguments passed to src/train.py and src/eval.py is the following: base_config < method_config < opts arguments.

Mutiprocessing : This code supports distributed training. To leverage this feature, set the gpus option accordingly (for instance gpus: [0, 1, 2, 3]).

1.5 Test your models

Once trained (or once pre-trained models downloaded), you can evaluate your model on the test split of each dataset by running:

bash scripts/test.sh <method> <architecture> <base_dataset> <test_dataset>

Results will be saved in results/ / where corresponds to a unique hash number of the config (you can only get the same result folder iff all hyperparameters are the same).

2. Visualization of results

2.1 Training metrics

During training, training loss and validation accuracy are recorded and saved as .npy files in the checkpoint folder. Then, you can use the src/plot.py to plot these metrics (even during training).

Example 1: Plot the metrics of the standard (=non episodic) resnet-18 on ImageNet:

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/standard/

Example 2: Plot the metrics of all Resnet-18 trained on ImageNet

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/

2.2 Inference metrics

For methods that perform test-time optimization (for instance MAML, TIM, Finetune, ...), method specific metrics are plotted in real-time (versus test iterations) and averaged over test epidodes, which can allow you to track unexpected behavior easily. Such metrics are implemented in src/metrics/, and the choice of which metric to plot is specificied through the eval_metrics option in the method .yaml config file. An example with TIM method is provided below.

2.3 Visualization of episodes

By setting the option visu: True at inference, you can visualize samples of episodes. An example of such visualization is given below:

The samples will be saved in results/. All relevant optons can be found in the base.yaml file, in the EVAL-VISU section.

3. Incorporate your own method

This code was designed to allow easy incorporation of new methods.

Step 1: Add your method .py file to src/methods/ by following the template provided in src/methods/method.py.

Step 2: Add import in src/methods/__init__.py

Step 3: Add your method .yaml config file including the required options episodic_training and method (name of the class corresponding to your method). Also make sure that if your method performs test-time optimization, you also properly set the option iter that specifies the number of optimization steps performed at inference (this argument is also used to plot the inference metrics, see section 2.2).

4. Contributions

Contributions are more than welcome. In particular, if you want to add methods/pre-trained models, do make a pull-request.

5. Citation

If you find this repo useful for your research, please consider citing the following papers:

@misc{boudiaf2021mutualinformation,
      title={Mutual-Information Based Few-Shot Classification}, 
      author={Malik Boudiaf and Ziko Imtiaz Masud and Jérôme Rony and Jose Dolz and Ismail Ben Ayed and Pablo Piantanida},
      year={2021},
      eprint={2106.12252},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Additionally, do not hesitate to file issues if you encounter problems, or reach out directly to Malik Boudiaf ([email protected]).

6. Acknowledgments

I thank the authors of meta-dataset for releasing their code and the author of open-source TFRecord reader for open sourcing an awesome Pytorch-compatible TFRecordReader ! Also big thanks to @hkervadec for his thorough code review !

Owner
Malik Boudiaf
Malik Boudiaf
Efficient electromagnetic solver based on rigorous coupled-wave analysis for 3D and 2D multi-layered structures with in-plane periodicity

Efficient electromagnetic solver based on rigorous coupled-wave analysis for 3D and 2D multi-layered structures with in-plane periodicity, such as gratings, photonic-crystal slabs, metasurfaces, surf

Alex Song 17 Dec 19, 2022
OntoProtein: Protein Pretraining With Ontology Embedding

OntoProtein This is the implement of the paper "OntoProtein: Protein Pretraining With Ontology Embedding". OntoProtein is an effective method that mak

ZJUNLP 80 Dec 14, 2022
Official PyTorch implementation of "Improving Face Recognition with Large AgeGaps by Learning to Distinguish Children" (BMVC 2021)

Inter-Prototype (BMVC 2021): Official Project Webpage This repository provides the official PyTorch implementation of the following paper: Improving F

Jungsoo Lee 16 Jun 30, 2022
Deep Dual Consecutive Network for Human Pose Estimation (CVPR2021)

Beanie - is an asynchronous ODM for MongoDB, based on Motor and Pydantic. It uses an abstraction over Pydantic models and Motor collections to work wi

295 Dec 29, 2022
Official PyTorch implementation of StyleGAN3

Modified StyleGAN3 Repo Changes Made tied to python 3.7 syntax .jpgs instead of .pngs for training sample seeds to recreate the 1024 training grid wit

Derrick Schultz (he/him) 83 Dec 15, 2022
This a classic fintech problem that introduces real life difficulties such as data imbalance. Check out the notebook to find out more!

Credit Card Fraud Detection Introduction Online transactions have become a crucial part of any business over the years. Many of those transactions use

Jonathan Hasbani 0 Jan 20, 2022
SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers

SAGE: Sensitivity-guided Adaptive Learning Rate for Transformers This repo contains our codes for the paper "No Parameters Left Behind: Sensitivity Gu

Chen Liang 23 Nov 07, 2022
Official implementation for the paper "SAPE: Spatially-Adaptive Progressive Encoding for Neural Optimization".

SAPE Project page Paper Official implementation for the paper "SAPE: Spatially-Adaptive Progressive Encoding for Neural Optimization". Environment Cre

36 Dec 09, 2022
Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

ChongjianGE 89 Dec 02, 2022
This repository contains the source code of an efficient 1D probabilistic model for music time analysis proposed in ICASSP2022 venue.

Jump Reward Inference for 1D Music Rhythmic State Spaces An implementation of the probablistic jump reward inference model for music rhythmic informat

Mojtaba Heydari 25 Dec 16, 2022
PRIME: A Few Primitives Can Boost Robustness to Common Corruptions

PRIME: A Few Primitives Can Boost Robustness to Common Corruptions This is the official repository of PRIME, the data agumentation method introduced i

Apostolos Modas 34 Oct 30, 2022
Gesture recognition on Event Data

Event based Gesture Recognition Gesture recognition on Event Data usually involv

2 Feb 14, 2022
Latent Network Models to Account for Noisy, Multiply-Reported Social Network Data

VIMuRe Latent Network Models to Account for Noisy, Multiply-Reported Social Network Data. If you use this code please cite this article (preprint). De

6 Dec 15, 2022
An Abstract Cyber Security Simulation and Markov Game for OpenAI Gym

gym-idsgame An Abstract Cyber Security Simulation and Markov Game for OpenAI Gym gym-idsgame is a reinforcement learning environment for simulating at

Kim Hammar 29 Dec 03, 2022
Code for Boundary-Aware Segmentation Network for Mobile and Web Applications

BASNet Boundary-Aware Segmentation Network for Mobile and Web Applications This repository contain implementation of BASNet in tensorflow/keras. comme

Hamid Ali 8 Nov 24, 2022
Txt2Xml tool will help you convert from txt COCO format to VOC xml format in Object Detection Problem.

TXT 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Txt2Xml too

Nguyễn Trường Lâu 4 Nov 24, 2022
This is a file about Unet implemented in Pytorch

Unet this is an implemetion of Unet in Pytorch and it's architecture is as follows which is the same with paper of Unet component of Unet Convolution

Dragon 1 Dec 03, 2021
Implementation of fast algorithms for Maximum Spanning Tree (MST) parsing that includes fast ArcMax+Reweighting+Tarjan algorithm for single-root dependency parsing.

Fast MST Algorithm Implementation of fast algorithms for (Maximum Spanning Tree) MST parsing that includes fast ArcMax+Reweighting+Tarjan algorithm fo

Miloš Stanojević 11 Oct 14, 2022
Implementation of average- and worst-case robust flatness measures for adversarial training.

Relating Adversarially Robust Generalization to Flat Minima This repository contains code corresponding to the MLSys'21 paper: D. Stutz, M. Hein, B. S

David Stutz 13 Nov 27, 2022
Evaluating different engineering tricks that make RL work

Reinforcement Learning Tricks, Index This repository contains the code for the paper "Distilling Reinforcement Learning Tricks for Video Games". Short

Anssi 15 Dec 26, 2022