Official Pytorch implementation of MixMo framework

Overview

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks

Official PyTorch implementation of the MixMo framework | paper | docs

Alexandre Ramé, Rémy Sun, Matthieu Cord

Citation

If you find this code useful for your research, please cite:

@article{rame2021ixmo,
    title={MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks},
    author={Alexandre Rame and Remy Sun and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2103.06132}
}

Abstract

Recent strategies achieved ensembling “for free” by fitting concurrently diverse subnetworks inside a single base network. The main idea during training is that each subnetwork learns to classify only one of the multiple inputs simultaneously provided. However, the question of how to best mix these multiple inputs has not been studied so far.

In this paper, we introduce MixMo, a new generalized framework for learning multi-input multi-output deep subnetworks. Our key motivation is to replace the suboptimal summing operation hidden in previous approaches by a more appropriate mixing mechanism. For that purpose, we draw inspiration from successful mixed sample data augmentations. We show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse.

We improve state of the art for image classification on CIFAR-100 and Tiny ImageNet datasets. Our easy to implement models notably outperform data augmented deep ensembles, without the inference and memory overheads. As we operate in features and simply better leverage the expressiveness of large networks, we open a new line of research complementary to previous works.

Overview

Most important code sections

This repository provides a general wrapper over PyTorch to reproduce the main results from the paper. The code sections specific to MixMo can be found in:

  1. mixmo.loaders.dataset_wrapper.py and specifically MixMoDataset to create batches with multiple inputs and multiple outputs.
  2. mixmo.augmentations.mixing_blocks.py where we create the mixing masks, e.g. via linear summing (_mixup_mask) or via patch mixing (_cutmix_mask).
  3. mixmo.networks.resnet.py and mixmo.networks.wrn.py where we adapt the network structures to handle:
    • multiple inputs via multiple conv1s encoders (one for each input). The function mixmo.augmentations.mixing_blocks.mix_manifold is used to mix the extracted representations according to the masks provided in metadata from MixMoDataset.
    • multiple outputs via multiple predictions.

This translates to additional tensor management in mixmo.learners.learner.py.

Pseudo code

Our MixMoDataset wraps a PyTorch Dataset. The batch_repetition_sampler repeats the same index b times in each batch. Moreover, we provide SoftCrossEntropyLoss which handles soft-labels required by mixed sample data augmentations such as CutMix.

from mixmo.loaders import (dataset_wrapper, batch_repetition_sampler)
from mixmo.networks.wrn import WideResNetMixMo
from mixmo.core.loss import SoftCrossEntropyLoss as criterion

...

# cf mixmo.loaders.loader
train_dataset = dataset_wrapper.MixMoDataset(
        dataset=CIFAR100(os.path.join(dataplace, "cifar100-data")),
        num_members=2,  # we use M=2 subnetworks
        mixmo_mix_method="cutmix",  # patch mixing, linker to mixmo.augmentations.mixing_blocks._cutmix_mask
        mixmo_alpha=2,  # mixing ratio sampled from Beta distribution with concentration 2
        mixmo_weight_root=3  # root for reweighting of loss components 3
        )
network = WideResNetMixMo(depth=28, widen_factor=10, num_classes=100)

...

# cf mixmo.learners.learner and mixmo.learners.model_wrapper
for _ in range(num_epochs):
    for indexes_0, indexes_1 in batch_repetition_sampler(batch_size=64, b=4, max_index=len(train_dataset)):
        for (inputs_0, inputs_1, targets_0, targets_1, metadata_mixmo_masks) in train_dataset(indexes_0, indexes_1):
            outputs_0, outputs_1 = network([inputs_0, inputs_1], metadata_mixmo_masks)
            loss = criterion(outputs_0, targets_0) + criterion(outputs_1, targets_1)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Configuration files

Our code heavily relies on yaml config files. In the mixmo-pytorch/config folder, we provide the configs to reproduce the main paper results.

For example, the state-of-the-art exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4 means that:

  • cifar100: dataset is CIFAR-100.
  • wrn2810-2: WideResNet-28-10 network architecture with M=2 subnetworks.
  • cutmixmo-p5: mixing block is patch mixing with probability p=0.5 else linear mixing.
  • msdacutmix: use CutMix mixed sample data augmentation.
  • bar4: batch repetition to b=4.

Results and available checkpoints

