Simple implementation of INCE, the algorithm described in "Graph Neural Network Contextual Embedding for Deep Learning on Tabular Data"
INCE is a Deep Learning (DL) model for tabular data that employs Graph Neural Networks (GNNs) and, more specifically, Interaction Networks for contextual embedding.
First an encoder model maps each tabular dataset feature into a latent vector or embedding and then a decoder model takes the embeddings and uses them to solve the supervised learning task. The encoder model is composed by two components: the columnar and the contextual embedding. The decoder model is given by a Multi-Layer Perceptron (MLP) tuned to the learning task to solve
COLUMNAR EMBEDDING: All features (categorical and continuous) are individually projected in a common dense latent space.
CONTEXTUAL EMBEDDING: The features obtained from columnar embedding are organized in a fully-connected graph with an extra virtual node, called CLS as in BERT. Then, a stack of Interaction Networks models the relationship among all the nodes - original features and CLS virtual node - and enhances their representation. The resulting CLS virtual node is sent into the final classifier/regressor
Schematic workflow of Interaction Network
INCE has been tested on the benchmark described in the table below:
Dataset | Rows | Num. Feats | Cat. Feats | Task |
HELOC | 9871 | 21 | 2 | Binary |
California Housing | 20640 | 8 | 0 | Regression |
Adult Incoming | 32561 | 6 | 8 | Binary |
Forest Cover Type | 581 K | 10 | 2 (4 + 40) | Multi-Class (7) |
HIGGS | 11 M | 27 | 1 | Binary |
and compared with the following baselines:
Standard methods: Linear Model, KNN, Decision Tree, Random Forest, XGBoost, LightGBM, CatBoost.
Deep learning models: MLP, DeepFM, DeepGBM, RLN, TabNet, VIME, TabTrasformer, NODE, Net-DNF, SAINT, FT-Transformer.
(See the paper for details and references)
The main results are summarized in table and plot shown below:
Dataset | Metrics | Best tree | Best DL | INCE | |||
Result | Model | Result | Model | Result | Rank | ||
HELOC | Accuracy ↑ | 83.6 % | CatBoost | 82.6 % | Net-DNF | 84.2 ± 0.5 % | 🥇 Abs. |
California Housing | MSE ↓ | 0.195 | LightGBM | 0.226 | SAINT | 0.216 ± 0.007 | 🥇 DL |
Adult Incoming | Accuracy ↑ | 87.4 % | LightGBM | 86.1 % | DeepFM | 86.8 ± 0.3 % | 🥇 DL |
SAINT | |||||||
Forest Cover Type | Accuracy ↑ | 97.3 % | XGBoost | 96.3 % | SAINT | 97.1 ± 0.1 % | 🥇 DL |
🥈 Abs. | |||||||
HIGGS | Accuracy ↑ | 77.6 % | XGBoost | 79.8 % | SAINT | 79.1 ± 0.0 % | 🥈 DL |
🥈 Abs. |
Requirements:
numpy==1.23.5
pandas==1.5.2
scikit-learn==1.1.3
torch==1.13.0+cu117
torch-cluster==1.6.0+pt113cu117
torch-geometric==2.2.0
torch-scatter==2.1.0+pt113cu117
torch-sparse==0.6.15+pt113cu117
torch-spline-conv==1.2.1+pt113cu117
tqdm==4.64.1
Train/Test INCE on California Housing dataset:
python main.py -d ./src/datasets/json_config/california_housing.json -m ./src/models/json_config/INCE.json
If you use this codebase, please cite our work:
@article{villaizan2023graph,
title="{Graph Neural Network contextual embedding for Deep Learning on Tabular Data}",
author={Villaizán-Vallelado, Mario and Salvatori, Matteo and Carro Martinez, Belén and Sanchez Esguevillas, Antonio Javier},
year={2024},
journal={Neural Networks},
volume={173},
issn={0893-6080},
doi={10.1016/j.neunet.2024.106180},
url={https://www.sciencedirect.com/science/article/pii/S0893608024001047}
}