/nvib

Primary LanguagePython

Nonparametric Variational Information Bottleneck (NVIB)

[Paper]

figures/nvib_denoising.png

The NVIB Python package containing the NVIB layer, KL divergence loss functions and the Denoising attention module. This is the package for the paper A Variational AutoEncoder for Transformers with Nonparametric Variational Information Bottleneck

Please cite the original authors for their work in any publication(s) that uses this work:

@inproceedings{henderson23_nvib,
author    = {James Henderson and Fabio Fehr},
title     = {{A VAE for Transformers with Nonparametric Variational Information Bottleneck}},
year      = {2023},
booktitle = {International Conference on Learning Representations},
url={https://openreview.net/forum?id=6QkjC_cs03X}
}

Description

The NVIB project containing the NVIB layer, KL divergence loss functions and the Denoising attention module.

Requirements

  • Python 3.9
  • PyTorch 1.10.0
  • math

Installation

Clone this repository. Activate your environment and install this package locally into your environment:

git clone https://gitlab.idiap.ch/ffehr/nvib.git
pip install nvib/.

Project status

Development is ongoing and soon to have implementations for:

  • Denoising attention for multihead attention
  • Implicit reparamerisation gradients
  • KL divergence functions are methods of the NVIB layer class
  • Initialisations
  • Update to Pytorch.1.13.0

Python Usage

Import the package and its components

from nvib.nvib_layer import Nvib
from nvib.kl import kl_gaussian, kl_dirichlet
from nvib.denoising_attention import DenoisingMultiheadAttention

For running the following examples:

# For examples
import torch
import torch.nn as nn

Ns, Nt, B, H = 10, 6, 2, 512
number_samples = 3
encoder_output = torch.rand(Ns,B,H)
src_key_padding_mask = torch.zeros((B,Ns),dtype=bool)
tgt = torch.rand(Nt,B,H)
tgt_key_padding_mask = torch.zeros((B,Nt),dtype=bool)
memory_key_padding_mask = torch.zeros((number_samples,Ns),dtype=bool)
device = "cpu"

Nonparametric Variational Information Bottleneck

Initialise the NVIB layer (Source length = N_s, embedding size = H, Batch size = B).

  • size_in The embedding size input
  • size_out The embedding size output (typically the same)
  • prior_mu Prior for Gaussian means \mu^p
  • prior_var Prior for Gaussian variance (\sigma^2)^p
  • prior_alpha Prior for Dirichlet psuedo-counts \alpha_0^p
  • delta Conditional prior \alpha^\Delta - Proportion of vectors you would like to retain
  • kappa Number of samples per component \kappa^\Delta

Note: The output size in training will always be (N_s+1) \times \kappa^\Delta as it includes the prior (+1) and does \kappa^\Delta samples in training. At evaluation time we only use the means and thus only N_s+1.

nvib_layer = Nvib(size_in=H,
              size_out=H,
              prior_mu=0,
              prior_var=1,
              prior_alpha=1,
              delta=1,
              kappa=1)

Run the forward of the layer with encoder_output size (N_s, B, H) and boolean mask size (B, N_s) where True masks the token.

latent_dict = nvib_layer(encoder_output, src_key_padding_mask)

The dictionary returned is of the form:

{z,pi,memory_key_padding_mask,mu,logvar,alpha}

where z is a tuple containing (z, pi, mu, logvar) variables. This tuple is what is passed to the DenoisingMultiheadAttention forward function such that it may access the parameters.

  • The z within the tuple is the Gaussian component vectors. ((N_s+1) \times \kappa^\Delta, B, H)
  • alpha is the psuedo-counts. ((N_s+1) \times \kappa^\Delta, B, 1)
  • pi is the Dirichlet probability reparameterised from psuedo-counts ((N_s+1) \times \kappa^\Delta, B, 1)
  • mu is the means of the Gaussian components. ((N_s+1) \times \kappa^\Delta, B, H)
  • logvar is the logged variance of the Gaussian components. ((N_s+1) \times \kappa^\Delta, B, H)
  • memory_key_padding_mask is the encoders boolean attention mask. (B, (N_s+1) \times \kappa^\Delta)
  • avg_num_vec is the number of non-zero psuedo-counts averaged over the batch (used for logging)
  • avg_prop_vec is the proportion of non-zero psuedo-counts averaged over the batch (used for logging)
  • avg_alpha0 is the sum of psuedo-counts used averaged over the batch (used for logging)

sampling can be done as follows with integer number_samples (seen as a batch size) and boolean mask size (B, N_s) where True masks the token. This mask is made with N_s being the largest size you wish to sample and lengths can predetermined by the user.

z = nvib_layer.sample(number_samples, memory_key_padding_mask, device)

Denoising Attention

This duplicates and augments the multi_head_attention_forward function and multi_head_attention class from Pytorch.

Initialise the Transformer decoder: Note: nhead = 1

decoder_layer = nn.TransformerDecoderLayer(d_model=H,
                                        dim_feedforward=4*H,
                                        nhead=1,
                                        dropout=0.1)

transformer_decoder = nn.TransformerDecoder(decoder_layer,
                                            num_layers=1)

Set each layer which interfaces encoder and decoder to Denoising Attention:

for layer_num, layer in enumerate(transformer_decoder.layers):
    layer.multihead_attn = DenoisingMultiheadAttention(embed_dim=H,
                                                    num_heads=1,
                                                    dropout=0.1,
                                                    bias=False)

Now the forward for this decoder: Note: It assumes keys and values from the encoder output are a tuple (z, pi, mu, logvar) where the z within the tuple was the original input.

output = transformer_decoder(tgt=tgt,
                            memory=latent_dict["z"],
                            tgt_key_padding_mask=tgt_key_padding_mask,
                            memory_key_padding_mask=latent_dict["memory_key_padding_mask"])

KL functions

Simple implementation for KL divergence between univariate Gaussians tensors augmented with weights from our psuedo-counts \alpha (see paper for more details). Note: Remember to set the priors here.

kl_g = kl_gaussian(**latent_dict, prior_mu=0, prior_var=1, kappa=1)

where mu, logvar, alpha and the memory_key_padding_mask come from NVIB layer latent dict and priors and number of samples \kappa^\Delta are set. The output is a KL loss of dimension (B).

The KL divergence between Dirichlet components (see paper for more details).

kl_d = kl_dirichlet(**latent_dict, prior_alpha=1, delta=1, kappa=1)

where alpha and the memory_key_padding_mask come from NVIB layer latent dict and priors and number of samples \kappa^\Delta are set. The output is a KL loss of dimension (B).

Repository Structure

.
├── nvib
│   ├── __init__.py
│   ├── denoising_attention.py
│   ├── kl.py
│   └── nvib_layer.py
├── README.rst
└── setup.py

Contact

For questions or reporting issues to this software package, kindly contact the second author.