Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

Overview

PAWS-TF 🐾

Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS) in TensorFlow (2.4.1).

PAWS introduces a simple way to combine a very small fraction of labeled data with a comparatively larger corpus of unlabeled data during pre-training. With its approach, it sets the state-of-the-art in semi-supervised learning (as of May 2021) beating methods like SimCLRV2, Meta Pseudo Labels that too with fewer parameters and a smaller pre-training schedule. For details, I recommend checking out the original paper as well as this blog post by the authors.

This repository implements and includes all the major bits proposed in PAWS in TensorFlow. The only major difference is that the pre-training and subsequent fine-tuning weren't run for the original number of epochs (600 and 30 respectively) to save compute. I have reused the utility components for PAWS loss from the original implementation.

Dataset ⌗

The current code works with CIFAR10 and uses 4000 labeled samples (8%) during pre-training (along with the unlabeled samples).

Features

  • Multi-crop augmentation strategy (originally introduced in SwAV)
  • Class stratified sampler (common in few-shot classification problems)
  • WarmUpCosine learning rate schedule (which is typical for self-supervised and semi-supervised pre-training)
  • LARS optimizer (comes from TensorFlow Model Garden)

The trunk portion (all, except the last classification layer) of a WideResNet-28-2 is used inside the encoder for CIFAR10. All the experimental configurations were followed from the Appendix C of the paper.

Setup and code structure 💻

A GCP VM (n1-standard-8) with a single V100 GPU was used for executing the code.

  • paws_train.py runs the pre-training as introduced in PAWS.
  • fine_tune.py runs the fine-tuning part as suggested in Appendix C. Note that this is only required for CIFAR10.
  • nn_eval.py runs the soft nearest neighbor classification on CIFAR10 test set.

Pre-training and fine-tuning total take 1.4 hours to complete. All the logs are available in misc/logs.txt. Additionally, the indices that were used to sample the labeled examples from the CIFAR10 training set are available here.

Results 📊

Pre-training

PAWS minimizes the cross-entropy loss (as well as maximizes mean-entropy) during pre-training. This is what the training plot indicates too:

To evaluate the effectivity of the pre-training, PAWS performs soft nearest neighbor classification to report the top-1 accuracy score on a given test set.

Top-1 Accuracy

This repository gets to 73.46% top-1 accuracy on the CIFAR10 test set. Again, note that I only pre-trained for 50 epochs (as opposed to 600) and fine-tuned for 10 epochs (as opposed to 30). With the original schedule this score should be around 96.0%.

In the following PCA projection plot, we see that the embeddings of images (computed after fine-tuning) of PAWS are starting to be well separated:

Notebooks 📘

There are two Colab Notebooks:

Misc ⺟

  • Model weights are available here for reproducibility.
  • With mixed-precision training, the performance can further be improved. I am open to accepting contributions that would implement mixed-precision training in the current code.

Acknowledgements

  • Huge amount of thanks to Mahmoud Assran (first author of PAWS) for patiently resolving my doubts.
  • ML-GDE program for providing GCP credit support.

Paper Citation

