/emb-gam

An interpretable and efficient predictor using pre-trained language models. Scikit-learn compatible.

Primary LanguageJupyter Notebook

Emb-GAM

Interpretable linear model that leverages a pre-trained language model to better learn interactions. One-line fit function.

📚 sklearn-friendly api📖 demo notebook

Official code for using / reproducing Emb-GAM from the paper "Emb-GAM: an interpretable and efficient predictor using pre-trained language models" (singh & gao, 2022). Emb-GAM uses a pre-trained language model to extract features from text data then combines them in order to extract out a simple, linear model.

Quickstart

Installation: The best way to use Emb-GAM is through the imodelsx package: pip install imodelsx

  • For finer control, you can instead clone and install this repo from source

Usage example (see api or demo notebook for more details):

from imodelsx import EmbGAMClassifier
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = EmbGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Docs

Abstract: Deep learning models have achieved impressive prediction performance but often sacrifice interpretability, a critical consideration in high-stakes domains such as healthcare or policymaking. In contrast, generalized additive models (GAMs) can maintain interpretability but often suffer from poor prediction performance due to their inability to effectively capture feature interactions. In this work, we aim to bridge this gap by using pre-trained large-language models to extract embeddings for each input before learning a linear model in the embedding space. The final model (which we call Emb-GAM) is a transparent, linear function of its input features and feature interactions. Leveraging the language model allows \methods to learn far fewer linear coefficients, model larger interactions, and generalize well to novel inputs (e.g. unseen ngrams in text). Across a variety of natural-language-processing datasets, Emb-GAM achieves strong prediction performance without sacrificing interpretability.
  • the main api requires simply importing embgam.EmbGAMClassifier or embgam.EmbGAMRegressor
  • the experiments and scripts folder contains hyperparameters for running sweeps contained in the paper
  • the notebooks folder contains notebooks for analyzing the outputs + making figures
  • stored outputs after running all experiments are available in this gdrive folder

Related work

  • imodelsX package (github) - interpretability for text datasets
  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning

If this work is useful for you, please cite the following!

@article{singh2022embgam,
  title = {Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
  author = {Singh, Chandan and Gao, Jianfeng},
  journal={arXiv preprint arXiv:2209.11799},
  doi = {10.48550/arxiv.2209.11799},
  url = {https://arxiv.org/abs/2209.11799},
  year = {2022},
}