Extending chain representations
cscherrer opened this issue · 8 comments
Hi,
I had misunderstood some of the goals of this package, thankfully @cpfiffer got me straightened out. I'm looking into using this as an interface for populating SampleChains.
The docs seem mostly oriented toward people building new samplers, and not so much for people building new ways of representing chains. So I have lots of questions...
- I'd like to avoid getting all of the samples and then calling
bundle_samples
. It seems like I should be able to instead overloadsave!!
and then havebundle_samples
be a no-op. Is that right? - Cameron mentioned a
resume
function that can pick up sampling where it left off, without needing to go back through the warmup phase. But I don't see it in this repo. Where can I find it? - Where are things like per-sample diagnostics stored? Divergences for HMC, that sort of thing.
- Do you have any examples of using this with log-weighted samples? I need these for importance sampling, VI, etc.
I'm sure I'll have more to come, but this will get me started. Thanks :)
- Yeah, I think that's right.
- The
resume
stuff is in Turing proper -- it kind of implicitly assumes that theresume
knowledge lives with the modelling language. We could add it in here, I think, but not sure what the form would take yet. - Totally up to you -- store them in whatever internal struct as part of
state
orsample
as needed (in this context) - MCMCChains has a field for this but nothing we use that much. I'd just include it in
state
orsample
and toss it off to the chain to handle as-needed.
Great, thanks @cpfiffer .
The resume will change the state of an existing object, so I think it should have a !
. I could just fall back on my current interface for this.
I think I'm still confused what you mean by "state". There are a few things other than samples that can be important. Some are fixed size, like
- The iterator
- Some characteristics that don't change, but apply overall. For example, the log-density function, or the mass matrix for HMC. These are fixed-size.
- The log-density contribution of each scalar component of a sample. This comes into play for samplers like Gen.
And some that scale linearly with the number of samples:
- Per-sample diagnostics, like divergences. These scale with the number of samples, but don't change the semantics of the samples.
- Log-density values.
- Log-weights. These affect the semantics, for example coming into play for computing expected values.
I think only the first group should be considered "state", and per-sample diagnostics should be separate from samples and state (I'm currently calling them "info", which I think IIRC I got from Turing somewhere). I was thinking of having separate logp
and logweight
fields as part of the interface, one for each sample.
Hmm, and further complicating this is that state
is not currently allowed to be saved:
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
That could make this very tricky.
I guess I could overload mcmcsample
?
Sure, one can always roll a custom mcmcsample
. The only downside is that you lose some of the default features. BTW if you want to have more control over the sampling procedure (e.g., computing statistics after every nth step, using convergence criteria) the iteration (or transducer) interface can be useful.
Regarding the points in the OP, I agree with what @cpfiffer said. resume
is currently defined in DynamicPPL but there are a bunch of issues and discussions in which we suggested to move it to AbstractMCMC (and the possibility to specify initial samples as well). In general, the policy was to experiment with interface changes/extensions in downstream packages first before moving them to AbstractMCMC and enforcing them in all implementations. I imagined that the sample
/mcmcsample
methods should probably allow to specify an initial state, and then resume
could be defined as
function resume(rng::Random.AbstractRNG, chain, args...; kwargs...)
return sample(
rng, getmodel(chain), getsampler(chain), args...;
state=getstate(chain), kwargs...,
)
end
Thanks @devmotion , I was hoping to allow a convergence criteria as a stopping condition, so this is great.
There does seem to be an assumption that everything the user could need is required to be part of the sample. For DynamicHMC, my setup looks like this (AdvancedHMC will be very similar):
@concrete struct DynamicHMCChain{T} <: AbstractChain{T}
samples # :: AbstractVector{T}
logp # log-density for distribution the sample was drawn from
info # Per-sample metadata, type depends on sampler used
meta # Metadata associated with the sample as a whole
state
transform
end
Here
samples
includes variables specified by the model:
julia> samples(chain)
100-element TupleVector with schema (x = Float64, σ = Float64)
(x = -0.1±0.34, σ = 0.576±0.37)
logp
has the log-density information for each sample:
julia> logp(chain)[1:5]
5-element ElasticArrays.ElasticVector{Float64, 0, Vector{Float64}}:
-1.136224195720376
-0.42132266397402207
-0.9789248604768969
-1.136224195720376
-0.8517859618293282
-
Many samplers will also use a
logweights
field -
info
has some diagnostic information:
julia> info(chain)[1:5]
5-element ElasticArrays.ElasticVector{DynamicHMC.TreeStatisticsNUTS, 0, Vector{DynamicHMC.TreeStatisticsNUTS}}:
DynamicHMC.TreeStatisticsNUTS(-1.283461597663962, 3, turning at positions 6:9, 0.9703961901160859, 11, DynamicHMC.Directions(0xdfea943d))
DynamicHMC.TreeStatisticsNUTS(-1.150959614787742, 1, turning at positions 2:3, 0.9646928495527286, 3, DynamicHMC.Directions(0x74715257))
DynamicHMC.TreeStatisticsNUTS(-1.1699430991621091, 3, turning at positions 3:6, 1.0, 11, DynamicHMC.Directions(0x8472ffea))
DynamicHMC.TreeStatisticsNUTS(-1.941965877904205, 1, turning at positions 2:3, 0.7405784149911505, 3, DynamicHMC.Directions(0x8c1d2457))
DynamicHMC.TreeStatisticsNUTS(-1.3103584844087501, 2, turning at positions -2:1, 0.9999999999999999, 3, DynamicHMC.Directions(0x9483096d))
meta
has information determined by the warmup phase, and will be different for each sampler:
julia> meta(chain).H
Hamiltonian with Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [1.1613920024118645, 0.7589536122573856]
julia> meta(chain).algorithm
DynamicHMC.NUTS{Val{:generalized}}(10, -1000.0, Val{:generalized}())
julia> meta(chain).ϵ
0.2634132789343616
julia> meta(chain).rng
Random._GLOBAL_RNG()
state
contains the iterator state, and is assumed to not be accessed by the end usertransform
is specific to HMC, and is just what you'd expect.
I guess I could cram my samples
, logp
, logweights
, and info
into your samples
, as long as you don't assign any semantics to this. Then meta
and transform
would only be written after warmup, and our state
fields would match up. Does that sound right?
Yeah, that should work. I think you could just dump them into a tuple or small wrapper struct when you return them as state
.
I can't return them as state
, that would make them unavailable since save!!
doesn't include state
as an argument. I think everything needs to be in sample
, then I can pull it apart after receiving it