Official Implementation of SWAD (NeurIPS 2021)

Related tags

Deep Learningswad
Overview

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21)

Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, Sungrae Park.

Note that this project is built upon [email protected].

Preparation

Dependencies

pip install -r requirements.txt

Datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path

Environments

Environment details used for our study.

Python: 3.8.6
PyTorch: 1.7.0+cu92
Torchvision: 0.8.1+cu92
CUDA: 9.2
CUDNN: 7603
NumPy: 1.19.4
PIL: 8.0.1

How to Run

train_all.py script conducts multiple leave-one-out cross-validations for all target domain.

python train_all.py exp_name --dataset PACS --data_dir /my/datasets/path

Experiment results are reported as a table. In the table, the row SWAD indicates out-of-domain accuracy from SWAD. The row SWAD (inD) indicates in-domain validation accuracy.

Example results:

+------------+--------------+---------+---------+---------+---------+
| Selection  | art_painting | cartoon |  photo  |  sketch |   Avg.  |
+------------+--------------+---------+---------+---------+---------+
|   oracle   |   82.245%    | 85.661% | 97.530% | 83.461% | 87.224% |
|    iid     |   87.919%    | 78.891% | 96.482% | 78.435% | 85.432% |
|    last    |   82.306%    | 81.823% | 95.135% | 82.061% | 85.331% |
| last (inD) |   95.807%    | 95.291% | 96.306% | 95.477% | 95.720% |
| iid (inD)  |   97.275%    | 96.619% | 96.696% | 97.253% | 96.961% |
|    SWAD    |   89.750%    | 82.942% | 97.979% | 81.870% | 88.135% |
| SWAD (inD) |   97.713%    | 97.649% | 97.316% | 98.074% | 97.688% |
+------------+--------------+---------+---------+---------+---------+

In this example, the DG performance of SWAD for PACS dataset is 88.135%.

If you set indomain_test option to True, the validation set is splitted to validation and test sets, and the (inD) keys become to indicate in-domain test accuracy.

Reproduce the results of the paper

We provide the instructions to reproduce the main results of the paper, Table 1 and 2. Note that the difference in a detailed environment or uncontrolled randomness may bring a little different result from the paper.

  • PACS
python train_all.py PACS0 --dataset PACS --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS1 --dataset PACS --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS2 --dataset PACS --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • VLCS
python train_all.py VLCS0 --dataset VLCS --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS1 --dataset VLCS --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS2 --dataset VLCS --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
  • OfficeHome
python train_all.py OH0 --dataset OfficeHome --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH1 --dataset OfficeHome --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH2 --dataset OfficeHome --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • TerraIncognita
python train_all.py TR0 --dataset TerraIncognita --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR1 --dataset TerraIncognita --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR2 --dataset TerraIncognita --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • DomainNet
python train_all.py DN0 --dataset DomainNet --deterministic --trial_seed 0 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN1 --dataset DomainNet --deterministic --trial_seed 1 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN2 --dataset DomainNet --deterministic --trial_seed 2 --checkpoint_freq 500 --data_dir /my/datasets/path

Main Results

Citation

The paper will be published at NeurIPS 2021.

