Implementation of TabTransformer, attention network for tabular data, in Pytorch

Overview

Tab Transformer

Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.

Install

$ pip install tab-transformer-pytorch

Usage

import torch
from tab_transformer_pytorch import TabTransformer

cont_mean_std = torch.randn(10, 2)

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)

Unsupervised Training

To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.

Citations

@misc{huang2020tabtransformer,
    title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings}, 
    author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
    year={2020},
    eprint={2012.06678},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Comments
  • Minor Bug: actuation function being applied to output layer in class MLP

    Minor Bug: actuation function being applied to output layer in class MLP

    The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 1)
    

    The last line should be changed to is_last = ind >= (len(dims) - 2):

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 2)
    

    If you like, I can do a pull request.

    opened by rminhas 1
  • Update tab_transformer_pytorch.py

    Update tab_transformer_pytorch.py

    Add activation function out of the loop for the whole model, not after each of the linear layers. 'if is_last' condition was creating linear output all the time no matter what the activation function was.

    opened by EveryoneDirn 0
  • Unindent continuous_mean_std buffer

    Unindent continuous_mean_std buffer

    Problem: continuous_mean_std is not an attribute of TabTransformer if not defined in the argument explicitly. Example reproducing AttributeError:

    model = TabTransformer(
        categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
        num_continuous = 10,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
        mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
        mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    # continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm)
    x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
    x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually
    pred = model(x_categ, x_cont) # gives AttributeError
    
    

    Solution: Simply un-indenting the buffer registration of continuous_mean_std.

    opened by spliew 0
  • low gpu usage,

    low gpu usage,

    Hi.

    I'm having a problem with running your code with my dataset. It's pretty slow. GPU runs at 50% usage in average and each epoch takes almost 900 seconds to run.

    My dataset has 590540 rows, 24 categorical features, and 192 continuous features. Categories are encoded using Label encoder. Total dataset size is around 600Mb. My gpu is an integrated NVIDIA RTX 3060 with 6Gb of RAM. Optimizer is Adam.

    These are the software versions:

    Windows 10

    Python: 3.7.11 Pytorch: 1.7.0+cu110 Numpy: 1.21.2

    Let me know if you need more info from my side.

    Thanks.

    Xin.

    opened by xinqiao123 0
  • Intended usage of num_special_tokens?

    Intended usage of num_special_tokens?

    From what I understand, these are supposed to be reserved for oov values. Is the intended usage to set oov values in the input to some negative number and overwrite the offset? That is what it seems like it would take to achieve the desired outcome, but also seems somewhat confusing and clunky to do. Or perhaps I am misunderstanding its purpose? Thanks!

    opened by LLYX 2
  • No Category Shared Embedding?

    No Category Shared Embedding?

    I noticed that this implementation does not seem to have the feature of a shared embedding between each value belonging to the same category (unless I missed it) that the paper mentions (c_phi_i). If it's indeed missing, do you have plans to add that?

    Thanks for this implementation!

    opened by LLYX 3
  • index -1 is out of bounds for dimension 1 with size 17

    index -1 is out of bounds for dimension 1 with size 17

    I encountered this problem during the training process. What is the possible reason for this problem, and how can I solve this problem? Thanks!

      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
        return self.tabnet(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
        steps_output, M_loss = self.encoder(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 160, in forward
        M = self.att_transformers[step](prior, att)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 637, in forward
        x = self.selector(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
        return sparsemax(input, self.dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
        tau = input_cumsum.gather(dim, support_size - 1)
    RuntimeError: index -1 is out of bounds for dimension 1 with size 17
    Experiment has terminated.
    
    opened by hengzhe-zhang 2
  • Is there any training example about tabtransformer?

    Is there any training example about tabtransformer?

    Hi, I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

    opened by pancodex 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
Complete* list of autonomous driving related datasets

AD Datasets Complete* and curated list of autonomous driving related datasets Contributing Contributions are very welcome! To add or update a dataset:

Daniel Bogdoll 13 Dec 19, 2022
Architecture Patterns with Python (TDD, DDD, EDM)

architecture-traning Architecture Patterns with Python (TDD, DDD, EDM) Chapter 5. 높은 기어비와 낮은 기어비의 TDD 5.2 도메인 계층 테스트를 서비스 계층으로 옮겨야 하는가? 도메인 계층 테스트 def

minsung sim 2 Mar 04, 2022
DABO: Data Augmentation with Bilevel Optimization

DABO: Data Augmentation with Bilevel Optimization [Paper] The goal is to automatically learn an efficient data augmentation regime for image classific

ElementAI 24 Aug 12, 2022
Multi-Task Deep Neural Networks for Natural Language Understanding

New Release We released Adversarial training for both LM pre-training/finetuning and f-divergence. Large-scale Adversarial training for LMs: ALUM code

Xiaodong 2.1k Dec 30, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling For Official repo of NU-Wave: A Diffusion Probabilistic Model for Neural Audio Up

Rishikesh (ऋषिकेश) 38 Oct 11, 2022
Image Segmentation using U-Net, U-Net with skip connections and M-Net architectures

Brain-Image-Segmentation Segmentation of brain tissues in MRI image has a number of applications in diagnosis, surgical planning, and treatment of bra

Angad Bajwa 8 Oct 27, 2022
A PyTorch implementation of the baseline method in Panoptic Narrative Grounding (ICCV 2021 Oral)

A PyTorch implementation of the baseline method in Panoptic Narrative Grounding (ICCV 2021 Oral)

Biomedical Computer Vision @ Uniandes 52 Dec 19, 2022
PyTorch implementation of Rethinking Positional Encoding in Language Pre-training

TUPE PyTorch implementation of Rethinking Positional Encoding in Language Pre-training. Quickstart Clone this repository. git clone https://github.com

Jake Tae 5 Jan 27, 2022
Deep Reinforcement Learning for Keras.

Deep Reinforcement Learning for Keras What is it? keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seaml

Keras-RL 0 Dec 15, 2022
GAN-based 3D human pose estimation model for 3DV'17 paper

Tensorflow implementation for 3DV 2017 conference paper "Adversarially Parameterized Optimization for 3D Human Pose Estimation". @inproceedings{jack20

Dominic Jack 15 Feb 27, 2021
Implementation of 'X-Linear Attention Networks for Image Captioning' [CVPR 2020]

Introduction This repository is for X-Linear Attention Networks for Image Captioning (CVPR 2020). The original paper can be found here. Please cite wi

JDAI-CV 240 Dec 17, 2022
sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

sssegmentation is a general framework for our research on strongly supervised semantic segmentation.

445 Jan 02, 2023
Plenoxels: Radiance Fields without Neural Networks

Plenoxels: Radiance Fields without Neural Networks Alex Yu*, Sara Fridovich-Keil*, Matthew Tancik, Qinhong Chen, Benjamin Recht, Angjoo Kanazawa UC Be

Sara Fridovich-Keil 81 Dec 25, 2022
CountDown to New Year and shoot fireworks

CountDown and Shoot Fireworks About App This is an small application make you re

5 Dec 31, 2022
A Benchmark For Measuring Systematic Generalization of Multi-Hierarchical Reasoning

Orchard Dataset This repository contains the code used for generating the Orchard Dataset, as seen in the Multi-Hierarchical Reasoning in Sequences: S

Bill Pung 1 Jun 05, 2022
Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships.

feature-set-comp Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships. Reposito

Trent Henderson 7 May 25, 2022
Instance Semantic Segmentation List

Instance Semantic Segmentation List This repository contains lists of state-or-art instance semantic segmentation works. Papers and resources are list

bighead 87 Mar 06, 2022
RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

RMNet: Equivalently Removing Residual Connection from Networks This repository is the official implementation of "RMNet: Equivalently Removing Residua

184 Jan 04, 2023
DGN pymarl - Implementation of DGN on Pymarl, which could be trained by VDN or QMIX

This is the implementation of DGN on Pymarl, which could be trained by VDN or QM

4 Nov 23, 2022