Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

Overview

🦩 Flamingo - Pytorch

Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks

Install

$ pip install flamingo-pytorch

Usage

import torch
from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
    dim = 1024,
    depth = 2,
    dim_head = 64,
    heads = 8,
    num_latents = 64,    # the number of latents to shrink your media sequence to, perceiver style
    num_time_embeds = 4  # say you have 4 images maximum in your dialogue
)

medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)

Then you insert the GatedCrossAttentionBlock at different intervals in your giant language model. Your text would then attend to the perceived media from above

The recommended way to derive the media_locations boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do media_locations = text_id == media_token_id

import torch
from flamingo_pytorch import GatedCrossAttentionBlock

cross_attn = GatedCrossAttentionBlock(
    dim = 1024,
    dim_head = 64,
    heads = 8
)

text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)

media_locations = torch.randint(0, 2, (1, 512)).bool()

text = cross_attn(
    text,
    perceived,
    media_locations = media_locations
)

That's it!

Attention is all you need.

Full working example with Flamingo + PaLM 🌴 🦩 🌴

Integration with PaLM

First install vit-pytorch for the vision encoder

$ pip install vit-pytorch

Then

from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(vit, return_embeddings_only = True)

# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library

import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)

# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))

palm_logits = flamingo_palm(text)

# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

It is quite evident where this is all headed if you think beyond just images.

Inception

For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.

Citations

@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • PerceiverResampler missing some LayerNorms?

    PerceiverResampler missing some LayerNorms?

    Hey, it feels like PerceiverResampler is missing some LayerNorms? it seems to me we should layer-norm x before sending to attentions loop, and may be add layer-norm to ff(latents) + latents?

    opened by inspirit 7
  • Missing flatten op in PerceiverResampler?

    Missing flatten op in PerceiverResampler?

    Hi, It seems that Flamingo did "x_f = flatten(x_f) # [T, S, d] -> [T * S, d]" (batch size == 1) before putting x_f to attention.

    So, it should be like: medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension) perceived = perceive(medias) # (1, 64, 1024) - (batch, num latents, dimension)

    ??

    opened by zengyan-97 6
  • wrong attention masks?

    wrong attention masks?

    https://github.com/lucidrains/flamingo-pytorch/blob/44920f4191ba3c280ff84c6ebc76025656d1dab5/flamingo_pytorch/flamingo_pytorch.py#L159

    In the flamingo paper, the language features in the gated cross-attention layers only attend to the visual features from the immediate preceding image. I believe your attention masks are created in such a way that they attend to the visual features from all preceding images. Can you confirm? If so, a fix would be to simply change the '>=' to '=='.

    opened by dhansmair 4
  • zeroing out attention not working

    zeroing out attention not working

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_pytorch.py#L179

    you are not using the inplace version of the function: https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_

    so I think this line does not have an effect.

    Best, David

    opened by dhansmair 2
  • Applying parallel attn with ff to existing pretrained model?

    Applying parallel attn with ff to existing pretrained model?

    Hi - awesome work! I am trying to understand ? I couldn't find a paper - only a reference to https://github.com/kingoflolz/mesh-transformer-jax. Is this right? Am I understanding that it is bascially applying multiple operations of for qkv and ff at once? Is it possible to use this trick to modify an existing pretrained model?

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_palm.py#L90

    Many thanks in advance!

    Huu

    opened by ontocord 1
  • How to use Flamingo for VQA task?

    How to use Flamingo for VQA task?

    Hi, Thanks for sharing this awesome implementation. I am very interested in using Flamingo model for my usecase. How I can use this implementation to get inference on my dataset for VQA task? I have certain images of products and want extract some information image of product by questioning it. How I can do it ?

    Please help.

    thanks

    opened by karndeepsingh 0
  • Fine-tuning of a model

    Fine-tuning of a model

    Hi, Thank you for this great work. I want to ask how can I fine-tune this model on my dataset for some downstream task like image captioning or image classification? If it is possible for you can you also please share the code?

    opened by ans92 0
  • Need a sample ipython notebook

    Need a sample ipython notebook

    Hello, @lucidrains,

    Thank you for providing this.

    For demo purposes, could you please provide a sample demo in Jupyter notebook?🫠

    Best, LITDataScience

    opened by LITDataScience 0
Releases(0.1.2)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation

Proposal, Tracking and Segmentation (PTS): A Cascaded Network for Video Object Segmentation By Qiang Zhou*, Zilong Huang*, Lichao Huang, Han Shen, Yon

Forest 117 Apr 01, 2022
project page for VinVL

VinVL: Revisiting Visual Representations in Vision-Language Models Updates 02/28/2021: Project page built. Introduction This repository is the project

