Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

Overview

🦩 Flamingo - Pytorch

Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks

Install

$ pip install flamingo-pytorch

Usage

import torch
from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
    dim = 1024,
    depth = 2,
    dim_head = 64,
    heads = 8,
    num_latents = 64,    # the number of latents to shrink your media sequence to, perceiver style
    num_time_embeds = 4  # say you have 4 images maximum in your dialogue
)

medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)

Then you insert the GatedCrossAttentionBlock at different intervals in your giant language model. Your text would then attend to the perceived media from above

The recommended way to derive the media_locations boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do media_locations = text_id == media_token_id

import torch
from flamingo_pytorch import GatedCrossAttentionBlock

cross_attn = GatedCrossAttentionBlock(
    dim = 1024,
    dim_head = 64,
    heads = 8
)

text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)

media_locations = torch.randint(0, 2, (1, 512)).bool()

text = cross_attn(
    text,
    perceived,
    media_locations = media_locations
)

That's it!

Attention is all you need.

Full working example with Flamingo + PaLM 🌴 🦩 🌴

Integration with PaLM

First install vit-pytorch for the vision encoder

$ pip install vit-pytorch

Then

from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(vit, return_embeddings_only = True)

# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library

import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)

# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))

palm_logits = flamingo_palm(text)

# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

It is quite evident where this is all headed if you think beyond just images.

Inception

For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.

Citations

@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • PerceiverResampler missing some LayerNorms?

    PerceiverResampler missing some LayerNorms?

    Hey, it feels like PerceiverResampler is missing some LayerNorms? it seems to me we should layer-norm x before sending to attentions loop, and may be add layer-norm to ff(latents) + latents?

    opened by inspirit 7
  • Missing flatten op in PerceiverResampler?

    Missing flatten op in PerceiverResampler?

    Hi, It seems that Flamingo did "x_f = flatten(x_f) # [T, S, d] -> [T * S, d]" (batch size == 1) before putting x_f to attention.

    So, it should be like: medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension) perceived = perceive(medias) # (1, 64, 1024) - (batch, num latents, dimension)

    ??

    opened by zengyan-97 6
  • wrong attention masks?

    wrong attention masks?

    https://github.com/lucidrains/flamingo-pytorch/blob/44920f4191ba3c280ff84c6ebc76025656d1dab5/flamingo_pytorch/flamingo_pytorch.py#L159

    In the flamingo paper, the language features in the gated cross-attention layers only attend to the visual features from the immediate preceding image. I believe your attention masks are created in such a way that they attend to the visual features from all preceding images. Can you confirm? If so, a fix would be to simply change the '>=' to '=='.

    opened by dhansmair 4
  • zeroing out attention not working

    zeroing out attention not working

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_pytorch.py#L179

    you are not using the inplace version of the function: https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_

    so I think this line does not have an effect.

    Best, David

    opened by dhansmair 2
  • Applying parallel attn with ff to existing pretrained model?

    Applying parallel attn with ff to existing pretrained model?

    Hi - awesome work! I am trying to understand ? I couldn't find a paper - only a reference to https://github.com/kingoflolz/mesh-transformer-jax. Is this right? Am I understanding that it is bascially applying multiple operations of for qkv and ff at once? Is it possible to use this trick to modify an existing pretrained model?

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_palm.py#L90

    Many thanks in advance!

    Huu

    opened by ontocord 1
  • How to use Flamingo for VQA task?

    How to use Flamingo for VQA task?

    Hi, Thanks for sharing this awesome implementation. I am very interested in using Flamingo model for my usecase. How I can use this implementation to get inference on my dataset for VQA task? I have certain images of products and want extract some information image of product by questioning it. How I can do it ?

    Please help.

    thanks

    opened by karndeepsingh 0
  • Fine-tuning of a model

    Fine-tuning of a model

    Hi, Thank you for this great work. I want to ask how can I fine-tune this model on my dataset for some downstream task like image captioning or image classification? If it is possible for you can you also please share the code?

    opened by ans92 0
  • Need a sample ipython notebook

    Need a sample ipython notebook

    Hello, @lucidrains,

    Thank you for providing this.

    For demo purposes, could you please provide a sample demo in Jupyter notebook?🫠

    Best, LITDataScience

    opened by LITDataScience 0
Releases(0.1.2)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Deep Sketch-guided Cartoon Video Inbetweening

Cartoon Video Inbetweening Paper | DOI | Video The source code of Deep Sketch-guided Cartoon Video Inbetweening by Xiaoyu Li, Bo Zhang, Jing Liao, Ped

