PRNGSequence
- Without extension :
import jax
key = jax.random.key(0)
for _ in iterations:
key, _k = jax.random.split(key)
do_random_things(_k)
- With extension :
import jrd_extensions
key = jrd_extensions.PRNGSequence(0)
for _ in iterations:
do_random_things(next(k))
This package requires Python 3.10 or later and a working JAX installation. To install JAX, refer to the instructions.
pip install --upgrade pip
pip install git+https://github.com/Raffaelbdl/jrd_extensions