/MCTS-NNET

Monte Carlo Tree Search with Reinforcement Learning for Motion Planning

Primary LanguagePythonMIT LicenseMIT

Monte Carlo Tree Search with Reinforcement Learning for Motion Planning

This repo contains the code used in the paper "Monte Carlo Tree Search with Reinforcement Learning for Motion Planning", IEEE ITSC 2020

The following algorithms are implemented and benchmarked:

  • Rules-based (reflex) method: a simple emergency braking method
  • Tree Search: Uniform Cost Search (A*) and Dynamic Programming
  • MPC: Model Predictive Control
  • Sampling based Tree Search with MCTS: Monte Carlo Tree Search
  • DDQN: Double Deep Q-learning using Deep Neural Networks
  • MCTS-NNET: MCTS combined with DDQN.

Ultimately we find that combining MCTS planning and DQN learning in a single solution provides the best performance with real-time decisions. Here, a pre-trained DQN network is used to guide the tree search, providing fast and reliable estimates of Q-values and state values. We call this model MCTS-NNET, as it leverages on the insights of AlphaGo Zero.

Our results demonstrate the performance of MCTS-NNET achieving a 98% success rate when compared to a 84% success rate for DQN and a 85% success rate for MCTS alone. This is possible with an inference time of 4 ms.

Presentation video of the paper

Julia source code (optimized for speed)

The code was initially developped in Python and later on optimized for speed in Julia.
The Julia versions are much faster than the Python versions.

cd julia
julia scn.jl mcts
julia scn.jl mpc
julia scn.jl ucs
julia scn.jl dp

Python source code

Models and algorithms:

Utilities used (from Stanford CS221 and CS230 courses):

Mcts-nnet inference

Baseline (reflex-based), dqn, mcts, mcts-nnet on 100 tests:

cd mdp
python3 test_algo.py --algo baseline
python3 test_algo.py --algo dqn
python3 test_algo.py --algo mcts
python3 test_algo.py --algo mpc
python3 test_algo.py --algo mcts-nnet

Collision Avoidance Scenario
Collision Avoidance Scenario

Mcts-nnet training

cd mdp
python3 train_dqn.py

Training results are stored in mdp/models
Cf trained model dnn-0.31.pth.tar