TuringLang/AbstractMCMC.jl

Extending chain representations

cscherrer opened this issue · 8 comments

Hi,

I had misunderstood some of the goals of this package, thankfully @cpfiffer got me straightened out. I'm looking into using this as an interface for populating SampleChains.

The docs seem mostly oriented toward people building new samplers, and not so much for people building new ways of representing chains. So I have lots of questions...

  1. I'd like to avoid getting all of the samples and then calling bundle_samples. It seems like I should be able to instead overload save!! and then have bundle_samples be a no-op. Is that right?
  2. Cameron mentioned a resume function that can pick up sampling where it left off, without needing to go back through the warmup phase. But I don't see it in this repo. Where can I find it?
  3. Where are things like per-sample diagnostics stored? Divergences for HMC, that sort of thing.
  4. Do you have any examples of using this with log-weighted samples? I need these for importance sampling, VI, etc.

I'm sure I'll have more to come, but this will get me started. Thanks :)

  1. Yeah, I think that's right.
  2. The resume stuff is in Turing proper -- it kind of implicitly assumes that the resume knowledge lives with the modelling language. We could add it in here, I think, but not sure what the form would take yet.
  3. Totally up to you -- store them in whatever internal struct as part of state or sample as needed (in this context)
  4. MCMCChains has a field for this but nothing we use that much. I'd just include it in state or sample and toss it off to the chain to handle as-needed.

Great, thanks @cpfiffer .

The resume will change the state of an existing object, so I think it should have a !. I could just fall back on my current interface for this.

I think I'm still confused what you mean by "state". There are a few things other than samples that can be important. Some are fixed size, like

  • The iterator
  • Some characteristics that don't change, but apply overall. For example, the log-density function, or the mass matrix for HMC. These are fixed-size.
  • The log-density contribution of each scalar component of a sample. This comes into play for samplers like Gen.

And some that scale linearly with the number of samples:

  • Per-sample diagnostics, like divergences. These scale with the number of samples, but don't change the semantics of the samples.
  • Log-density values.
  • Log-weights. These affect the semantics, for example coming into play for computing expected values.

I think only the first group should be considered "state", and per-sample diagnostics should be separate from samples and state (I'm currently calling them "info", which I think IIRC I got from Turing somewhere). I was thinking of having separate logp and logweight fields as part of the interface, one for each sample.

Hmm, and further complicating this is that state is not currently allowed to be saved:

samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

That could make this very tricky.

I guess I could overload mcmcsample?

Sure, one can always roll a custom mcmcsample. The only downside is that you lose some of the default features. BTW if you want to have more control over the sampling procedure (e.g., computing statistics after every nth step, using convergence criteria) the iteration (or transducer) interface can be useful.

Regarding the points in the OP, I agree with what @cpfiffer said. resume is currently defined in DynamicPPL but there are a bunch of issues and discussions in which we suggested to move it to AbstractMCMC (and the possibility to specify initial samples as well). In general, the policy was to experiment with interface changes/extensions in downstream packages first before moving them to AbstractMCMC and enforcing them in all implementations. I imagined that the sample/mcmcsample methods should probably allow to specify an initial state, and then resume could be defined as

function resume(rng::Random.AbstractRNG, chain, args...; kwargs...)
    return sample(
        rng, getmodel(chain), getsampler(chain), args...;
        state=getstate(chain), kwargs...,
    )
end

Thanks @devmotion , I was hoping to allow a convergence criteria as a stopping condition, so this is great.

There does seem to be an assumption that everything the user could need is required to be part of the sample. For DynamicHMC, my setup looks like this (AdvancedHMC will be very similar):

@concrete struct  DynamicHMCChain{T} <: AbstractChain{T}
    samples     # :: AbstractVector{T}
    logp        # log-density for distribution the sample was drawn from
    info        # Per-sample metadata, type depends on sampler used
    meta        # Metadata associated with the sample as a whole
    state       
    transform
end

Here

  • samples includes variables specified by the model:
julia> samples(chain)
100-element TupleVector with schema (x = Float64, σ = Float64)
(x = -0.1±0.34, σ = 0.576±0.37)
  • logp has the log-density information for each sample:
julia> logp(chain)[1:5]
5-element ElasticArrays.ElasticVector{Float64, 0, Vector{Float64}}:
 -1.136224195720376
 -0.42132266397402207
 -0.9789248604768969
 -1.136224195720376
 -0.8517859618293282
  • Many samplers will also use a logweights field

  • info has some diagnostic information:

julia> info(chain)[1:5]
5-element ElasticArrays.ElasticVector{DynamicHMC.TreeStatisticsNUTS, 0, Vector{DynamicHMC.TreeStatisticsNUTS}}:
 DynamicHMC.TreeStatisticsNUTS(-1.283461597663962, 3, turning at positions 6:9, 0.9703961901160859, 11, DynamicHMC.Directions(0xdfea943d))
 DynamicHMC.TreeStatisticsNUTS(-1.150959614787742, 1, turning at positions 2:3, 0.9646928495527286, 3, DynamicHMC.Directions(0x74715257))
 DynamicHMC.TreeStatisticsNUTS(-1.1699430991621091, 3, turning at positions 3:6, 1.0, 11, DynamicHMC.Directions(0x8472ffea))
 DynamicHMC.TreeStatisticsNUTS(-1.941965877904205, 1, turning at positions 2:3, 0.7405784149911505, 3, DynamicHMC.Directions(0x8c1d2457))
 DynamicHMC.TreeStatisticsNUTS(-1.3103584844087501, 2, turning at positions -2:1, 0.9999999999999999, 3, DynamicHMC.Directions(0x9483096d))
  • meta has information determined by the warmup phase, and will be different for each sampler:
julia> meta(chain).H
Hamiltonian with Gaussian kinetic energy (Diagonal), diag(M⁻¹): [1.1613920024118645, 0.7589536122573856]

julia> meta(chain).algorithm
DynamicHMC.NUTS{Val{:generalized}}(10, -1000.0, Val{:generalized}())

julia> meta(chain).ϵ
0.2634132789343616

julia> meta(chain).rng
Random._GLOBAL_RNG()
  • state contains the iterator state, and is assumed to not be accessed by the end user
  • transform is specific to HMC, and is just what you'd expect.

I guess I could cram my samples, logp, logweights, and info into your samples, as long as you don't assign any semantics to this. Then meta and transform would only be written after warmup, and our state fields would match up. Does that sound right?

Yeah, that should work. I think you could just dump them into a tuple or small wrapper struct when you return them as state.

I can't return them as state, that would make them unavailable since save!! doesn't include state as an argument. I think everything needs to be in sample, then I can pull it apart after receiving it