Without `load_generated_functions`, static proposals must now be defined before static models
alex-lew opened this issue · 2 comments
The type StaticIRTraceAssmt{T}
is used to represent the ChoiceMaps of Static DSL generative function traces. The type parameter T
is the trace type of the Static DSL generative function in question.
If a user defines a static model and then defines a static proposal for the model, we have a problem when we try to update
a model trace with a choicemap from the proposal:
- The model's
update
was code-gen'd before the proposal trace type existed - As such, the 'world age' of any methods defined on the proposal's choice maps (e.g.
get_schema
) is higher than the world age of the model's update method - Therefore, the model's update method cannot call (e.g.)
get_schema
on a choicemap generated by the proposal.
Previously, the use of Gen.load_generated_functions
got around this, by waiting for all static generative functions to be defined before performing codegen 'all at once'. After #472, however, types and GFI implementations are generated immediately after a new static DSL method is defined.
I don't have thoughts yet on how best to fix this, but wanted to record the issue! Code to reproduce below:
# Model:
@gen (static) function datum(x::Float64, (grad)(inlier_std::Float64), (grad)(outlier_std::Float64),
(grad)(slope::Float64), (grad)(intercept::Float64))
z ~ bernoulli(0.5)
std = ifelse(z, inlier_std, outlier_std)
y ~ normal(x * slope + intercept, std)
return y
end
data = Map(datum)
@gen (static) function model(xs::Vector{Float64})
n = length(xs)
log_inlier_std ~ normal(0, 2)
log_outlier_std ~ normal(0, 2)
inlier_std = exp(log_inlier_std)
outlier_std = exp(log_outlier_std)
slope ~ normal(0, 2)
intercept ~ normal(0, 2)
data ~ data(xs, fill(inlier_std, n), fill(outlier_std, n),
fill(slope, n), fill(intercept, n))
end
# Proposal:
@gen (static) function slope_proposal(trace)
old_slope = trace[:slope]
slope ~ normal(old_slope, 0.5)
end
xs = [1.0, 2.0, 3.0, 4.0, 5.0]
tr = simulate(model, (xs,))
mh(tr, slope_proposal, ())
Thanks for discovering this issue! I looked through the SML implementation, and it looks like there are two things that need to be fixed.
First, there is this code that gets called at @generated
-function compilation time, which uses the schema
returned by get_address_schema
to determine whether conversion to a StaticChoiceMap
is necessary.
Gen.jl/src/static_ir/update.jl
Lines 467 to 475 in e1fb6eb
Looking at this code, and also the similar code used for the static modeling language implementation of regenerate
, project
, generate
, and choice_gradients
, I think it should be possible to handle the dispatching logic using regular, non-@generated
functions. Specifically, it should be possible to just add new method definitions to the code block returned by generate_generative_function
such that non-static choicemaps or selections are converted to static ones.
The above would resolve the issue for functions that only work with selections, as far as I can tell. However, this doesn't handle the fact that in update
, generate
and choice_gradients
, get_address_schema
can be called on StaticIRTraceAssmt
, and hence get_schema
can still be called on a newly defined StaticIRTrace
, which results in a world-age error. Furthermore, we cannot omit the call to get_address_schema
because the returned schema
variable is needed later on by sub-functions like process_forward!
etc. to generate more specialized code for only the variables that are in the schema.
To work around this, I can think of at least two different options:
-
Defining a generic version of
get_address_schema
that can directly return the necessary information, instead of forwarding to a newly defined version ofget_schema
for each new static function. To do this, we could store the schema information as either a field or type parameter withinStaticIRTraceAssmt
, and modifyget_choices
for static traces accordingly. -
Instead of calling
get_address_schema
within the body of a@generated
function, call in a wrapper, non-@generated
function, and pass the resulting information to the inner@generated
function as a type parameter (perhaps by adding one more argument to the inner@generated
function so that this information can be passed along).
My sense is that Option 1 would be the neater solution. I can take a stab at it once I find some time!