Official Pytorch implementation of MixMo framework

Overview

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks

Official PyTorch implementation of the MixMo framework | paper | docs

Alexandre Ramé, Rémy Sun, Matthieu Cord

Citation

If you find this code useful for your research, please cite:

@article{rame2021ixmo,
    title={MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks},
    author={Alexandre Rame and Remy Sun and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2103.06132}
}

Abstract

Recent strategies achieved ensembling “for free” by fitting concurrently diverse subnetworks inside a single base network. The main idea during training is that each subnetwork learns to classify only one of the multiple inputs simultaneously provided. However, the question of how to best mix these multiple inputs has not been studied so far.

In this paper, we introduce MixMo, a new generalized framework for learning multi-input multi-output deep subnetworks. Our key motivation is to replace the suboptimal summing operation hidden in previous approaches by a more appropriate mixing mechanism. For that purpose, we draw inspiration from successful mixed sample data augmentations. We show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse.

We improve state of the art for image classification on CIFAR-100 and Tiny ImageNet datasets. Our easy to implement models notably outperform data augmented deep ensembles, without the inference and memory overheads. As we operate in features and simply better leverage the expressiveness of large networks, we open a new line of research complementary to previous works.

Overview

Most important code sections

This repository provides a general wrapper over PyTorch to reproduce the main results from the paper. The code sections specific to MixMo can be found in:

  1. mixmo.loaders.dataset_wrapper.py and specifically MixMoDataset to create batches with multiple inputs and multiple outputs.
  2. mixmo.augmentations.mixing_blocks.py where we create the mixing masks, e.g. via linear summing (_mixup_mask) or via patch mixing (_cutmix_mask).
  3. mixmo.networks.resnet.py and mixmo.networks.wrn.py where we adapt the network structures to handle:
    • multiple inputs via multiple conv1s encoders (one for each input). The function mixmo.augmentations.mixing_blocks.mix_manifold is used to mix the extracted representations according to the masks provided in metadata from MixMoDataset.
    • multiple outputs via multiple predictions.

This translates to additional tensor management in mixmo.learners.learner.py.

Pseudo code

Our MixMoDataset wraps a PyTorch Dataset. The batch_repetition_sampler repeats the same index b times in each batch. Moreover, we provide SoftCrossEntropyLoss which handles soft-labels required by mixed sample data augmentations such as CutMix.

from mixmo.loaders import (dataset_wrapper, batch_repetition_sampler)
from mixmo.networks.wrn import WideResNetMixMo
from mixmo.core.loss import SoftCrossEntropyLoss as criterion

...

# cf mixmo.loaders.loader
train_dataset = dataset_wrapper.MixMoDataset(
        dataset=CIFAR100(os.path.join(dataplace, "cifar100-data")),
        num_members=2,  # we use M=2 subnetworks
        mixmo_mix_method="cutmix",  # patch mixing, linker to mixmo.augmentations.mixing_blocks._cutmix_mask
        mixmo_alpha=2,  # mixing ratio sampled from Beta distribution with concentration 2
        mixmo_weight_root=3  # root for reweighting of loss components 3
        )
network = WideResNetMixMo(depth=28, widen_factor=10, num_classes=100)

...

# cf mixmo.learners.learner and mixmo.learners.model_wrapper
for _ in range(num_epochs):
    for indexes_0, indexes_1 in batch_repetition_sampler(batch_size=64, b=4, max_index=len(train_dataset)):
        for (inputs_0, inputs_1, targets_0, targets_1, metadata_mixmo_masks) in train_dataset(indexes_0, indexes_1):
            outputs_0, outputs_1 = network([inputs_0, inputs_1], metadata_mixmo_masks)
            loss = criterion(outputs_0, targets_0) + criterion(outputs_1, targets_1)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Configuration files

Our code heavily relies on yaml config files. In the mixmo-pytorch/config folder, we provide the configs to reproduce the main paper results.

For example, the state-of-the-art exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4 means that:

  • cifar100: dataset is CIFAR-100.
  • wrn2810-2: WideResNet-28-10 network architecture with M=2 subnetworks.
  • cutmixmo-p5: mixing block is patch mixing with probability p=0.5 else linear mixing.
  • msdacutmix: use CutMix mixed sample data augmentation.
  • bar4: batch repetition to b=4.

Results and available checkpoints

CIFAR-100 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar100
-- Vanilla 81.79 exp_cifar100_wrn2810_1net_standard_bar1.yaml
-- Mixup 83.43 exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 83.95 exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 82.92 exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 82.96 exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 85.52 - 85.59 exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 85.36 - 85.57 exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 85.77 - 85.92 exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

