InvTorch: memory-efficient models with invertible functions

Related tags

Deep Learninginvtorch
Overview

InvTorch: Memory-Efficient Invertible Functions

This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible functions. So, not only the intermediate activations will be released from memory. The input tensors get deallocated and recomputed later using the inverse function only in the backward pass. This is useful in extreme situations where more compute is traded with memory. However, there are few caveats to consider which are detailed here.

Installation

InvTorch has minimal dependencies. It only requires PyTorch version 1.10.0 or later.

conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install invtorch

Basic Usage

The main module that we are interested in is InvertibleModule which inherits from torch.nn.Module. Subclass it to implement your own invertible code.

import torch
from torch import nn
from invtorch import InvertibleModule


class InvertibleLinear(InvertibleModule):
    def __init__(self, in_features, out_features):
        super().__init__(invertible=True, checkpoint=True)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def function(self, inputs):
        outputs = inputs @ self.weight.T + self.bias
        requires_grad = self.do_require_grad(inputs, self.weight, self.bias)
        return outputs.requires_grad_(requires_grad)

    def inverse(self, outputs):
        return (outputs - self.bias) @ self.weight.T.pinverse()

Structure

You can immediately notice few differences to the regular PyTorch module here. There is no longer a need to define forward(). Instead, it is replaced with function(*inputs). Additionally, it is necessary to define its inverse function as inverse(*outputs). Both methods can only take one or more positional arguments and return a torch.Tensor or a tuple of outputs which can have anything including tensors.

Requires Gradient

function() must manually call .requires_grad_(True/False) on all output tensors. The forward pass is run in no_grad mode and there is no way to detect which output need gradients without tracing. It is possible to infer this from requires_grad values of the inputs and self.parameters(). The above code uses do_require_grad() which returns True if any input did require gradient.

Example

Now, this model is ready to be instantiated and used directly.

x = torch.randn(10, 3)
model = InvertibleLinear(3, 5)
print('Is invertible:', model.check_inverse(x))

y = model(x)
print('Output requires_grad:', y.requires_grad)
print('Input was freed:', x.storage().size() == 0)

y.backward(torch.randn_like(y))
print('Input was restored:', x.storage().size() != 0)

Checkpoint and Invertible Modes

InvertibleModule has two flags which control the mode of operation; checkpoint and invertible. If checkpoint was set to False, or when working in no_grad mode, or no input or parameter has requires_grad set to True, it acts exactly as a normal PyTorch module. Otherwise, the model is either invertible or an ordinary checkpoint depending on whether invertible is set to True or False, respectively. Those, flags can be changed at any time during operation without any repercussions.

Limitations

Under the hood, InvertibleModule uses invertible_checkpoint(); a low-level implementation which allows it to function. There are few considerations to keep in mind when working with invertible checkpoints and non-materialized tensors. Please, refer to the documentation in the code for more details.

Overriding forward()

Although forward() is now doing important things to ensure the validity of the results when calling invertible_checkpoint(), it can still be overridden. The main reason of doing so is to provide a more user-friendly interface; function signature and output format. For example, function() could return extra outputs that are not needed in the module outputs but are essential for correctly computing the inverse(). In such case, define forward() to wrap outputs = super().forward(*inputs) more cleanly.

TODOs

Here are few feature ideas that could be implemented to enrich the utility of this package:

  • Add more basic operations and modules
  • Add coupling and interleave -based invertible operations
  • Add more checks to help the user in debugging more features
  • Allow picking some inputs to not be freed in invertible mode
  • Context-manager to temporarily change the mode of operation
  • Implement dynamic discovery for outputs that requires_grad
  • Develop an automatic mode optimization for a network for various objectives
You might also like...
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Efficient-GlobalPointer - Pytorch Efficient GlobalPointer
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

Releases(v0.5.0)
Owner
Modar M. Alfadly
Deep learning researcher interested in understanding neural networks
Modar M. Alfadly
Codes for "Template-free Prompt Tuning for Few-shot NER".

EntLM The source codes for EntLM. Dependencies: Cuda 10.1, python 3.6.5 To install the required packages by following commands: $ pip3 install -r requ

77 Dec 27, 2022
Y. Zhang, Q. Yao, W. Dai, L. Chen. AutoSF: Searching Scoring Functions for Knowledge Graph Embedding. IEEE International Conference on Data Engineering (ICDE). 2020

