TuringLang/AbstractMCMC.jl

`sample` equivalent but including states

torfjelde opened this issue · 4 comments

Sometimes I want both the samples and the states rather than just the states. Of course this can be achieved by just using the iterator interface explicitly, or a callback, but it's a bit inconvenient to have to write this every time.

Would it make sense to introduce a sample_with_states method or simply a keyword argument to sample, e.g. include_states::Bool, specifying whether or not to also include the states in the return-value (I'm more in favour of just calling save!! with tuple (sample, state) rather than changing than returning having the kwarg making it so that we instead return samples, states at the end).

Thoughts?

I'm in favor of adding the keyword argument and passing it to save!!, so instead of calling

samples = save!!(samples, sample, i, model, sampler; kwargs...)

we would call

samples = save!!(samples, (sample, state), i, model, sampler; kwargs...)

as you suggested. Then our default version of save!! would just throw out state if it was not requested. Though, now that I'm looking at this it kind of looks like there might be some type instability. Perhaps we could add a dispatch type on include_states?

At the very least, we shouldn't add a whole other sample_with_states method just to keep maintenance down, since the changes are fairly minor and it's not worth copying all the sample code around.

Hmm, this is actually a bit more annoying that I originally thought 😕

As you said, it'll introduce type-instabilities unless we make it a Val typed argument or something. We could of course do that, but it's more annoying that just a kwarg.

Maybe we should just provide a callback? I.e.

struct StateHistoryCallback{A}
    states::A
end
StateHistoryCallback() = StateHistoryCallback(Any[])

function (cb::StateHistoryCallback)(rng, model, sampler, sample, state, i; kwargs...)
    push!(cb.states, state)
    return nothing
end

so users can do

state_history = []
sample(..., callback=StateHistoryCallback(state_history))
state_history

I dunno that seems super hacky to me. I think it's valuable to provide the states but I think it's worth investing having it's own code path (whatever that looks like).