Self-supervised learning optimally robust representations for domain generalization.

Overview

OptDom: Learning Optimal Representations for Domain Generalization

This repository contains the official implementation for Optimal Representations for Covariate Shift️. Our paper theoretically characterizes the minimal sufficient representations for optimal domain generalization (DG) under covariate shift and derives practical self-supervised learning (SSL) objectives for learning such representations.

We provide code for reproducing our main results with contribution highlights:

  • Finetuning pretrained SSL models (CLIP) to be superior robust DG models ️[minimal example]
  • A novel contrastive adversarial domain bottleneck for learning domain-invariant representations ️[implementation]

Setup

  1. Install PyTorch 1.7.1 and CLIP following the instructions.
  2. Install other packages: pip install -r requirements.txt.

Finetune & Evaluate CLIP on DomainBed

Our paper derives SSL objectives for learning optimally robust representations and gives insights into the superior robustness of CLIP (Sec 4). Here we include the code for finetuning CLIP with our proposed objectives and evaluating on the DomainBed benchmark, which reproduces our experiments in Sec 6.2.

The implementation is included in DomainBed directory which is highly based on the DomainBed repo. The CLIP based models are implemented in domainbed/clip_algorithms.py, and the domain bottlenecks are in domainbed/bottlenecks.py. The training script for finetuning CLIP with bottlenecks is domainbed/scripts/train_clip.py.

Preparation

Move to the DomainBed directory and download the datasets:

python -m domainbed.scripts.download --data_dir ./datasets/

By default, we download the datasets: PACS, VLCS, OfficeHome, TerraIncognita, DomainNet.

Launch a single run

If you want to launch a single run for debugging, run with command:

bash run_debug.sh

The key arguments include:

  • --dataset: dataset for finetuning and evaluation.
  • --algorithm: algorithms implemented with CLIP, see domainbed/clip_algorithms.py.
  • --test_envs: list of left-out environments for testing, others used for training/finetuning.
  • --hparams: JSON-serialized hyperprameter dict, see domainbed/hparams_registry.py for list of all hyperprameters.

Note that the result of a single run could be very sensitive to hyperparameters and random seed, we recommend to launch a sweep over hyperparameters and random seeds as in DomainBed.

Launch a sweep with tuning

To launch a sweep, run with command:

bash run_sweep_clip.sh

A sweep over 10 hyperparameters and 5 random seeds is launched for each dataset and algorithm. By default, the CLIP-RN50 model is used, and you can also run with other models by changing the clip_model argument, e.g., ViT-B/32 for CLIP-ViT-B/32. Also to launch a sweep, you need to select or implement a command launcher in domainbed/command_launchers.py by setting the launcher argument. If you are using slurm, we already implement a slurm launcher that you can adapt from.

After the sweep is finished, you can collect result with the notebook collect_clip_results.ipynb. Note that the results may be slightly different from the paper due to code cleaning.

(Optional) Run CAD in DomainBed setup

You can also evaluate our proposed (conditional) CAD bottleneck in the DomainBed setup where a ResNet-50 is end-to-end trained on source domains and evaluated on a left-out target domain. We include the implementation in domainbed/algorithms.py, which you can run with command:

bash run_sweep_e2e_dombed.sh

Also you can collect result with the notebook collect_e2e_results.ipynb. Note that as the claim of our paper, the algorithms in this setup lack access to the information of the target domain, so we don't expect our bottlenecks and other algorithms to necessarily outperform ERM. However, our CAD bottleneck does lead to consistent improvement surprisingly.

Finetune CLIP on LAION-400M

Coming soon!

Minimal Code for Custom Finetuning

If you want to finetune CLIP on your dataset with our bottlenecks, we provide the minimal code example:

import torch
from torch.utils.data import DataLoader, TensorDataset
import clip
from tqdm import tqdm

from domainbed import hparams_registry
from domainbed import algorithms


