Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Overview

Invariant Point Attention - Pytorch

Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.

  • write up a test for invariance under rotation
  • enforce float32 for certain operations

Install

$ pip install invariant-point-attention

Usage

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = torch.randn(1, 256, 64)      # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask          = torch.ones(1, 256).bool()    # (batch x seq)

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)  # (batch x seq x rot1 x rot2) - example is identity
translations  = torch.zeros(1, 256, 3) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,
    heads = 8,
    require_pairwise_repr = False   # set this to False to use the module without pairwise representations
)

seq           = torch.randn(1, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

attn_out = attn(
    seq,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False

import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock

block = IPABlock(
    dim = 64,
    heads = 8,
    scalar_key_dim = 16,
    scalar_value_dim = 16,
    point_key_dim = 4,
    point_value_dim = 4
)

seq           = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

block_out = block(
    seq,
    pairwise_repr = pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)

# apply updates to rotations and translations for the next iteration

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
Comments
  • Computing point dist - use cartesian dimension instead of hidden dimension

    Computing point dist - use cartesian dimension instead of hidden dimension

    https://github.com/lucidrains/invariant-point-attention/blob/2f1fb7ca003d9c94d4144d1f281f8cbc914c01c2/invariant_point_attention/invariant_point_attention.py#L130

    I think it should be dim=-1, thus using the cartesian (xyz) axis, rather than dim=-2, which uses the hidden dimension.

    opened by aced125 3
  • In-place rotation detach not allowed

    In-place rotation detach not allowed

    Hi, this is probably highly version-dependent (I have pytorch=1.11.0, pytorch3d=0.7.0 nightly), but I thought I'd report it. Torch doesn't like the in-place detach of the rotation tensor. Full stack trace (from denoise.py):

    Traceback (most recent call last):
      File "denoise.py", line 56, in <module>
        denoised_coords = net(
      File "/home/pi-user/miniconda3/envs/piai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/pi-user/invariant-point-attention/invariant_point_attention/invariant_point_attention.py", line 336, in forward
        rotations.detach_()
    RuntimeError: Can't detach views in-place. Use detach() instead. If you are using DistributedDataParallel (DDP) for training, and gradient_as_bucket_view is set as True, gradients are views of DDP buckets, and hence detach_() cannot be called on these gradients. To fix this error, please refer to the Optimizer.zero_grad() function in torch/optim/optimizer.py as the solution.
    

    Switching to rotations = rotations.detach() seems to behave correctly (tested in denoise.py and my own code). I'm not totally sure if this allocates a separate tensor, or just creates a new node pointing to the same data.

    opened by sidnarayanan 1
  • Report a bug that causes instability in training

    Report a bug that causes instability in training

    Hi, I would like to report a bug in the rotation, that causes instability in training. https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L322

    The IPA Transformer is similar to the structure module in AF2, where the recycling is used. Note that we usually detach the gradient of rotation, which may causes instability during training. The reason is that the gradient of rotation would update the rotation during back propagation, which results in the instability based on experiments. Therefore we usually detach the rotation to dispel the updating effect of gradient descent. I have seen you do this in your alphafold2 repo (https://github.com/lucidrains/alphafold2).

    If you think this is a problem, please let me know. I am happy to submit a pr to fix that.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • Subtle mistake in the implementation

    Subtle mistake in the implementation

    Hi. Thanks for your implementation. It is very helpful. However, I find that you miss the dropout in the IPAModule.

    https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L239

    In the alphafold2 supplementary, the dropout is nested in the layer norm, which also holds true in the layer norm at transition layer (line 9 in the figure below). image

    If you think this is a problem, please let me know. I will submit a pr to fix it. Thanks again for sharing such an amazing repo.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • change quaternions update as original alphafold2

    change quaternions update as original alphafold2

    In the original alphafold2 IPA module, pure-quaternion (without real part) description is used for quaternion update. This can be broken down to the residual-update-like formulation. But in this code you use (1, a, b, c) style quaternion so I believe the quaternion update should be done as a simple multiply update. As far as I have tested, the loss seems to go down more efficiently with the modification.

    opened by ShintaroMinami 1
  • #126 maybe omit the 'self.point_attn_logits_scale'?

    #126 maybe omit the 'self.point_attn_logits_scale'?

    Hi luci:

    I read the original paper and compare it to your implement, found one place might be some mistake:

    #126. attn_logits_points = -0.5 * (point_dist * point_weights).sum(dim = -1),

    I thought it should be attn_logits_points = -0.5 * (point_dist * point_weights * self.point_attn_logits_scale).sum(dim = -1)

    Thanks for your sharing!

    opened by CiaoHe 1
  • Application of Invariant point attention : preserver part of structure.

    Application of Invariant point attention : preserver part of structure.

    Hi, lucidrian. First of all really thanks for your work!

    I have a question, how can I change(denoise) the structure only in the region I want, how do I do it? (denoise.py)

    opened by hw-protein 0
  • Equivariance test for IPA Transformer

    Equivariance test for IPA Transformer

    @lucidrains I would like to ask about the equivariance of the transformer (not IPA blocks). I wonder if you checked for the equivariance of the output when you allow the transformation of local points to global points using the updated quaternions and translations. I am not sure why this test fails in my case.

    opened by amrhamedp 1
Owner
Phil Wang
Working with Attention
Phil Wang
Starter kit for getting started in the Music Demixing Challenge.

Music Demixing Challenge - Starter Kit 👉 Challenge page This repository is the Music Demixing Challenge Submission template and Starter kit! Clone th

AIcrowd 106 Dec 20, 2022
The implementation of CVPR2021 paper Temporal Query Networks for Fine-grained Video Understanding, by Chuhan Zhang, Ankush Gupta and Andrew Zisserman.

Temporal Query Networks for Fine-grained Video Understanding 📋 This repository contains the implementation of CVPR2021 paper Temporal_Query_Networks

55 Dec 21, 2022
My freqtrade strategies

My freqtrade-strategies Hi there! This is repo for my freqtrade-strategies. My name is Ilya Zelenchuk, I'm a lecturer at the SPbU university (https://

171 Dec 05, 2022
official Pytorch implementation of ICCV 2021 paper FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting.

FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting By Rui Liu, Hanming Deng, Yangyi Huang, Xiaoyu Shi, Lewei Lu, Wenxiu

77 Dec 27, 2022
Generative code template for PixelBeasts 10k NFT project.

generator-template Generative code template for combining transparent png attributes into 10,000 unique images. Used for the PixelBeasts 10k NFT proje

Yohei Nakajima 9 Aug 24, 2022
HSC4D: Human-centered 4D Scene Capture in Large-scale Indoor-outdoor Space Using Wearable IMUs and LiDAR. CVPR 2022

HSC4D: Human-centered 4D Scene Capture in Large-scale Indoor-outdoor Space Using Wearable IMUs and LiDAR. CVPR 2022 [Project page | Video] Getting sta

51 Nov 29, 2022
Energy consumption estimation utilities for Jetson-based platforms

This repository contains a utility for measuring energy consumption when running various programs in NVIDIA Jetson-based platforms. Currently TX-2, NX, and AGX are supported.

OpenDR 10 Jun 17, 2022
code and models for "Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation"

Laplacian Pyramid Reconstruction and Refinement for Semantic Segmentation This repository contains code and models for the method described in: Golnaz

55 Jun 18, 2022
Visual dialog agents with pre-trained vision-and-language encoders.

Learning Better Visual Dialog Agents with Pretrained Visual-Linguistic Representation Or READ-UP: Referring Expression Agent Dialog with Unified Pretr

7 Oct 08, 2022
A general-purpose, flexible, and easy-to-use simulator alongside an OpenAI Gym trading environment for MetaTrader 5 trading platform (Approved by OpenAI Gym)

gym-mtsim: OpenAI Gym - MetaTrader 5 Simulator MtSim is a simulator for the MetaTrader 5 trading platform alongside an OpenAI Gym environment for rein

Mohammad Amin Haghpanah 184 Dec 31, 2022
Small utility to demangle Nim symbols in callgrind files

nim_callgrind A small utility to demangle Nim symbols from callgrind files. Usage Run your (Nim) program with something like this: valgrind --tool=cal

kraptor 3 Feb 15, 2022
Learning Optical Flow from a Few Matches (CVPR 2021)

Learning Optical Flow from a Few Matches This repository contains the source code for our paper: Learning Optical Flow from a Few Matches CVPR 2021 Sh

Shihao Jiang (Zac) 159 Dec 16, 2022
Semantic Image Synthesis with SPADE

Semantic Image Synthesis with SPADE New implementation available at imaginaire repository We have a reimplementation of the SPADE method that is more

NVIDIA Research Projects 7.3k Jan 07, 2023
Deep Learning Training Scripts With Python

Deep Learning Training Scripts DNN Frameworks Caffe PyTorch Tensorflow CNN Models VGG ResNet DenseNet Inception Language Modeling GatedCNN-LM Attentio

Multicore Computing Research Lab 16 Dec 15, 2022
Generic Event Boundary Detection: A Benchmark for Event Segmentation

Generic Event Boundary Detection: A Benchmark for Event Segmentation We release our data annotation & baseline codes for detecting generic event bound

47 Nov 22, 2022
PyTorch implementations for our SIGGRAPH 2021 paper: Editable Free-viewpoint Video Using a Layered Neural Representation.

st-nerf We provide PyTorch implementations for our paper: Editable Free-viewpoint Video Using a Layered Neural Representation SIGGRAPH 2021 Jiakai Zha

Diplodocus 258 Jan 02, 2023
A DCGAN to generate anime faces using custom mined dataset

Anime-Face-GAN-Keras A DCGAN to generate anime faces using custom dataset in Keras. Dataset The dataset is created by crawling anime database websites

Pavitrakumar P 190 Jan 03, 2023
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
An easier way to build neural search on the cloud

An easier way to build neural search on the cloud Jina is a deep learning-powered search framework for building cross-/multi-modal search systems (e.g

Jina AI 17k Jan 02, 2023
On Uncertainty, Tempering, and Data Augmentation in Bayesian Classification

Understanding Bayesian Classification This repository hosts the code to reproduce the results presented in the paper On Uncertainty, Tempering, and Da

Sanyam Kapoor 18 Nov 17, 2022