/pytorch-widedeep

A flexible package to combine tabular data with text and images using Wide and Deep models in Pytorch

Primary LanguagePythonMIT LicenseMIT

Build Status

pytorch-widedeep

A flexible package to combine tabular data with text and images using wide and deep models.

Introduction

pytorch-widedeep is based on Google's Wide and Deep Algorithm. Details of the original algorithm can be found here, and the nice research paper can be found here.

In general terms, pytorch-widedeep is a package to use deep learning with tabular data. In particular, is intended to facilitate the combination of text and images with corresponding tabular data using wide and deep models. With that in mind there are two architectures that can be implemented with just a few lines of code.

Architectures

Architecture 1:

Architecture 1 combines the Wide, one-hot encoded features with the outputs from the DeepDense, DeepText and DeepImage components connected to a final output neuron or neurons, depending on whether we are performing a binary classification or regression, or a multi-class classification. The components within the faded-pink rectangles are concatenated.

In math terms, and following the notation in the paper, Architecture 1 can be formulated as:

Where 'W' are the weight matrices applied to the wide model and to the final activations of the deep models, 'a' are these final activations, and φ(x) are the cross product transformations of the original features 'x'. In case you are wondering what are "cross product transformations", here is a quote taken directly from the paper: "For binary features, a cross-product transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if the constituent features (“gender=female” and “language=en”) are all 1, and 0 otherwise".

Architecture 2

Architecture 2 combines the Wide one-hot encoded features with the Deep components of the model connected to the output neuron(s), after the different Deep components have been themselves combined through a FC-Head (that I refer as deephead).

In math terms, and following the notation in the paper, Architecture 2 can be formulated as:

When using pytorch-widedeep, the assumption is that the so called Wide and DeepDense components in the figures are always present, while DeepText and DeepImage are optional. pytorch-widedeep includes standard text (stack of LSTMs) and image (pre-trained ResNets or stack of CNNs) models. However, the user can use any custom model as long as it has an attribute called output_dim with the size of the last layer of activations, so that WideDeep can be constructed. See the examples folder for more information.

Installation

Install using pip:

pip install pytorch-widedeep

Or install directly from github

pip install git+https://github.com/jrzaurin/pytorch-widedeep.git

Developer Install

# Clone the repository
git clone https://github.com/jrzaurin/pytorch-widedeep
cd pytorch-widedeep

# Install in dev mode
pip install -e .

Examples

There are a number of notebooks in the examples folder plus some additional files. These notebooks cover most of the utilities of this package and can also act as documentation. In the case that github does not render the notebooks, or it renders them missing some parts, they are saved as markdown files in the docs folder.

Quick start

Binary classification with the adult dataset using Wide and DeepDense and defaults settings.

import pandas as pd
from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy

# these next 3 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
df = pd.read_csv('data/adult/adult.csv.zip')
df['income_label'] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop('income', axis=1, inplace=True)

# prepare wide, crossed, embedding and continuous columns
wide_cols  = ['education', 'relationship', 'workclass', 'occupation','native-country', 'gender']
cross_cols = [('education', 'occupation'), ('native-country', 'occupation')]
embed_cols = [('education',16), ('workclass',16), ('occupation',16),('native-country',32)]
cont_cols  = ["age", "hours-per-week"]
target_col = 'income_label'

# target
target = df[target_col].values

# wide
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df)
wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)

# deepdense
preprocess_deep = DeepPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_deep = preprocess_deep.fit_transform(df)
deepdense = DeepDense(hidden_layers=[64,32],
                      deep_column_idx=preprocess_deep.deep_column_idx,
                      embed_input=preprocess_deep.embeddings_input,
                      continuous_cols=cont_cols)

# build, compile, fit and predict
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method='binary', metrics=[BinaryAccuracy])
model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=256, val_split=0.2)
model.predict(X_wide=X_wide_te, X_deep=X_deep_te)

Of course, one can do much more, such as using different initializations, optimizers or learning rate schedulers for each component of the overall model. Adding FC-Heads to the Text and Image components. Using the Focal Loss, warming up individual components before joined training, etc. See the examples or the docs folders for a better understanding of the content of the package and its functionalities.

Testing

pytest tests

Acknowledgments

This library takes from a series of other libraries, so I think it is just fair to mention them here in the README (specific mentions are also included in the code).

The Callbacks and Initializers structure and code is inspired by the torchsample library, which in itself partially inspired by Keras.

The TextProcessor class in this library uses the fastai's Tokenizer and Vocab. The code at utils.fastai_transforms is a minor adaptation of their code so it functions within this library. To my experience their Tokenizer is the best in class.

The ImageProcessor class in this library uses code from the fantastic Deep Learning for Computer Vision (DL4CV) book by Adrian Rosebrock.