Passing chain index in parallel sampling
theogf opened this issue · 5 comments
I have the problem that I would like to give different initial parametrizations to each of my chain when sampling them in parallel.
One solution I thought of would be to pass the chain index as a _chain_index
kwarg or something so in the subsequent sample
or step
I could just do init_params[_chain_index]
. I could also imagine that having the chain index could be practical for other reasons (personal logging etc).
One practical issue with this problem is that there is no support for init_params
in AbstractMCMC but only in some downstream packages such as AdvancedMH, EllipticalSliceSampling or DynamicPPL.
Maybe we could make init_params
more official and address this problem without too many changes:
- Add
init_params=nothing
keyword argument to multichainsample
methods - Forward
init_params=init_params === nothing ? nothing : init_params[idx]
to the single chain sample call
We don't even have to handle the keyword argument in the single chain sample method, it would just be forwarded automatically to the step
method and hence the implementation in downstream packages would take care of it.
If one wants to start all chains with the same parameters one would have to pass e.g. a Fill
array of the initial parameters. Otherwise (ie if we try to support both a single and multiple sets of initial parameters) I think we would always run into ambiguity issues.
The init_params
is my usecase but I could imagine that there are much more situations where it would be useful. My whole point is to make AbstractMCMC flexible enough such that the end user can really do what they want (and define init_params
the way they want. The problem is that not all information is passed right now
The init_params is my usecase
For this use case I think it is easier to not use indices in downstream implementations. Otherwise every implementation of the initial step (for a single chain) would have to handle all combinations and types of chain_index
and init_params
.
The problem is that not all information is passed right now
Currently, the idea is that one can implement sampling for a single chain and it will just work for multiple chains as well without having to think about it. So multiple chains-specific keyword arguments are a bit annoying but maybe that's what we have to do at some point. If the only use case are initial parameters we could postpone this discussion though.
It would be really handy to have an init_params
that got split across the per-chain calls to sample
. Currently not having this is pretty annoying with NUTS.