Implementation of ConvMixer for "Patches Are All You Need? 🤷"

Overview

Patches Are All You Need? 🤷

This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?" by Asher Trockman and Zico Kolter.

🔎 New: Check out this repository for training ConvMixers on CIFAR-10.

Code overview

The most important code is in convmixer.py. We trained ConvMixers using the timm framework, which we copied from here.

Update: ConvMixer is now integrated into the timm framework itself. You can see the PR here.

Inside pytorch-image-models, we have made the following modifications. (Though one could look at the diff, we think it is convenient to summarize them here.)

  • Added ConvMixers
    • added timm/models/convmixer.py
    • modified timm/models/__init__.py
  • Added "OneCycle" LR Schedule
    • added timm/scheduler/onecycle_lr.py
    • modified timm/scheduler/scheduler.py
    • modified timm/scheduler/scheduler_factory.py
    • modified timm/scheduler/__init__.py
    • modified train.py (added two lines to support this LR schedule)

We are confident that the use of the OneCycle schedule here is not critical, and one could likely just as well train ConvMixers with the built-in cosine schedule.

Evaluation

We provide some model weights below:

Model Name Kernel Size Patch Size File Size
ConvMixer-1536/20 9 7 207MB
ConvMixer-768/32* 7 7 85MB
ConvMixer-1024/20 9 14 98MB

* Important: ConvMixer-768/32 here uses ReLU instead of GELU, so you would have to change convmixer.py accordingly (we will fix this later).

You can evaluate ConvMixer-1536/20 as follows:

python validate.py --model convmixer_1536_20 --b 64 --num-classes 1000 --checkpoint [/path/to/convmixer_1536_20_ks9_p7.pth.tar] [/path/to/ImageNet1k-val]

You should get a 81.37% accuracy.

Training

If you had a node with 10 GPUs, you could train a ConvMixer-1536/20 as follows (these are exactly the settings we used):

sh distributed_train.sh 10 [/path/to/ImageNet1k] 
    --train-split [your_train_dir] 
    --val-split [your_val_dir] 
    --model convmixer_1536_20 
    -b 64 
    -j 10 
    --opt adamw 
    --epochs 150 
    --sched onecycle 
    --amp 
    --input-size 3 224 224
    --lr 0.01 
    --aa rand-m9-mstd0.5-inc1 
    --cutmix 0.5 
    --mixup 0.5 
    --reprob 0.25 
    --remode pixel 
    --num-classes 1000 
    --warmup-epochs 0 
    --opt-eps=1e-3 
    --clip-grad 1.0

We also included a ConvMixer-768/32 in timm/models/convmixer.py (though it is simple to add more ConvMixers). We trained that one with the above settings but with 300 epochs instead of 150 epochs.

Note: If you are training on CIFAR-10 instead of ImageNet-1k, we recommend setting --scale 0.75 1.0 as well, since the default value of 0.08 1.0 does not make sense for 32x32 inputs.

The tweetable version of ConvMixer, which requires from torch.nn import *:

