This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

Related tags

Deep LearningMesa
Overview

Mesa: A Memory-saving Training Framework for Transformers

This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for Transformers.

By Zizheng Pan, Peng Chen, Haoyu He, Jing Liu, Jianfei Cai and Bohan Zhuang.

image-20211116105242785

Installation

  1. Create a virtual environment with anaconda.

    conda create -n mesa python=3.7 -y
    conda activate mesa
    
    # Install PyTorch, we use PyTorch 1.7.1 with CUDA 10.1 
    pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    
    # Install ninja
    pip install ninja
  2. Build and install Mesa.

    # cloen this repo
    git clone https://github.com/zhuang-group/Mesa
    # build
    cd Mesa/
    # You need to have an NVIDIA GPU
    python setup.py develop

Usage

  1. Prepare your policy and save as a text file, e.g. policy.txt.

    on gelu: # layer tag, choices: fc, conv, gelu, bn, relu, softmax, matmul, layernorm
        by_index: all # layer index
        enable: True # enable for compressing
        level: 256 # we adopt 8-bit quantization by default
        ema_decay: 0.9 # the decay rate for running estimates
        
        by_index: 1 2 # e.g. exluding GELU layers that indexed by 1 and 2.
        enable: False
  2. Next, you can wrap your model with Mesa by:

    import mesa as ms
    ms.policy.convert_by_num_groups(model, 3)
    # or convert by group size with ms.policy.convert_by_group_size(model, 64)
    
    # setup compression policy
    ms.policy.deploy_on_init(model, '[path to policy.txt]', verbose=print, override_verbose=False)

    That's all you need to use Mesa for memory saving.

    Note that convert_by_num_groups and convert_by_group_size only recognize nn.XXX, if your code has functional operations, such as [email protected] and F.Softmax, you may need to manually setup these layers. For example:

    # matrix multipcation (before)
    out = Q@K.transpose(-2, -1)
    # with Mesa
    self.mm = ms.MatMul(quant_groups=3)
    out = self.mm(q, k.transpose(-2, -1))
    
    # sofmtax (before)
    attn = attn.softmax(dim=-1)
    # with Mesa
    self.softmax = ms.Softmax(dim=-1, quant_groups=3)
    attn = self.softmax(attn)
  3. You can also target one layer by:

    import mesa as ms
    # previous 
    self.act = nn.GELU()
    # with Mesa
    self.act = ms.GELU(quant_groups=[num of quantization groups])

Demo projects for DeiT and Swin

We provide demo projects to replicate our results of training DeiT and Swin with Mesa, please refer to DeiT-Mesa and Swin-Mesa.

Results on ImageNet

Model Param (M) FLOPs (G) Train Memory Top-1 (%)
DeiT-Ti 5 1.3 4,171 71.9
DeiT-Ti w/ Mesa 5 1.3 1,858 72.1
DeiT-S 22 4.6 8,459 79.8
DeiT-S w/ Mesa 22 4.6 3,840 80.0
DeiT-B 86 17.5 17,691 81.8
DeiT-B w/ Mesa 86 17.5 8,616 81.8
Swin-Ti 29 4.5 11,812 81.3
Swin-Ti w/ Mesa 29 4.5 5,371 81.3
PVT-Ti 13 1.9 7,800 75.1
PVT-Ti w/ Mesa 13 1.9 3,782 74.9

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Acknowledgments

This repository has adopted part of the quantization codes from ActNN, we thank the authors for their open-sourced code.

Owner
Zhuang AI Group
Zhuang AI Group
Annotate with anyone, anywhere.

h h is the web app that serves most of the https://hypothes.is/ website, including the web annotations API at https://hypothes.is/api/. The Hypothesis

Hypothesis 2.6k Jan 08, 2023
Semantic Scholar's Author Disambiguation Algorithm & Evaluation Suite

