blackjax-devs/blackjax

Improvements to `run_inference_algorithm`

reubenharry opened this issue · 0 comments

Current behavior

  1. run_inference_algorithm can optionally take a position or an initial state. The try-except handler is a little unreliable (e.g. except not thrown in presence of other exceptions). Moreover, it seems more modular to delegate the transformation from initial_position to initial_state to the caller of run_inference_algorithm.
  2. run_inference_algorithm produces n samples. For high dimensional problems this is memory inefficient.
  3. transform currently only applies to state and not Info, so there isn't a way to dispense of a part of the diagnostic info.

Desired behavior

  1. Make run_inference_algorithm only take initial_state
  2. Allow run_inference_algorithm to have a memory-efficient mode, where it computes a running average of a desired expectation.
  3. Change transform to take also Info as argument. This will be a breaking change (more so than (1))