@misc{assran2021semisupervised,
      title={Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples}, 
      author={Mahmoud Assran and Mathilde Caron and Ishan Misra and Piotr Bojanowski and Armand Joulin and Nicolas Ballas and Michael Rabbat},
      year={2021},
      eprint={2104.13963},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Supplementary code for the paper
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

ISTR: End-to-End Instance Segmentation with Transformers (https://arxiv.org/abs/2105.00637)

This is the project page for the paper: ISTR: End-to-End Instance Segmentation via Transformers, Jie Hu, Liujuan Cao, Yao Lu, ShengChuan Zhang, Yan Wa

Releases(v1.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
Code for the paper "Improving Vision-and-Language Navigation with Image-Text Pairs from the Web" (ECCV 2020)

Improving Vision-and-Language Navigation with Image-Text Pairs from the Web Arjun Majumdar, Ayush Shrivastava, Stefan Lee, Peter Anderson, Devi Parikh

Arjun Majumdar 44 Dec 14, 2022
Video2x - A lossless video/GIF/image upscaler achieved with waifu2x, Anime4K, SRMD and RealSR.

Official Discussion Group (Telegram): https://t.me/video2x A Discord server is also available. Please note that most developers are only on Telegram.

K4YT3X 5.9k Dec 31, 2022
LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice,

LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice, for a model of choice, by iteratively removing each feature from the set, and eval

Ahmet Erdem 691 Dec 23, 2022
ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation.

ENet This work has been published in arXiv: ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation. Packages: train contains too

e-Lab 344 Nov 21, 2022
A template repository for submitting a job to the Slurm Cluster installed at the DISI - University of Bologna

Cluster di HPC con GPU per esperimenti di calcolo (draft version 1.0) Per poter utilizzare il cluster il primo passo è abilitare l'account istituziona

20 Dec 16, 2022
This repository contains the implementation of the following paper: Cross-Descriptor Visual Localization and Mapping

Cross-Descriptor Visual Localization and Mapping This repository contains the implementation of the following paper: "Cross-Descriptor Visual Localiza

Mihai Dusmanu 81 Oct 06, 2022
An Open-Source Package for Information Retrieval.

OpenMatch An Open-Source Package for Information Retrieval. 😃 What's New Top Spot on TREC-COVID Challenge (May 2020, Round2) The twin goals of the ch

THUNLP 439 Dec 27, 2022
Direct design of biquad filter cascades with deep learning by sampling random polynomials.

IIRNet Direct design of biquad filter cascades with deep learning by sampling random polynomials. Usage git clone https://github.com/csteinmetz1/IIRNe

Christian J. Steinmetz 55 Nov 02, 2022
An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

An Ensemble of CNN (Python 3.5.1 Tensorflow 1.3 numpy 1.13)

0 May 06, 2022
Just playing with getting VQGAN+CLIP running locally, rather than having to use colab.

Just playing with getting VQGAN+CLIP running locally, rather than having to use colab.

Nerdy Rodent 2.3k Jan 04, 2023
Exploration & Research into cross-domain MEV. Initial focus on ETH/POLYGON.

xMEV, an apt exploration This is a small exploration on the xMEV opportunities between Polygon and Ethereum. It's a data analysis exercise on a few pa

odyslam.eth 7 Oct 18, 2022
Resources for the "Evaluating the Factual Consistency of Abstractive Text Summarization" paper

Evaluating the Factual Consistency of Abstractive Text Summarization Authors: Wojciech Kryściński, Bryan McCann, Caiming Xiong, and Richard Socher Int

Salesforce 165 Dec 21, 2022
Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting

Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting This is the origin Pytorch implementation of Informer in the followin

Haoyi 3.1k Dec 29, 2022
PyTorch implementation of InstaGAN: Instance-aware Image-to-Image Translation

InstaGAN: Instance-aware Image-to-Image Translation Warning: This repo contains a model which has potential ethical concerns. Remark that the task of

Sangwoo Mo 827 Dec 29, 2022
Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index.

TechSEO Crawler Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index. Play with the r

JR Oakes 57 Nov 24, 2022
Adversarial Graph Representation Adaptation for Cross-Domain Facial Expression Recognition (AGRA, ACM 2020, Oral)

Cross Domain Facial Expression Recognition Benchmark Implementation of papers: Cross-Domain Facial Expression Recognition: A Unified Evaluation Benchm

89 Dec 09, 2022
Official PyTorch implementation of paper: Standardized Max Logits: A Simple yet Effective Approach for Identifying Unexpected Road Obstacles in Urban-Scene Segmentation (ICCV 2021 Oral Presentation)

SML (ICCV 2021, Oral) : Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Standardi

SangHun 61 Dec 27, 2022
Preprocessed Datasets for our Multimodal NER paper

Unified Multimodal Transformer (UMT) for Multimodal Named Entity Recognition (MNER) Two MNER Datasets and Codes for our ACL'2020 paper: Improving Mult

76 Dec 21, 2022
T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time

T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time The first Lidar-only odometry framework with high performance based on tr

Pengwei Zhou 183 Dec 01, 2022