AutoSF The code for our paper "AutoSF: Searching Scoring Functions for Knowledge Graph Embedding" and this paper has been accepted by ICDE2020. News:

AutoML Research 64 Dec 17, 2022
Code repository for our paper regarding the L3D dataset.

The Large Labelled Logo Dataset (L3D): A Multipurpose and Hand-Labelled Continuously Growing Dataset Website: https://lhf-labs.github.io/tm-dataset Da

LHF Labs 9 Dec 14, 2022
PassAPI is a password generator in hash format and fully developed in Python, with the aim of teaching how to handle and build

simple, elegant and safe Introduction PassAPI is a password generator in hash format and fully developed in Python, with the aim of teaching how to ha

Johnsz 2 Mar 02, 2022
WeakVRD-Captioning - Implementation of paper Improving Image Captioning with Better Use of Caption

WeakVRD-Captioning - Implementation of paper Improving Image Captioning with Better Use of Caption

30 Oct 28, 2022
The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track.

ISC21-Descriptor-Track-1st The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track. You can check our solution

lyakaap 73 Dec 24, 2022
Image Fusion Transformer

Image-Fusion-Transformer Platform Python 3.7 Pytorch =1.0 Training Dataset MS-COCO 2014 (T.-Y. Lin, M. Maire, S. Belongie, J. Hays, P. Perona, D. Ram

Vibashan VS 68 Dec 23, 2022
A PoC Corporation Relationship Knowledge Graph System on top of Nebula Graph.

Corp-Rel is a PoC of Corpartion Relationship Knowledge Graph System. It's built on top of the Open Source Graph Database: Nebula Graph with a dataset

Wey Gu 20 Dec 11, 2022
Code for the paper "Multi-task problems are not multi-objective"

Multi-Task problems are not multi-objective This is the code for the paper "Multi-Task problems are not multi-objective" in which we show that the com

Michael Ruchte 5 Aug 19, 2022
MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition

MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition Paper: MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition accepted fo

64 Dec 18, 2022
[ACM MM 2021] Diverse Image Inpainting with Bidirectional and Autoregressive Transformers

Diverse Image Inpainting with Bidirectional and Autoregressive Transformers Installation pip install -r requirements.txt Dataset Preparation Given the

Yingchen Yu 25 Nov 09, 2022
UniFormer - official implementation of UniFormer

UniFormer This repo is the official implementation of "Uniformer: Unified Transf

SenseTime X-Lab 573 Jan 04, 2023
Rainbow: Combining Improvements in Deep Reinforcement Learning

Rainbow Rainbow: Combining Improvements in Deep Reinforcement Learning [1]. Results and pretrained models can be found in the releases. DQN [2] Double

Kai Arulkumaran 1.4k Dec 29, 2022
Project ArXiv Citation Network

Project ArXiv Citation Network Overview This project involved the analysis of the ArXiv citation network. Usage The complete code of this project is i

Dennis Núñez-Fernández 5 Oct 20, 2022
Implementation of ProteinBERT in Pytorch

ProteinBERT - Pytorch (wip) Implementation of ProteinBERT in Pytorch. Original Repository Install $ pip install protein-bert-pytorch Usage import torc

Phil Wang 92 Dec 25, 2022
Learning Temporal Consistency for Low Light Video Enhancement from Single Images (CVPR2021)

StableLLVE This is a Pytorch implementation of "Learning Temporal Consistency for Low Light Video Enhancement from Single Images" in CVPR 2021, by Fan

99 Dec 19, 2022
Code and Resources for the Transformer Encoder Reasoning Network (TERN)

Transformer Encoder Reasoning Network Code for the cross-modal visual-linguistic retrieval method from "Transformer Reasoning Network for Image-Text M

Nicola Messina 53 Dec 30, 2022
Discerning Decision-Making Process of Deep Neural Networks with Hierarchical Voting Transformation

Configurations Change HOME_PATH in CONFIG.py as the current path Data Prepare CENSINCOME Download data Put census-income.data and census-income.test i

2 Aug 14, 2022
Code for Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021)

Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021) Hang Zhou, Yasheng Sun, Wayne Wu, Chen Cha

Hang_Zhou 628 Dec 28, 2022
The repo of Feedback Networks, CVPR17

Feedback Networks http://feedbacknet.stanford.edu/ Paper: Feedback Networks, CVPR 2017. Amir R. Zamir*,Te-Lin Wu*, Lin Sun, William B. Shen, Bertram E

Stanford Vision and Learning Lab 87 Nov 19, 2022