Align before Fuse: Vision and Language Representation Learning with Momentum Distillation

Overview

Align before Fuse: Vision and Language Representation Learning with Momentum Distillation (Salesforce Research)

This is the official PyTorch implementation of the ALBEF paper [Blog]. This repository supports pre-training on custom datasets, as well as finetuning on VQA, SNLI-VE, NLVR2, Image-Text Retrieval on MSCOCO and Flickr30k, and visual grounding on RefCOCO+. Pre-trained and finetuned checkpoints are released.

Requirements:

  • pytorch 1.8.0
  • transformers 4.8.1
  • timm 0.4.9

Download:

Visualization:

We provide code in visualize.ipynb to visualize the important areas in an image for each word in a text. Here is an example visualization using the visual grounding checkpoint.

Pre-training on custom datasets:

  1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
  2. In configs/Pretrain.yaml, set the paths for the json files.
  3. Pre-train the model using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain 

Image-Text Retrieval:

  1. Download MSCOCO or Flickr30k datasets from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/Retrieval_coco.yaml or configs/Retrieval_flickr.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Retrieval.py \
--config ./configs/Retrieval_flickr.yaml \
--output_dir output/Retrieval_flickr \
--checkpoint [Pretrained checkpoint]

VQA:

  1. Download VQA v2 dataset and Visual Genome dataset from the original websites.
  2. Download and extract the provided dataset json files.
  3. In configs/VQA.yaml, set the paths for the json files and the image paths.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VQA.py \
--config ./configs/VQA.yaml \
--output_dir output/vqa \
--checkpoint [Pretrained checkpoint]
  1. Evaluate the result using the official evaluation server.

Visual Entailment:

  1. Download SNLI-VE dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/VE.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env VE.py \
--config ./configs/VE.yaml \
--output_dir output/VE \
--checkpoint [Pretrained checkpoint]

Visual Grounding on RefCOCO+:

  1. Download MSCOCO dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/Grounding.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env Grounding.py \
--config ./configs/Grounding.yaml \
--output_dir output/RefCOCO \
--gradcam_mode itm \ 
--block_num 8 \
--checkpoint [Pretrained checkpoint]

NLVR2:

NLVR2 requires an additional pre-training step with text-assignment (TA) to adapt the model for image-pair inputs. In order to perform TA, first set the paths for the json training files in configs/NLVR_pretrain.yaml, then run:

python -m torch.distributed.launch --nproc_per_node=8 --use_env Pretrain_nlvr.py \
--config ./configs/NLVR_pretrain.yaml \
--output_dir output/NLVR_pretrain \
--checkpoint [Pretrained checkpoint]

We provide the checkpoint after TA pre-training, which can be fine-tuned with the following steps.

  1. Download NLVR2 dataset from the original website.
  2. Download and extract the provided dataset json files.
  3. In configs/NLVR.yaml, set the paths for the json files and the image path.
  4. Finetune the pre-trained checkpoint using 8 A100 GPUs:
python -m torch.distributed.launch --nproc_per_node=8 --use_env NLVR.py \
--config ./configs/NLVR.yaml \
--output_dir output/NLVR \
--checkpoint [TA pretrained checkpoint]

Citation

If you find this code to be useful for your research, please consider citing.

@article{ALBEF,
      title={Align before Fuse: Vision and Language Representation Learning with Momentum Distillation}, 
      author={Junnan Li and Ramprasaath R. Selvaraju and Akhilesh Deepak Gotmare and Shafiq Joty and Caiming Xiong and Steven Hoi},
      year={2021},
      journal={arXiv preprint arXiv:2107.07651},
}
Owner
Salesforce
A variety of vendor agnostic projects which power Salesforce
Salesforce
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python = 3.8 pytorch = 1.8.0

DV Lab 63 Dec 16, 2022
AFL binary instrumentation

E9AFL --- Binary AFL E9AFL inserts American Fuzzy Lop (AFL) instrumentation into x86_64 Linux binaries. This allows binaries to be fuzzed without the

242 Dec 12, 2022
WPPNets: Unsupervised CNN Training with Wasserstein Patch Priors for Image Superresolution

WPPNets: Unsupervised CNN Training with Wasserstein Patch Priors for Image Superresolution This code belongs to the paper [1] available at https://arx