# 1. Determine whether you do supervised or contrastive finetuning:
#       - True: use a cross-entropy loss with a supervised dataset
#       - False: use a contrastive loss with a text-image-pair dataset
supervised_funetuning = True

if supervised_funetuning:
    loss_name = "Sup"
    dataset_name = "my suervised dataset"
else:
    loss_name = "Contrast"
    dataset_name = "my text-image pair dataset"


# 2. Determine the bottleneck you want to use with different properties
bottleneck_name = "CondCAD"  # Ent, CAD, CondCAD
algorithm_name = loss_name + "CLIPBottleneck" + bottleneck_name


# 3. Set hyperparameters, you can also change the hyperparameter dict and default values
hparams = hparams_registry.default_hparams(algorithm_name, dataset_name)


# 4. Load pretrained CLIP models
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

pretrained, preprocess = clip.load(hparams['clip_model'], device, jit=False)


# 5. Load your dataset, you  dataset should have the form:
#       - (image, label) for supervised finetuning
#       - (image, text) for contrastive finetuning
#    Remember to use the CLIP preprocessing function for image transformation,
#       and your dataset should be a list of sub-datasets from different domains (singleton for a single domain)
dataset = load_your_dataset(dataset_name, preprocess)
num_envs = len(dataset)
num_classes = dataset.num_classes  # dummy for text-image-pair dataset


# 6. Featurize your dataset with CLIP models

def get_clip_feature(clip_model, x, y):
    """Compute CLIP features"""
    with torch.no_grad():
        z = clip_model.encode_image(x).float()
        if not supervised_funetuning:  # `y` is a batch of texts that should be tokenized
            y = clip_model.encode_text(clip.tokenize(y)).float()
    return z, y

def clip_featurize_data(clip_model, dataset, device):
    """Featurize a dataset"""
    Z, Y = [], []
    for x, y in tqdm(DataLoader(dataset, batch_size=512, num_workers=4)):
        z, y = get_clip_feature(clip_model, x.to(device), y.to(device))
        Z += [z.cpu()]
        Y += [y.cpu()]
    return TensorDataset(torch.cat(Z), torch.cat(Y))

def clip_precompute_splits(clip_model, splits, device):
    _splits = []
    for ds in splits:
        _splits.append(clip_featurize_data(clip_model, ds, device))
    return _splits


dataset = clip_precompute_splits(pretrained, dataset, device)
train_loaders = [DataLoader(
    dataset=env,
    batch_size=hparams['batch_size'],
    num_workers=4)
    for i, env in enumerate(dataset)]
train_minibatches_iterator = zip(*train_loaders)
steps_per_epoch = int(min([len(env) / hparams['batch_size'] for env in dataset]))
n_steps = hparams['max_step']


# 7. Initialize the model:
algorithm_class = algorithms.get_algorithm_class(algorithm_name)
algorithm = algorithm_class(pretrained.visual.output_dim, num_classes, num_envs, hparams, pretrained, None)
algorithm.to(device)
algorithm.train()


# 8. Finetune the model:
for step in range(n_steps):
    minibatches_device = [(x.to(device), y.to(device)) for x, y in next(train_minibatches_iterator)]
    algorithm.adjust_lr(step, n_steps, steps_per_epoch)
    step_vals = algorithm.update(minibatches_device, None)

Cite

If you find this work relevant to your work, please cite our paper:

@article{ruan2021optdom,
  title={Optimal Representations for Covariate Shift},
  author={Ruan, Yangjun and  Dubois, Yann and Maddison, Chris J},
  journal={arXiv preprint arXiv:2201.00057},
  year={2022},
}

Acknowledgement

Our code is based on:

Owner
Yangjun Ruan
Ph.D. student @ UofT & Vector Previously undergrad @ ZJU
Yangjun Ruan
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
Official Implementation of HRDA: Context-Aware High-Resolution Domain-Adaptive Semantic Segmentation

