/Gaussian-Mixture-VAE

[Pytorch] Minimal implementation of a Variational Autoencoder (VAE) with Categorical Latent variables inspired from "Categorical Reparameterization with Gumbel-Softmax".

Primary LanguagePython

Gaussian-Mixture-VAE

This repository contains a minimal PyTorch implementation of a Variational Autoencoder with Categorical Latent variables.

Model

We consider data distributed as a time series $(x)_t$ whose transitions are described by an auto-regressive mixture of $K$ Gaussians. The generative process can be described hierarchically as follows

$$z_t \sim Categorical(p_1, p_2, \dots, p_K)$$ $$x_t | x_{t-1} \sim Normal \left( \Pi_{z_t}x_{t-1}, \Sigma \right)$$

Intuitively, at time $t$ the series evolves as a linear autoregressive model according to one of $K$ choices of parameter matrix $(\Pi)_k$.

Variational inference

We leverage Variational Inference to estimate the parameters $(p_1, p_2, \dots, p_K), (\Pi_1, \Pi_2, \dots, \Pi_k), \Sigma$ of the above model.

Specifically, this is achieved by maximizing the variational lower bound (ELBO) on the log-likelihood of the observed data.

The resulting optimization problem takes the form

$$\max_{\theta, \phi} \mathbb E_{q_{\theta}(z_t|x_t, x_{t-1})} \left[ \log p_{\theta}(x_t | x_{t-1}, z_t) + \log p_{\theta}(z_t) - \log q_{\phi} (z_t | x_t, x_{t-1}) \right],$$

where $p_{\theta}$ is as described in the generative model and $q_{\phi}$ is a variational distribution parametrized by a feed-forward neural network.

Gumbel approximation

The main challenges in using discrete latents in VAEs is the inherent non-differentiability of the resulting PMF of the variational distribution $q_{\phi}$. This makes it unfeasible to optimize the objective via gradient-based methods.

We solve this issue by approximating the one-hot encoded vectors $z_t$ by a vector drawn from a Gumbel distribution. The temperature of the Gumbel distribution is initialized at an arbitrary value and slowly annealed to 0, making the approximation progressively more accurate.

References