The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

To create a new environment with conda

conda create -n saint_env python=3.8
conda activate saint_env

We recommend installing the latest pytorch, torchvision, einops, pandas, wget, sklearn packages.

You can install them using

conda install pytorch torchvision -c pytorch
conda install -c conda-forge einops 
conda install -c conda-forge pandas 
conda install -c conda-forge python-wget 
conda install -c anaconda scikit-learn 

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

Optional

We used wandb to update our logs. But it is optional.

conda install -c conda-forge wandb 

Training & Evaluation

In each of our experiments, we use a single Nvidia GeForce RTX 2080Ti GPU.

First download the processed datasets from this link into the folder ./data

To train the model(s) in the paper, run this command:

python train.py  --dataset <dataset_name> --attentiontype <attention_type> 

Pretraining is useful when there are few training data samples. Sample code looks like this

python train.py  --dataset <dataset_name> --attentiontype <attention_type> --pretrain --pt_tasks <pretraining_task_touse> --pt_aug <augmentations_on_data_touse> --ssl_avail_y <Number_of_labeled_samples>

Train all 16 datasets by running bash files. train.sh for supervised learning and train_pt.sh for pretraining and semi-supervised learning

bash train.sh
bash train_pt.sh

Arguments

  • --dataset : Dataset name. We support only the 16 datasets discussed in the paper. Supported datasets are ['1995_income','bank_marketing','qsar_bio','online_shoppers','blastchar','htru2','shrutime','spambase','philippine','mnist','arcene','volkert','creditcard','arrhythmia','forest','kdd99']
  • --embedding_size : Size of the feature embeddings
  • --transformer_depth : Depth of the model. Number of stages.
  • --attention_heads : Number of attention heads in each Attention layer.
  • --cont_embeddings : Style of embedding continuous data.
  • --attentiontype : Variant of SAINT. 'col' refers to SAINT-s variant, 'row' is SAINT-i, and 'colrow' refers to SAINT.
  • --pretrain : To enable pretraining
  • --pt_tasks : Losses we want to use for pretraining. Multiple arguments can be passed.
  • --pt_aug : Types of data augmentations used in pretraining. Multiple arguments are allowed. We support only mixup and CutMix right now.
  • --ssl_avail_y : Number of labeled samples used in semi-supervised experiments. Default is 0, which means all samples are labeled and is supervised case.
  • --pt_projhead_style : Projection head style used in contrastive pipeline.
  • --nce_temp : Temperature used in contrastive loss function.
  • --active_log : To update the logs onto wandb. This is optional

Evaluation

We choose the best model by evaluating the model on validation dataset. The AUROC(for binary classification datasets) and Accuracy (for multiclass classification datasets) of the best model on test datasets is printed after training is completed. If wandb is enabled, they are logged to 'test_auroc_bestep', 'test_accuracy_bestep' variables.

Acknowledgements

We would like to thank the following public repo from which we borrowed various utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2021saint,
  title={SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training},
  author={Somepalli, Gowthami and Goldblum, Micah and Schwarzschild, Avi and Bruss, C Bayan and Goldstein, Tom},
  journal={arXiv preprint arXiv:2106.01342},
  year={2021}
}

Owner
Gowthami Somepalli
Gowthami Somepalli
PROJECT - Az Residential Real Estate Analysis

AZ RESIDENTIAL REAL ESTATE ANALYSIS -Decided on libraries to import. Includes pa

2 Jul 05, 2022
public repo for ESTER dataset and modeling (EMNLP'21)

Project / Paper Introduction This is the project repo for our EMNLP'21 paper: https://arxiv.org/abs/2104.08350 Here, we provide brief descriptions of

PlusLab 19 Oct 27, 2022
Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet)

Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet) Our paper: https://arxiv.org/abs/2111.13324 We will release the complet

15 Oct 17, 2022
Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"

T-Few This repository contains the official code for the paper: "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learni

220 Dec 31, 2022
Official public repository of paper "Intention Adaptive Graph Neural Network for Category-Aware Session-Based Recommendation"

Intention Adaptive Graph Neural Network (IAGNN) This is the official repository of paper Intention Adaptive Graph Neural Network for Category-Aware Se

9 Nov 22, 2022
This program automatically runs Python code copied in clipboard

CopyRun This program runs Python code which is copied in clipboard WARNING!! USE AT YOUR OWN RISK! NO GUARANTIES IF ANYTHING GETS BROKEN. DO NOT COPY

vertinski 4 Sep 10, 2021
Torchserve server using a YoloV5 model running on docker with GPU and static batch inference to perform production ready inference.

Yolov5 running on TorchServe (GPU compatible) ! This is a dockerfile to run TorchServe for Yolo v5 object detection model. (TorchServe (PyTorch librar

82 Nov 29, 2022
Inverse Rendering for Complex Indoor Scenes: Shape, Spatially-Varying Lighting and SVBRDF From a Single Image

Inverse Rendering for Complex Indoor Scenes: Shape, Spatially-Varying Lighting and SVBRDF From a Single Image (Project page) Zhengqin Li, Mohammad Sha

209 Jan 05, 2023
Diagnostic tests for linguistic capacities in language models

LM diagnostics This repository contains the diagnostic datasets and experimental code for What BERT is not: Lessons from a new suite of psycholinguist

61 Jan 02, 2023
Cossim - Sharpened Cosine Distance implementation in PyTorch

Sharpened Cosine Distance PyTorch implementation of the Sharpened Cosine Distanc

Istvan Fehervari 10 Mar 22, 2022
[ICCV 2021] Deep Hough Voting for Robust Global Registration

Deep Hough Voting for Robust Global Registration, ICCV, 2021 Project Page | Paper | Video Deep Hough Voting for Robust Global Registration Junha Lee1,

Junha Lee 10 Dec 02, 2022
Differentiable rasterization applied to 3D model simplification tasks

nvdiffmodeling Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Automatic 3D Model

NVIDIA Research Projects 336 Dec 30, 2022
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Ă–zdemir Cetin & Heinz Koeppl | Pr

Christoph Reich 23 Sep 21, 2022
Unofficial PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution

PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution [arXiv 2021].

Christoph Reich 122 Dec 12, 2022
Quantify the difference between two arbitrary curves in space

similaritymeasures Quantify the difference between two arbitrary curves Curves in this case are: discretized by inidviudal data points ordered from a

Charles Jekel 175 Jan 08, 2023
PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020).

Scaffold-Federated-Learning PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020). Environment numpy=

KI 30 Dec 29, 2022
Equivariant layers for RC-complement symmetry in DNA sequence data

Equi-RC Equivariant layers for RC-complement symmetry in DNA sequence data This is a repository that implements the layers as described in "Reverse-Co

7 May 19, 2022
A Benchmark For Measuring Systematic Generalization of Multi-Hierarchical Reasoning

Orchard Dataset This repository contains the code used for generating the Orchard Dataset, as seen in the Multi-Hierarchical Reasoning in Sequences: S

Bill Pung 1 Jun 05, 2022
PyTorch code for Composing Partial Differential Equations with Physics-Aware Neural Networks

FInite volume Neural Network (FINN) This repository contains the PyTorch code for models, training, and testing, and Python code for data generation t

Cognitive Modeling 20 Dec 18, 2022
Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images

Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images In this paper, we present an effective Dynamic Enhancement Anchor

13 Dec 09, 2022