probcomp/Gen.jl

Without `load_generated_functions`, static proposals must now be defined before static models

alex-lew opened this issue · 2 comments

The type StaticIRTraceAssmt{T} is used to represent the ChoiceMaps of Static DSL generative function traces. The type parameter T is the trace type of the Static DSL generative function in question.

If a user defines a static model and then defines a static proposal for the model, we have a problem when we try to update a model trace with a choicemap from the proposal:

  • The model's update was code-gen'd before the proposal trace type existed
  • As such, the 'world age' of any methods defined on the proposal's choice maps (e.g. get_schema) is higher than the world age of the model's update method
  • Therefore, the model's update method cannot call (e.g.) get_schema on a choicemap generated by the proposal.

Previously, the use of Gen.load_generated_functions got around this, by waiting for all static generative functions to be defined before performing codegen 'all at once'. After #472, however, types and GFI implementations are generated immediately after a new static DSL method is defined.

I don't have thoughts yet on how best to fix this, but wanted to record the issue! Code to reproduce below:

# Model:
@gen (static) function datum(x::Float64, (grad)(inlier_std::Float64), (grad)(outlier_std::Float64),
                             (grad)(slope::Float64), (grad)(intercept::Float64))
    z ~ bernoulli(0.5)
    std = ifelse(z, inlier_std, outlier_std)
    y ~ normal(x * slope + intercept, std)
    return y
end

data = Map(datum)

@gen (static) function model(xs::Vector{Float64})
    n = length(xs)
    log_inlier_std ~ normal(0, 2)
    log_outlier_std ~ normal(0, 2)
    inlier_std = exp(log_inlier_std)
    outlier_std = exp(log_outlier_std)
    slope ~ normal(0, 2)
    intercept  ~ normal(0, 2)
    data ~ data(xs, fill(inlier_std, n), fill(outlier_std, n),
        fill(slope, n), fill(intercept, n))
end

# Proposal:
@gen (static) function slope_proposal(trace)
    old_slope = trace[:slope]
    slope ~ normal(old_slope, 0.5)
end

xs = [1.0, 2.0, 3.0, 4.0, 5.0]
tr = simulate(model, (xs,))
mh(tr, slope_proposal, ())

(@ztangent, @femtomc)

Thanks for discovering this issue! I looked through the SML implementation, and it looks like there are two things that need to be fixed.

First, there is this code that gets called at @generated-function compilation time, which uses the schema returned by get_address_schema to determine whether conversion to a StaticChoiceMap is necessary.

function codegen_update(trace_type::Type{T}, args_type::Type, argdiffs_type::Type,
constraints_type::Type) where {T<:StaticIRTrace}
gen_fn_type = get_gen_fn_type(trace_type)
schema = get_address_schema(constraints_type)
# convert the constraints to a static assignment if it is not already one
if !(isa(schema, StaticAddressSchema) || isa(schema, EmptyAddressSchema))
return quote $(GlobalRef(Gen, :update))(trace, args, argdiffs, $(QuoteNode(StaticChoiceMap))(constraints)) end
end

Looking at this code, and also the similar code used for the static modeling language implementation of regenerate, project, generate, and choice_gradients, I think it should be possible to handle the dispatching logic using regular, non-@generated functions. Specifically, it should be possible to just add new method definitions to the code block returned by generate_generative_function such that non-static choicemaps or selections are converted to static ones.

The above would resolve the issue for functions that only work with selections, as far as I can tell. However, this doesn't handle the fact that in update, generate and choice_gradients, get_address_schema can be called on StaticIRTraceAssmt, and hence get_schema can still be called on a newly defined StaticIRTrace, which results in a world-age error. Furthermore, we cannot omit the call to get_address_schema because the returned schema variable is needed later on by sub-functions like process_forward! etc. to generate more specialized code for only the variables that are in the schema.

To work around this, I can think of at least two different options:

  1. Defining a generic version of get_address_schema that can directly return the necessary information, instead of forwarding to a newly defined version of get_schema for each new static function. To do this, we could store the schema information as either a field or type parameter within StaticIRTraceAssmt, and modify get_choices for static traces accordingly.

  2. Instead of calling get_address_schema within the body of a @generated function, call in a wrapper, non-@generated function, and pass the resulting information to the inner @generated function as a type parameter (perhaps by adding one more argument to the inner @generated function so that this information can be passed along).

My sense is that Option 1 would be the neater solution. I can take a stab at it once I find some time!

Resolved by #510.