hyper-optimized alpha-zero implementation with ray + cython for speed
train an agent that beats random actions and pure MCTS in 2 minutes
train.py
: distributed training with rayctree/
: mcts nodes in cython (node.py = pure python)mcts.py
: mcts playoutsnetwork.py
: neural net stuffboard.py
: gomoku board
- ray distributed parts (
train.py
):- one distributed replay buffer
- N actors with the 'best model' weights which self-play games and store data in replay buffer
- M 'candidate models' which pull from the replay buffer and train
- each iteration they play against the 'best model' and if they win the 'best model' weights is updated
- include write/evaluation locks on 'best weights'
- 1 best model weights store (PS / parameter server)
- stores the best weights which are retrived by self-play and updated when candidates win
- cython impl
ctree/
: c++/cython mctsnode.py
: pure python mcts
-- todos --
- jax network impl
- tpu + gpu support
- saved model weights