/jax-baseline

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

Primary LanguagePythonMIT LicenseMIT

Jax-Baseline

Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.

Features

  • 2-3 times faster than previous Torch and Tensorflow implementations
  • Optimized using JAX's Just-In-Time (JIT) compilation
  • Flexible solution for Gym and Unity ML environments

Installation

pip install -r requirement.txt
pip install .

Implementation Status

  • ✔️ : Optional implemented
  • ✅ : Defualt implemented at papers
  • ❌ : Not implemeted yet or can not implemented
  • 💤 : Implemented but didn't update a while (can not guarantee working well now)

Supported Environments

Name Q-Net based Actor-Critic based DPG based
Gymnasium ✔️ ✔️ ✔️
MultiworkerGym with Ray ✔️ ✔️ ✔️
Unity-ML Environments 💤 💤 💤

Implemented Algorithms

Q-Net bases

Name Double1 Dueling2 Per3 N-step45 NoisyNet6 Munchausen7 Ape-X8 HL-Gauss9
DQN10 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
C5111 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
QRDQN12 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
IQN13 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
FQF14 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
SPR15 ✔️ ✔️
BBF16 ✔️ ✔️ ✔️

Actor-Critic based

Name Box Discrete IMPALA17
A2C18 ✔️ ✔️ ✔️
PPO19 ✔️ ✔️ ✔️20
Truly PPO(TPPO)21 ✔️ ✔️

DPG bases

Name Per3 N-step45 Ape-X8
DDPG22 ✔️ ✔️ ✔️
TD323 ✔️ ✔️ ✔️
SAC24 ✔️ ✔️
TQC25 ✔️ ✔️
TD726 ✅(LAP27)

Performance Compariton

Test

To test Atari with DQN (or C51, QRDQN, IQN, FQF):

python test/run_qnet.py --algo DQN --env BreakoutNoFrameskip-v4 --learning_rate 0.0002 \
		--steps 5e5 --batch 32 --train_freq 1 --target_update 1000 --node 512 \
		--hidden_n 1 --final_eps 0.01 --learning_starts 20000 --gamma 0.995 --clip_rewards

500K steps can be run in just 15 minutes on Atari Breakout (540 steps/sec). Performance measured on Nvidia RTX3080 and AMD Ryzen 9 5950X in a single process.

score : 9.600, epsilon : 0.010, loss : 0.181 |: 100%|███████| 500000/500000 [15:24<00:00, 540.88it/s]

Footnotes

  1. Double DQN paper

  2. Dueling DQN paper

  3. PER 2

  4. N-step TD 2

  5. RAINBOW DQN 2

  6. Noisy network

  7. Munchausen rl

  8. Ape-X 2

  9. HL-GAUSS

  10. DQN

  11. C51

  12. QRDQN

  13. IQN

  14. FQF

  15. SPR

  16. BBF

  17. IMPALA

  18. A3C

  19. PPO

  20. IMPALA + PPO, APPO

  21. Truly PPO

  22. DDPG

  23. TD3

  24. SAC

  25. TQC

  26. TD7

  27. LaP