DANet for Tabular data classification/ regression.

Related tags

Deep LearningDANet
Overview

Deep Abstract Networks

A pyTorch implementation for AAAI-2022 paper DANets: Deep Abstract Networks for Tabular Data Classification and Regression.

Brief Introduction

Tabular data are ubiquitous in real world applications. Although many commonly-used neural components (e.g., convolution) and extensible neural networks (e.g., ResNet) have been developed by the machine learning community, few of them were effective for tabular data and few designs were adequately tailored for tabular data structures. In this paper, we propose a novel and flexible neural component for tabular data, called Abstract Layer (AbstLay), which learns to explicitly group correlative input features and generate higher-level features for semantics abstraction. Also, we design a structure re-parameterization method to compress AbstLay, thus reducing the computational complexity by a clear margin in the reference phase. A special basic block is built using AbstLays, and we construct a family of Deep Abstract Networks (DANets) for tabular data classification and regression by stacking such blocks. In DANets, a special shortcut path is introduced to fetch information from raw tabular features, assisting feature interactions across different levels. Comprehensive experiments on real-world tabular datasets show that our AbstLay and DANets are effective for tabular data classification and regression, and the computational complexity is superior to competitive methods.

DANets illustration

DANets

Downloads

Dataset

Download the datasets from the following links:

(Optional) Before starting the program, you may change the file format to .pkl by using svm2pkl() or csv2pkl() functions in ./data/data_util.py.

Weights for inference models

The demo weights for Forest Cover Type dataset is available in the folder "./Weights/".

How to use

Setting

  1. Clone or download this repository, and cd the path.
  2. Build a working python environment. Python 3.7 is fine for this repository.
  3. Install packages following the requirements.txt, e.g., by using pip install -r requirements.txt.