HRDA: Context-Aware High-Resolution Domain-Adaptive Semantic Segmentation by Lukas Hoyer, Dengxin Dai, and Luc Van Gool [Arxiv] [Paper] Overview Unsup

Lukas Hoyer 149 Dec 28, 2022
A library of extension and helper modules for Python's data analysis and machine learning libraries.

Mlxtend (machine learning extensions) is a Python library of useful tools for the day-to-day data science tasks. Sebastian Raschka 2014-2020 Links Doc

Sebastian Raschka 4.2k Jan 02, 2023
Official implementation of "Articulation Aware Canonical Surface Mapping"

Articulation-Aware Canonical Surface Mapping Nilesh Kulkarni, Abhinav Gupta, David F. Fouhey, Shubham Tulsiani Paper Project Page Requirements Python

Nilesh Kulkarni 56 Dec 16, 2022
TimeSHAP explains Recurrent Neural Network predictions.

TimeSHAP TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes even

Feedzai 90 Dec 18, 2022
This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong Poisons

Adversarial poison generation and evaluation. This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong

31 Nov 01, 2022
Do Neural Networks for Segmentation Understand Insideness?

This is part of the code to reproduce the results of the paper Do Neural Networks for Segmentation Understand Insideness? [pdf] by K. Villalobos (*),

biolins 0 Mar 20, 2021
AutoVideo: An Automated Video Action Recognition System

AutoVideo is a system for automated video analysis. It is developed based on D3M infrastructure, which describes machine learning with generic pipeline languages. Currently, it focuses on video actio

Data Analytics Lab at Texas A&M University 267 Dec 17, 2022
A working implementation of the Categorical DQN (Distributional RL).

Categorical DQN. Implementation of the Categorical DQN as described in A distributional Perspective on Reinforcement Learning. Thanks to @tudor-berari

Florin Gogianu 98 Sep 20, 2022
Repo for "Event-Stream Representation for Human Gaits Identification Using Deep Neural Networks"

Summary This is the code for the paper Event-Stream Representation for Human Gaits Identification Using Deep Neural Networks by Yanxiang Wang, Xian Zh

zhangxian 54 Jan 03, 2023
Code for the paper "Adapting Monolingual Models: Data can be Scarce when Language Similarity is High"

Wietse de Vries • Martijn Bartelds • Malvina Nissim • Martijn Wieling Adapting Monolingual Models: Data can be Scarce when Language Similarity is High

Wietse de Vries 5 Aug 02, 2021
Steer OpenAI's Jukebox with Music Taggers

TagBox Steer OpenAI's Jukebox with Music Taggers! The closest thing we have to VQGAN+CLIP for music! Unsupervised Source Separation By Steering Pretra

Ethan Manilow 34 Nov 02, 2022
Label-Free Model Evaluation with Semi-Structured Dataset Representations

Label-Free Model Evaluation with Semi-Structured Dataset Representations Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch

8 Oct 06, 2022
The fundamental package for scientific computing with Python.

NumPy is the fundamental package needed for scientific computing with Python. Website: https://www.numpy.org Documentation: https://numpy.org/doc Mail

NumPy 22.4k Jan 09, 2023
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

65 Dec 27, 2022
Pytorch code for paper "Image Compressed Sensing Using Non-local Neural Network" TMM 2021.

NL-CSNet-Pytorch Pytorch code for paper "Image Compressed Sensing Using Non-local Neural Network" TMM 2021. Note: this repo only shows the strategy of

WenxueCui 7 Nov 07, 2022
LaBERT - A length-controllable and non-autoregressive image captioning model.

Length-Controllable Image Captioning (ECCV2020) This repo provides the implemetation of the paper Length-Controllable Image Captioning. Install conda

bearcatt 53 Nov 13, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

23 Nov 11, 2022
An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects different compression algorithms have.

ImageCompressionSimulation An Image compression simulator that uses Source Extractor and Monte Carlo methods to examine the post compressive effects o

James Park 1 Dec 11, 2021