Adding StatsBase.predict to the API
sethaxen opened this issue · 7 comments
In Turing, StatsBase.predict
is overloaded to dispatch on DynamicPPL.Model
and MCMCChains.Chains
(https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand
on the model. We also want to do the same thing for InferenceData
(see #465).
It would be convenient if StatsBase.predict
was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand
for a conditioned model:
StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)
Maybe this could even be part of AbstractPPL and be defined on AbstractPPL.AbstractProbabilisticProgram
: condition
is part of its API, only rand
is not clearly specified there yet (probably should be done anyway).
Yeah, makes sense.
I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x))
is probably not the greatest idea as it defaults to NamedTuple
which can blow up compilation times for many models.
And regarding adding to APPL; we need to propagate that change back to v0.5 too then, because v0.6 is currently not compatible with DPPL (see #440).
I'm down with this, but it's worth pointing out that just calling
rand(rng, condition(model, x))
is probably not the greatest idea as it defaults toNamedTuple
which can blow up compilation times for many models.
Would rand(rng, OrderedDict, condition(model, x))
be the way to go then?
Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?
For maximal model-compat, yes. But you do of course take a performance hit as a result 😕
Hrm. Maybe then predict
should use a NamedTuple
if x
is a NamedTuple
(imperfect because you can have few parameters but many data points). Or provide an API for specifying the return type, like rand
does (but supporting two optional positional parameters rng
and T
complicates the interface)
Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)
Adding T
to predict
(with some default) would be in line with our API for rand
though - there type T
can be specified already.