@inproceedings{cha2021swad,
  title={SWAD: Domain Generalization by Seeking Flat Minima},
  author={Cha, Junbum and Chun, Sanghyuk and Lee, Kyungjae and Cho, Han-Cheol and Park, Seunghyun and Lee, Yunsung and Park, Sungrae},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

License

This source code is released under the MIT license, included here.

This project includes some code from DomainBed, also MIT licensed.

Code for models used in Bashiri et al., "A Flow-based latent state generative model of neural population responses to natural images".

A Flow-based latent state generative model of neural population responses to natural images Code for "A Flow-based latent state generative model of ne

Sinz Lab 5 Aug 26, 2022
A library for efficient similarity search and clustering of dense vectors.

Faiss Faiss is a library for efficient similarity search and clustering of dense vectors. It contains algorithms that search in sets of vectors of any

Meta Research 18.8k Jan 08, 2023
Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis

Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Website | ICCV paper | arXiv | Twitter This repository contains the official i

Ajay Jain 73 Dec 27, 2022
Instance Segmentation in 3D Scenes using Semantic Superpoint Tree Networks

SSTNet Instance Segmentation in 3D Scenes using Semantic Superpoint Tree Networks(ICCV2021) by Zhihao Liang, Zhihao Li, Songcen Xu, Mingkui Tan, Kui J

83 Nov 29, 2022
[CVPR 2022] Deep Equilibrium Optical Flow Estimation

Deep Equilibrium Optical Flow Estimation This is the official repo for the paper Deep Equilibrium Optical Flow Estimation (CVPR 2022), by Shaojie Bai*

CMU Locus Lab 136 Dec 18, 2022
[CVPR 2021] Unsupervised 3D Shape Completion through GAN Inversion

ShapeInversion Paper Junzhe Zhang, Xinyi Chen, Zhongang Cai, Liang Pan, Haiyu Zhao, Shuai Yi, Chai Kiat Yeo, Bo Dai, Chen Change Loy "Unsupervised 3D

100 Dec 22, 2022
Representing Long-Range Context for Graph Neural Networks with Global Attention

Graph Augmentation Graph augmentation/self-supervision/etc. Algorithms gcn gcn+virtual node gin gin+virtual node PNA GraphTrans Augmentation methods N

UC Berkeley RISE 67 Dec 30, 2022
[ACM MM 2021] Diverse Image Inpainting with Bidirectional and Autoregressive Transformers

Diverse Image Inpainting with Bidirectional and Autoregressive Transformers Installation pip install -r requirements.txt Dataset Preparation Given the

Yingchen Yu 25 Nov 09, 2022
A copy of Ares that costs 30 fucking dollars.

Finalement, j'ai décidé d'abandonner cette idée, je me suis comporté comme un enfant qui été en colère. Comme m'ont dit certaines personnes j'ai des c

Bleu 24 Apr 14, 2022
A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

The GatedTabTransformer. A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. C

Radi Cho 60 Dec 15, 2022
Official implementation of the paper ``Unifying Nonlocal Blocks for Neural Networks'' (ICCV'21)

Spectral Nonlocal Block Overview Official implementation of the paper: Unifying Nonlocal Blocks for Neural Networks (ICCV'21) Spectral View of Nonloca

91 Dec 14, 2022
[Link]deep_portfolo - Use Reforcemet earg ad Supervsed learg to Optmze portfolo allocato []

rl_portfolio This Repository uses Reinforcement Learning and Supervised learning to Optimize portfolio allocation. The goal is to make profitable agen

Deepender Singla 165 Dec 02, 2022
IOT: Instance-wise Layer Reordering for Transformer Structures

Introduction This repository contains the code for Instance-wise Ordered Transformer (IOT), which is introduced in the ICLR2021 paper IOT: Instance-wi

IOT 19 Nov 15, 2022
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
Exploration of some patients clinical variables.

Answer_ALS_clinical_data Exploration of some patients clinical variables. All the clinical / metadata data is available here: https://data.answerals.o

1 Jan 20, 2022
Data-depth-inference - Data depth inference with python

Welcome! This readme will guide you through the use of the code in this reposito

Marco 3 Feb 08, 2022
Latex code for making neural networks diagrams

PlotNeuralNet Latex code for drawing neural networks for reports and presentation. Have a look into examples to see how they are made. Additionally, l

Haris Iqbal 18.6k Jan 01, 2023
Categorical Depth Distribution Network for Monocular 3D Object Detection

CaDDN CaDDN is a monocular-based 3D object detection method. This repository is based off of [OpenPCDet]. Categorical Depth Distribution Network for M

Toronto Robotics and AI Laboratory 289 Jan 05, 2023
Reference implementation for Structured Prediction with Deep Value Networks

Deep Value Network (DVN) This code is a python reference implementation of DVNs introduced in Deep Value Networks Learn to Evaluate and Iteratively Re

Michael Gygli 55 Feb 02, 2022