
Python implementation of GLN in different frameworks

PyGLN: Gated Linear Network implementations for NumPy, PyTorch, TensorFlow and JAX

Implementations of Gated Linear Networks (GLNs), a new family of neural networks introduced by DeepMind in a recent paper, using various frameworks: NumPy, PyTorch, TensorFlow and JAX.

To use pygln, simply clone the repository and install the package:

git clone git@github.com:aiwabdn/pygln.git
cd pygln
pip install -e .


To get started, we provide some utility functions in pygln.utils, for instance, to obtain the MNIST dataset:

from pygln.utils import get_mnist

X_train, y_train, X_test, y_test = get_mnist()

Since Gated Linear Networks are binary classifiers by default, let's first train a classifier for the target digit 3:

y_train_3 = (y_train == 3)
y_test_3 = (y_test == 3)

We provide a generic wrapper around all four backend implementations. Here, we use the NumPy version (see below for full list of arguments):

from pygln import GLN

model_3 = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=X_train.shape[1])

Alternatively, the various implementations can be imported directly via their respective submodule:

from pygln.numpy import GLN

model_3 = GLN(layer_sizes=[4, 4, 1], input_size=X_train.shape[1])

Next we train the model for one epoch on the dataset:

for n in range(X_train.shape[0]):
    pred = model_3.predict(input=X_train[n:n+1], target=y_train_3[n:n+1])

Note that GLNs are updated in an online unbatched fashion, so simply by passing each instance and corresponding binary target to model.predict(). To speed up training, it can make sense to use small batch sizes (~10).

Finally, to use the model for prediction on unknown instances, we just omit the target parameter -- this time the batched version:

import numpy as np

preds = []
batch_size = 100
for n in range(np.ceil(X_test.shape[0] / batch_size).astype(int)):
    batch = X_test[n * batch_size: (n + 1) * batch_size]
    pred = model_3.predict(batch)

As accuracy for the trained model we get:

import numpy as np
from sklearn.metrics import accuracy_score

accuracy_score(y_test_3, np.concatenate(preds, axis=0))

As can be seen, the accuracy is already quite high, despite the fact that we only did one pass through the data.

To train a classifier for the entire MNIST dataset, we create a GLN model with 10 classes. If num_classes provided is greater than 2, our implementations implicitly create the same number of separate binary GLNs and train them simultaneously in a one-vs-all fashion:

model = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=X_train.shape[1],

for n in range(X_train.shape[0]):
    model.predict(input=X_train[n:n+1], target=y_train[n:n+1])

preds = []
for n in range(X_test.shape[0]):

accuracy_score(y_test, np.vstack(preds))

We provide utils.evaluate to run experiments on the MNIST dataset. For instance, to train a GLN as a binary classifier for a particular digit with batches of 4:

from pygln.utils import evaluate_mnist

model_3 = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=784)

print(evaluate_mnist(model_3, mnist_class=3, batch_size=4))
And to train on all classes:

model = GLN(backend='numpy', layer_sizes=[4, 4, 1], input_size=784,

print(evaluate_mnist(model, batch_size=4))
GLN Interface


GLN(backend: str,
    layer_sizes: Sequence[int],
    input_size: int,
    context_map_size: int = 4,
    num_classes: int = 2,
    base_predictor: Optional[Callable] = None,
    learning_rate: float = 1e-4,
    pred_clipping: float = 1e-3,
    weight_clipping: float = 5.0,
    bias: bool = True,
    context_bias: bool = True)

Gated Linear Network constructor.


  • backend ("jax", "numpy", "pytorch", "tf"): Which backend implementation to use.
  • layer_sizes (list[int >= 1]): List of layer output sizes.
  • input_size (int >= 1): Input vector size.
  • context_map_size (int >= 1): Context dimension, i.e. number of context halfspaces.
  • num_classes (int >= 2): For values >2, turns GLN into a multi-class classifier by internally creating a one-vs-all binary GLN classifier per class and return the argmax as output.
  • base_predictor (np.array[N] -> np.array[K]): If given, maps the N-dim input vector to a corresponding K-dim vector of base predictions (could be a constant prior), instead of simply using the clipped input vector itself.
  • learning_rate (float > 0.0): Update learning rate.
  • pred_clipping (0.0 < float < 0.5): Clip predictions into [p, 1 - p] at each layer.
  • weight_clipping (float > 0.0): Clip weights into [-w, w] after each update.
  • bias (bool): Whether to add a bias prediction in each layer.
  • context_bias (bool): Whether to use a random non-zero bias for context halfspace gating.


GLN.predict(input: np.ndarray,
            target: np.ndarray = None,
            return_probs: bool = False) -> np.ndarray

Predict the class for the given inputs, and optionally update the weights.

PyTorch implementation takes torch.Tensors (on the same device as the model) as parameters.


  • input (np.array[B, N]): Batch of B N-dim float input vectors.
  • target (np.array[B]): Optional batch of B bool/int target class labels which, if given, triggers an online update if given.
  • return_probs (bool): Whether to return the classification probability (for each one-vs-all classifier if num_classes given) instead of the class.


  • Predicted class per input instance, or classification probabilities if return_probs set.


We attempted to replicate the results of the paper on MNIST. The paper does not specify the model architecture exactly. With the tf model we achieved ~96% accuracy on all 10 classes in MNIST in one pass of the data. We present the results from some of our experiments here. We welcome results from other experiments or ideas on how to improve it.

  • Varying the depth and width of GLNs

  • Varying weight_clipping, pred_clipping, context_map_size of a GLN

  • Switching base predictor functions, layer bias and context bias of a GLN

Cite PyGLN

  author       = {Basu, Anindya and Kuhnle, Alexander},
  title        = {{PyGLN}: {G}ated {L}inear {N}etwork implementations for {NumPy}, {PyTorch}, {TensorFlow} and {JAX}},
  year         = {2020},
  }