PyTorch implementation of "Supervised Contrastive Learning" (and SimCLR incidentally)

Overview

SupContrast: Supervised Contrastive Learning

This repo covers an reference implementation for the following papers in PyTorch, using CIFAR as an illustrative example:
(1) Supervised Contrastive Learning. Paper
(2) A Simple Framework for Contrastive Learning of Visual Representations. Paper

Loss Function

The loss function SupConLoss in losses.py takes features (L2 normalized) and labels as input, and return the loss. If labels is None or not passed to the it, it degenerates to SimCLR.

Usage:

from losses import SupConLoss

# define loss with a temperature `temp`
criterion = SupConLoss(temperature=temp)

# features: [bsz, n_views, f_dim]
# `n_views` is the number of crops from each image
# better be L2 normalized in f_dim dimension
features = ...
# labels: [bsz]
labels = ...

# SupContrast
loss = criterion(features, labels)
# or SimCLR
loss = criterion(features)
...

Comparison

Results on CIFAR-10:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 95.0
SupContrast ResNet50 Supervised Contrastive 96.0
SimCLR ResNet50 Unsupervised Contrastive 93.6

Results on CIFAR-100:

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy 75.3
SupContrast ResNet50 Supervised Contrastive 76.5
SimCLR ResNet50 Unsupervised Contrastive 70.7

Results on ImageNet (Stay tuned):

Arch Setting Loss Accuracy(%)
SupCrossEntropy ResNet50 Supervised Cross Entropy -
SupContrast ResNet50 Supervised Contrastive 79.1 (MoCo trick)
SimCLR ResNet50 Unsupervised Contrastive -

Running

You might use CUDA_VISIBLE_DEVICES to set proper number of GPUs, and/or switch to CIFAR100 by --dataset cifar100.
(1) Standard Cross-Entropy

python main_ce.py --batch_size 1024 \
  --learning_rate 0.8 \
  --cosine --syncBN \

(2) Supervised Contrastive Learning
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.1 \
  --cosine

You can also specify --syncBN but I found it not crucial for SupContrast (syncBN 95.9% v.s. BN 96.0%).
Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 5 \
  --ckpt /path/to/model.pth

(3) SimCLR
Pretraining stage:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5 \
  --temp 0.5 \
  --cosine --syncBN \
  --method SimCLR

The --method SimCLR flag simply stops labels from being passed to SupConLoss criterion. Linear evaluation stage:

python main_linear.py --batch_size 512 \
  --learning_rate 1 \
  --ckpt /path/to/model.pth

On custom dataset:

python main_supcon.py --batch_size 1024 \
  --learning_rate 0.5  \ 
  --temp 0.1 --cosine \
  --dataset path \
  --data_folder ./path \
  --mean "(0.4914, 0.4822, 0.4465)" \
  --std "(0.2675, 0.2565, 0.2761)" \
  --method SimCLR

The --data_folder must be of form ./path/label/xxx.png folowing https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.ImageFolder convension.

and

t-SNE Visualization

(1) Standard Cross-Entropy

(2) Supervised Contrastive Learning

(3) SimCLR

Reference

@Article{khosla2020supervised,
    title   = {Supervised Contrastive Learning},
    author  = {Prannay Khosla and Piotr Teterwak and Chen Wang and Aaron Sarna and Yonglong Tian and Phillip Isola and Aaron Maschinot and Ce Liu and Dilip Krishnan},
    journal = {arXiv preprint arXiv:2004.11362},
    year    = {2020},
}
Owner
Yonglong Tian
CS Ph.D. student in AI @ MIT
Yonglong Tian
Supervised Contrastive Learning for Product Matching

Contrastive Product Matching This repository contains the code and data download links to reproduce the experiments of the paper "Supervised Contrasti

Web-based Systems Group @ University of Mannheim 18 Dec 10, 2022
Python module providing a framework to trace individual edges in an image using Gaussian process regression.

Edge Tracing using Gaussian Process Regression Repository storing python module which implements a framework to trace individual edges in an image usi

Jamie Burke 7 Dec 27, 2022
ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN

ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN CVPR 2020 (Oral); Pose and Appearance Attributes Transfer;

