Neural Network Dynamics for Model-Based Deep Reinforcement Learning with Model-Free Fine-Tuning
Abstract: Model-free deep reinforcement learning algorithms have been shown to be capable of learning a wide range of robotic skills, but typically require a very large number of samples to achieve good performance. Model-based algorithms, in principle, can provide for much more efficient learning, but have proven difficult to extend to expressive, high-capacity models such as deep neural networks. In this work, we demonstrate that medium-sized neural network models can in fact be combined with model predictive control (MPC) to achieve excellent sample complexity in a model-based reinforcement learning algorithm, producing stable and plausible gaits to accomplish various complex locomotion tasks. We also propose using deep neural network dynamics models to initialize a model-free learner, in order to combine the sample efficiency of model-based approaches with the high task-specific performance of model-free methods. We empirically demonstrate on MuJoCo locomotion tasks that our pure model-based approach trained on just minutes of random action data can follow arbitrary trajectories, and that our hybrid algorithm can accelerate model-free learning on high-speed benchmark tasks, achieving sample efficiency gains of 3-5x on swimmer, cheetah, hopper, and ant agents.
- For installation guide, go to installation.md
- For notes on how to use your own environment, how to edit envs, etc. go to notes.md
How to run everything
cd scripts
./swimmer_mbmf.sh
./cheetah_mbmf.sh
./hopper_mbmf.sh
./ant_mbmf.sh
Each of those scripts does something similar to the following (but for multiple seeds):
python main.py --seed=0 --run_num=1 --yaml_file='swimmer_forward'
python mbmf.py --run_num=1 --which_agent=2
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=2 --num_workers_trpo=2 --std_on_mlp_policy=0.5
python plot_mbmf.py --trpo_dir=[trpo_dir] --std_on_mlp_policy=0.5 --which_agent=2 --run_nums 1 --seeds 0
Note that [trpo_dir] above corresponds to where the TRPO runs are saved. Probably somewhere in ~/rllab/data/...
Each of these steps are further explained in the following sections.
How to run MB
Need to specify:
--yaml_file Specify the corresponding yaml file
--seed Set random seed to set for numpy and tensorflow
--run_num Specify what directory to save files under
--use_existing_training_data To use the data that already exists in the directory run_num instead of recollecting
--desired_traj_type What type of trajectory to follow (if you want to follow a trajectory)
--num_rollouts_save_for_mf Number of on-policy rollouts to save after last aggregation iteration, to be used later
--might_render If you might want to visualize anything during the run
--visualize_MPC_rollout To set a breakpoint and visualize the on-policy rollouts after each agg iteration
--perform_forwardsim_for_vis To visualize an open-loop prediction made by the learned dynamics model
--print_minimal To not print messages regarding progress/notes/etc.
Examples:
python main.py --seed=0 --run_num=0 --yaml_file='cheetah_forward'
python main.py --seed=0 --run_num=1 --yaml_file='swimmer_forward'
python main.py --seed=0 --run_num=2 --yaml_file='ant_forward'
python main.py --seed=0 --run_num=3 --yaml_file='hopper_forward'
python main.py --seed=0 --run_num=4 --yaml_file='cheetah_trajfollow' --desired_traj_type='straight' --visualize_MPC_rollout
python main.py --seed=0 --run_num=4 --yaml_file='cheetah_trajfollow' --desired_traj_type='backward' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=4 --yaml_file='cheetah_trajfollow' --desired_traj_type='forwardbackward' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=5 --yaml_file='swimmer_trajfollow' --desired_traj_type='straight' --visualize_MPC_rollout
python main.py --seed=0 --run_num=5 --yaml_file='swimmer_trajfollow' --desired_traj_type='left_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=5 --yaml_file='swimmer_trajfollow' --desired_traj_type='right_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='straight' --visualize_MPC_rollout
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='left_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='right_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
python main.py --seed=0 --run_num=6 --yaml_file='ant_trajfollow' --desired_traj_type='u_turn' --visualize_MPC_rollout --use_existing_training_data --use_existing_dynamics_model
How to run MBMF
Need to specify:
--save_trpo_run_num number Number used as part of directory name for saving mbmf TRPO run (you can use 1,2,3,etc to differentiate your different seeds)
--run_num Specify what directory to get relevant MB data from & to save new MBMF files in
--which_agent Specify which agent (1 ant, 2 swimmer, 4 cheetah, 6 hopper)
--std_on_mlp_policy Initial std you want to set on your pre-initialization policy for TRPO to use
--num_workers_trpo How many worker threads (cpu) for TRPO to use
--might_render If you might want to visualize anything during the run
--visualize_mlp_policy To visualize the rollout performed by trained policy (that will serve as pre-initialization for TRPO)
--visualize_on_policy_rollouts To set a breakpoint and visualize the on-policy rollouts after each agg iteration of dagger
--print_minimal To not print messages regarding progress/notes/etc.
--use_existing_pretrained_policy To run only the TRPO part (if you've already done the imitation learning part in the same directory)
Note that the finished TRPO run saves to ~/rllab/data/local/experiments/
Examples:
python mbmf.py --run_num=1 --which_agent=2 --std_on_mlp_policy=1.0
python mbmf.py --run_num=0 --which_agent=4 --std_on_mlp_policy=0.5
python mbmf.py --run_num=3 --which_agent=6 --std_on_mlp_policy=1.0
python mbmf.py --run_num=2 --which_agent=1 --std_on_mlp_policy=0.5
How to run MF
Run pure TRPO, for comparisons.
Need to specify command line args as desired
--seed Set random seed to set for numpy and tensorflow
--steps_per_rollout Length of each rollout that TRPO should collect
--save_trpo_run_num Number used as part of directory name for saving TRPO run (you can use 1,2,3,etc to differentiate your different seeds)
--which_agent Specify which agent (1 ant, 2 swimmer, 4 cheetah, 6 hopper)
--num_workers_trpo How many worker threads (cpu) for TRPO to use
--num_trpo_iters Total number of TRPO iterations to run before stopping
Note that the finished TRPO run saves to ~/rllab/data/local/experiments/
Examples:
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=4 --num_workers_trpo=4
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=2 --num_workers_trpo=4
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=1 --num_workers_trpo=4
python trpo_run_mf.py --seed=0 --save_trpo_run_num=1 --which_agent=6 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=4 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=2 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=1 --num_workers_trpo=4
python trpo_run_mf.py --seed=50 --save_trpo_run_num=2 --which_agent=6 --num_workers_trpo=4
How to plot
- MBMF
-Need to specify the commandline arguments as desired (in plot_mbmf.py)
-Examples of running the plotting script:
cd plotting
python plot_mbmf.py --trpo_dir=[trpo_dir] --std_on_mlp_policy=1.0 --which_agent=2 --run_nums 1 --seeds 0
python plot_mbmf.py --trpo_dir=[trpo_dir] --std_on_mlp_policy=1.0 --which_agent=2 --run_nums 1 2 3 --seeds 0 70 100
Note that [trpo_dir] above corresponds to where the TRPO runs are saved. Probably somewhere in ~/rllab/data/...
-
Dynamics model training and validation losses per aggregation iteration
IPython notebook: plotting/plot_loss.ipynb
Example plots: docs/sample_plots/... -
Visualize a forward simulation (an open-loop multi-step prediction of the elements of the state space)
IPython notebook: plotting/plot_forwardsim.ipynb
Example plots: docs/sample_plots/... -
Visualize the trajectories (on policy rollouts) per aggregation iteration
IPython notebook: plotting/plot_trajfollow.ipynb
Example plots: docs/sample_plots/...