Jax-Baseline is a Reinforcement Learning implementation using JAX and Flax/Haiku libraries, mirroring the functionality of Stable-Baselines.
- 2-3 times faster than previous Torch and Tensorflow implementations
- Optimized using JAX's Just-In-Time (JIT) compilation
- Flexible solution for Gym and Unity ML environments
pip install -r requirement.txt
pip install .
- ✔️ : Optional implemented
- ✅ : Defualt implemented at papers
- ❌ : Not implemeted yet or can not implemented
- 💤 : Implemented but didn't update a while (can not guarantee working well now)
Name |
Q-Net based |
Actor-Critic based |
DPG based |
Gymnasium |
✔️ |
✔️ |
✔️ |
MultiworkerGym with Ray |
✔️ |
✔️ |
✔️ |
Unity-ML Environments |
💤 |
💤 |
💤 |
Name |
Double 1 |
Dueling 2 |
Per 3 |
N-step 45 |
NoisyNet 6 |
Munchausen 7 |
Ape-X 8 |
HL-Gauss 9 |
DQN10 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
C5111 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
QRDQN12 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
IQN13 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
❌ |
FQF14 |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
✔️ |
❌ |
❌ |
SPR15 |
✅ |
✅ |
✅ |
✅ |
✅ |
✔️ |
❌ |
✔️ |
BBF16 |
✅ |
✅ |
✅ |
✅ |
✔️ |
✔️ |
❌ |
✔️ |
Name |
Box |
Discrete |
IMPALA 17 |
A2C18 |
✔️ |
✔️ |
✔️ |
PPO19 |
✔️ |
✔️ |
✔️20 |
Truly PPO(TPPO)21 |
✔️ |
✔️ |
❌ |
Name |
Per 3 |
N-step 45 |
Ape-X 8 |
DDPG22 |
✔️ |
✔️ |
✔️ |
TD323 |
✔️ |
✔️ |
✔️ |
SAC24 |
✔️ |
✔️ |
❌ |
TQC25 |
✔️ |
✔️ |
❌ |
TD726 |
✅(LAP27) |
❌ |
❌ |
To test Atari with DQN (or C51, QRDQN, IQN, FQF):
python test/run_qnet.py --algo DQN --env BreakoutNoFrameskip-v4 --learning_rate 0.0002 \
--steps 5e5 --batch 32 --train_freq 1 --target_update 1000 --node 512 \
--hidden_n 1 --final_eps 0.01 --learning_starts 20000 --gamma 0.995 --clip_rewards
500K steps can be run in just 15 minutes on Atari Breakout (540 steps/sec).
Performance measured on Nvidia RTX3080 and AMD Ryzen 9 5950X in a single process.
score : 9.600, epsilon : 0.010, loss : 0.181 |: 100%|███████| 500000/500000 [15:24<00:00, 540.88it/s]