/poe-vae

A modular implementation of product of experts VAEs for multimodal data

Primary LanguagePythonMIT LicenseMIT

PoE VAE

Product of experts variational autoencoders for multimodal data

This repo contains code for quickly and easily implementing multimodal variational autoencoders (VAEs).

Usage

$ python main.py --help

Modular Multimodal VAE Abstraction

import torch
import torch.nn as nn

from src.encoders_decoders import GatherLayer, NetworkList, SplitLinearLayer
from src.likelihoods import GroupedLikelihood, BernoulliLikelihood
from src.objectives import StandardElbo
from src.priors import StandardGaussianPrior
from src.variational_posteriors import DiagonalGaussianPosterior
from src.variational_strategies import GaussianPoeStrategy

# Make a VAE with two modalities, both with 392 dimensions, and a 20-dimensional
# latent space. The VAE is simply a collection of different pieces, with each
# piece subclassing `torch.nn.Module`.
latent_dim, m_dim = 20, 392
vae = nn.ModuleDict({
  'encoder': NetworkList(
    nn.ModuleList([
      nn.Sequential(
        nn.Linear(m_dim,200),
        nn.ReLU(),
        SplitLinearLayer(200, (latent_dim,latent_dim)),
      ),
      nn.Sequential(
        nn.Linear(m_dim,200),
        nn.ReLU(),
        SplitLinearLayer(200, (latent_dim,latent_dim)),
      ),
    ])
  ),
  'variational_strategy': GaussianPoeStrategy(),
  'variational_posterior': DiagonalGaussianPosterior(),
  'decoder': nn.Sequential(
    nn.Linear(latent_dim,200),
    nn.ReLU(),
    SplitLinearLayer(200, (m_dim,m_dim)),
    GatherLayer(),
  ),
  'likelihood': GroupedLikelihood(
    BernoulliLikelihood(),
    BernoulliLikelihood(),
  ),
  'prior': StandardGaussianPrior(),
})

# Feed the VAE to an objective. The objective determines how data is routed
# through the various VAE pieces to determine a loss. Objectives also subclass
# `torch.nn.Module`.
objective = StandardElbo(vae)

# Train the VAE like any other PyTorch model.
loader = make_dataloader(...)
optimizer = torch.optim.Adam(objective.parameters())
for epoch in range(100):
  for batch in loader:
    objective.zero_grad()
    loss = objective(batch)
    loss.backward()
    optimizer.step()

Methods Implemented

  • MVAE --variational-strategy=gaussian_poe --variational-posterior=diag_gaussian --prior=standard_gaussian --objective=mvae_elbo
  • MMVAE --variational-strategy=gaussian_moe --variational-posterior=diag_gaussian_mixture --prior=standard_gaussian --objective=mmvae_elbo
  • s-VAE (originally a single modality VAE) --variational-strategy=vmf_poe --variational-posterior=vmf_product --prior=uniform_hyperspherical --objective=elbo
  • MIWAE --unstructured-encoder=True --variational-posterior=diag_gaussian --prior=standard_gaussian --objective=elbo
  • partial VAE TO DO --variational-strategy=permutation_invariant --variational-posterior=diag_gaussian --prior=standard_gaussian --objective=elbo
  • VAEVAE?
  • MoPoE VAE?

Applying this to your own data

Check out src/datasets/ for some examples of how to do this. To use the existing training framework, you will also have to modify DATASET_MAP and MODEL_MAP in src/param_maps.py.

Dependencies

See also:

  • MVAE repo, uses a product of experts strategy for combining evidence across modalities.
  • MMVAE repo, uses a mixture of experts strategy for combining evidence across modalities.
  • Hyperspherical VAE repo, a VAE with a latent space defined on an n-sphere with von Mises-Fisher-distributed approximate posteriors.

TO DO

  1. Validation set for early stopping
  2. Implement STL gradients?
  3. Student experts?
  4. Compare network architectures w/ other papers
  5. partial-VAE implementation
  6. Add a documentation markdown file
  7. Implement jackknife variational inference?
  8. AR-ELBO for vMF
  9. Double check unstructured recognition models work
  10. Is there an easy way for Encoder and DecoderModalityEmbeddings to share parameters?
  11. Test vMF KL divergence
  12. Why is MVAE performing poorly?