Fabian Altekrueger 5 Jun 02, 2022
Reimplement of SimSwap training code

SimSwap-train Reimplement of SimSwap training code Instructions 1.Environment Preparation (1)Refer to the README document of SIMSWAP to configure the

seeprettyface.com 111 Dec 31, 2022
PyTorch implementation of PSPNet segmentation network

pspnet-pytorch PyTorch implementation of PSPNet segmentation network Original paper Pyramid Scene Parsing Network Details This is a slightly different

Roman Trusov 532 Dec 29, 2022
Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences

Neighbor2Seq: Deep Learning on Massive Graphs by Transforming Neighbors to Sequences This repository is an official PyTorch implementation of Neighbor

DIVE Lab, Texas A&M University 8 Jun 12, 2022
[TPDS'21] COSCO: Container Orchestration using Co-Simulation and Gradient Based Optimization for Fog Computing Environments

COSCO Framework COSCO is an AI based coupled-simulation and container orchestration framework for integrated Edge, Fog and Cloud Computing Environment

imperial-qore 39 Dec 25, 2022
A PyTorch Lightning Callback for pushing models to the Hugging Face Hub 🤗⚡️

hf-hub-lightning A callback for pushing lightning models to the Hugging Face Hub. Note: I made this package for myself, mostly...if folks seem to be i

Nathan Raw 27 Dec 14, 2022
UNION: An Unreferenced Metric for Evaluating Open-ended Story Generation

UNION Automatic Evaluation Metric described in the paper UNION: An UNreferenced MetrIc for Evaluating Open-eNded Story Generation (EMNLP 2020). Please

50 Dec 30, 2022
GluonMM is a library of transformer models for computer vision and multi-modality research

GluonMM is a library of transformer models for computer vision and multi-modality research. It contains reference implementations of widely adopted baseline models and also research work from Amazon

42 Dec 02, 2022
This repository contains the re-implementation of our paper deSpeckNet: Generalizing Deep Learning Based SAR Image Despeckling

deSpeckNet-TF-GEE This repository contains the re-implementation of our paper deSpeckNet: Generalizing Deep Learning Based SAR Image Despeckling publi

Adugna Mullissa 16 Sep 07, 2022
Text-Based Ideal Points

Text-Based Ideal Points Source code for the paper: Text-Based Ideal Points by Keyon Vafa, Suresh Naidu, and David Blei (ACL 2020). Update (June 29, 20

Keyon Vafa 37 Oct 09, 2022
Code for Transformers Solve Limited Receptive Field for Monocular Depth Prediction

Official PyTorch code for Transformers Solve Limited Receptive Field for Monocular Depth Prediction. Guanglei Yang, Hao Tang, Mingli Ding, Nicu Sebe,

stanley 152 Dec 16, 2022
[ICCV 2021] Relaxed Transformer Decoders for Direct Action Proposal Generation

RTD-Net (ICCV 2021) This repo holds the codes of paper: "Relaxed Transformer Decoders for Direct Action Proposal Generation", accepted in ICCV 2021. N

Multimedia Computing Group, Nanjing University 80 Nov 30, 2022
Code for our paper: Online Variational Filtering and Parameter Learning

Variational Filtering To run phi learning on linear gaussian (Fig1a) python linear_gaussian_phi_learning.py To run phi and theta learning on linear g

16 Aug 14, 2022
Changing the Mind of Transformers for Topically-Controllable Language Generation

We will first introduce the how to run the IPython notebook demo by downloading our pretrained models. Then, we will introduce how to run our training and evaluation code.

IESL 20 Dec 06, 2022
2021 National Underwater Robotics Vision Optics

2021-National-Underwater-Robotics-Vision-Optics 2021年全国水下机器人算法大赛-光学赛道-B榜精度第18名 (Kilian_Di的团队:A榜[email pro

Di Chang 9 Nov 04, 2022
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
Sign Language Transformers (CVPR'20)

Sign Language Transformers (CVPR'20) This repo contains the training and evaluation code for the paper Sign Language Transformers: Sign Language Trans

Necati Cihan Camgoz 164 Dec 30, 2022
Reinforcement Learning via Supervised Learning

Reinforcement Learning via Supervised Learning Installation Run pip install -e . in an environment with Python = 3.7.0, 3.9. The code depends on MuJ

Scott Emmons 49 Nov 28, 2022