def ConvMixer(h,d,k,p,n):
 S,C,A=Sequential,Conv2d,lambda x:S(x,GELU(),BatchNorm2d(h))
 R=type('',(S,),{'forward':lambda s,x:s[0](x)+x})
 return S(A(C(3,h,p,p)),*[S(R(A(C(h,h,k,groups=h,padding=k//2))),A(C(h,h,1))) for i in range(d)],AdaptiveAvgPool2d(1),Flatten(),Linear(h,n))
Comments
  • Cifar10 baseline doesn't reach 95%

    Cifar10 baseline doesn't reach 95%

    Hello, I tried convmixer256 on Cifar-10 with the same timm options specified for ImageNet (except the num_classes) and it doesn't go beyond 90% accuracy. Could you please specify the options used for Cifar-10 experiment ?

    opened by K-H-Ismail 13
  • What's new about this model?

    What's new about this model?

    Why “patches” are all you need? Patch embedding is Conv7x7 stem, The body is simply repeated Conv9x9 + Conv1x1, (Not challenging your work, it's indeed very interesting), but just kindly wondering what's new about this model?

    opened by vztu 5
  • Training scheme modifications for small GPUs

    Training scheme modifications for small GPUs

    Hi authors. Your paper has demonstrated a quite intriguing observation. I wish you luck with your submission. Thanks for sharing the code of the submission. When running the code, I got an issue regarding OOM when using the default batch size of 64. In the end I can only run with 8 samples per batch per GPU as my GPUs have only 11GB. I would like to know if you have tried smaller GPUs and achieved the same results. So far, besides learning rate modified according to the linear rule, I haven't made any change yet. If you tried training using smaller GPUs before, could you please share your experience? Thank you very much!

    opened by justanhduc 4
  • Experiments with full convolutional layers instead of patch embedding?

    Experiments with full convolutional layers instead of patch embedding?

    Have the author tried to replace the patch embedding with the just convolution?That is, using 1 stride instead of p?

    With this setting, this is a standard convolution network like MobileNet. I wonder what would be the performance?Is the performance gain of Convmix due to the patch embedding or the depthwise conv layers?

    Very interested in this work, thanks.

    opened by forjiuzhou 2
  • Training time

    Training time

    Hi, first of all thanks for a very interesting paper.

    I would like to know how long did it take you to train the models? I'm trying to train ConvMixer-768/32 using 2xV100 and one epoch is ~3 hours, so I would estimate that full training would take ~= 2 * 3 * 300 ~= 1800 GPU hours, which is insane. Even if you trained with 10 GPUs it would take ~1 week for one experiment to finish. Are my calculations correct?

    opened by bonlime 1
  • padding=same?

    padding=same?

    https://github.com/tmp-iclr/convmixer/blob/1cefd860a1a6a85369887d1a633425cedc2afd0a/convmixer.py#L18 There is an error:TypeError: conv2d(): argument 'padding' (position 5) must be tuple of ints, not str.

    opened by linhaoqi027 1
  • Add Docker environment & web demo

    Add Docker environment & web demo

    Hey @ashertrockman, @tmp-iclr ! wave

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can try out your model! View it here: https://replicate.com/locuslab/convmixer and have a look at some Image classification examples we already uploaded.

    By clicking "Claim this model" You'll be able to edit the everything, and we'll feature it on our website and tweet about it too.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. blush

    opened by ariel415el 0
  • Add Docker environment & web demo

    Add Docker environment & web demo

    Hey @ashertrockman, @tmp-iclr ! 👋

    This pull request makes it possible to run your model inside a Docker environment, which makes it easier for other people to run it. We're using an open source tool called Cog to make this process easier.

    This also means we can make a web page where other people can try out your model! View it here: https://replicate.com/locuslab/convmixer and have a look at some Image classification examples we already uploaded.

    By clicking "Claim this model" You'll be able to edit the everything, and we'll feature it on our website and tweet about it too.

    In case you're wondering who I am, I'm from Replicate, where we're trying to make machine learning reproducible. We got frustrated that we couldn't run all the really interesting ML work being done. So, we're going round implementing models we like. 😊

    opened by ariel415el 0
  • Fix notebooks

    Fix notebooks

    Hi.

    Fixed errors in pytorch-image-models/notebooks/{EffResNetComparison,GeneralizationToImageNetV2}.ipynb notebooks:

    • added missed pynvml installation;
    • resolved missed imports;
    • resolved errors due to outdated calls of timm library.

    Tested in colab env: "Run all" without any errors.

    opened by amrzv 0
  • CIFAR-10 training settings

    CIFAR-10 training settings

    First of all, thank you for the interesting work. I was experimenting the one with patch size 1 and kernel size 9 with CIFAR-10 with the following training settings:

    --model tiny_convmixer
     -b 64 -j 8 
    --opt adamw 
    --epochs 200 
    --sched onecycle 
    --amp 
    --input-size 3 32 32 
    --lr 0.01 
    --aa rand-m9-mstd0.5-inc1 
    --cutmix 0.5 
    --mixup 0.5 
    --reprob 0.25 
    --remode pixel 
    --num-classes 10
    --warmup-epochs 0
    --opt-eps 1e-3
    --clip-grad 1.0
    --scale 0.75 1.0
    --weight-decay 0.01
    --mean 0.4914 0.4822 0.4465
    --std 0.2471 0.2435 0.2616
    

    I could get only 95.89%. I am supposed to get 96.03% according to Table 4 in the paper. Can you please let me know any setting I missed? Thank you again.

    opened by fugokidi 0
  • Segmentation ConvMixer architecture ?

    Segmentation ConvMixer architecture ?

    I was trying to figure what a Segmentation ConvMixer would look like, and came up with that (residual connection inspired by MultiResUNet). Does it make sense to you ?

    image

    opened by divideconcept 0
  • Request more experiment results to compare to other architecture.

    Request more experiment results to compare to other architecture.

    Hi! This work is pretty interesting, but I think there should are more results like in "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight" as they replace local self-attention with depth-wise convolution in Swin Transformer. Since you conduct an advanced one with a more simple architecture compared to SwinTransformer, so I wonder if ConvMixer can get similar performance on object detection and semantic segmentation.

    opened by LuoXin-s 1
Releases(timm-v1.0)
Owner
CMU Locus Lab
Zico Kolter's Research Group
CMU Locus Lab
Garbage Detection system which will detect objects based on whether it is plastic waste or plastics or just garbage.

Garbage Detection using Yolov5 on Jetson Nano 2gb Developer Kit. Garbage detection system which will detect objects based on whether it is plastic was

Rishikesh A. Bondade 2 May 13, 2022
Train the HRNet model on ImageNet

High-resolution networks (HRNets) for Image classification News [2021/01/20] Add some stronger ImageNet pretrained models, e.g., the HRNet_W48_C_ssld_

HRNet 866 Jan 04, 2023
Implementing DropPath/StochasticDepth in PyTorch

%load_ext memory_profiler Implementing Stochastic Depth/Drop Path In PyTorch DropPath is available on glasses my computer vision library! Introduction

Francesco Saverio Zuppichini 13 Jan 05, 2023
Open source Python module for computer vision

About PCV PCV is a pure Python library for computer vision based on the book "Programming Computer Vision with Python" by Jan Erik Solem. More details

Jan Erik Solem 1.9k Jan 06, 2023
Distributional Sliced-Wasserstein distance code

Distributional Sliced Wasserstein distance This is a pytorch implementation of the paper "Distributional Sliced-Wasserstein and Applications to Genera

VinAI Research 39 Jan 01, 2023
An SMPC companion library for Syft

SyMPC A library that extends PySyft with SMPC support SyMPC /ˈsɪmpəθi/ is a library which extends PySyft ≥0.3 with SMPC support. It allows computing o

Arturo Marquez Flores 0 Oct 13, 2021
Research code for the paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models"

Introduction This repository contains research code for the ACL 2021 paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual

AdapterHub 20 Aug 04, 2022
DockStream: A Docking Wrapper to Enhance De Novo Molecular Design

DockStream Description DockStream is a docking wrapper providing access to a collection of ligand embedders and docking backends. Docking execution an

AstraZeneca - Molecular AI 72 Jan 02, 2023
[ICML 2021] A fast algorithm for fitting robust decision trees.

GROOT: Growing Robust Trees Growing Robust Trees (GROOT) is an algorithm that fits binary classification decision trees such that they are robust agai

Cyber Analytics Lab 17 Nov 21, 2022
Evaluation toolkit of the informative tracking benchmark comprising 9 scenarios, 180 diverse videos, and new challenges.

Informative-tracking-benchmark Informative tracking benchmark (ITB) higher diversity. It contains 9 representative scenarios and 180 diverse videos. m

Xin Li 15 Nov 26, 2022
Learning Generative Models of Textured 3D Meshes from Real-World Images, ICCV 2021

Learning Generative Models of Textured 3D Meshes from Real-World Images This is the reference implementation of "Learning Generative Models of Texture

Dario Pavllo 115 Jan 07, 2023
Fine-grained Post-training for Improving Retrieval-based Dialogue Systems - NAACL 2021

Fine-grained Post-training for Multi-turn Response Selection Implements the model described in the following paper Fine-grained Post-training for Impr

Janghoon Han 83 Dec 20, 2022
The backbone CSPDarkNet of YOLOX.

YOLOX-Backbone The backbone CSPDarkNet of YOLOX. In this project, you can enjoy: CSPDarkNet-S CSPDarkNet-M CSPDarkNet-L CSPDarkNet-X CSPDarkNet-Tiny C

Jianhua Yang 9 Aug 22, 2022
Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized

VQGAN-CLIP-Docker About Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized This is a stripped and minimal dependency repository for running loca

Kevin Costa 73 Sep 11, 2022
Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Dataset Distillation by Matching Training Trajectories Project Page | Paper This repo contains code for training expert trajectories and distilling sy

George Cazenavette 256 Jan 05, 2023
This a classic fintech problem that introduces real life difficulties such as data imbalance. Check out the notebook to find out more!

Credit Card Fraud Detection Introduction Online transactions have become a crucial part of any business over the years. Many of those transactions use

Jonathan Hasbani 0 Jan 20, 2022
This repository contains the code and models for the following paper.

DC-ShadowNet Introduction This is an implementation of the following paper DC-ShadowNet: Single-Image Hard and Soft Shadow Removal Using Unsupervised

AuAgCu 65 Dec 27, 2022
A Python framework for developing parallelized Computational Fluid Dynamics software to solve the hyperbolic 2D Euler equations on distributed, multi-block structured grids.

pyHype: Computational Fluid Dynamics in Python pyHype is a Python framework for developing parallelized Computational Fluid Dynamics software to solve

Mohamed Khalil 21 Nov 22, 2022
OpenCV, MediaPipe Pose Estimation, Affine Transform for Icon Overlay

Yoga Pose Identification and Icon Matching Project Goal Detect yoga poses performed by a user and overlay a corresponding icon image. Running the main

Anna Garverick 1 Dec 03, 2021
Restricted Boltzmann Machines in Python.

How to Use First, initialize an RBM with the desired number of visible and hidden units. rbm = RBM(num_visible = 6, num_hidden = 2) Next, train the m

Edwin Chen 928 Dec 30, 2022