308 Jan 09, 2023
[CVPR'20] TTSR: Learning Texture Transformer Network for Image Super-Resolution

TTSR Official PyTorch implementation of the paper Learning Texture Transformer Network for Image Super-Resolution accepted in CVPR 2020. Contents Intr

Multimedia Research 689 Dec 28, 2022
3D-Reconstruction 基于深度学习方法的单目多视图三维重建

基于深度学习方法的单目多视图三维重建 Part I 三维重建 代码:Part1 技术文档:[Markdown] [PDF] 原始图像:Original Images 点云结果:Point Cloud Results-1

HMT_Curo 19 Dec 26, 2022
PyTorch implementation of "Optimization Planning for 3D ConvNets"

Optimization-Planning-for-3D-ConvNets Code for the ICML 2021 paper: Optimization Planning for 3D ConvNets. Authors: Zhaofan Qiu, Ting Yao, Chong-Wah N

Zhaofan Qiu 2 Jan 12, 2022
A little Python application to auto tag your photos with the power of machine learning.

Tag Machine A little Python application to auto tag your photos with the power of machine learning. Report a bug or request a feature Table of Content

Florian Torres 14 Dec 21, 2022
This library is a location of the LegacyLogger for PyTorch Lightning.

neptune-contrib Documentation See neptune-contrib documentation site Installation Get prerequisites python versions 3.5.6/3.6 are supported Install li

neptune.ai 26 Oct 07, 2021
A map update dataset and benchmark

MUNO21 MUNO21 is a dataset and benchmark for machine learning methods that automatically update and maintain digital street map datasets. Previous dat

16 Nov 30, 2022
Evaluation and Benchmarking of Speech Super-resolution Methods

Speech Super-resolution Evaluation and Benchmarking What this repo do: A toolbox for the evaluation of speech super-resolution algorithms. Unify the e

Haohe Liu (刘濠赫) 84 Dec 20, 2022
Repositório da disciplina de APC, no segundo semestre de 2021

NOTAS FINAIS: https://github.com/fabiommendes/apc2018/blob/master/nota-final.pdf Algoritmos e Programação de Computadores Este é o Git da disciplina A

16 Dec 16, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
face2comics by Sxela (Alex Spirin) - face2comics datasets

This is a paired face to comics dataset, which can be used to train pix2pix or similar networks.

Alex 164 Nov 13, 2022
This repository contains code for the paper "Decoupling Representation and Classifier for Long-Tailed Recognition", published at ICLR 2020

Classifier-Balancing This repository contains code for the paper: Decoupling Representation and Classifier for Long-Tailed Recognition Bingyi Kang, Sa

Facebook Research 820 Dec 26, 2022
Group Activity Recognition with Clustered Spatial Temporal Transformer

GroupFormer Group Activity Recognition with Clustered Spatial-TemporalTransformer Backbone Style Action Acc Activity Acc Config Download Inv3+flow+pos

28 Dec 12, 2022
ICLR 2021: Pre-Training for Context Representation in Conversational Semantic Parsing

SCoRe: Pre-Training for Context Representation in Conversational Semantic Parsing This repository contains code for the ICLR 2021 paper "SCoRE: Pre-Tr

Microsoft 28 Oct 02, 2022
Unicorn can be used for performance analyses of highly configurable systems with causal reasoning

Unicorn can be used for performance analyses of highly configurable systems with causal reasoning. Users or developers can query Unicorn for a performance task.

AISys Lab 27 Jan 05, 2023
Optimized Gillespie algorithm for simulating Stochastic sPAtial models of Cancer Evolution (OG-SPACE)

OG-SPACE Introduction Optimized Gillespie algorithm for simulating Stochastic sPAtial models of Cancer Evolution (OG-SPACE) is a computational framewo

Data and Computational Biology Group UNIMIB (was BI*oinformatics MI*lan B*icocca) 0 Nov 17, 2021
Implementation of PyTorch-based multi-task pre-trained models

mtdp Library containing implementation related to the research paper "Multi-task pre-training of deep neural networks for digital pathology" (Mormont

Romain Mormont 27 Oct 14, 2022
This code is for eCaReNet: explainable Cancer Relapse Prediction Network.

eCaReNet This code is for eCaReNet: explainable Cancer Relapse Prediction Network. (Towards Explainable End-to-End Prostate Cancer Relapse Prediction

Institute of Medical Systems Biology 2 Jul 28, 2022
Logsig-RNN: a novel network for robust and efficient skeleton-based action recognition

GCN_LogsigRNN This repository holds the codebase for the paper: Logsig-RNN: a novel network for robust and efficient skeleton-based action recognition

7 Oct 14, 2022