A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

Overview

The GatedTabTransformer.

A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. Check out our paper on arXiv.

Architecture

Usage

import torch
import torch.nn as nn
from gated_tab_transformer import GatedTabTransformer

model = GatedTabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    transformer_dim = 32,               # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    transformer_depth = 6,              # depth, paper recommended 6
    transformer_heads = 8,              # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_act = nn.LeakyReLU(0),          # activation for final mlp, defaults to relu, but could be anything else (selu, etc.)
    mlp_depth=4,                        # mlp hidden layers depth
    mlp_dimension=32,                   # dimension of mlp layers
    gmlp_enabled=True                   # gmlp or standard mlp
)

x_categ = torch.randint(0, 5, (1, 5))   # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)             # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)
print(pred)

Citation

@misc{cholakov2022gatedtabtransformer,
      title={The GatedTabTransformer. An enhanced deep learning architecture for tabular modeling}, 
      author={Radostin Cholakov and Todor Kolev},
      year={2022},
      eprint={2201.00199},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@software{GatedTabTransformer,
  author = {{Radostin Cholakov, Todor Kolev}},
  title = {The GatedTabTransformer.},
  url = {https://github.com/radi-cho/GatedTabTransformer},
  version = {0.0.1},
  date = {2021-12-15},
}
Owner
Radi Cho
10th grader, coding for more than 6 years now ;) Looking forward the next big thing!
Radi Cho
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

Salesforce 334 Jan 06, 2023
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Perceiver This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on t

Rishit Dagli 84 Oct 15, 2022
Stock-Prediction - prediction of stock market movements using sentiment analysis and deep learning.

Stock-Prediction- In this project, we aim to enhance the prediction of stock market movements using sentiment analysis and deep learning. We divide th

5 Jan 25, 2022
House_prices_kaggle - Predict sales prices and practice feature engineering, RFs, and gradient boosting

House Prices - Advanced Regression Techniques Predicting House Prices with Machine Learning This project is build to enhance my knowledge about machin

Gurpreet Singh 1 Jan 01, 2022
Code for paper [ACE: Ally Complementary Experts for Solving Long-Tailed Recognition in One-Shot] (ICCV 2021, oral))

ACE: Ally Complementary Experts for Solving Long-Tailed Recognition in One-Shot This repository is the official PyTorch implementation of ICCV-21 pape

Jiarui 21 May 09, 2022
Generic Foreground Segmentation in Images

Pixel Objectness The following repository contains pretrained model for pixel objectness. Please visit our project page for the paper and visual resul

Suyog Jain 157 Nov 21, 2022
Transfer Learning library for Deep Neural Networks.

Transfer and meta-learning in Python Each folder in this repository corresponds to a method or tool for transfer/meta-learning. xfer-ml is a standalon

Amazon 245 Dec 08, 2022
The official homepage of the COCO-Stuff dataset.

The COCO-Stuff dataset Holger Caesar, Jasper Uijlings, Vittorio Ferrari Welcome to official homepage of the COCO-Stuff [1] dataset. COCO-Stuff augment

Holger Caesar 715 Dec 31, 2022
Subpopulation detection in high-dimensional single-cell data

PhenoGraph for Python3 PhenoGraph is a clustering method designed for high-dimensional single-cell data. It works by creating a graph ("network") repr

Dana Pe'er Lab 42 Sep 05, 2022
A library for hidden semi-Markov models with explicit durations

hsmmlearn hsmmlearn is a library for unsupervised learning of hidden semi-Markov models with explicit durations. It is a port of the hsmm package for

Joris Vankerschaver 69 Dec 20, 2022
Model parallel transformers in Jax and Haiku

Mesh Transformer Jax A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers. See enwik8_example.py fo

Ben Wang 4.8k Jan 01, 2023
Dynamic View Synthesis from Dynamic Monocular Video

Dynamic View Synthesis from Dynamic Monocular Video Project Website | Video | Paper Dynamic View Synthesis from Dynamic Monocular Video Chen Gao, Ayus

Chen Gao 139 Dec 28, 2022
Weight initialization schemes for PyTorch nn.Modules

nninit Weight initialization schemes for PyTorch nn.Modules. This is a port of the popular nninit for Torch7 by @kaixhin. ##Update This repo has been

Alykhan Tejani 69 Jan 26, 2021
PaRT: Parallel Learning for Robust and Transparent AI

PaRT: Parallel Learning for Robust and Transparent AI This repository contains the code for PaRT, an algorithm for training a base network on multiple

Mahsa 0 May 02, 2022
Implementation of Memory-Compressed Attention, from the paper "Generating Wikipedia By Summarizing Long Sequences"

Memory Compressed Attention Implementation of the Self-Attention layer of the proposed Memory-Compressed Attention, in Pytorch. This repository offers

Phil Wang 47 Dec 23, 2022
FlingBot: The Unreasonable Effectiveness of Dynamic Manipulations for Cloth Unfolding

This repository contains code for training and evaluating FlingBot in both simulation and real-world settings on a dual-UR5 robot arm setup for Ubuntu 18.04

Columbia Artificial Intelligence and Robotics Lab 70 Dec 06, 2022
Adjust Decision Boundary for Class Imbalanced Learning

Adjusting Decision Boundary for Class Imbalanced Learning This repository is the official PyTorch implementation of WVN-RS, introduced in Adjusting De

Peyton Byungju Kim 16 Jan 04, 2023
On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition

On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition With the spirit of reproducible research, this repository contains codes requ

0 Feb 24, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

730 Jan 09, 2023