JRD_EXTENSIONS

Python Version Code Style

Features | Installation

Features

PRNGSequence
  • Without extension :
import jax
key = jax.random.key(0)
for _ in iterations:
    key, _k = jax.random.split(key)
    do_random_things(_k)
  • With extension :
import jrd_extensions
key = jrd_extensions.PRNGSequence(0)
for _ in iterations:
    do_random_things(next(k))

Installation

This package requires Python 3.10 or later and a working JAX installation. To install JAX, refer to the instructions.

pip install --upgrade pip
pip install git+https://github.com/Raffaelbdl/jrd_extensions