/tf2-fm-zoo

Python package repo for the factorization machine implementations from the tensorflow2_model_zoo

Primary LanguagePythonMIT LicenseMIT

tf2-fm-zoo

Python package for the factorization machine implementations from tensorflow2_model_zoo.

Acknowledgement

The original implementation for the methods in this repo were done by Ren Zhang who kindly granted permission to use his code for the creation of the package.

Installation

pip install tf2_fm_zoo

Basic Example

import tensorflow as tf
import numpy as np
import pandas as pd

from sklearn.preprocessing import KBinsDiscretizer
from sklearn.datasets import load_boston

from fm_zoo.fm import FactorizationMachine


X, y = load_boston(return_X_y=True)

X = X[:,:3]
y = tf.cast(y, dtype=tf.float32)

kbd = KBinsDiscretizer(n_bins=15, encode="ordinal")

nunique_vals = pd.DataFrame(X).nunique()
X = tf.cast(kbd.fit_transform(X), dtype=tf.int64)

fm = FactorizationMachine(
    feature_cards=tf.cast(nunique_vals, tf.int32), 
    factor_dim=3)

fm.compile(loss=tf.keras.losses.mean_squared_error, optimizer="Adam")
hist = fm.fit(
    X, y, 
    validation_split=0.15, 
    batch_size=16,
    epochs=100,
    callbacks=[
      tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)
    ])

pd.DataFrame(hist.history).plot(figsize=(15,10))

Supported Models

Model Reference Year
FM Factorization Machines 2010
FFM Field-aware factorization machines for CTR prediction 2016
FNN Deep Learning over Multi-field Categorical Data: A Case Study on User Response Prediction 2016
AFM Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks 2017
DeepFM DeepFM: A Factorization-Machine based Neural Network for CTR Prediction 2017
NFM Nerual Factorization Machines for Sparse Predictive Analytics 2017
xDeepFM xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems 2018
AutoInt AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks 2018
FNFM Field-aware Neural Factorization Machine for Click-Through Rate Prediction 2019