TuringLang/AbstractMCMC.jl

Passing chain index in parallel sampling

theogf opened this issue · 5 comments

I have the problem that I would like to give different initial parametrizations to each of my chain when sampling them in parallel.
One solution I thought of would be to pass the chain index as a _chain_index kwarg or something so in the subsequent sample or step I could just do init_params[_chain_index]. I could also imagine that having the chain index could be practical for other reasons (personal logging etc).

One practical issue with this problem is that there is no support for init_params in AbstractMCMC but only in some downstream packages such as AdvancedMH, EllipticalSliceSampling or DynamicPPL.

Maybe we could make init_params more official and address this problem without too many changes:

  • Add init_params=nothing keyword argument to multichain sample methods
  • Forward init_params=init_params === nothing ? nothing : init_params[idx] to the single chain sample call

We don't even have to handle the keyword argument in the single chain sample method, it would just be forwarded automatically to the step method and hence the implementation in downstream packages would take care of it.

If one wants to start all chains with the same parameters one would have to pass e.g. a Fill array of the initial parameters. Otherwise (ie if we try to support both a single and multiple sets of initial parameters) I think we would always run into ambiguity issues.

The init_params is my usecase but I could imagine that there are much more situations where it would be useful. My whole point is to make AbstractMCMC flexible enough such that the end user can really do what they want (and define init_params the way they want. The problem is that not all information is passed right now

The init_params is my usecase

For this use case I think it is easier to not use indices in downstream implementations. Otherwise every implementation of the initial step (for a single chain) would have to handle all combinations and types of chain_index and init_params.

The problem is that not all information is passed right now

Currently, the idea is that one can implement sampling for a single chain and it will just work for multiple chains as well without having to think about it. So multiple chains-specific keyword arguments are a bit annoying but maybe that's what we have to do at some point. If the only use case are initial parameters we could postpone this discussion though.

It would be really handy to have an init_params that got split across the per-chain calls to sample. Currently not having this is pretty annoying with NUTS.

Support for init_params in the ensemble methods (split across per-chain calls to sample) was added in #94. Can this issue be closed?