proteneer/timemachine

Consider unifying approach to PRNG state

Opened this issue · 5 comments

Thanks to @mcwitt for thoughtful comments: migrating from #978 (comment) . Would be good to discuss and adopt project-wide conventions, if possible.

Some approaches currently used:

  • some objects store and update their own random state (integrator, barostat, context, etc.)
  • some objects have random functions that accept numpy rng's, jax rng key's, or integer random seeds
  • some sites reference global numpy random state

Some possible trade-offs:

  • ease of making application code deterministic
  • ease of making certain kinds of mistakes
  • explicitness
  • simplicity

See also: https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

mcwitt commented

Attempting to map out a decision tree:

  1. Nondeterministic functions accept PRNG state. This is the approach taken by JAX.
    Advantages: most explicit; avoids global state / side effects; allows for parallelization, sampling independent chains without creating extra instances.
    Disadvantages: significant refactoring and some boilerplate code to pass around random state arguments
    1. Use numpy API (numpy.random.Generator)
    2. Use JAX API (jax.random.PRNGKey).
      Advantages: compatible with JAX transformations (jit, etc.)
      Disadvantages: opens up unique class of potential errors: forgetting to split PRNG state
  2. Instances with nondeterministic methods maintain their own PRNG state. Seed or initial PRNG state is passed to the constructor.
    Advantages: avoids global state; reduced refactoring / boilerplate compared with option (1).
    Disadvantages: awkward / inefficient to parallelize sampling or sample independent chains (requires constructing many instances that differ only in seed, or additional interface to mutate the random state of an instance); not compatible (?) with JAX
  3. Use global PRNG state. This is the approach taken by the original numpy.random API.
    Advantages: simple; no boilerplate, trivial refactoring.
    Disadvantages: global state; need to be very careful about side effects, e.g. setting seed in one place unintentionally affecting downstream results; not compatible (?) with JAX

Thanks for mapping these out.

By default I lean towards (2) due to familiarity, imposing looser requirements (maybe one class uses cuRAND, a different class only uses numpy, ...), and since it seems harder to make certain kinds of errors (forgetting to split keys etc.). But (1) does seem cleaner.

An additional practice that may be compatible with all of the above options is to implement a random function random_f(x) by composing a deterministic function f(x, gaussian_noise) (whose implementation is RNG-agnostic) with random generation gaussian_noise = rng.normal(0, 1, x.shape) or gaussian_noise = np.random.randn(*x.shape) or ...

mcwitt commented

and since it seems harder to make certain kinds of errors (forgetting to split keys etc.)

That's a good point that forgetting to split keys would be a class of error unique to (1b) (added to "disadvantages" above).

An additional practice that may be compatible with all of the above options is to implement a random function random_f(x) by composing a deterministic function f(x, gaussian_noise) (whose implementation is RNG-agnostic) with random generation

This seems similar in spirit to option (1) to me, but does have the benefit of being RNG-agnostic. It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync.

It seems like it does introduce some additional room for error, though. E.g. the caller must be careful to ensure that the input randomness has the expected distribution, and because generation is decoupled from transformation, it might be harder to keep in sync.

That's true, and this is also a realistic concern in the context of reweighting, where we might have a deterministic function f(theta, samples_from_theta_0), expecting input randomness that may be very complicated, expensive, failure-prone to generate.

So it doesn't get lost, some further observations from @mcwitt in : #1128 (comment)