Men Yifang 400 Dec 29, 2022
《Rethinking Sptil Dimensions of Vision Trnsformers》(2021)

Rethinking Spatial Dimensions of Vision Transformers Byeongho Heo, Sangdoo Yun, Dongyoon Han, Sanghyuk Chun, Junsuk Choe, Seong Joon Oh | Paper NAVER

NAVER AI 224 Dec 27, 2022
AWS documentation corpus for zero-shot open-book question answering.

aws-documentation We present the AWS documentation corpus, an open-book QA dataset, which contains 25,175 documents along with 100 matched questions a

Sia Gholami 2 Jul 07, 2022
Source code of our BMVC 2021 paper: AniFormer: Data-driven 3D Animation with Transformer

AniFormer This is the PyTorch implementation of our BMVC 2021 paper AniFormer: Data-driven 3D Animation with Transformer. Haoyu Chen, Hao Tang, Nicu S

24 Nov 02, 2022
Official Python implementation of the 'Sparse deconvolution'-v0.3.0

Sparse deconvolution Python v0.3.0 Official Python implementation of the 'Sparse deconvolution', and the CPU (NumPy) and GPU (CuPy) calculation backen

Weisong Zhao 23 Dec 28, 2022
It is a simple library to speed up CLIP inference up to 3x (K80 GPU)

CLIP-ONNX It is a simple library to speed up CLIP inference up to 3x (K80 GPU) Usage Install clip-onnx module and requirements first. Use this trick !

Gerasimov Maxim 93 Dec 20, 2022
Code release for Local Light Field Fusion at SIGGRAPH 2019

Local Light Field Fusion Project | Video | Paper Tensorflow implementation for novel view synthesis from sparse input images. Local Light Field Fusion

1.1k Dec 27, 2022
Hand Gesture Volume Control | Open CV | Computer Vision

Gesture Volume Control Hand Gesture Volume Control | Open CV | Computer Vision Use gesture control to change the volume of a computer. First we look i

Jhenil Parihar 3 Jun 15, 2022
Robust fine-tuning of zero-shot models

Robust fine-tuning of zero-shot models This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabri

224 Dec 29, 2022
GPU Programming with Julia - course at the Swiss National Supercomputing Centre (CSCS), ETH Zurich

Course Description The programming language Julia is being more and more adopted in High Performance Computing (HPC) due to its unique way to combine

Samuel Omlin 192 Jan 03, 2023
This repository contains the source code of an efficient 1D probabilistic model for music time analysis proposed in ICASSP2022 venue.

Jump Reward Inference for 1D Music Rhythmic State Spaces An implementation of the probablistic jump reward inference model for music rhythmic informat

Mojtaba Heydari 25 Dec 16, 2022
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion" Coming soon, as soon as I finish a

Ziyao Zeng 14 Feb 26, 2022
A PyTorch implementation of Radio Transformer Networks from the paper "An Introduction to Deep Learning for the Physical Layer".

An Introduction to Deep Learning for the Physical Layer An usable PyTorch implementation of the noisy autoencoder infrastructure in the paper "An Intr

Gram.AI 120 Nov 21, 2022
Code for approximate graph reduction techniques for cardinality-based DSFM, from paper

SparseCard Code for approximate graph reduction techniques for cardinality-based DSFM, from paper "Approximate Decomposable Submodular Function Minimi

Nate Veldt 1 Nov 25, 2022
Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Label-Efficient Semantic Segmentation with Diffusion Models Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion

Yandex Research 355 Jan 06, 2023
Securetar - A streaming wrapper around python tarfile and allow secure handling files and support encryption

Secure Tar Secure Tarfile library It's a streaming wrapper around python tarfile

Pascal Vizeli 2 Dec 09, 2022
A modular domain adaptation library written in PyTorch.

A modular domain adaptation library written in PyTorch.

Kevin Musgrave 225 Dec 29, 2022
Pseudo-Visual Speech Denoising

Pseudo-Visual Speech Denoising This code is for our paper titled: Visual Speech Enhancement Without A Real Visual Stream published at WACV 2021. Autho

Sindhu 94 Oct 22, 2022