Reinforcement Learning Library in TensorFlow and OpenAI gym
NOTE: This is work in progress and the core structure of some parts of the library is likely to change
The goal of this library is to provide standard implementations of core Reinforcement Learning algorithms. The library is specifically targetted at research applications and aims to provide reusable constructs for easy implementations of new algorithms. Furthermore, it includes detailed hyperparameter and training logs, real-time training metrics plots (currently in TensorBoard, configurable matplotlib plots coming soon). The code is written in TensorFlow and supports gym-compatible envornments.
All currently-implemented algorithms achieve competitive performance with the results reported in the original papers (the default hyperparameters are not optimal for all environments).
git clone https://github.com/nikonikolov/rltf.git
- Python >= 3.5
- Tensorflow >= 1.4.0
- OpenAI gym
Algorithm | Model | Agent | Orignal Paper |
---|---|---|---|
DQN | dqn.py | dqn_agent.py | DQN |
Double DQN | next | next | Double DQN |
Dueling DQN | next | next | Dueling DQN |
C51 | c51.py | dqn_agent.py | C51 |
QR-DQN | qrdqn.py | dqn_agent.py | QR-DQN |
DDPG | ddpg.py | ddpg_agent.py | DDPG |
NAF | next | next | NAF |
Other algorithms are also coming in the near future:
An implemntation of an algorithm is composed of two parts: agent and model
- Should inherit from the Agent class
- Provides communication between the Model and the environment
- Responsible for stepping the environment and running the train procedure
- Manages the replay buffer (if any)
- Should inherit from the Model class
- A passive component which only implements the Tensorflow computation graph for the algorithm
- Implements the graph training procedure
- Exposes the graph input and output Tensors so they can be run by the Agent
After running any of the examples below, your logs will be saved in
trained_models/<model>/<env-id>_<run-number>
. If you enabled model saving,
the NN and its weights will be saved in the same folder. Furthermore, the
folder will contain:
params.txt
- file containing the values of the hyperparameters usedrun.log
- runtime log of the programtb/
- folder containing the TensorBoard plots of the training process
To see the TensorBoard plots, run:
tensorboard --logdir="<path/to/tb/dir"
and then go to http://localhost:6006 in your browser
python3 -m examples.run_ddpg_agent --model <model-name> --env-id <env-id>
For more details run:
python3 -m examples.run_dqn_agent --help
python3 -m examples.run_dqn_agent --model <model-name> --env-id <env-id>
For more details run:
python3 -m examples.run_dqn_agent --help
Note that run_dqn_agent
enforces only Atari environments. Moreover, it
requires that the environment used is of type <env-name>NoFrameskip-v4
(e.g. PongNoFrameskip-v4
). The NoFrameskip-v4
gym environments (combined
with some additional wrappers) are the ones corresponding to the training
process described in the orginal DQN Nature paper. If you want to use other
environment versions, you will need to add or remove some env wrappers
(see )