
Primary LanguagePython

Nonparametric Variational Information Bottleneck (NVIB)



The NVIB Python package containing the NVIB layer, KL divergence loss functions and the Denoising attention module. This is the package for the paper A Variational AutoEncoder for Transformers with Nonparametric Variational Information Bottleneck

Please cite the original authors for their work in any publication(s) that uses this work:

author    = {James Henderson and Fabio Fehr},
title     = {{A VAE for Transformers with Nonparametric Variational Information Bottleneck}},
year      = {2023},
booktitle = {International Conference on Learning Representations},


The NVIB project containing the NVIB layer, KL divergence loss functions and the Denoising attention module.


  • Python 3.9
  • PyTorch 1.10.0
  • math


Clone this repository. Activate your environment and install this package locally into your environment:

git clone https://gitlab.idiap.ch/ffehr/nvib.git
pip install nvib/.

Project status

Development is ongoing and soon to have implementations for:

  • Denoising attention for multihead attention
  • Implicit reparamerisation gradients
  • KL divergence functions are methods of the NVIB layer class
  • Initialisations
  • Update to Pytorch.1.13.0

Python Usage

Import the package and its components

from nvib.nvib_layer import Nvib
from nvib.kl import kl_gaussian, kl_dirichlet
from nvib.denoising_attention import DenoisingMultiheadAttention

For running the following examples:

# For examples
import torch
import torch.nn as nn

Ns, Nt, B, H = 10, 6, 2, 512
number_samples = 3
encoder_output = torch.rand(Ns,B,H)
src_key_padding_mask = torch.zeros((B,Ns),dtype=bool)
tgt = torch.rand(Nt,B,H)
tgt_key_padding_mask = torch.zeros((B,Nt),dtype=bool)
memory_key_padding_mask = torch.zeros((number_samples,Ns),dtype=bool)
device = "cpu"

Nonparametric Variational Information Bottleneck

Initialise the NVIB layer (Source length = N_s, embedding size = H, Batch size = B).

  • size_in The embedding size input
  • size_out The embedding size output (typically the same)
  • prior_mu Prior for Gaussian means \mu^p
  • prior_var Prior for Gaussian variance (\sigma^2)^p
  • prior_alpha Prior for Dirichlet psuedo-counts \alpha_0^p
  • delta Conditional prior \alpha^\Delta - Proportion of vectors you would like to retain
  • kappa Number of samples per component \kappa^\Delta

Note: The output size in training will always be (N_s+1) \times \kappa^\Delta as it includes the prior (+1) and does \kappa^\Delta samples in training. At evaluation time we only use the means and thus only N_s+1.

nvib_layer = Nvib(size_in=H,

Run the forward of the layer with encoder_output size (N_s, B, H) and boolean mask size (B, N_s) where True masks the token.

latent_dict = nvib_layer(encoder_output, src_key_padding_mask)

The dictionary returned is of the form:


where z is a tuple containing (z, pi, mu, logvar) variables. This tuple is what is passed to the DenoisingMultiheadAttention forward function such that it may access the parameters.

  • The z within the tuple is the Gaussian component vectors. ((N_s+1) \times \kappa^\Delta, B, H)
  • alpha is the psuedo-counts. ((N_s+1) \times \kappa^\Delta, B, 1)
  • pi is the Dirichlet probability reparameterised from psuedo-counts ((N_s+1) \times \kappa^\Delta, B, 1)
  • mu is the means of the Gaussian components. ((N_s+1) \times \kappa^\Delta, B, H)
  • logvar is the logged variance of the Gaussian components. ((N_s+1) \times \kappa^\Delta, B, H)
  • memory_key_padding_mask is the encoders boolean attention mask. (B, (N_s+1) \times \kappa^\Delta)
  • avg_num_vec is the number of non-zero psuedo-counts averaged over the batch (used for logging)
  • avg_prop_vec is the proportion of non-zero psuedo-counts averaged over the batch (used for logging)
  • avg_alpha0 is the sum of psuedo-counts used averaged over the batch (used for logging)

sampling can be done as follows with integer number_samples (seen as a batch size) and boolean mask size (B, N_s) where True masks the token. This mask is made with N_s being the largest size you wish to sample and lengths can predetermined by the user.

z = nvib_layer.sample(number_samples, memory_key_padding_mask, device)

Denoising Attention

This duplicates and augments the multi_head_attention_forward function and multi_head_attention class from Pytorch.

Initialise the Transformer decoder: Note: nhead = 1

decoder_layer = nn.TransformerDecoderLayer(d_model=H,

transformer_decoder = nn.TransformerDecoder(decoder_layer,

Set each layer which interfaces encoder and decoder to Denoising Attention:

for layer_num, layer in enumerate(transformer_decoder.layers):
    layer.multihead_attn = DenoisingMultiheadAttention(embed_dim=H,

Now the forward for this decoder: Note: It assumes keys and values from the encoder output are a tuple (z, pi, mu, logvar) where the z within the tuple was the original input.

output = transformer_decoder(tgt=tgt,

KL functions

Simple implementation for KL divergence between univariate Gaussians tensors augmented with weights from our psuedo-counts \alpha (see paper for more details). Note: Remember to set the priors here.

kl_g = kl_gaussian(**latent_dict, prior_mu=0, prior_var=1, kappa=1)

where mu, logvar, alpha and the memory_key_padding_mask come from NVIB layer latent dict and priors and number of samples \kappa^\Delta are set. The output is a KL loss of dimension (B).

The KL divergence between Dirichlet components (see paper for more details).

kl_d = kl_dirichlet(**latent_dict, prior_alpha=1, delta=1, kappa=1)

where alpha and the memory_key_padding_mask come from NVIB layer latent dict and priors and number of samples \kappa^\Delta are set. The output is a KL loss of dimension (B).

Repository Structure

├── nvib
│   ├── __init__.py
│   ├── denoising_attention.py
│   ├── kl.py
│   └── nvib_layer.py
├── README.rst
└── setup.py


For questions or reporting issues to this software package, kindly contact the second author.