Xiaoyu Li 37 Dec 22, 2022
CellRank's reproducibility repository.

CellRank's reproducibility repository We believe that reproducibility is key and have made it as simple as possible to reproduce our results. Please e

Theis Lab 8 Oct 08, 2022
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2020

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2020

Phillip Lippe 1.1k Jan 07, 2023
Improving 3D Object Detection with Channel-wise Transformer

"Improving 3D Object Detection with Channel-wise Transformer" Thanks for the OpenPCDet, this implementation of the CT3D is mainly based on the pcdet v

Hualian Sheng 107 Dec 20, 2022
Research code for the paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models"

Introduction This repository contains research code for the ACL 2021 paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual

AdapterHub 20 Aug 04, 2022
This repository is for EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpretation Data

InterpretationData This repository is for our EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpr

4 Apr 21, 2022
Implementations of paper Controlling Directions Orthogonal to a Classifier

Classifier Orthogonalization Implementations of paper Controlling Directions Orthogonal to a Classifier , ICLR 2022, Yilun Xu, Hao He, Tianxiao Shen,

Yilun Xu 33 Dec 01, 2022
Hso-groupie - A pwnable challenge in Real World CTF 4th

Hso-groupie - A pwnable challenge in Real World CTF 4th

Riatre Foo 42 Dec 05, 2022
The official implementation code of "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction."

PlantStereo This is the official implementation code for the paper "PlantStereo: A Stereo Matching Benchmark for Plant Surface Dense Reconstruction".

Wang Qingyu 14 Nov 28, 2022
TLDR: Twin Learning for Dimensionality Reduction

TLDR (Twin Learning for Dimensionality Reduction) is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self

NAVER 105 Dec 28, 2022
DvD-TD3: Diversity via Determinants for TD3 version

DvD-TD3: Diversity via Determinants for TD3 version The implementation of paper Effective Diversity in Population Based Reinforcement Learning. Instal

3 Feb 11, 2022
Logistic Bandit experiments. Official code for the paper "Jointly Efficient and Optimal Algorithms for Logistic Bandits".

Code for the paper Jointly Efficient and Optimal Algorithms for Logistic Bandits, by Louis Faury, Marc Abeille, Clément Calauzènes and Kwang-Sun Jun.

Faury Louis 1 Jan 22, 2022
RIFE - Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE - Real-Time Intermediate Flow Estimation for Video Frame Interpolation YouTube | BiliBili 16X interpolation results from two input images: Introd

旷视天元 MegEngine 28 Dec 09, 2022
PyTorch Implementation for Fracture Detection in Wrist Bone X-ray Images

wrist-d PyTorch Implementation for Fracture Detection in Wrist Bone X-ray Images note: Paper: Under Review at MPDI Diagnostics Submission Date: Novemb

Fatih UYSAL 5 Oct 12, 2022
A large-scale video dataset for the training and evaluation of 3D human pose estimation models

ASPset-510 ASPset-510 (Australian Sports Pose Dataset) is a large-scale video dataset for the training and evaluation of 3D human pose estimation mode

Aiden Nibali 36 Oct 30, 2022
SOLOv2 on onnx & tensorRT

SOLOv2.tensorRT: NOTE: code based on WXinlong/SOLO add support to TensorRT inference onnxruntime tensorRT full_dims and dynamic shape postprocess with

47 Nov 26, 2022
An optimization and data collection toolbox for convenient and fast prototyping of computationally expensive models.

An optimization and data collection toolbox for convenient and fast prototyping of computationally expensive models. Hyperactive: is very easy to lear

Simon Blanke 422 Jan 04, 2023
Portfolio analytics for quants, written in Python

QuantStats: Portfolio analytics for quants QuantStats Python library that performs portfolio profiling, allowing quants and portfolio managers to unde

Ran Aroussi 2.7k Jan 08, 2023
2021搜狐校园文本匹配算法大赛 分比我们低的都是帅哥队

sohu_text_matching 2021搜狐校园文本匹配算法大赛Top2:分比我们低的都是帅哥队 本repo包含了本次大赛决赛环节提交的代码文件及答辩PPT,提交的模型文件可在百度网盘获取(链接:https://pan.baidu.com/s/1T9FtwiGFZhuC8qqwXKZSNA ,

hflserdaniel 43 Oct 01, 2022
9th place solution in "Santa 2020 - The Candy Cane Contest"

Santa 2020 - The Candy Cane Contest My solution in this Kaggle competition "Santa 2020 - The Candy Cane Contest", 9th place. Basic Strategy In this co

toshi_k 22 Nov 26, 2021