/RAS

AISTATS 2019: Reference-based Adversarial Sampling & Its applications to Soft Q-learning

Primary LanguageJupyter Notebook

Adversarial Learning of a Sampler Based on an Unnormalized Distribution

The RAS (Referenced-based Adversarial Sampling) algorithm is proposed to enable adversarial learning applicable to general unnormalized distribution sampling, with demonstrations on constrained domain sampling and soft Q-learning. This repository contains source code to reproduce the results presented in the paper Adversarial Learning of a Sampler Based on an Unnormalized Distribution (AISTATS 2019):

@inproceedings{Li_RAS_2019_AISTATS,
  title={Adversarial Learning of a Sampler Based on an Unnormalized Distribution},
  author={Chunyuan Li, Ke Bai, Jianqiao Li, Guoyin Wang, Changyou Chen, Lawrence Carin},
  booktitle={AISTATS},
  year={2019}
}

Introduction

Comparison of RAS and GAN learning scenarios

Learning a neural sampler q to approximate the target distribution p, where only the latter's unnormalized form u or empirical samples p' is available, respectively.

RAS GAN
Illustration
Method We propose the “reference” p_r to bridge neural samples q and unnormalized form u, making the evaluations of both F_1 and F_2 terms feasible. Directly matching neural samples q to empirical samples p'
Setup Learning from unnormalized form u Learning from empirical samples p'
Generator
Discriminator q vs p_r q vs p'
Application to reinforcement learning Learning to take optimal actions based on Q-functions GAIL: Learning to take optimal actions based on expert sample trajectories (a.k.a. Imitation learning)

Discussion

  1. In many applications (e.g. Soft Q-learining), only u is known, from which we are inerested in drawing samples efficiently
  2. The choice of p_r has an effect on learning; It should be carefully chosen.

Contents

There are three steps to use this codebase to reproduce the results in the paper.

  1. Dependencies

  2. Experiments

    2.1. Adversarial Soft Q-learning

    2.2. Constrained Domain Sampling

    2.3. Entropy Regularization

  3. Reproduce paper figure results

Dependencies

This code is based on Python 2.7, with the main dependencies being TensorFlow==1.7.0. Additional dependencies for running experiments are: numpy, cPickle, scipy, math, gensim.

Adversarial Soft Q-learning

We consider the following environments: Hopper, Half-cheetah, Ant, Walker, Swimmer and Humanoid. All soft q-learning code is at sql:

To run:

python mujoco_all_sql.py --env Hopper

It takes the following options (among others) as arguments:

  • --env It specifies the MuJoCo/rllab environment; default Hopper.
  • --log_dir Address to save the training log.
  • For other arguements, please refer to the github repo soft-q-learning

Other related hyper-parameters setting are located in sql/examples/mujoco_all_sql.py. The default reference distribution is Beta distribution. The reference distribution option supports "beta" (Beta distribution) and "norm" (Gaussian distribution).

Swimmer (rllab) Humanoid (rllab) Hopper-v1 Half-cheetah-v1 Ant-v1 Walker-v1

Note: Humanoid has a higher action space dimension, making adversarial learning instable; More future work is needed to make Humanoid run better.

Constrained Domain Sampling

To show that RAS can draw samples when the support is bounded, we apply it to sample from the distributions with the support [c1,c2]. Please see the code at constrained_sampling.

RAS: Beta ref. RAS: Gaussian ref. SVGD Amortized SVGD

Please note that RAS Gaussian ref. recovers AVB-AC (Adversarial Variational Bayes with Adaptive Contrast).

Entropy Regularization

An entropy term H(x) is approximated to stablize adversarial training. As examples, we consider to regularize the following GAN variants: GAN, SN-GAN, D2GAN and Unrolled-GAN. All entropy-regularization code is at entropy:

To run:

python run_test.py --model gan_cc

It takes the following options (among others) as arguments:

  • The --model specifies the GAN variant to apply the entropy regularizer. It supports [gan,d2gan,ALLgan,SNgan]; default gan. To apply entropy regularizer, change the argument of --model as [gan_cc,d2gan_cc,ALLgan_cc,SNgan_cc]
Entropy regularizer on 8-GMM toy dataset SN-GAN SN-GAN + Entropy

Reproduce paper figure results

Jupyter notebooks in plots folders are used to reproduce paper figure results.

Note that without modification, we have copyed our extracted results into the notebook, and script will output figures in the paper. If you've run your own training and wish to plot results, you'll have to organize your results in the same format instead.

Questions?

Please drop us (Chunyuan, Ke, Jianqiao or Guoyin) a line if you have any questions.