Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

Overview

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation (Salesforce Research)

This is the official PyTorch implementation of the ALBEF paper [Blog]. This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k, and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released.

Requirements:

  • pytorch 1.8.0
  • transformers 4.8.1
  • timm 0.4.9

Download:

Visualization:

We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text. Here is an example visualization using the visual grounding checkpoint.

Pre-training on custom datasets:

  1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
  2. In configs/Pretrain.yaml, set the paths for the json files.
  3. Pre-train the model using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 

Image-Text Retrieval:

  1. Download MSCOCO or Flickr30k datasets from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/Retrieval_flickr \
--checkpoint [Pretrained checkpoint]

VQA:

  1. Download VQA v2 dataset and Visual Genome dataset from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/VQA.yaml, set the paths for the json files and the image paths.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir output/vqa \
--checkpoint [Pretrained checkpoint]
  1. Evaluate the result using the official evaluation server.

Visual Entailment:

  1. Download SNLI-VE dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/VE.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py \
--config ./configs/VE.yaml \
--output_dir output/VE \
--checkpoint [Pretrained checkpoint]

Visual Grounding on RefCOCO+:

  1. Download MSCOCO dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/Grounding.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir output/RefCOCO \
--gradcam_mode itm \ 
--block_num 8 \
--checkpoint [Pretrained checkpoint]

NLVR2:

NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py \
--config ./configs/NLVR_pretrain.yaml \
--output_dir output/NLVR_pretrain \
--checkpoint [Pretrained checkpoint]

We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps.

  1. Download NLVR2 dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/NLVR.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir output/NLVR \
--checkpoint [TA pretrained checkpoint]

Citation

If you find this code to be useful for your research, please consider citing.

@article{ALBEF,
      title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation}, 
      author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi},
      year={2021},
      journal={arXiv preprint arXiv:2107.07651},
}
Owner
Salesforce
A variety of vendor agnostic projects which power Salesforce
Salesforce
VOS: Learning What You Don’t Know by Virtual Outlier Synthesis

VOS This is the source code accompanying the paper VOS: Learning What You Don’t

248 Dec 25, 2022
[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding

[AAAI 2022] Negative Sample Matters: A Renaissance of Metric Learning for Temporal Grounding Official Pytorch implementation of Negative Sample Matter

Multimedia Computing Group, Nanjing University 69 Dec 26, 2022
Improved Fitness Optimization Landscapes for Sequence Design

ReLSO Improved Fitness Optimization Landscapes for Sequence Design Description Citation How to run Training models Original data source Description In

Krishnaswamy Lab 44 Dec 20, 2022
Parameterising Simulated Annealing for the Travelling Salesman Problem

Parameterising Simulated Annealing for the Travelling Salesman Problem

Gary Sun 55 Jun 15, 2022
CAST: Character labeling in Animation using Self-supervision by Tracking

CAST: Character labeling in Animation using Self-supervision by Tracking (Published as a conference paper at EuroGraphics 2022) Note: The CAST paper c

15 Nov 18, 2022
Personal project about genus-0 meshes, spherical harmonics and a cow

How to transform a cow into spherical harmonics ? Spot the cow, from Keenan Crane's blog Context In the field of Deep Learning, training on images or

3 Aug 22, 2022
Automatic Differentiation Multipole Moment Molecular Forcefield

Automatic Differentiation Multipole Moment Molecular Forcefield Performance notes On a single gpu, using waterbox_31ang.pdb example from MPIDplugin wh

4 Jan 07, 2022
A PyTorch implementation of "Semi-Supervised Graph Classification: A Hierarchical Graph Perspective" (WWW 2019)

SEAL ⠀⠀⠀ A PyTorch implementation of Semi-Supervised Graph Classification: A Hierarchical Graph Perspective (WWW 2019) Abstract Node classification an

Benedek Rozemberczki 202 Dec 27, 2022
This repository contains source code for the Situated Interactive Language Grounding (SILG) benchmark

SILG This repository contains source code for the Situated Interactive Language Grounding (SILG) benchmark. If you find this work helpful, please cons

Victor Zhong 17 Nov 27, 2022
Code of the lileonardo team for the 2021 Emotion and Theme Recognition in Music task of MediaEval 2021

Emotion and Theme Recognition in Music The repository contains code for the submission of the lileonardo team to the 2021 Emotion and Theme Recognitio

Vincent Bour 8 Aug 02, 2022
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
A treasure chest for visual recognition powered by PaddlePaddle

简体中文 | English PaddleClas 简介 飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,助力使用者训练出更好的视觉模型和应用落地。 近期更新 2021.11.1 发布PP-ShiTu技术报告,新增饮料识别demo 2021.10.23 发

4.6k Dec 31, 2022
ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation

ENet in Caffe Execution times and hardware requirements Network 1024x512 1280x720 Parameters Model size (fp32) ENet 20.4 ms 32.9 ms 0.36 M 1.5 MB SegN

Timo Sämann 561 Jan 04, 2023
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022
Tool for working with Y-chromosome data from YFull and FTDNA

ycomp ycomp is a tool for working with Y-chromosome data from YFull and FTDNA. Run ycomp -h for information on how to use the program. Installation Th

Alexander Regueiro 2 Jun 18, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
implicit displacement field

Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields [project page][paper][cite] Geometry-Consistent Neural Shape Represe

Yifan Wang 100 Dec 19, 2022
LIVECell - A large-scale dataset for label-free live cell segmentation

LIVECell dataset This document contains instructions of how to access the data associated with the submitted manuscript "LIVECell - A large-scale data

Sartorius Corporate Research 112 Jan 07, 2023
Codes for the AAAI'22 paper "TransZero: Attribute-guided Transformer for Zero-Shot Learning"

TransZero [arXiv] This repository contains the testing code for the paper "TransZero: Attribute-guided Transformer for Zero-Shot Learning" accepted to

Shiming Chen 52 Jan 01, 2023
Dilated RNNs in pytorch

PyTorch Dilated Recurrent Neural Networks PyTorch implementation of Dilated Recurrent Neural Networks (DilatedRNN). Getting Started Installation: $ pi

Zalando Research 200 Nov 17, 2022