S2AND This repository provides access to the S2AND dataset and S2AND reference model described in the paper S2AND: A Benchmark and Evaluation System f

AI2 54 Nov 28, 2022
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
This is the formal code implementation of the CVPR 2022 paper 'Federated Class Incremental Learning'.

Official Pytorch Implementation for GLFC [CVPR-2022] Federated Class-Incremental Learning This is the official implementation code of our paper "Feder

Race Wang 57 Dec 27, 2022
Codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense neural networks

DominoSearch This is repository for codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense n

11 Sep 10, 2022
Lucid library adapted for PyTorch

Lucent PyTorch + Lucid = Lucent The wonderful Lucid library adapted for the wonderful PyTorch! Lucent is not affiliated with Lucid or OpenAI's Clarity

Lim Swee Kiat 520 Dec 26, 2022
BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work

BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work. For this project, I used the sigmoid function as an activation

Manas Bommakanti 1 Jan 22, 2022
Checking fibonacci - Generating the Fibonacci sequence is a classic recursive problem

Fibonaaci Series Generating the Fibonacci sequence is a classic recursive proble

Moureen Caroline O 1 Feb 15, 2022
Voila - Voilà turns Jupyter notebooks into standalone web applications

Rendering of live Jupyter notebooks with interactive widgets. Introduction Voilà turns Jupyter notebooks into standalone web applications. Unlike the

Voilà Dashboards 4.5k Jan 03, 2023
The authors' official PyTorch SigWGAN implementation

The authors' official PyTorch SigWGAN implementation This repository is the official implementation of [Sig-Wasserstein GANs for Time Series Generatio

9 Jun 16, 2022
AEI: Actors-Environment Interaction with Adaptive Attention for Temporal Action Proposals Generation

AEI: Actors-Environment Interaction with Adaptive Attention for Temporal Action Proposals Generation A pytorch-version implementation codes of paper:

11 Dec 13, 2022
Official implement of "CAT: Cross Attention in Vision Transformer".

CAT: Cross Attention in Vision Transformer This is official implement of "CAT: Cross Attention in Vision Transformer". Abstract Since Transformer has

100 Dec 15, 2022
Official implementation for "Symbolic Learning to Optimize: Towards Interpretability and Scalability"

Symbolic Learning to Optimize This is the official implementation for ICLR-2022 paper "Symbolic Learning to Optimize: Towards Interpretability and Sca

VITA 8 Dec 19, 2022
Script utilizando OpenCV e modelo Machine Learning para detectar o uso de máscaras.

Reconhecendo máscaras Este repositório contém um script em Python3 que reconhece se um rosto está ou não portando uma máscara! O código utiliza da bib

Maria Eduarda de Azevedo Silva 168 Oct 20, 2022
A project that uses optical flow and machine learning to detect aimhacking in video clips.

waldo-anticheat A project that aims to use optical flow and machine learning to visually detect cheating or hacking in video clips from fps games. Che

waldo.vision 542 Dec 03, 2022
Joint Unsupervised Learning (JULE) of Deep Representations and Image Clusters.

Joint Unsupervised Learning (JULE) of Deep Representations and Image Clusters. Overview This project is a Torch implementation for our CVPR 2016 paper

Jianwei Yang 278 Dec 25, 2022
Code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectrograms, using the PyTorch Lightning.

stereoEEG2speech We provide code for a seq2seq architecture with Bahdanau attention designed to map stereotactic EEG data from human brains to spectro

15 Nov 11, 2022
[AAAI2021] The source code for our paper 《Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion》.

DSM The source code for paper Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion Project Website; Datasets li

Jinpeng Wang 114 Oct 16, 2022
Aalto-cs-msc-theses - Listing of M.Sc. Theses of the Department of Computer Science at Aalto University

Aalto-CS-MSc-Theses Listing of M.Sc. Theses of the Department of Computer Scienc

Jorma Laaksonen 3 Jan 27, 2022