Implementation of TabTransformer, attention network for tabular data, in Pytorch

Overview

Tab Transformer

Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.

Install

$ pip install tab-transformer-pytorch

Usage

import torch
from tab_transformer_pytorch import TabTransformer

cont_mean_std = torch.randn(10, 2)

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)

Unsupervised Training

To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.

Citations

@misc{huang2020tabtransformer,
    title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings}, 
    author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
    year={2020},
    eprint={2012.06678},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Comments
  • Minor Bug: actuation function being applied to output layer in class MLP

    Minor Bug: actuation function being applied to output layer in class MLP

    The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 1)
    

    The last line should be changed to is_last = ind >= (len(dims) - 2):

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 2)
    

    If you like, I can do a pull request.

    opened by rminhas 1
  • Update tab_transformer_pytorch.py

    Update tab_transformer_pytorch.py

    Add activation function out of the loop for the whole model, not after each of the linear layers. 'if is_last' condition was creating linear output all the time no matter what the activation function was.

    opened by EveryoneDirn 0
  • Unindent continuous_mean_std buffer

    Unindent continuous_mean_std buffer

    Problem: continuous_mean_std is not an attribute of TabTransformer if not defined in the argument explicitly. Example reproducing AttributeError:

    model = TabTransformer(
        categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
        num_continuous = 10,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
        mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
        mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    # continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm)
    x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
    x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually
    pred = model(x_categ, x_cont) # gives AttributeError
    
    

    Solution: Simply un-indenting the buffer registration of continuous_mean_std.

    opened by spliew 0
  • low gpu usage,

    low gpu usage,

    Hi.

    I'm having a problem with running your code with my dataset. It's pretty slow. GPU runs at 50% usage in average and each epoch takes almost 900 seconds to run.

    My dataset has 590540 rows, 24 categorical features, and 192 continuous features. Categories are encoded using Label encoder. Total dataset size is around 600Mb. My gpu is an integrated NVIDIA RTX 3060 with 6Gb of RAM. Optimizer is Adam.

    These are the software versions:

    Windows 10

    Python: 3.7.11 Pytorch: 1.7.0+cu110 Numpy: 1.21.2

    Let me know if you need more info from my side.

    Thanks.

    Xin.

    opened by xinqiao123 0
  • Intended usage of num_special_tokens?

    Intended usage of num_special_tokens?

    From what I understand, these are supposed to be reserved for oov values. Is the intended usage to set oov values in the input to some negative number and overwrite the offset? That is what it seems like it would take to achieve the desired outcome, but also seems somewhat confusing and clunky to do. Or perhaps I am misunderstanding its purpose? Thanks!

    opened by LLYX 2
  • No Category Shared Embedding?

    No Category Shared Embedding?

    I noticed that this implementation does not seem to have the feature of a shared embedding between each value belonging to the same category (unless I missed it) that the paper mentions (c_phi_i). If it's indeed missing, do you have plans to add that?

    Thanks for this implementation!

    opened by LLYX 3
  • index -1 is out of bounds for dimension 1 with size 17

    index -1 is out of bounds for dimension 1 with size 17

    I encountered this problem during the training process. What is the possible reason for this problem, and how can I solve this problem? Thanks!

      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
        return self.tabnet(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
        steps_output, M_loss = self.encoder(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 160, in forward
        M = self.att_transformers[step](prior, att)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 637, in forward
        x = self.selector(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
        return sparsemax(input, self.dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
        tau = input_cumsum.gather(dim, support_size - 1)
    RuntimeError: index -1 is out of bounds for dimension 1 with size 17
    Experiment has terminated.
    
    opened by hengzhe-zhang 2
  • Is there any training example about tabtransformer?

    Is there any training example about tabtransformer?

    Hi, I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

    opened by pancodex 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
ObjectDetNet is an easy, flexible, open-source object detection framework

Getting started with the ObjectDetNet ObjectDetNet is an easy, flexible, open-source object detection framework which allows you to easily train, resu

5 Aug 25, 2020
Automatic detection and classification of Covid severity degree in LUS (lung ultrasound) scans

Final-Project Final project in the Technion, Biomedical faculty, by Mor Ventura, Dekel Brav & Omri Magen. Subproject 1: Automatic Detection of LUS Cha

Mor Ventura 1 Dec 18, 2021
K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation in E-Commerce (EMNLP Founding 2021)

Introduction K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation in E-Commerce. Installation PyTor

Xu Song 21 Nov 16, 2022
code for paper "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?"

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search? Code for paper: Does Unsupervised Architecture Representation

39 Dec 17, 2022
An Open-Source Tool for Automatic Disease Diagnosis..

OpenMedicalChatbox An Open-Source Package for Automatic Disease Diagnosis. Overview Due to the lack of open source for existing RL-base automated diag

8 Nov 08, 2022
This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

This project aim to create multi-label classification annotation tool to boost annotation speed and make it more easier.

4 Aug 02, 2022
An OpenAI-Gym Package for Training and Testing Reinforcement Learning algorithms with OpenSim Models

Authors: Utkarsh A. Mishra and Dr. Dimitar Stanev Advisors: Dr. Dimitar Stanev and Prof. Auke Ijspeert, Biorobotics Laboratory (BioRob), EPFL Video Pl

Utkarsh Mishra 16 Dec 13, 2022
Stochastic Scene-Aware Motion Prediction

Stochastic Scene-Aware Motion Prediction [Project Page] [Paper] Description This repository contains the training code for MotionNet and GoalNet of SA

Mohamed Hassan 31 Dec 09, 2022
Code accompanying the paper "ProxyFL: Decentralized Federated Learning through Proxy Model Sharing"

ProxyFL Code accompanying the paper "ProxyFL: Decentralized Federated Learning through Proxy Model Sharing" Authors: Shivam Kalra*, Junfeng Wen*, Jess

Layer6 Labs 14 Dec 06, 2022
working repo for my xumx-sliCQ submissions to the ISMIR 2021 MDX

Music Demixing Challenge - xumx-sliCQ This repository is the GitHub mirror of my working submission repository for the AICrowd ISMIR 2021 Music Demixi

4 Aug 25, 2021
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Code for Temporally Abstract Partial Models

Code for Temporally Abstract Partial Models Accompanies the code for the experimental section of the paper: Temporally Abstract Partial Models, Khetar

DeepMind 19 Jul 13, 2022
Roach: End-to-End Urban Driving by Imitating a Reinforcement Learning Coach

CARLA-Roach This is the official code release of the paper End-to-End Urban Driving by Imitating a Reinforcement Learning Coach by Zhejun Zhang, Alexa

Zhejun Zhang 118 Dec 28, 2022
Cross-Document Coreference Resolution

Cross-Document Coreference Resolution This repository contains code and models for end-to-end cross-document coreference resolution, as decribed in ou

Arie Cattan 29 Nov 28, 2022
A library of extension and helper modules for Python's data analysis and machine learning libraries.

Mlxtend (machine learning extensions) is a Python library of useful tools for the day-to-day data science tasks. Sebastian Raschka 2014-2020 Links Doc

Sebastian Raschka 4.2k Jan 02, 2023
Starter kit for getting started in the Music Demixing Challenge.

Music Demixing Challenge - Starter Kit 👉 Challenge page This repository is the Music Demixing Challenge Submission template and Starter kit! Clone th

AIcrowd 106 Dec 20, 2022
JFB: Jacobian-Free Backpropagation for Implicit Models

JFB: Jacobian-Free Backpropagation for Implicit Models

Typal Research 28 Dec 11, 2022
Nsdf: A mesh SDF with just some code we can directly paste into our raymarcher

nsdf Representing SDFs of arbitrary meshes has been a bit tricky so far. Express

Jan Ivanecky 5 Feb 18, 2022
PyTorch implementation of adversarial patch

adversarial-patch PyTorch implementation of adversarial patch This is an implementation of the Adversarial Patch paper. Not official and likely to hav

Jamie Hayes 172 Nov 29, 2022