CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP

Related tags

Deep Learningcloob
Overview

CLOOB: Modern Hopfield Networks with InfoLOOB Outperform CLIP

Andreas Fürst* 1, Elisabeth Rumetshofer* 1, Viet Tran1, Hubert Ramsauer1, Fei Tang3, Johannes Lehner1, David Kreil2, Michael Kopp2, Günter Klambauer1, Angela Bitto-Nemling1, Sepp Hochreiter1 2

1 ELLIS Unit Linz and LIT AI Lab, Institute for Machine Learning, Johannes Kepler University Linz, Austria
2 Institute of Advanced Research in Artificial Intelligence (IARAI)
3 HERE Technologies
* Equal contribution


Detailed blog post on this paper at this link.

The full paper is available here.


Implementation of CLOOB

This repository contains the implemenation of CLOOB used to obtain the results reported in the paper. The implementation is based on OpenCLIP, an open source implementation of OpenAI's CLIP.

Setup

We provide an 'environment.yml' file to set up a conda environment with all required packages. Run the following command to clone the repository and create the environment.

# Clone repository and swtich into the directory
git clone https://github.com/ml-jku/cloob
cd cloob

# Create the environment and activate it
conda env create --file environment.yml
conda activate cloob

# Additionally, webdataset needs to be installed from git repo for pre-training on YFCC 
pip install git+https://github.com/tmbdev/webdataset.git

# Add the directory to the PYTHONPATH environment variable
export PYTHONPATH="$PYTHONPATH:$PWD/src"

Data

For pre-training we use the two datasets supported by OpenCLIP, namely Conceptual Captions and YFCC.

Conceptual Captions

OpenCLIP already provides a script to download and prepare the Conceptual Captions dataset, which contains 2.89M training images and 13k validation images. First, download the Conceptual Captions URLs and then run the script gather_cc.py.

python3 src/data/gather_cc.py path/to/Train_GCC-training.tsv path/to/Validation_GCC-1.1.0-Validation.tsv

YFCC

We use the same subset of ~15M images from the YFCC100M dataset as CLIP. They provide a list of (line number, photo identifier, photo hash) of each image contained in this subset here.

For more information see YFCC100m Subset on OpenAI's github.

Downstream Tasks

In the paper we report results on several downstream tasks. Except for ImageNet we provide links to already pre-processed versions (where necessary) of the respective test set.

Dataset Description Official Processed
Birdsnap This dataset contains images of North American bird species, however
our dataset is smaller than reported in CLIP as some samples are no longer available.
Link Link
Country211 This dataset was published in CLIP and is a small subset of the YFCC100m dataset.
It consists of photos that can be assigned to 211 countries via GPS coordinates.
For each country 200 photos are sampled for the training set and 100 for testing.
Link Link
Flowers102 Images of 102 flower categories commonly occuring in the United Kingdom were collected.
Several classes are very similar and there is a large variation in scale, pose and lighting.
Link Link
GTSRB This dataset was released for a challenge held at the IJCNN 2011.
The dataset contains images of german traffic signs from more than 40 classes.
Link Link
Stanford Cars This dataset contains images of 196 car models at the level of make,
model and year (e.g. Tesla Model S Sedan 2012).
Link Link
UCF101 The dataset has been created by extracting the middle frame from each video. Link Link
ImageNet This dataset spans 1000 object classes and contains 1,281,167 training images,
50,000 validation images and 100,000 test images.
Link -
ImageNet v2 The ImageNetV2 dataset contains new test data for the ImageNet benchmark. Link -

Usage

In the following there is an example command for pretraining on CC with an effective batch size of 512 when used on 4 GPUs.

/conceptual_captions/Train-GCC-training_output.csv" \ --val-data=" /conceptual_captions/Validation_GCC-1.1.0-Validation_output.csv" \ --path-data=" /conceptual_captions" \ --imagenet-val=" /imagenet/val" \ --warmup 20000 \ --batch-size=128 \ --lr=1e-3 \ --wd=0.1 \ --lr-scheduler="cosine-restarts" \ --restart-cycles=10 \ --epochs=70 \ --method="cloob" \ --init-inv-tau=30 \ --init-scale-hopfield=8 \ --workers=8 \ --model="RN50" \ --dist-url="tcp://127.0.0.1:6100" \ --batch-size-eval=512 ">
python -u src/training/main.py \
--train-data="
       
        /conceptual_captions/Train-GCC-training_output.csv
        "
        \
--val-data="
       
        /conceptual_captions/Validation_GCC-1.1.0-Validation_output.csv
        "
        \
--path-data="
       
        /conceptual_captions
        "
        \
--imagenet-val="
       
        /imagenet/val
        "
        \
