/RL-Flax

Various reinforcement learning algorithms written in Jax + Flax

Primary LanguagePythonMIT LicenseMIT

logo

RL-Flax

This is an attempt to recreate many reinforcement learning algorithms in Jax(Flax) world as single-file implementations.

Run Locally

Clone the project

  git clone https://github.com/MyNameIsArko/RL-Flax

Go to the project directory

  cd RL-Flax

Install basic dependencies

  pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
  pip install flax tensorflow-probability

Install environment specific dependencies

Run the algorithm you want!

Roadmap

  • DQN

  • Rainbow DQN

  • A2C

  • A3C

Contributing

Any kind of contribution is welcome!

If you know a little bit of Jax+Flax and know ins and outs of some algorithm then make a pull request. I'll gladly accept it as this is a big project for one man.

Acknowledgements

License

MIT