Training

  1. Set the hyperparameters in config files (./config/default.py or ./config/*.yaml).
    Notably, the hyperparameters in .yaml file will cover those in default.py.

  2. Run by python main.py --c [config_path] --g [gpu_id].

    • -c: The config file path
    • -g: GPU device ID
  3. The checkpoint models and best models will be saved at the ./logs file.

Inference

  1. Replace the resume_dir path with the file path containing your trained model/weight.
  2. Run codes by using python predict.py -d [dataset_name] -m [model_file_path] -g [gpu_id].
    • -d: Dataset name
    • -m: Model path for loading
    • -g: GPU device ID

Config Hyperparameters

Normal parameters

  • dataset: str
    The dataset name given must match those in ./data/dataset.py.

  • task: str
    Choose one of the pre-given tasks 'classification' and 'regression'.

  • resume_dir: str
    The log path containing the checkpoint models.

  • logname: str
    The directory names of the models save at ./logs.

  • seed: int
    The random seed.

Model parameters

  • layer: int (default=20)
    Number of abstract layers to stack

  • k: int (default=5)
    Number of masks

  • base_outdim: int (default=64)
    The output feature dimension in abstract layer.

  • drop_rate: float (default=0.1)
    Dropout rate in shortcut module

Fit parameters

  • lr: float (default=0.008)
    Learning rate

  • max_epochs: int (default=5000)
    Maximum number of epochs in training.

  • patience: int (default=1500)
    Number of consecutive epochs without improvement before performing early stopping. If patience is set to 0, then no early stopping will be performed.

  • batch_size: int (default=8192)
    Number of examples per batch.

  • virtual_batch_size: int (default=256)
    Size of the mini batches used for "Ghost Batch Normalization". virtual_batch_size must divide batch_size.

Citations

@inproceedings{danets, 
   title={DANets: Deep Abstract Networks for Tabular Data Classification and Regression}, 
   author={Chen, Jintai and Liao, Kuanlun and Wan, Yao and Chen, Danny Z and Wu, Jian}, 
   booktitle={AAAI}, 
   year={2022}
 }
Owner
Ronnie Rocket
Ronnie Rocket
Inflated i3d network with inception backbone, weights transfered from tensorflow

I3D models transfered from Tensorflow to PyTorch This repo contains several scripts that allow to transfer the weights from the tensorflow implementat

Yana 479 Dec 08, 2022
Implementation of "Efficient Regional Memory Network for Video Object Segmentation" (Xie et al., CVPR 2021).

RMNet This repository contains the source code for the paper Efficient Regional Memory Network for Video Object Segmentation. Cite this work @inprocee

Haozhe Xie 76 Dec 14, 2022
TensorFlow implementation of "TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?"

TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? Source: Improving Vision Transformer Efficiency and Accuracy by Learning to Tokenize

Aritra Roy Gosthipaty 23 Dec 24, 2022
This repo is to present various code demos on how to use our Graph4NLP library.

Deep Learning on Graphs for Natural Language Processing Demo The repository contains code examples for DLG4NLP tutorials at NAACL 2021, SIGIR 2021, KD

Graph4AI 143 Dec 23, 2022
Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning using 🤗 transformers

hierarchical-transformer-1d Implementation of H-Transformer-1D, Hierarchical Attention for Sequence Learning using 🤗 transformers In Progress!! 2021.

MyungHoon Jin 7 Nov 06, 2022
A large-scale face dataset for face parsing, recognition, generation and editing.

CelebAMask-HQ [Paper] [Demo] CelebAMask-HQ is a large-scale face image dataset that has 30,000 high-resolution face images selected from the CelebA da

switchnorm 1.7k Dec 26, 2022
SimpleDepthEstimation - An unified codebase for NN-based monocular depth estimation methods

SimpleDepthEstimation Introduction This is an unified codebase for NN-based monocular depth estimation methods, the framework is based on detectron2 (

8 Dec 13, 2022
Pytorch implementation of the paper "Topic Modeling Revisited: A Document Graph-based Neural Network Perspective"

Graph Neural Topic Model (GNTM) This is the pytorch implementation of the paper "Topic Modeling Revisited: A Document Graph-based Neural Network Persp

Dazhong Shen 8 Sep 14, 2022
This is the official repository of XVFI (eXtreme Video Frame Interpolation)

XVFI This is the official repository of XVFI (eXtreme Video Frame Interpolation), https://arxiv.org/abs/2103.16206 Last Update: 20210607 We provide th

Jihyong Oh 195 Dec 29, 2022
[ICLR 2021] "Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective" by Wuyang Chen, Xinyu Gong, Zhangyang Wang

Neural Architecture Search on ImageNet in Four GPU Hours: A Theoretically Inspired Perspective [PDF] Wuyang Chen, Xinyu Gong, Zhangyang Wang In ICLR 2

VITA 156 Nov 28, 2022
Detecting Potentially Harmful and Protective Suicide-related Content on Twitter

TwitterSuicideML Scripts for reproducing the Machine Learning analysis of the paper: Detecting Potentially Harmful and Protective Suicide-related Cont

3 Oct 17, 2022
Official PyTorch code for WACV 2022 paper "CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows"

CFLOW-AD: Real-Time Unsupervised Anomaly Detection with Localization via Conditional Normalizing Flows WACV 2022 preprint:https://arxiv.org/abs/2107.1

Denis 156 Dec 28, 2022
AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

AMTML-KD: Adaptive Multi-teacher Multi-level Knowledge Distillation

Frank Liu 26 Oct 13, 2022
Hyper-parameter optimization for sklearn

hyperopt-sklearn Hyperopt-sklearn is Hyperopt-based model selection among machine learning algorithms in scikit-learn. See how to use hyperopt-sklearn

1.4k Jan 01, 2023
Malware Bypass Research using Reinforcement Learning

Malware Bypass Research using Reinforcement Learning

Bobby Filar 76 Dec 26, 2022
Train a state-of-the-art yolov3 object detector from scratch!

TrainYourOwnYOLO: Building a Custom Object Detector from Scratch This repo let's you train a custom image detector using the state-of-the-art YOLOv3 c

AntonMu 616 Jan 08, 2023
Depth-Aware Video Frame Interpolation (CVPR 2019)

DAIN (Depth-Aware Video Frame Interpolation) Project | Paper Wenbo Bao, Wei-Sheng Lai, Chao Ma, Xiaoyun Zhang, Zhiyong Gao, and Ming-Hsuan Yang IEEE C

Wenbo Bao 7.7k Dec 31, 2022
A simple log parser and summariser for IIS web server logs

IISLogFileParser A basic parser tool for IIS Logs which summarises findings from the log file. Inspired by the Gist https://gist.github.com/wh13371/e7

2 Mar 26, 2022
This repository contains the map content ontology used in narrative cartography

Narrative-cartography-ontology This repository contains the map content ontology used in narrative cartography, which is associated with a submission

Weiming Huang 0 Oct 31, 2021
AnimationKit: AI Upscaling & Interpolation using Real-ESRGAN+RIFE

ALPHA 2.5: Frostbite Revival (Released 12/23/21) Changelog: [ UI ] Chained design. All steps link to one another! Use the master override toggles to s

87 Nov 16, 2022