TuringLang/DynamicPPL.jl

Adding StatsBase.predict to the API

sethaxen opened this issue · 7 comments

In Turing, StatsBase.predict is overloaded to dispatch on DynamicPPL.Model and MCMCChains.Chains (https://github.com/TuringLang/Turing.jl/blob/d76d914231db0198b99e5ca5d69d80934ee016b3/src/inference/Inference.jl#L532-L564). This effectively does batch prediction, conditioning the model on each draw in the chains and calls rand on the model. We also want to do the same thing for InferenceData (see #465).

It would be convenient if StatsBase.predict was added to the DynamicPPL API. It's already an indirect dependency of this package. As suggested by @devmotion in #465 (comment), its default implementation could be to just call rand for a conditioned model:

StatsBase.predict(rng::AbstractRNG, model::DynamicPPL.Model, x) = rand(rng, condition(model, x))
StatsBase.predict(model::DynamicPPL.Model, x) = predict(Random.default_rng(), model, x)

Maybe this could even be part of AbstractPPL and be defined on AbstractPPL.AbstractProbabilisticProgram: condition is part of its API, only rand is not clearly specified there yet (probably should be done anyway).

Yeah, makes sense.

I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x)) is probably not the greatest idea as it defaults to NamedTuple which can blow up compilation times for many models.

And regarding adding to APPL; we need to propagate that change back to v0.5 too then, because v0.6 is currently not compatible with DPPL (see #440).

I'm down with this, but it's worth pointing out that just calling rand(rng, condition(model, x)) is probably not the greatest idea as it defaults to NamedTuple which can blow up compilation times for many models.

Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?

Would rand(rng, OrderedDict, condition(model, x)) be the way to go then?

For maximal model-compat, yes. But you do of course take a performance hit as a result 😕

Hrm. Maybe then predict should use a NamedTuple if x is a NamedTuple (imperfect because you can have few parameters but many data points). Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)

Or provide an API for specifying the return type, like rand does (but supporting two optional positional parameters rng and T complicates the interface)

Adding T to predict (with some default) would be in line with our API for rand though - there type T can be specified already.