CIFAR-10 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar10
-- Vanilla 96.37 exp_cifar10_wrn2810_1net_standard_bar1.yaml
-- Mixup 97.07 exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 97.28 exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 96.71 exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 96.88 exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 97.52 exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 97.73 exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 97.83 exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

Tiny ImageNet-200 with PreActResNet-18-width

Method Width Top-1 Accuracy config file in mixmo-pytorch/config/tiny
Vanilla 1 62.75 exp_tinyimagenet_res18_1net_standard_bar1.yaml
Linear-MixMo 1 62.91 exp_tinyimagenet_res18-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 1 64.32 exp_tinyimagenet_res18-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 2 64.91 exp_tinyimagenet_res182_1net_standard_bar1.yaml
Linear-MixMo 2 67.03 exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 2 69.12 exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 3 65.84 exp_tinyimagenet_res183_1net_standard_bar1.yaml
Linear-MixMo 3 68.36 exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 3 70.23 exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml

Installation

Requirements overview

  • python >= 3.6
  • torch >= 1.4.0
  • torchsummary >= 1.5.1
  • torchvision >= 0.5.0
  • tensorboard >= 1.14.0

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/mixmo-pytorch.git
  1. Install this repository and the dependencies using pip:
$ conda create --name mixmo python=3.6.10
$ conda activate mixmo
$ cd mixmo-pytorch
$ pip install -r requirements.txt

With this, you can edit the MixMo code on the fly.

Datasets

We advise to first create a dedicated data folder dataplace, that will be provided as an argument in the subsequent scripts.

  • CIFAR

CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in your provided dataplace.

  • Tiny-ImageNet

Tiny-ImageNet dataset needs to be download beforehand. The following process is forked from manifold mixup.

  1. Download the zipped data from https://tiny-imagenet.herokuapp.com/.
  2. Extract the zipped data in folder dataplace.
  3. Run the following script (This will arange the validation data in the format required by the pytorch loader).
$ python scripts/script_load_tiny_data.py --dataplace $dataplace

Running the code

Training

Baseline

First, to train a baseline model, simply execute the following command:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810_1net_standard_bar1.yaml --dataplace $dataplace --saveplace $saveplace

It will create an output folder exp_cifar100_wrn2810_1net_standard_bar1 located in parent folder saveplace. This folder includes model checkpoints, a copy of your config file, logs and tensorboard logs. By default, if the output folder already exists, training will load the last weights epoch and will continue. If you want to forcefully restart training, simply add --from_scratch as an argument.

MixMo

When training MixMo, you just need to select the appropriate config file. For example, to obtain state of the art results on CIFAR-100 by combining Cut-MixMo and CutMix, just execute:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --saveplace $saveplace

Evaluation

To evaluate the accuracy of a given strategy, you can train your own model, or just download our pretrained checkpoints:

$ python3 scripts/evaluate.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --checkpoint $checkpoint --tempscal
  • checkpoint can be either:
    • a path towards a checkpoint.
    • an int matching the training epoch you wish to evaluate. In that case, you need to provide --saveplace $saveplace.
    • the string best: we then automatically select the best training epoch. In that case, you need to provide --saveplace $saveplace.
  • --tempscal: indicates that you will apply temperature scaling

Results will be printed at the end of the script.

If you wish to test the models against common corruptions and perturbations, download the CIFAR-100-c dataset in your dataplace. Then use --robustness at evaluation.

Create your own configuration files and learning strategies

You can create new configs automatically via:

$ python3 scripts/templateutils_mixmo.py --template_path scripts/exp_mixmo_template.yaml --config_dir config/$your_config_dir --dataset $dataset

Acknowledgements and references

An implementation of the [Hierarchical (Sig-Wasserstein) GAN] algorithm for large dimensional Time Series Generation