CIFAR-100 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar100
-- Vanilla 81.79 exp_cifar100_wrn2810_1net_standard_bar1.yaml
-- Mixup 83.43 exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 83.95 exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 82.92 exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 82.96 exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 85.52 - 85.59 exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 85.36 - 85.57 exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 85.77 - 85.92 exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

CIFAR-10 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar10
-- Vanilla 96.37 exp_cifar10_wrn2810_1net_standard_bar1.yaml
-- Mixup 97.07 exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 97.28 exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 96.71 exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 96.88 exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 97.52 exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 97.73 exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 97.83 exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

Tiny ImageNet-200 with PreActResNet-18-width

Method Width Top-1 Accuracy config file in mixmo-pytorch/config/tiny
Vanilla 1 62.75 exp_tinyimagenet_res18_1net_standard_bar1.yaml
Linear-MixMo 1 62.91 exp_tinyimagenet_res18-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 1 64.32 exp_tinyimagenet_res18-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 2 64.91 exp_tinyimagenet_res182_1net_standard_bar1.yaml
Linear-MixMo 2 67.03 exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 2 69.12 exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 3 65.84 exp_tinyimagenet_res183_1net_standard_bar1.yaml
Linear-MixMo 3 68.36 exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 3 70.23 exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml

Installation

Requirements overview

  • python >= 3.6
  • torch >= 1.4.0
  • torchsummary >= 1.5.1
  • torchvision >= 0.5.0
  • tensorboard >= 1.14.0

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/mixmo-pytorch.git
  1. Install this repository and the dependencies using pip:
$ conda create --name mixmo python=3.6.10
$ conda activate mixmo
$ cd mixmo-pytorch
$ pip install -r requirements.txt

With this, you can edit the MixMo code on the fly.

Datasets

We advise to first create a dedicated data folder dataplace, that will be provided as an argument in the subsequent scripts.

  • CIFAR

CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in your provided dataplace.

  • Tiny-ImageNet

Tiny-ImageNet dataset needs to be download beforehand. The following process is forked from manifold mixup.

  1. Download the zipped data from https://tiny-imagenet.herokuapp.com/.
  2. Extract the zipped data in folder dataplace.
  3. Run the following script (This will arange the validation data in the format required by the pytorch loader).
$ python scripts/script_load_tiny_data.py --dataplace $dataplace

Running the code

Training

Baseline

First, to train a baseline model, simply execute the following command:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810_1net_standard_bar1.yaml --dataplace $dataplace --saveplace $saveplace

It will create an output folder exp_cifar100_wrn2810_1net_standard_bar1 located in parent folder saveplace. This folder includes model checkpoints, a copy of your config file, logs and tensorboard logs. By default, if the output folder already exists, training will load the last weights epoch and will continue. If you want to forcefully restart training, simply add --from_scratch as an argument.

MixMo

When training MixMo, you just need to select the appropriate config file. For example, to obtain state of the art results on CIFAR-100 by combining Cut-MixMo and CutMix, just execute:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --saveplace $saveplace

Evaluation

To evaluate the accuracy of a given strategy, you can train your own model, or just download our pretrained checkpoints:

$ python3 scripts/evaluate.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --checkpoint $checkpoint --tempscal
  • checkpoint can be either:
    • a path towards a checkpoint.
    • an int matching the training epoch you wish to evaluate. In that case, you need to provide --saveplace $saveplace.
    • the string best: we then automatically select the best training epoch. In that case, you need to provide --saveplace $saveplace.
  • --tempscal: indicates that you will apply temperature scaling

Results will be printed at the end of the script.

If you wish to test the models against common corruptions and perturbations, download the CIFAR-100-c dataset in your dataplace. Then use --robustness at evaluation.

Create your own configuration files and learning strategies

You can create new configs automatically via:

$ python3 scripts/templateutils_mixmo.py --template_path scripts/exp_mixmo_template.yaml --config_dir config/$your_config_dir --dataset $dataset

Acknowledgements and references

PocketNet: Extreme Lightweight Face Recognition Network using Neural Architecture Search and Multi-Step Knowledge Distillation

PocketNet This is the official repository of the paper: PocketNet: Extreme Lightweight Face Recognition Network using Neural Architecture Search and M

Fadi Boutros 40 Dec 22, 2022
⚾🤖⚾ Automatic baseball pitching overlay in realtime

⚾ Automatically overlaying pitch motion and trajectory with machine learning! This project takes your baseball pitching clips and automatically genera

