Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch

Overview

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Yannic Kilcher presentation

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# do your vision transformer training

vit = Extractor(vit, return_embeddings_only = True)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Hi @lucidrains, thank you for the implementation. Just wanted to confirm this with you, based on your code we're normalizing the img embedding and text embedding respectively using a learnable Layer Norm transformation before applying the contrastive loss. But based on my understanding, for contrastive loss we typically maximize the relative cosine similarity so the embeddings should be L2-normed instead of layernormed? Thank you.

    opened by fedshyvana 2
  • Extractor in vit_pytorch will detach the tensor.

    Extractor in vit_pytorch will detach the tensor.

    Thanks for your code! I think I may find a little bug. The cloned tensor in Extractor will be detached (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/extractor.py#L39). So gradient may not propagate back to image encoder.

    opened by techkang 2
  • Maybe don't need this rearrange

    Maybe don't need this rearrange

    I think the logits before this line in shape (bsz, length, num_tokens) -> so I don't think here need one more rearrange https://github.com/lucidrains/CoCa-pytorch/blob/25de0b04326d8dc4c6f969e90b4466fc4894835e/coca_pytorch/coca_pytorch.py#L461

    opened by CiaoHe 2
  • How to train the model using my own dataset?

    How to train the model using my own dataset?

    Can someone tell me how to train the model using my own dataset? is it like below?But I have many images and texts...

    # train by giving CoCa your text and images with `return_loss = True`
    loss = coca(
        text = text,
        images = images,
        return_loss = True  # set this to True to get the full caption + contrastive loss
    )
    
    opened by keepcodeandsmile 1
  • why train VIT visual encoder first?

    why train VIT visual encoder first?

    Hi, thanks for sharing this repo. In the CoCA paper, both the visual encoder and text encoder are end-to end trained. But in this repo, the vit is first pretrained then fixed to train CoCa.

    opened by Flowerfan 1
  • attn_mask

    attn_mask

    cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
    
    attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
    sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
    

    Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

    opened by pldlgb 0
  • Reproducing the results in the paper

    Reproducing the results in the paper

    Thanks for this repo. Curious, is this an independent implementation of the CoCa paper? If yes, did you reproduce any result in the paper to ensure correctness of implementation?

    opened by GKIBMNY 0
  • Generating the caption of a given image

    Generating the caption of a given image

    Hello,

    Thank you for having implemented this model. Have you already implemented some code to generate the caption of a given image? If not, do you have an idea about how you would do it in this particular architecture?

    Thank you in advance.

    opened by claudiogreco 0
Releases(0.0.7)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Code for Max-Margin Contrastive Learning - AAAI 2022

Max-Margin Contrastive Learning This is a pytorch implementation for the paper Max-Margin Contrastive Learning accepted to AAAI 2022. This repository

Anshul Shah 12 Oct 22, 2022
Human4D Dataset tools for processing and visualization

HUMAN4D: A Human-Centric Multimodal Dataset for Motions & Immersive Media HUMAN4D constitutes a large and multimodal 4D dataset that contains a variet

tofis 15 Nov 09, 2022
Algorithmic encoding of protected characteristics and its implications on disparities across subgroups

Algorithmic encoding of protected characteristics and its implications on disparities across subgroups This repository contains the code for the paper

Team MIRA - BioMedIA 15 Oct 24, 2022
zeus is a Python implementation of the Ensemble Slice Sampling method.

zeus is a Python implementation of the Ensemble Slice Sampling method. Fast & Robust Bayesian Inference, Efficient Markov Chain Monte Carlo (MCMC), Bl

Minas Karamanis 197 Dec 04, 2022
A Python reference implementation of the CF data model

cfdm A Python reference implementation of the CF data model. References Compliance with FAIR principles Documentation https://ncas-cms.github.io/cfdm

NCAS CMS 25 Dec 13, 2022
Code repository for paper `Skeleton Merger: an Unsupervised Aligned Keypoint Detector`.

Skeleton Merger Skeleton Merger, an Unsupervised Aligned Keypoint Detector. The paper is available at https://arxiv.org/abs/2103.10814. A map of the r

北海若 48 Nov 14, 2022
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
STEAL - Learning Semantic Boundaries from Noisy Annotations (CVPR 2019)

STEAL This is the official inference code for: Devil Is in the Edges: Learning Semantic Boundaries from Noisy Annotations David Acuna, Amlan Kar, Sanj

469 Dec 26, 2022
Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Davis Rempe 367 Dec 24, 2022
Code for ICLR 2020 paper "VL-BERT: Pre-training of Generic Visual-Linguistic Representations".

VL-BERT By Weijie Su, Xizhou Zhu, Yue Cao, Bin Li, Lewei Lu, Furu Wei, Jifeng Dai. This repository is an official implementation of the paper VL-BERT:

Weijie Su 698 Dec 18, 2022
Code-free deep segmentation for computational pathology

NoCodeSeg: Deep segmentation made easy! This is the official repository for the manuscript "Code-free development and deployment of deep segmentation

André Pedersen 26 Nov 23, 2022
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
Official code for "On the Frequency Bias of Generative Models", NeurIPS 2021

Frequency Bias of Generative Models Generator Testbed Discriminator Testbed This repository contains official code for the paper On the Frequency Bias

35 Nov 01, 2022
A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

2 Jul 25, 2022
Pretraining on Dynamic Graph Neural Networks

Pretraining on Dynamic Graph Neural Networks Our article is PT-DGNN and the code is modified based on GPT-GNN Requirements python 3.6 Ubuntu 18.04.5 L

7 Dec 17, 2022
Code for Towards Unifying Behavioral and Response Diversity for Open-ended Learning in Zero-sum Games

Unifying Behavioral and Response Diversity for Open-ended Learning in Zero-sum Games How to run our algorithm? Create the new environment using: conda

MARL @ SJTU 8 Dec 27, 2022
This repository attempts to replicate the SqueezeNet architecture and implement the same on an image classification task.

SqueezeNet-Implementation This repository attempts to replicate the SqueezeNet architecture using TensorFlow discussed in the research paper: "Squeeze

Rohan Mathur 3 Dec 13, 2022
Codes for CVPR2021 paper "PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization"

PWCLO-Net: Deep LiDAR Odometry in 3D Point Clouds Using Hierarchical Embedding Mask Optimization (CVPR 2021) This is the official implementation of PW

Intelligent Robotics and Machine Vision Lab 42 Dec 18, 2022
Is RobustBench/AutoAttack a suitable Benchmark for Adversarial Robustness?

Adversrial Machine Learning Benchmarks This code belongs to the papers: Is RobustBench/AutoAttack a suitable Benchmark for Adversarial Robustness? Det

Adversarial Machine Learning 9 Nov 27, 2022
The official code of "SCROLLS: Standardized CompaRison Over Long Language Sequences".

SCROLLS This repository contains the official code of the paper: "SCROLLS: Standardized CompaRison Over Long Language Sequences". Links Official Websi

TAU NLP Group 39 Dec 23, 2022