/Variational_Recurrent_Neural_Network

Implemented this novel architecture (in pytorch) as part of my research study for the problem I focused on, during my work as a Research Intern at Language Technology Research Center, IIIT Hyderabad, India.

Primary LanguageJupyter Notebook

Variational_Recurrent_Neural_Networks

Implemented this novel architecture (in pytorch) as part of my research study for the problem I focused on, during my work as a Research Intern at Language Technology Research Center, IIIT Hyderabad, India.

Wrote a blog on medium: Variational Recurrent Neural Networks - VRNNs


Variational Recurrent Neural Networks are a class of latent variable models for sequential data. The major idea behind this work is the inclusion of latent random variables at every time step of the RNN, or more specifically, it contains variational autoencoder at each and every time step of the RNN.

High level structure of VRNN:

High level structure of VRNN

Detailed view of a cell at timestep t of VRNN:

Detailed view of a cell at timestep t of VRNN

Original Research Paper - A Latent Variable Model For Sequential Data

From the original research paper:
"In this paper, we explore the inclusion of latent random variables into the hidden state of a recurrent neural network (RNN) by combining the elements of the variational autoencoder. We argue that through the use of high-level latent random variables, the variational RNN (VRNN)1 can model the kind of variability observed in highly structured sequential data such as natural speech. We empirically evaluate the proposed model against other related sequential models on four speech datasets and one handwriting dataset. Our results show the important roles that latent random variables can play in the RNN dynamics."


Running codes:

For training:
Navigate to the src folder and run the command: python train.py

For sampling:
Navigate to the src folder and:

  1. For sampling from the prior distribution: python sample_from_prior.py
  2. For sampling from posterior distribution: python sample_from_posterior.py
    (uncomment the get_data() line in main function (in sample_from_posterior.py) for which you want to get posterior samples)