Tony Chou 240 Dec 05, 2022
MoCoGAN: Decomposing Motion and Content for Video Generation

MoCoGAN: Decomposing Motion and Content for Video Generation This repository contains an implementation and further details of MoCoGAN: Decomposing Mo

Sergey Tulyakov 514 Dec 18, 2022
Scales, Chords, and Cadences: Practical Music Theory for MIR Researchers

ISMIR-musicTheoryTutorial This repository has slides and Jupyter notebooks for the ISMIR 2021 tutorial Scales, Chords, and Cadences: Practical Music T

Johanna Devaney 58 Oct 11, 2022
基于Pytorch实现优秀的自然图像分割框架!(包括FCN、U-Net和Deeplab)

语义分割学习实验-基于VOC数据集 usage: 下载VOC数据集,将JPEGImages SegmentationClass两个文件夹放入到data文件夹下。 终端切换到目标目录,运行python train.py -h查看训练 (torch) Li Xiang 28 Dec 21, 2022

A program that can analyze videos according to the weights you select

MaskMonitor A program that can analyze videos according to the weights you select 下載 訓練完的 weight檔案 執行 MaskDetection.py 內部可更改 輸入來源(鏡頭, 影片, 圖片) 以及輸出條件(人

Patrick_star 1 Nov 07, 2021
A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization components are included and optional.

Description A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization co

AoxiangFan 9 Nov 10, 2022
E-Ink Magic Calendar that automatically syncs to Google Calendar and runs off a battery powered Raspberry Pi Zero

MagInkCal This repo contains the code needed to drive an E-Ink Magic Calendar that uses a battery powered (PiSugar2) Raspberry Pi Zero WH to retrieve

2.8k Dec 28, 2022
Code for "Learning Canonical Representations for Scene Graph to Image Generation", Herzig & Bar et al., ECCV2020

Learning Canonical Representations for Scene Graph to Image Generation (ECCV 2020) Roei Herzig*, Amir Bar*, Huijuan Xu, Gal Chechik, Trevor Darrell, A

roei_herzig 24 Jul 07, 2022
Official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting

1 SNAS4MTF This repo is the official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting. 1.1 The frame

SZJ 5 Sep 21, 2022
MultiSiam: Self-supervised Multi-instance Siamese Representation Learning for Autonomous Driving

MultiSiam: Self-supervised Multi-instance Siamese Representation Learning for Autonomous Driving Code will be available soon. Motivation Architecture

Kai Chen 24 Apr 19, 2022
Lecture materials for Cornell CS5785 Applied Machine Learning (Fall 2021)

Applied Machine Learning (Cornell CS5785, Fall 2021) This repo contains executable course notes and slides for the Applied ML course at Cornell and Co

Volodymyr Kuleshov 103 Dec 31, 2022
On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition

On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition With the spirit of reproducible research, this repository contains codes requ

0 Feb 24, 2022
Fuzzing the Kernel Using Unicornafl and AFL++

Unicorefuzz Fuzzing the Kernel using UnicornAFL and AFL++. For details, skim through the WOOT paper or watch this talk at CCCamp19. Is it any good? ye

Security in Telecommunications 283 Dec 26, 2022
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Leaderboard, taxonomy, and curated list of few-shot object detection papers.

229 Dec 13, 2022
NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

Göktuğ Karakaşlı 16 Dec 05, 2022
Official implement of Paper:A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sening images

A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images 深度监督影像融合网络DSIFN用于高分辨率双时相遥感影像变化检测 Of

Chenxiao Zhang 135 Dec 19, 2022
Attention for PyTorch with Linear Memory Footprint

Attention for PyTorch with Linear Memory Footprint Unofficially implements https://arxiv.org/abs/2112.05682 to get Linear Memory Cost on Attention (+

11 Jan 09, 2022
A criticism of a recent paper on buggy image downsampling methods in popular image processing and deep learning libraries.

A criticism of a recent paper on buggy image downsampling methods in popular image processing and deep learning libraries.

70 Jul 12, 2022
【steal piano】GitHub偷情分析工具!

【steal piano】GitHub偷情分析工具! 你是否有这样的困扰,有一天你的仓库被很多人加了star,但是你却不知道这些人都是从哪来的? 别担心,GitHub偷情分析工具帮你轻松解决问题! 原理 GitHub偷情分析工具透过分析star的时间以及他们之间的follow关系,可以推测出每个st

黄巍 442 Dec 21, 2022