Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Overview

Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review

Install

$ pip install lambda-networks

Usage

Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)

Citations

@inproceedings{
    anonymous2021lambdanetworks,
    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=xTJEN-ggl1b},
    note={under review}
}
Comments
  • Contiguity problem:

    Contiguity problem: "RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input."

    It seems that LambaLayer breaks contiguity when I try it.

    layer(x).is_contiguous()
    >> False
    

    I have to use .contiguous() where I train with it, is it normal?

    opened by Whiax 5
  • Warning: Mixed memory format inputs detected while calling the operator.

    Warning: Mixed memory format inputs detected while calling the operator.

    I have added lambda layers to every block of Resnet, but the following warning will appear. Will it affect the result?

    Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())

    opened by Pluto1314 4
  • Implementation of Lambda convolution

    Implementation of Lambda convolution

    Thanks for the great work in the implementations!

    I would like to ask whether there is a difference in using 'Conv2d' as suggested in Eq. 3 in the paper and your implementation of 'Conv3d'. These two convs treat the (h x w)-dimension as a 1-d sequence and a 2-d image, respectively. I believe they are quite different in concept.

    Point out if I misunderstood.

    Thx a lot.

    opened by romulus0914 3
  • How lambda layer handle the downsample in LambdaResNet?

    How lambda layer handle the downsample in LambdaResNet?

    Hi, Thanks for your clear code, i try to implement the LambdaResNet. Does lambda layer replace all conv2d layer? If so, how does lambda layer handle the downsample in conv2d, like stride=2? Or just keep the conv2d if stride =2, replace only the conv2d layers in stride =1?

    opened by qiaoran-dawnlight 2
  • question about hybrid lambdaResnet

    question about hybrid lambdaResnet

    Hi,

    In the paper, there is this paragraph:

    When working with hybrid LambdaNetworks, we use a single lambda layer in c4 for LambdaResNet50, 3 lambda layers for LambdaResNet101, 6 lambda layers for LambdaResNet-152/200/270/350 and 8 lambda layers for LambdaResNet-420.

    I have several questions about constructing the hybrid lambdaResnet:

    1. Do we only need to replace the 3x3conv with lambda layer in the C4 stage rather than C4 and C5(as in the ablation study)?
    2. When there is more than 1 lambda layers, such as the case of LambdaResNet101, are we replacing the 3x3conv with 3 lambda layers? And in the resnet50 case, we replace the 3x3conv with 1 lambda layers ?
    opened by CoinCheung 2
  • Question: Is there an easy way to visualise lambdas?

    Question: Is there an easy way to visualise lambdas?

    I want to train classifier and tell what regions it pays the most.. well.. attention to :) And make this simultaneously with an inference without using gradcam etc Can I do this?

    opened by lebionick 1
  • Fix relative positional attention (position lambda)

    Fix relative positional attention (position lambda)

    Hi lucidrains, Thanks to your nice work.

    I found an error in the position lambda (relative positional attention) implementation. Relative positional attention, λp, should be translation equivariance, as written in the paper Sec. 3.2. It means that the positional embedding has a constraint, E[n, m] = E[t(n), t(m)], but it is missed in current implementation. This PR fixes it by adding the translation equivariance constraint. I checked that this PR improves the result in my experiment.

    NOTE that this PR modify the function parameter n, from total area (n=w*h) to length of each side (n=w=h).

    opened by khanrc 1
  • Use of Keras Lambda

    Use of Keras Lambda

    Hey! Thank for the awesome implementations :D

    I was wondering why the use of tf.keras.layers.Lambda? Seems unnecessary, regular calls to TF operations works and is more readable.

    https://github.com/lucidrains/lambda-networks/blob/06a48f2a5b41f3cd278aee67838c32051a0a9bed/lambda_networks/tfkeras.py#L73

    You can also call the functional version of the softmax instead.

    opened by cgarciae 1
  • Lambda for a sequence of images

    Lambda for a sequence of images

    Thanks for the quick implementation!

    I have a problem where I have a sequence of images rather than 1. (a video) So instead of having a dimension batch, channels, height, width, I also have after batch a length dimension to determine sequence length.

    Given a known max_length (for the positional embedding), in forward, should conv4d be used instead of 3d to allow interaction between frames?

    In the paper they do mention this could serve as a general framework for sequences of images, so I wonder if you explored that in implementation (where obviously a single image is just a case where length=1)

    opened by AmitMY 1
  • why flops so high

    why flops so high

    I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

    opened by zisuina 0
  • How to load_model correctly

    How to load_model correctly

    Hi everyone, I am struggling when loading this model saved in .h5 file.

    How is the correct way to load this network? If I use custom_objects I get init() got an unexpected keyword argument 'name'

    opened by JNaranjo-Alcazar 0
  • Please add clarity to code

    Please add clarity to code

    so Phil - I love your work - I wish you could go extra few steps to help out users. I found this class by François-Guillaume @frgfm - which adds in clear math coments. I want to merge it but there's a bit of code drift don't want to introduce any bugs. I beseech you to go extra step to help users bridge from papers to code.

    https://github.com/frgfm/Holocron/blob/bcc3ea19a477e4b28dc5973cdbe92a9b05c690bb/holocron/nn/modules/lambda_layer.py

    eg. please articulate return types def forward(self, x: torch.Tensor) -> torch.Tensor:

    Please give any clarity in arguments. # Project input and context to get queries, keys & values

    Throw in some maths as a comment / this is great as it bridges the paper to the code.

    B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)

    import torch
    from torch import nn, einsum
    import torch.nn.functional as F
    from typing import Optional
    
    __all__ = ['LambdaLayer']
    
    
    class LambdaLayer(nn.Module):
        """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
        <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
        <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
        Args:
            in_channels (int): input channels
            out_channels (int, optional): output channels
            dim_k (int): key dimension
            n (int, optional): number of input pixels
            r (int, optional): receptive field for relative positional encoding
            num_heads (int, optional): number of attention heads
            dim_u (int, optional): intra-depth dimension
        """
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            dim_k: int,
            n: Optional[int] = None,
            r: Optional[int] = None,
            num_heads: int = 4,
            dim_u: int = 1
        ) -> None:
            super().__init__()
            self.u = dim_u
            self.num_heads = num_heads
    
            if out_channels % num_heads != 0:
                raise AssertionError('values dimension must be divisible by number of heads for multi-head query')
            dim_v = out_channels // num_heads
    
            # Project input and context to get queries, keys & values
            self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False)
            self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False)
            self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False)
    
            self.norm_q = nn.BatchNorm2d(dim_k * num_heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
    
            self.local_contexts = r is not None
            if r is not None:
                if r % 2 != 1:
                    raise AssertionError('Receptive kernel size should be odd')
                self.padding = r // 2
                self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
            else:
                if n is None:
                    raise AssertionError('You must specify the total sequence length (h x w)')
                self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            b, c, h, w = x.shape
    
            # Project inputs & context to retrieve queries, keys and values
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
    
            # Normalize queries & values
            q = self.norm_q(q)
            v = self.norm_v(v)
    
            # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
            q = q.reshape(b, self.num_heads, -1, h * w)
            # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)
            k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
            # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)
            v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
    
            # Normalized keys
            k = k.softmax(dim=-1)
    
            # Content function
            λc = einsum('b u k m, b u v m -> b k v', k, v)
            Yc = einsum('b h k n, b k v -> b n h v', q, λc)
    
            # Position function
            if self.local_contexts:
                # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W
                v = v.reshape(b, self.u, v.shape[2], h, w)
                λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
                Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
            else:
                λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
                Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
    
            Y = Yc + Yp
            # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W
            out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)
            return out
    
    opened by johndpope 1
  • Image Size

    Image Size

    Are non-square image blocks allowed for context? Using global context and a non-square dimensions (96, 128), I get an error on this line about dimension size.

    λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)

    opened by anklebreaker 0
  • LambdaResNet Implementation?

    LambdaResNet Implementation?

    I have been looking around and found one implementation of LambdaResNets, although there seem to be some metric performance problems and I've found wall-clock performance problems (runs ~7x slower than normal resnets).

    Do you plan on putting out a lambdaresnet model in this repository?

    opened by nollied 4
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022
Riemannian Adaptive Optimization Methods with pytorch optim

geoopt Manifold aware pytorch.optim. Unofficial implementation for “Riemannian Adaptive Optimization Methods” ICLR2019 and more. Installation Make sur

642 Jan 03, 2023
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Lambda Networks - Pytorch Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ l

Phil Wang 1.5k Jan 07, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
PyTorch toolkit for biomedical imaging

farabio is a minimal PyTorch toolkit for out-of-the-box deep learning support in biomedical imaging. For further information, see Wikis and Docs.

San Askaruly 47 Dec 28, 2022
270 Dec 24, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
Training PyTorch models with differential privacy

Opacus is a library that enables training PyTorch models with differential privacy. It supports training with minimal code changes required on the cli

1.3k Dec 29, 2022
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 07, 2023
Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

PyTorch Implementation of Differentiable ODE Solvers This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpr

Ricky Chen 4.4k Jan 04, 2023
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022