Hierarchical GAN for large dimensional financial market data Implementation This repository is an implementation of the [Hierarchical (Sig-Wasserstein

11 Nov 29, 2022
Code for the preprint "Well-classified Examples are Underestimated in Classification with Deep Neural Networks"

This is a repository for the paper of "Well-classified Examples are Underestimated in Classification with Deep Neural Networks" The implementation and

LancoPKU 25 Dec 11, 2022
particle tracking model, works with the ROMS output file(qck.nc, his.nc)

particle-tracking-model-for-ROMS particle tracking model, works with the ROMS output file(qck.nc, his.nc) description this is a 2-dimensional particle

xusheng 1 Jan 11, 2022
Lolviz - A simple Python data-structure visualization tool for lists of lists, lists, dictionaries; primarily for use in Jupyter notebooks / presentations

lolviz By Terence Parr. See Explained.ai for more stuff. A very nice looking javascript lolviz port with improvements by Adnan M.Sagar. A simple Pytho

Terence Parr 785 Dec 30, 2022
Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks

flownet2-pytorch Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks. Multiple GPU training is supported, a

NVIDIA Corporation 2.8k Dec 27, 2022
An educational AI robot based on NVIDIA Jetson Nano.

JetBot Looking for a quick way to get started with JetBot? Many third party kits are now available! JetBot is an open-source robot based on NVIDIA Jet

NVIDIA AI IOT 2.6k Dec 29, 2022
Jupyter notebooks for the code samples of the book "Deep Learning with Python"

Jupyter notebooks for the code samples of the book "Deep Learning with Python"

François Chollet 16.2k Dec 30, 2022
VQMIVC - Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion

VQMIVC: Vector Quantization and Mutual Information-Based Unsupervised Speech Representation Disentanglement for One-shot Voice Conversion (Interspeech

Disong Wang 262 Dec 31, 2022
Space Ship Simulator using python

FlyOver Basic space-ship simulator using python How to run? Just double click run.py What modules do i need? All modules that i currently using is bui

0 Oct 09, 2022
GNN-based Recommendation Benchmark

GRecX A Fair Benchmark for GNN-based Recommendation Homepage and Documentation Homepage: Documentation: Paper: GRecX: An Efficient and Unified Benchma

73 Oct 17, 2022
Simple Python application to transform Serial data into OSC messages

SerialToOSC-Bridge Simple Python application to transform Serial data into OSC messages. The current purpose is to be a compatibility layer between ha

Division of Applied Acoustics at Chalmers University of Technology 3 Jun 03, 2021
Leveraging OpenAI's Codex to solve cornerstone problems in Music

Music-Codex Leveraging OpenAI's Codex to solve cornerstone problems in Music Please NOTE: Presented generated samples were created by OpenAI's Codex P

Alex 2 Mar 11, 2022
PyTorch implementations for our SIGGRAPH 2021 paper: Editable Free-viewpoint Video Using a Layered Neural Representation.

st-nerf We provide PyTorch implementations for our paper: Editable Free-viewpoint Video Using a Layered Neural Representation SIGGRAPH 2021 Jiakai Zha

Diplodocus 258 Jan 02, 2023
This repository contains a CBIR system that uses swin transformer to extract image's feature.

Swin-transformer based CBIR This repository contains a CBIR(content-based image retrieval) system. Here we use Swin-transformer to extract query image

JsHou 12 Nov 17, 2022
Code for models used in Bashiri et al., "A Flow-based latent state generative model of neural population responses to natural images".

A Flow-based latent state generative model of neural population responses to natural images Code for "A Flow-based latent state generative model of ne

Sinz Lab 5 Aug 26, 2022
[NeurIPS 2021] Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training

Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training Code for NeurIPS 2021 paper "Better Safe Than Sorry: Preventing Delu

Lue Tao 29 Sep 20, 2022
Implementation of the paper Recurrent Glimpse-based Decoder for Detection with Transformer.

REGO-Deformable DETR By Zhe Chen, Jing Zhang, and Dacheng Tao. This repository is the implementation of the paper Recurrent Glimpse-based Decoder for

Zhe Chen 33 Nov 30, 2022
Analyzes your GitHub Profile and presents you with a report on how likely you are to become the next MLH Fellow!

Fellowship Prediction GitHub Profile Comparative Analysis Tool Built with BentoML Table of Contents: Features Disclaimer Technologies Used Contributin

Damir Temir 51 Dec 29, 2022
Semi-Supervised Semantic Segmentation with Pixel-Level Contrastive Learning from a Class-wise Memory Bank

This repository provides the official code for replicating experiments from the paper: Semi-Supervised Semantic Segmentation with Pixel-Level Contrast

Iñigo Alonso Ruiz 58 Dec 15, 2022
Official code for "Distributed Deep Learning in Open Collaborations" (NeurIPS 2021)

Distributed Deep Learning in Open Collaborations This repository contains the code for the NeurIPS 2021 paper "Distributed Deep Learning in Open Colla

Yandex Research 96 Sep 15, 2022