This repository contains the models and experiments that are used in the article https://interpreting-rl-behavior.github.io/. The code was originally forked from https://github.com/jbkjr/train-procgen-pytorch which contains code to run the Procgen environments (https://openai.com/blog/procgen-benchmark/) and the PPO agent we interpret in this work.
This README provides instructions for how to replicate the results in our paper.
Overview of steps:
- Train agent on procgen task
- Record dataset of real agent-environment rollouts
- Train generative model on recorded dataset of real agent-environment rollouts
- Run analyses on recorded dataset of real agent-environment rollouts
- Record dataset of simulated agent-environment rollouts from the generative model
- Run analyses on the recorded simulated rollouts.
- Analysis of the prediction quality over time
All scripts should be run from the root dir.
First install the package locally, and use the following command to install in editable mode:
pip install -e .
To train the agent on coinrun:
python train.py --exp_name [agent_training_experiment_name] --env_name coinrun --param_name hard-rec --num_levels 1000000 --distribution_mode hard --num_timesteps 200000000 --num_checkpoints 500
This will save training data and a model in a directory in
logs/procgen/coinrun/[your_experiment_name]/
Each training run has a unique seed, so each seed gets its own directory in the above folder like so:
logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]
Then to plot the training curve for that training run:
python plot_training_csv.py --datapath="logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]"
You can render your trained agent to see what its behaviour looks like:
python render.py --exp_name=[agent_rendering_experiment_name] --env_name="coinrun" --distribution_mode="hard" --param_name="hard-local-dev-rec" --device="cpu" --model_file="logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]/[agent_name].pth"
Assuming your agent is behaviour as you'd like it to, now we can start interpreting it.
To begin interpretation, we need to record a bunch of agent-environment rollouts in order to train the generative model:
python record.py --exp_name [recording_experiment_name] --env_name coinrun --param_name hard-rec --num_levels 1000000 --distribution_mode hard --num_checkpoints 200 --model_file="logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]/[agent_name].pth" --logdir="[path_to_rollout_data_save_dir]" python record.py --model_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --logdir=./ --env_name coinrun --param_name hard-rec-record --num_levels 1000000 --distribution_mode hard --num_checkpoints 200
Note that --logdir
should have plenty of storage space (100's of GB).
With this recorded data, we can start to train the generative model on agent-environment rollouts:
python train_gen_model.py --agent_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --gen_mod_exp_name=dev --model_file="generative/results/rssm53_largepos_sim_penalty_extraconverterlayers/20220106_181406/model_epoch3_batch20000.pt"
That'll take a 1-4 days to train on a single GPU. Once it's trained, we'll record some agent- environment rollouts from the model. This will enable us to compare the simulations to the true rollouts and will help us understand our generative model (which includes the agent that we want to interpret) better. This is how we record samples from the generative model:
python record_gen_samples.py --agent_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --gen_mod_exp_name=dev --model_file="generative/results/rssm53_largepos_sim_penalty_extraconverterlayers/20220106_181406/model_epoch3_batch20000.pt"
Now we're ready to start some analysis.
The generative model is a VAE, and therefore consists of an encoder and decoder. The decoder is the part we want to interpret because it simulates agent- environment rollouts. It will be informative, therefore, to get a picture of what's going on inside the latent vector of the VAE, since this is the input to the decoder.
Analysis of agent's hidden state
We'll next analyse the agent's hidden state with a few dimensionality reduction methods. First we precompute the dimensionality reduction analyses:
python analysis/hidden_analysis_precompute.py
with 10'000 episodes (not samples). Increase request for memory and compute time to cope with more episodes.
which will save the analysis data in analysis/hx_analysis_precomp/
Next we'll make some plots from the precomputed analyses of the agent's hidden states:
python analysis/hidden_analysis_plotting.py
These depict what the agent is 'thinking' during many episodes, visualised using several different dimensionality reduction and clustering methods.
Analysis of environment hidden states
python analysis/env_h_analysis_precompute.py
with 20'000 samples of len 24. Increase request for memory and compute time to cope with more samples.
then
python analysis/env_h_analysis_plotting.py
Saliency maps calculate the gradient (averaged over noised samples) of some network quantity (e.g. the agent's value function output) with respect to inputs or intermediate network activations. We can thus calculate how important dimensions of the generated observations or agent hidden states are for the value function.
Say we wanted to generate saliency maps with respect to value and leftwards actions for specifically the generated samples numbered 33 39 56 84. We'd use the following command:
python saliency_experiments.py --agent_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --gen_mod_exp_name=dev --model_file="generative/results/rssm53_largepos_sim_penalty_extraconverterlayers/20220106_181406/model_epoch3_batch20000.pt"
If we wanted to generate saliency maps for the same quantities but combine those
samples into one sample by taking their mean latent space vector (instead of
iterating over each sample individually), we'd add
the flag --combine_samples_not_iterate
If we wanted to generate saliency maps for all samples from 0 to 100, we'd replace
the --sample_ids 33 39 56 84
flag with --sample_ids 0 to 100
.
After we've calculated the saliency maps, we can use them to identify the causal structure of the control algorithm used by the agent.
First we cluster the agent-environment dynamics. These clusters correspond to behaviours.
python analysis/combined_agent_env_hx_analysis_precompute.py
(Now would be a good time to look at the interpretability panel since we've just generated everything it needs to run.)
We need to summarise the IC dynamics for each behaviour. We summarize them and plot them using xcorr plots between ICs at each timestep.
python xcorr_and_xcaus_plots.py
Then we compare the magnitude and sign of the corresponding entries in the cross-correlation and Jacobian matrices to identify where gradients are consistent with the dynamics, both with and without passing gradients through the environment.
python analysis/dynamics_grads_consistency_plots.py
If our hypotheses about the role of different directions in hidden-state space are correct, we should be able to make predictions about how the agent should behave when those directions are altered.
We can record the hidden states while either swapping different directions in hidden-state-space or collapsing directions into the nullspace so that the agent can't use those directions.
We can use the record_informinit_gen_samples.py
script to do this.
By default, the CLI arguments for --swap_directions_from
and
--swap_directions_to
are empty. If we want to swap the 10th hx direction
with the 12th hx direction and at the same time collapse the 5th hx direction
into the nullspace, we simply add the arguments
--swap_directions_from 10 5 --swap_directions_to 12 None
It's also advised to change the directory that the recordings get saved to in order not to overwrite previous data from the unaltered agent hx dynamics. To do this add something like:
--data_save_dir=generative/recorded_validations_swapping