Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch

Overview

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Yannic Kilcher presentation

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch

Then

import torch

# import vision transformer

from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

# do your vision transformer training

vit = Extractor(vit, return_embeddings_only = True)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

    Hi @lucidrains, thank you for the implementation. Just wanted to confirm this with you, based on your code we're normalizing the img embedding and text embedding respectively using a learnable Layer Norm transformation before applying the contrastive loss. But based on my understanding, for contrastive loss we typically maximize the relative cosine similarity so the embeddings should be L2-normed instead of layernormed? Thank you.

    opened by fedshyvana 2
  • Extractor in vit_pytorch will detach the tensor.

    Extractor in vit_pytorch will detach the tensor.

    Thanks for your code! I think I may find a little bug. The cloned tensor in Extractor will be detached (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/extractor.py#L39). So gradient may not propagate back to image encoder.

    opened by techkang 2
  • Maybe don't need this rearrange

    Maybe don't need this rearrange

    I think the logits before this line in shape (bsz, length, num_tokens) -> so I don't think here need one more rearrange https://github.com/lucidrains/CoCa-pytorch/blob/25de0b04326d8dc4c6f969e90b4466fc4894835e/coca_pytorch/coca_pytorch.py#L461

    opened by CiaoHe 2
  • How to train the model using my own dataset?

    How to train the model using my own dataset?

    Can someone tell me how to train the model using my own dataset? is it like below?But I have many images and texts...

    # train by giving CoCa your text and images with `return_loss = True`
    loss = coca(
        text = text,
        images = images,
        return_loss = True  # set this to True to get the full caption + contrastive loss
    )
    
    opened by keepcodeandsmile 1
  • why train VIT visual encoder first?

    why train VIT visual encoder first?

    Hi, thanks for sharing this repo. In the CoCA paper, both the visual encoder and text encoder are end-to end trained. But in this repo, the vit is first pretrained then fixed to train CoCa.

    opened by Flowerfan 1
  • attn_mask

    attn_mask

    cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
    
    attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
    sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
    

    Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

    opened by pldlgb 0
  • Reproducing the results in the paper

    Reproducing the results in the paper

    Thanks for this repo. Curious, is this an independent implementation of the CoCa paper? If yes, did you reproduce any result in the paper to ensure correctness of implementation?

    opened by GKIBMNY 0
  • Generating the caption of a given image

    Generating the caption of a given image

    Hello,

    Thank you for having implemented this model. Have you already implemented some code to generate the caption of a given image? If not, do you have an idea about how you would do it in this particular architecture?

    Thank you in advance.

    opened by claudiogreco 0
Releases(0.0.7)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Finetune alexnet with tensorflow - Code for finetuning AlexNet in TensorFlow >= 1.2rc0

Finetune AlexNet with Tensorflow Update 15.06.2016 I revised the entire code base to work with the new input pipeline coming with TensorFlow = versio

Frederik Kratzert 766 Jan 04, 2023
Open-sourcing the Slates Dataset for recommender systems research

FINN.no Recommender Systems Slate Dataset This repository accompany the paper "Dynamic Slate Recommendation with Gated Recurrent Units and Thompson Sa

FINN.no 48 Nov 28, 2022
EFENet: Reference-based Video Super-Resolution with Enhanced Flow Estimation

EFENet EFENet: Reference-based Video Super-Resolution with Enhanced Flow Estimation Code is a bit messy now. I woud clean up soon. For training the EF

Yaping Zhao 19 Nov 05, 2022
CR-FIQA: Face Image Quality Assessment by Learning Sample Relative Classifiability

This is the official repository of the paper: CR-FIQA: Face Image Quality Assessment by Learning Sample Relative Classifiability A private copy of the

Fadi Boutros 33 Dec 31, 2022
Deep learning with TensorFlow and earth observation data.

Deep Learning with TensorFlow and EO Data Complete file set for Jupyter Book Autor: Development Seed Date: 04 October 2021 ISBN: (to come) Notebook tu

Development Seed 20 Nov 16, 2022
PromptDet: Expand Your Detector Vocabulary with Uncurated Images

PromptDet: Expand Your Detector Vocabulary with Uncurated Images Paper Website Introduction The goal of this work is to establish a scalable pipeline

103 Dec 20, 2022
Existing Literature about Machine Unlearning

Machine Unlearning Papers 2021 Brophy and Lowd. Machine Unlearning for Random Forests. In ICML 2021. Bourtoule et al. Machine Unlearning. In IEEE Symp

Jonathan Brophy 213 Jan 08, 2023
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
Classical OCR DCNN reproduction based on PaddlePaddle framework.

Paddle-SVHN Classical OCR DCNN reproduction based on PaddlePaddle framework. This project reproduces Multi-digit Number Recognition from Street View I

1 Nov 12, 2021
Imaginaire - NVIDIA's Deep Imagination Team's PyTorch Library

Imaginaire Docs | License | Installation | Model Zoo Imaginaire is a pytorch library that contains optimized implementation of several image and video

NVIDIA Research Projects 3.6k Dec 29, 2022
Algorithm to texture 3D reconstructions from multi-view stereo images

MVS-Texturing Welcome to our project that textures 3D reconstructions from images. This project focuses on 3D reconstructions generated using structur

Nils Moehrle 766 Jan 04, 2023
Rasterize with the least efforts for researchers.

utils3d Rasterize and do image-based 3D transforms with the least efforts for researchers. Based on numpy and OpenGL. It could be helpful when you wan

Ruicheng Wang 8 Dec 15, 2022
LSTM and QRNN Language Model Toolkit for PyTorch

LSTM and QRNN Language Model Toolkit This repository contains the code used for two Salesforce Research papers: Regularizing and Optimizing LSTM Langu

Salesforce 1.9k Jan 08, 2023
Drone Task1 - Drone Task1 With Python

Drone_Task1 Matching Results 3.mp4 1.mp4

MLV Lab (Machine Learning and Vision Lab at Korea University) 11 Nov 14, 2022
Official Chainer implementation of GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral)

GP-GAN: Towards Realistic High-Resolution Image Blending (ACMMM 2019, oral) [Project] [Paper] [Demo] [Related Work: A2RL (for Auto Image Cropping)] [C

Wu Huikai 402 Dec 27, 2022
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 09, 2023
A basic implementation of Layer-wise Relevance Propagation (LRP) in PyTorch.

Layer-wise Relevance Propagation (LRP) in PyTorch Basic unsupervised implementation of Layer-wise Relevance Propagation (Bach et al., Montavon et al.)

Kai Fabi 28 Dec 26, 2022
Vector.ai assignment

fabio-tests-nisargatman Low Level Approach: ###Tables: continents: id*, name, population, area, createdAt, updatedAt countries: id*, name, population,

Ravi Pullagurla 1 Nov 09, 2021
Official implementation of deep-multi-trajectory-based single object tracking (IEEE T-CSVT 2021).

DeepMTA_PyTorch Officical PyTorch Implementation of "Dynamic Attention-guided Multi-TrajectoryAnalysis for Single Object Tracking", Xiao Wang, Zhe Che

Xiao Wang(王逍) 7 Dec 03, 2022
Like ThreeJS but for Python and based on wgpu

pygfx A render engine, inspired by ThreeJS, but for Python and targeting Vulkan/Metal/DX12 (via wgpu). Introduction This is a Python render engine bui

139 Jan 07, 2023