/pygln

Python implementation of GLN in different frameworks

Primary LanguagePython

PyGLN: Gated Linear Network implementations for NumPy, PyTorch, TensorFlow and JAX

Implementations of Gated Linear Networks (GLNs), a new family of neural networks introduced by DeepMind in a recent paper, using various frameworks: NumPy, PyTorch, TensorFlow and JAX.

Published under GNU GPLv3 license.

Find our blogpost on Gated Linear Networks here.

Installation

To use pygln, simply clone the repository and install the package:

git clone git@github.com:aiwabdn/pygln.git
cd pygln
pip install -e .

Usage

To get started, we provide some utility functions in pygln.utils, for instance, to obtain the MNIST dataset:

from pygln import utils

X_train, y_train, X_test, y_test = utils.get_mnist()

Since Gated Linear Networks are binary classifiers by default, let's first train a classifier for the target digit 3:

y_train_3 = (y_train == 3)
y_test_3 = (y_test == 3)

We provide a generic wrapper around all four backend implementations. Here, we use the NumPy version (see below for full list of arguments):

from pygln import GLN

model_3 = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=X_train.shape[1])

Alternatively, the various implementations can be imported directly via their respective submodule:

from pygln.numpy import GLN

model_3 = GLN(layer_sizes=[4, 4, 1], input_size=X_train.shape[1])

Next we train the model for one epoch on the dataset:

for n in range(X_train.shape[0]):
    pred = model_3.predict(X_train[n:n+1], target=y_train_3[n:n+1])

Note that GLNs are updated in an online unbatched fashion, so simply by passing each instance and corresponding binary target to model.predict(). To speed up training, it can make sense to use small batch sizes (~10).

Finally, to use the model for prediction on unknown instances, we just omit the target parameter -- this time the batched version:

import numpy as np

preds = []
batch_size = 100
for n in range(np.ceil(X_test.shape[0] / batch_size).astype(int)):
    batch = X_test[n * batch_size: (n + 1) * batch_size]
    pred = model_3.predict(batch)
    preds.append(pred)

As accuracy for the trained model we get:

import numpy as np
from sklearn.metrics import accuracy_score

accuracy_score(y_test_3, np.concatenate(preds, axis=0))
0.9861

As can be seen, the accuracy is already quite high, despite the fact that we only did one pass through the data.

To train a classifier for the entire MNIST dataset, we create a GLN model with 10 classes. If num_classes provided is greater than 2, our implementations implicitly create the same number of separate binary GLNs and train them simultaneously in a one-vs-all fashion:

model = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=X_train.shape[1],
            num_classes=10)

for n in range(X_train.shape[0]):
    model.predict(X_train[n:n+1], target=y_train[n:n+1])

preds = []
for n in range(X_test.shape[0]):
    preds.append(model.predict(X_test[n]))

accuracy_score(y_test, np.vstack(preds))
0.9409

We provide utils.evaluate_mnist to run experiments on the MNIST dataset. For instance, to train a GLN as a binary classifier for a particular digit with batches of 4:

from pygln import utils

model_3 = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=784)

print(utils.evaluate_mnist(model_3, mnist_class=3, batch_size=4))
100%|███████████████████████████████| 15000/15000 [00:10<00:00, 1366.94it/s]
100%|█████████████████████████████████| 2500/2500 [00:01<00:00, 2195.59it/s]

98.69

And to train on all classes:

model = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=784,
            num_classes=10)

print(utils.evaluate_mnist(model, batch_size=4))
100%|████████████████████████████████| 15000/15000 [00:35<00:00, 418.21it/s]
100%|██████████████████████████████████| 2500/2500 [00:03<00:00, 764.10it/s]

94.69

GLN Interface

Constructor

GLN(backend: str,
    layer_sizes: Sequence[int],
    input_size: int,
    context_map_size: int = 4,
    num_classes: int = 2,
    base_predictor: Optional[Callable] = None,
    learning_rate: float = 1e-4,
    pred_clipping: float = 1e-3,
    weight_clipping: float = 5.0,
    bias: bool = True,
    context_bias: bool = True)

Gated Linear Network constructor.

Args:

  • backend ("jax", "numpy", "pytorch", "tf"): Which backend implementation to use.
  • layer_sizes (list[int >= 1]): List of layer output sizes.
  • input_size (int >= 1): Input vector size.
  • num_classes (int >= 2): For values >2, turns GLN into a multi-class classifier by internally creating a one-vs-all binary GLN classifier per class and return the argmax as output.
  • context_map_size (int >= 1): Context dimension, i.e. number of context halfspaces.
  • bias (bool): Whether to add a bias prediction in each layer.
  • context_bias (bool): Whether to use a random non-zero bias for context halfspace gating.
  • base_predictor (np.array[N] -> np.array[K]): If given, maps the N-dim input vector to a corresponding K-dim vector of base predictions (could be a constant prior), instead of simply using the clipped input vector itself.
  • learning_rate (float > 0.0): Update learning rate.
  • pred_clipping (0.0 < float < 0.5): Clip predictions into [p, 1 - p] at each layer.
  • weight_clipping (float > 0.0): Clip weights into [-w, w] after each update.

Predict

GLN.predict(input: np.ndarray,
            target: np.ndarray = None,
            return_probs: bool = False) -> np.ndarray

Predict the class for the given inputs, and optionally update the weights.

PyTorch implementation takes torch.Tensors (on the same device as the model) as parameters.

Args:

  • input (np.array[B, N]): Batch of B N-dim float input vectors.
  • target (np.array[B]): Optional batch of B bool/int target class labels which, if given, triggers an online update if given.
  • return_probs (bool): Whether to return the classification probability (for each one-vs-all classifier if num_classes given) instead of the class.

Returns:

  • Predicted class per input instance, or classification probabilities if return_probs set.

Cite PyGLN

@misc{pygln2020,
  author       = {Basu, Anindya and Kuhnle, Alexander},
  title        = {{PyGLN}: {G}ated {L}inear {N}etwork implementations for {NumPy}, {PyTorch}, {TensorFlow} and {JAX}},
  year         = {2020},
  url          = {https://github.com/aiwabdn/pygln}
}