Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch

Overview

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Yannic Kilcher presentation

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# do your vision transformer training

vit = Extractor(vit, return_embeddings_only = True)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Hi @lucidrains, thank you for the implementation. Just wanted to confirm this with you, based on your code we're normalizing the img embedding and text embedding respectively using a learnable Layer Norm transformation before applying the contrastive loss. But based on my understanding, for contrastive loss we typically maximize the relative cosine similarity so the embeddings should be L2-normed instead of layernormed? Thank you.

    opened by fedshyvana 2
  • Extractor in vit_pytorch will detach the tensor.

    Extractor in vit_pytorch will detach the tensor.

    Thanks for your code! I think I may find a little bug. The cloned tensor in Extractor will be detached (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/extractor.py#L39). So gradient may not propagate back to image encoder.

    opened by techkang 2
  • Maybe don't need this rearrange

    Maybe don't need this rearrange

    I think the logits before this line in shape (bsz, length, num_tokens) -> so I don't think here need one more rearrange https://github.com/lucidrains/CoCa-pytorch/blob/25de0b04326d8dc4c6f969e90b4466fc4894835e/coca_pytorch/coca_pytorch.py#L461

    opened by CiaoHe 2
  • How to train the model using my own dataset?

    How to train the model using my own dataset?

    Can someone tell me how to train the model using my own dataset? is it like below?But I have many images and texts...

    # train by giving CoCa your text and images with `return_loss = True`
    loss = coca(
        text = text,
        images = images,
        return_loss = True  # set this to True to get the full caption + contrastive loss
    )
    
    opened by keepcodeandsmile 1
  • why train VIT visual encoder first?

    why train VIT visual encoder first?

    Hi, thanks for sharing this repo. In the CoCA paper, both the visual encoder and text encoder are end-to end trained. But in this repo, the vit is first pretrained then fixed to train CoCa.

    opened by Flowerfan 1
  • attn_mask

    attn_mask

    cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
    
    attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
    sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
    

    Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

    opened by pldlgb 0
  • Reproducing the results in the paper

    Reproducing the results in the paper

    Thanks for this repo. Curious, is this an independent implementation of the CoCa paper? If yes, did you reproduce any result in the paper to ensure correctness of implementation?

    opened by GKIBMNY 0
  • Generating the caption of a given image

    Generating the caption of a given image

    Hello,

    Thank you for having implemented this model. Have you already implemented some code to generate the caption of a given image? If not, do you have an idea about how you would do it in this particular architecture?

    Thank you in advance.

    opened by claudiogreco 0
Releases(0.0.7)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Pythonic particle-based (super-droplet) warm-rain/aqueous-chemistry cloud microphysics package with box, parcel & 1D/2D prescribed-flow examples in Python, Julia and Matlab

PySDM PySDM is a package for simulating the dynamics of population of particles. It is intended to serve as a building block for simulation systems mo

Atmospheric Cloud Simulation Group @ Jagiellonian University 32 Oct 18, 2022
NeuralCompression is a Python repository dedicated to research of neural networks that compress data

NeuralCompression is a Python repository dedicated to research of neural networks that compress data. The repository includes tools such as JAX-based entropy coders, image compression models, video c

Facebook Research 297 Jan 06, 2023
Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)

Junxian He 57 Jan 01, 2023
这是一个unet-pytorch的源码,可以训练自己的模型

Unet:U-Net: Convolutional Networks for Biomedical Image Segmentation目标检测模型在Pytorch当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Downl

Bubbliiiing 567 Jan 05, 2023
A Distributional Approach To Controlled Text Generation

A Distributional Approach To Controlled Text Generation This is the repository code for the ICLR 2021 paper "A Distributional Approach to Controlled T

NAVER 102 Jan 07, 2023
SatelliteSfM - A library for solving the satellite structure from motion problem

Satellite Structure from Motion Maintained by Kai Zhang. Overview This is a libr

