Tensorflow implementation of Generating Sentences from a Continuous Space.
- Python packages:
- Python 3.4 or higher
- Tensorflow r0.12
- Numpy
- Clone this repository:
git clone https://github.com/Chung-I/Variational-Recurrent-Autoencoder-Tensorflow.git
- Set up conda environment:
conda create -n vrae python=3.6
conda activate vrae
- Install python package requirements:
pip install -r requirements.txt
Training:
python vrae.py --model_dir models --do train --new True
Reconstruct:
python vrae.py --model_dir models --do reconstruct --new False --input input.txt --output output.txt
Sample (this script read only the first line of input.txt, generate num_pts samples, and write them into output.txt):
python vrae.py --model_dir models --do sample --new False --input input.txt --output output.txt
Interpolate (this script requires that input.txt consists of only two sentences; it generate num_pts interpolations between them, and write those interpolated sentences into output.txt)::
python vrae.py --model_dir models --do interpolate --new False --input input.txt --output output.txt
model_dir: The location of the config file config.json and the checkpoint file.
do: Accept 4 values: train, encode_decode, sample, or interpolate.
new: create models with fresh parameters if set to True; else read model parameters from checkpoints in model_dir.
Hyperparameters are not passed from command prompt like that in tensorflow/models/rnn/translate/translate.py. Instead, vrae.py reads hyperparameters from config.json in model_dir.
Below are hyperparameters in config.json:
-
model:size: embedding size, and encoder/decoder state size.latent_dim: latent space size.in_vocab_size: source vocabulary size.out_vocab_size: target vocabulary size.data_dir: path to the corpus.num_layers: number of layers for encoder and decoder.use_lstm: use lstm for encoder and decoder or not. UseBasicLSTMCellif set toTrue; elseGRUCellis used.buckets: A list of pairs of [input size, output size] for each bucket.bidirectional:bidirectional_rnnis used if set toTrue.probablistic: variance is set to zero if set toFalse.orthogonal_initializer:orthogonal_initializeris used if set toTrue; elseuniform_unit_scaling_initializeris used.iaf: inverse autoregressive flow is used if set toTrue.activation: activation for encoder-to-latent layer and latent-to-decoder layer.elu: exponential linear unit.prelu: parametric linear unit. (default)None: linear.
-
train:batch_sizebeam_size: beam size for decoding. Warning: beam search is still under implementation.NotImplementedErrorwould be raised ifbeam_sizeis set to be greater than 1.learning_rate: learning rate parameter passed intoAdamOptimizer.steps_per_checkpoint: save checkpoint everysteps_per_checkpointsteps.anneal: do KL cost annealing if set toTrue.kl_rate_rise_factor: KL term weight is increasd by this much everysteps_per_checkpointsteps.max_train_data_size: Limit on the size of training data (0: no limit).feed_previous: IfTrue, only the first of decoder_inputs will be used (the "GO" symbol), and all other decoder inputs will be generated by:next = embedding_lookup(embedding, argmax(previous_output)). In effect, this implements a greedy decoder. It can also be used during training to emulate http://arxiv.org/abs/1506.03099. IfFalse,decoder_inputsare used as given (the standard decoder case).kl_min: the minimum information constraint. Should be a non-negative float (where 0 is no constraint).max_gradient_norm: gradients will be clipped to maximally this norm.word_dropout_keep_prob: probability of randomly replacing some fraction of the conditioned-on word tokens with the generic unknown word tokenUNK. when equal to 0, the decoder sees no input.
-
reconstruct:
feed_previousword_dropout_keep_prob
-
sample:
feed_previousword_dropout_keep_probnum_pts: samplenum_ptspoints.
-
interpolate:
feed_previousword_dropout_keep_probnum_pts: samplenum_ptspoints.
Penn TreeBank corpus is included in the repo. We also provide a Chinese poem corpus, its preprocessed version (set {"model":{"data_dir": "<corpus_dir>"}} in <model_dir>/config.json to it), and its pretrained model (set model_dir to it), all of which can be found here.