--warmup 20000 \
--batch-size=128 \
--lr=1e-3 \
--wd=0.1 \
--lr-scheduler="cosine-restarts" \
--restart-cycles=10 \
--epochs=70 \
--method="cloob" \
--init-inv-tau=30 \
--init-scale-hopfield=8 \
--workers=8 \
--model="RN50" \
--dist-url="tcp://127.0.0.1:6100" \
--batch-size-eval=512

Zeroshot evaluation of downstream tasks

We provide a Jupyter notebook to perform zeroshot evaluation with a trained model.

LICENSE

MIT LICENSE

Owner
Institute for Machine Learning, Johannes Kepler University Linz
Software of the Institute for Machine Learning, JKU Linz
Institute for Machine Learning, Johannes Kepler University Linz
We propose a new method for effective shadow removal by regarding it as an exposure fusion problem.

Auto-exposure fusion for single-image shadow removal We propose a new method for effective shadow removal by regarding it as an exposure fusion proble

Qing Guo 146 Dec 31, 2022
Temporal Segment Networks (TSN) in PyTorch

TSN-Pytorch We have released MMAction, a full-fledged action understanding toolbox based on PyTorch. It includes implementation for TSN as well as oth

1k Jan 03, 2023
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
Neural HMMs are all you need (for high-quality attention-free TTS)

Neural HMMs are all you need (for high-quality attention-free TTS) Shivam Mehta, Éva Székely, Jonas Beskow, and Gustav Eje Henter This is the official

Shivam Mehta 0 Oct 28, 2022
DGCNN - Dynamic Graph CNN for Learning on Point Clouds

DGCNN is the author's re-implementation of Dynamic Graph CNN, which achieves state-of-the-art performance on point-cloud-related high-level tasks including category classification, semantic segmentat

Wang, Yue 1.3k Dec 26, 2022
PyTorch implementation of Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

Simple PyTorch Implementation of "Grokking" Implementation of Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets Usage Running

Teddy Koker 15 Sep 29, 2022
Pytorch implementation of AREL

Status: Archive (code is provided as-is, no updates expected) Agent-Temporal Attention for Reward Redistribution in Episodic Multi-Agent Reinforcement

8 Nov 25, 2022
Img-process-manual - Utilize Python Numpy and Matplotlib to realize OpenCV baisc image processing function

Img-process-manual - Opencv Library basic graphic processing algorithm coding reproduction based on Numpy and Matplotlib library

Jack_Shaw 2 Dec 12, 2022
ROCKET: Exceptionally fast and accurate time series classification using random convolutional kernels

ROCKET + MINIROCKET ROCKET: Exceptionally fast and accurate time series classification using random convolutional kernels. Data Mining and Knowledge D

298 Dec 26, 2022
Lava-DL, but with PyTorch-Lightning flavour

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Sami BARCHID 4 Oct 31, 2022
It is modified Tensorflow 2.x version of Mask R-CNN

[TF 2.X] Mask R-CNN for Object Detection and Segmentation [Notice] : The original mask-rcnn uses the tensorflow 1.X version. I modified it for tensorf

Milner 34 Nov 09, 2022
A framework to train language models to learn invariant representations.

Invariant Language Modeling Implementation of the training for invariant language models. Motivation Modern pretrained language models are critical co

6 Nov 16, 2022
This is the official code release for the paper Shape and Material Capture at Home

This is the official code release for the paper Shape and Material Capture at Home. The code enables you to reconstruct a 3D mesh and Cook-Torrance BRDF from one or more images captured with a flashl

89 Dec 10, 2022
Tianshou - An elegant PyTorch deep reinforcement learning library.

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on

Tsinghua Machine Learning Group 5.5k Jan 05, 2023
Global-Local Attention for Emotion Recognition

Global-Local Attention for Emotion Recognition Requirements Python 3 Install tensorflow (or tensorflow-gpu) = 2.0.0 Install some other packages pip i

Minh Nhat Le 15 Apr 21, 2022
Gems & Holiday Package Prediction

Predictive_Modelling Gems & Holiday Package Prediction This project is based on 2 cases studies : Gems Price Prediction and Holiday Package prediction

Avnika Mehta 1 Jan 27, 2022
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

Akshita Gupta 54 Nov 21, 2022
A curated list of the top 10 computer vision papers in 2021 with video demos, articles, code and paper reference.

The Top 10 Computer Vision Papers of 2021 The top 10 computer vision papers in 2021 with video demos, articles, code, and paper reference. While the w

Louis-François Bouchard 118 Dec 21, 2022
Energy consumption estimation utilities for Jetson-based platforms

This repository contains a utility for measuring energy consumption when running various programs in NVIDIA Jetson-based platforms. Currently TX-2, NX, and AGX are supported.

OpenDR 10 Jun 17, 2022
League of Legends Reinforcement Learning Environment (LoLRLE) multiple training scenarios using PPO.

League of Legends Reinforcement Learning Environment (LoLRLE) About This repo contains code to train an agent to play league of legends in a distribut

2 Aug 19, 2022