/MultiVae

Unifying Multimodal Variational Autoencoders (VAEs) in Pytorch

Primary LanguagePython

logo

Python Documentation Status

This library implements some of the most common Multimodal Variational Autoencoders methods in a unifying framework for effective benchmarking and development. You can find the list of implemented models below. It includes ready to use datasets like MnistSvhn 🔢, CelebA 😎 and PolyMNIST, and the most used metrics : Coherences, Likelihoods and FID. It integrates model monitoring with Wandb and a quick way to save/load model from HuggingFaceHub🤗.

Implemented models

Model Paper Official Implementation
JMVAE Joint Multimodal Learning with Deep Generative Models link
MVAE Multimodal Generative Models for Scalable Weakly-Supervised Learning link
MMVAE Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models link
MoPoE Generalized Multimodal ELBO link
MVTCAE Multi-View Representation Learning via Total Correlation Objective link
JNF,JNF-DCCA Improving Multimodal Joint Variational Autoencoders through Normalizing Flows and Correlation Analysis link
MMVAE + MMVAE+: ENHANCING THE GENERATIVE QUALITY OF MULTIMODAL VAES WITHOUT COMPROMISES link
Nexus Leveraging hierarchy in multimodal generative models for effective cross-modality inference link

Quickstart

Install the library by running:

pip install multivae

or by cloning the repository:

git clone https://github.com/AgatheSenellart/MultiVae.git
cd MultiVae
pip install .

Cloning the repository gives you access to tutorial notebooks and scripts in the 'example' folder.

Load a dataset easily:

from multivae.data.datasets import MnistSvhn
train_set = MnistSvhn(data_path='your_data_path', split="train", download=True)

Instantiate your favorite model:

from multivae.models import MVTCAE, MVTCAEConfig
model_config = MVTCAEConfig(
    latent_dim=20, 
    input_dims = {'mnist' : (1,28,28),'svhn' : (3,32,32)})
model = MVTCAE(model_config)

Define a trainer and train the model !

from multivae.trainers import BaseTrainer, BaseTrainerConfig
training_config = BaseTrainerConfig(
    learning_rate=1e-3,
    num_epochs=30
)

trainer = BaseTrainer(
    model=model,
    train_dataset=train_set,
    training_config=training_config,
)
trainer.train()

Documentation and Examples

See https://multivae.readthedocs.io

Several examples are provided in examples/ - as well as tutorial notebooks on how to use the main features of MultiVae(training, metrics, samplers) in the folder examples/tutorial_notebooks. As an introduction to the package, see the getting_started.ipynb notebook.

Table of Contents

Installation

(Back to top)

git clone https://github.com/AgatheSenellart/MultiVae.git
cd MultiVae
pip install .

Usage

(Back to top)

Our library allows you to use any of the models with custom configuration, encoders and decoders architectures and datasets easily. See our tutorial Notebook at /examples/tutorial_notebooks/getting_started.ipynb to easily get the gist of principal features.

Contribute

(Back to top)

If you want to contribute to the project, for instance by adding models to the library: clone the repository and install it in editable mode by using the -e option

pip install -e .

In order to propose a contribution, you can follow the guidelines in CONTRIBUTING.md file. Detailed tutorials are provided on how to implement a new model, sampler, metrics or dataset.

Reproducibility statement

Most implemented models are validated by reproducing a key result of the paper.

Alt text

Citation

(Back to top)

If you have used our package in your research, please consider citing our paper presenting the package :

MultiVae : A Python library for Multimodal Generative Autoencoders (2023, Agathe Senellart, Clément Chadebec and Stéphanie Allassonnière)

Bibtex entry :

@preprint{senellart:hal-04207151,
  TITLE = {{MultiVae: A Python library for Multimodal Generative Autoencoders}},
  AUTHOR = {Senellart, Agathe and Chadebec, Clement and Allassonniere, Stephanie},
  URL = {https://hal.science/hal-04207151},
  YEAR = {2023},
}