TuringLang/AbstractMCMC.jl

Bring back transitions_init and transitions_save!

mohamed82008 opened this issue · 9 comments

In Turing, the interface overloading can be made easier if we overload transitions_init and transitions_save! (or maybe rename it to transitions_save!!) to call transitions and save!! with different arguments. For example, the HMC definitions will be:

function AbstractMCMC.transitions_init(
    transition,
    model::AbstractModel,
    sampler::Sampler{<:Hamiltonian},
    N::Integer;
    discard_adapt = true,
    kwargs...
)
    if discard_adapt && isdefined(sampler.alg, :n_adapts)
        n = max(0, N - sampler.alg.n_adapts)
    else
        n = N
    end
    return AbstractMCMC.transitions(transition, model, sampler, n; kwargs...)
end

and

function AbstractMCMC.transitions_save!!(
    transitions::AbstractVector,
    iteration::Integer,
    transition,
    model::AbstractModel,
    sampler::Sampler{<:Hamiltonian},
    N::Integer;
    discard_adapt = true,
    kwargs...
)
    if discard_adapt && isdefined(sampler.alg, :n_adapts) && iteration <= sampler.alg.n_adapts
        return transitions
    else
        return AbstractMCMC.save!!(transitions, iteration, transition, model, sampler, N; kwargs...)
    end
end

Currently, one has to make use of BangBang in Turing which is less pleasant.

IMO we already have too many functions in AbstractMCMC that you can overload, so one shouldn't start reintroducing functions just to save an import BangBang statement in Turing. The current implementations of transitions_init and transitions_save! should just be renamed to transitions and save!! with the next release of AbstractMCMC. These methods are supposed to be implemented for input types owned by downstream packages, you are not supposed to call them.

But it's more than just saving import BangBang. The above methods can be made to implement the logic of NUTS sample dropping but is otherwise generic as far as the type of transitions is. So one can separate the logic of Sampler-specific behavior from that of transitions-specific behavior.

We can also just have our own AbstractMCMC.push!! which calls BangBang.push!! or gets overloaded for specific transitions types.

I think the interface is already too messy right now, I really don't think we should add more to the API. Specific samplers will always give you very specific transitions, so there is no clear separation of sampler- and transitions-specific behaviour anyways. Dropping the first n samples is a common use case and should maybe just be supported by the default logic in mcmcsample and the default implementations of transitions and save!! (at least I wanted/had something similar in EllipticalSliceSampling for quite some time). Again this would be much nicer with a iterator/transducer based implementation since then you could easily drop the first samples.

BTW if you want to separate the sampler- and transitions-part you can always define a transitions and save!! implementation in Turing for the algorithms that it owns that does exactly that.

Well what I meant by transitions-specific is e.g. writing to Vector vs writing to disk vs some other way. This is completely orthogonal to whether we want to save more information in the transition for a specific sampler, e.g. in Gibbs, or drop the first few iterations for NUTS. But I agree that dropping the first few iterations is common enough that it should probably be supported by default through a kwarg maybe.

Well what I meant by transitions-specific is e.g. writing to Vector vs writing to disk vs some other way.

I think ideally we first implement stuff like this in downstream packages (you have access to everything and can hence implement this logic for all InferenceAlgorithm in Turing) and based on these experiences figure out if something should become part of the general API in AbstractMCMC, something should be removed, or something added (like dropping the first samples). transitions_init and transitions_save! are not removed, as the title of the issue suggests, they were just renamed to 1) get cleaner names, and 2) indicate that saving might mutate the existing transitions and has to return the transitions, similar to the convention in BangBang (but, of course, users don't have to use BangBang in their own implementations of save!!).

For instance, based on my experiences, a major improvement for AbstractMCMC (and Turing etc) would be switching to a completely stateless approach, i.e., getting rid of mutation in samplers completely (I wanted to make that change for quite some time, but then always something else came up...). That's something that we figured out while improving Turing and noticing all the hacks that we currently use for initializing, e.g., HMC samplers.

So IMO it would be best to implement a more orthogonal approach to the saving in a downstream package such as Turing first, and then figure out if it should be moved upstream, instead of extending the general API right away. E.g., in the past some things such as sampling until convergence were moved from NestedSamplers to AbstractMCMC because we figured out that it might be generally useful.

I agree with @devmotion -- I don't want to muck around with generalish features in AbstractMCMC if we can help it.

I'll close this for now. We can revisit it if needed.