This repository provides a TensorFlow 2.x implementation of Generative Adversarial Imitation Learning (GAIL) and Behavioural Cloning (BC) for the classic CartPole-v0, MountainCar-v0, and Acrobot-v1 environments from OpenAI Gym. (based on Generative Adversarial Imitation Learning, Jonathan Ho & Stefano Ermon.) It uses PPO (recommended) or SARSA to generate experts.
- python: 3.8.10
- pyglet: 1.5.21
- TensorFlow: 2.6.0
- gym: 0.21.0
- ffmpeg (for SARSA)
- CartPole-v0
- MountainCar-v0
- Acrobot-v1
For all of these environments:
- State: Continuous
- Action: Discrete
Important: Ensure that your working directory is gail-ppo-tf-gym
.
This will affect how model checkpoints and other important files are saved.
Use python3 {}.py --help
for help with any file arguments.
High performing models are considered 'experts'. The expert thresholds for each environment are as follows:
- CartPole-v0: score of >=195 in 100 consecutive rollouts of policy (from OpenAI)
- MountainCar-v0: score of >=-110 in 100 consecutive rollouts of policy (from OpenAI)
- Acrobot-v1: score of >=-100 in 100 consecutive rollouts of policy (chosen based on known algorithms' performance)
Except SARSA (which is always rendered), the environment will only be rendered if the agent passes the 'expert' threshold in at least one iteration. BC is never rendered.
BC and SARSA will always save their models when run. GAIL and PPO by default will only save their models
if the policy reaches expert performance. This can be bypassed with the --force_save_final_model
flag.
PPO or SARSA can be used to generate the expert trajectory data. First, train the expert model with run_ppo.py
or run_sarsa.py
.
python3 run_ppo.py --env CartPole-v0
python3 run_sarsa.py --env CartPole-v0
For PPO, if the algorithm prints Clear!! Model saved.
then we have passed the expert threshold for this task,
and we can continue. A CartPole-v0 expert is already saved in this repo under trained_models/ppo
.
SARSA experts for all three environments are saved under trained_models/sarsa
, and are python pickle
files which store a tuple of (theta, rnd_seq)
which parameterizes the trained SARSA model.
For SARSA, the model will always be saved.
All trained models are saved under trained_models/{name_of_algo}/model_{env_name}.ckpt
.
For example:
python3 sample_trajectory.py --env CartPole-v0 --model trained_models/ppo/model_CartPole-v0.ckpt
python3 sample_trajectory.py --env CartPole-v0 --model trained_models/sarsa/model_CartPole-v0.ckpt
This step should save the expert's (PPO policy's) observations and actions under:
- trajectory/observations_{env_name}.csv
- trajectory/actions_{env_name}.csv
GAIL will expect that these observations and actions are stored here.
Train GAIL on PPO expert:
python3 run_gail.py --env CartPole-v0 --trajectorydir trajectory/ppo
Train GAIL on SARSA expert:
python3 run_gail.py --env CartPole-v0 --trajectorydir trajectory/sarsa
Run BC on PPO expert:
python3 run_behavior_clone.py --env CartPole-v0 --trajectorydir trajectory/ppo
Run BC on SARSA expert:
python3 run_behavior_clone.py --env CartPole-v0 --trajectorydir trajectory/sarsa
Running GAIL trained on PPO:
python3 test_policy.py --env CartPole-v0 --alg gail/ppo
Running GAIL trained on SARSA:
python3 test_policy.py --env CartPole-v0 --alg gail/sarsa
For example to test BC on SARSA:
python3 test_policy.py --alg bc/sarsa --model 1000