/INCE

Interaction Network Contextual Embedding

Primary LanguagePythonMIT LicenseMIT

Interaction Network Contextual Embedding (INCE)

Simple implementation of INCE, the algorithm described in "Graph Neural Network Contextual Embedding for Deep Learning on Tabular Data"

Model Description

INCE is a Deep Learning (DL) model for tabular data that employs Graph Neural Networks (GNNs) and, more specifically, Interaction Networks for contextual embedding.

First an encoder model maps each tabular dataset feature into a latent vector or embedding and then a decoder model takes the embeddings and uses them to solve the supervised learning task. The encoder model is composed by two components: the columnar and the contextual embedding. The decoder model is given by a Multi-Layer Perceptron (MLP) tuned to the learning task to solve

Encoder Decoder

COLUMNAR EMBEDDING: All features (categorical and continuous) are individually projected in a common dense latent space.

Columnar Embedding

CONTEXTUAL EMBEDDING: The features obtained from columnar embedding are organized in a fully-connected graph with an extra virtual node, called CLS as in BERT. Then, a stack of Interaction Networks models the relationship among all the nodes - original features and CLS virtual node - and enhances their representation. The resulting CLS virtual node is sent into the final classifier/regressor

Contextual Embedding

Schematic workflow of Interaction Network

IN GNN

Main Results

INCE has been tested on the benchmark described in the table below:

Dataset Rows Num. Feats Cat. Feats Task
HELOC 9871 21 2 Binary
California Housing 20640 8 0 Regression
Adult Incoming 32561 6 8 Binary
Forest Cover Type 581 K 10 2 (4 + 40) Multi-Class (7)
HIGGS 11 M 27 1 Binary

and compared with the following baselines:

Standard methods: Linear Model, KNN, Decision Tree, Random Forest, XGBoost, LightGBM, CatBoost.

Deep learning models: MLP, DeepFM, DeepGBM, RLN, TabNet, VIME, TabTrasformer, NODE, Net-DNF, SAINT, FT-Transformer.

(See the paper for details and references)

The main results are summarized in table and plot shown below:

Dataset Metrics Best tree Best DL INCE
Result Model Result Model Result Rank
HELOC Accuracy ↑ 83.6 % CatBoost 82.6 % Net-DNF 84.2 ± 0.5 % 🥇 Abs.
California Housing MSE ↓ 0.195 LightGBM 0.226 SAINT 0.216 ± 0.007 🥇 DL
Adult Incoming Accuracy ↑ 87.4 % LightGBM 86.1 % DeepFM 86.8 ± 0.3 % 🥇 DL
SAINT
Forest Cover Type Accuracy ↑ 97.3 % XGBoost 96.3 % SAINT 97.1 ± 0.1 % 🥇 DL
🥈 Abs.
HIGGS Accuracy ↑ 77.6 % XGBoost 79.8 % SAINT 79.1 ± 0.0 % 🥈 DL
🥈 Abs.

Boxplot Results

How to use the code

Requirements:

numpy==1.23.5
pandas==1.5.2
scikit-learn==1.1.3
torch==1.13.0+cu117
torch-cluster==1.6.0+pt113cu117
torch-geometric==2.2.0
torch-scatter==2.1.0+pt113cu117
torch-sparse==0.6.15+pt113cu117
torch-spline-conv==1.2.1+pt113cu117
tqdm==4.64.1

Train/Test INCE on California Housing dataset:

python main.py -d ./src/datasets/json_config/california_housing.json -m ./src/models/json_config/INCE.json

Citation

If you use this codebase, please cite our work:

@article{villaizan2023graph,
    title="{Graph Neural Network contextual embedding for Deep Learning on Tabular Data}",
    author={Villaizán-Vallelado, Mario and Salvatori, Matteo and Carro Martinez, Belén and Sanchez Esguevillas, Antonio Javier},
    year={2024},
    journal={Neural Networks},
    volume={173},
    issn={0893-6080},
    doi={10.1016/j.neunet.2024.106180},
    url={https://www.sciencedirect.com/science/article/pii/S0893608024001047}
}