Kai Zhang 190 Dec 08, 2022
Match SafeGraph POIs with Data collected through a cultural resource survey in Washington DC.

Match SafeGraph POI data with Cultural Resource Places in Washington DC Match SafeGraph POIs with Data collected through a cultural resource survey in

Changjie Chen 1 Jan 05, 2022
MaRS - a recursive filtering framework that allows for truly modular multi-sensor integration

The Modular and Robust State-Estimation Framework, or short, MaRS, is a recursive filtering framework that allows for truly modular multi-sensor integration

Control of Networked Systems - University of Klagenfurt 143 Dec 29, 2022
Repository for the AugmentedPCA Python package.

Overview This Python package provides implementations of Augmented Principal Component Analysis (AugmentedPCA) - a family of linear factor models that

Billy Carson 6 Dec 07, 2022
Fast, Attemptable Route Planner for Navigation in Known and Unknown Environments

FAR Planner uses a dynamically updated visibility graph for fast replanning. The planner models the environment with polygons and builds a global visi

Fan Yang 346 Dec 30, 2022
Converting CPT to bert form for use

cpt-encoder 将CPT转成bert形式使用 说明 刚刚刷到又出了一种模型:CPT,看论文显示,在很多中文任务上性能比mac bert还好,就迫不及待想把它用起来。 根据对源码的研究,发现该模型在做nlu建模时主要用的encoder部分,也就是bert,因此我将这部分权重转为bert权重类型

黄辉 1 Oct 14, 2021
Pytorch-Swin-Unet-V2 - a modified version of Swin Unet based on Swin Transfomer V2

Swin Unet V2 Swin Unet V2 is a modified version of Swin Unet arxiv based on Swin

Chenxu Peng 26 Dec 03, 2022
Labelbox is the fastest way to annotate data to build and ship artificial intelligence applications

Labelbox Labelbox is the fastest way to annotate data to build and ship artificial intelligence applications. Use this github repository to help you s

labelbox 1.7k Dec 29, 2022
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations

jaxdf - JAX-based Discretization Framework Overview | Example | Installation | Documentation ⚠️ This library is still in development. Breaking changes

UCL Biomedical Ultrasound Group 65 Dec 23, 2022
Replication of Pix2Seq with Pretrained Model

Pretrained-Pix2Seq We provide the pre-trained model of Pix2Seq. This version contains new data augmentation. The model is trained for 300 epochs and c

peng gao 51 Nov 22, 2022
This provides the R code and data to replicate results in "The USS Trustee’s risky strategy"

USSBriefs2021 This provides the R code and data to replicate results in "The USS Trustee’s risky strategy" by Neil M Davies, Jackie Grant and Chin Yan

1 Oct 30, 2021
Advanced Signal Processing Notebooks and Tutorials

Advanced Digital Signal Processing Notebooks and Tutorials Prof. Dr. -Ing. Gerald Schuller Jupyter Notebooks and Videos: Renato Profeta Applied Media

Guitars.AI 115 Dec 13, 2022
Code for ICCV 2021 paper: ARAPReg: An As-Rigid-As Possible Regularization Loss for Learning Deformable Shape Generators..

ARAPReg Code for ICCV 2021 paper: ARAPReg: An As-Rigid-As Possible Regularization Loss for Learning Deformable Shape Generators.. Installation The cod

Bo Sun 132 Nov 28, 2022
g9.py - Torch interactive graphics

g9.py - Torch interactive graphics A Torch toy in the browser. Demo at https://srush.github.io/g9py/ This is a shameless copy of g9.js, written in Pyt

Sasha Rush 13 Nov 16, 2022
Offline Multi-Agent Reinforcement Learning Implementations: Solving Overcooked Game with Data-Driven Method

Overcooked-AI We suppose to apply traditional offline reinforcement learning technique to multi-agent algorithm. In this repository, we implemented be

Baek In-Chang 14 Sep 16, 2022