Tractables/Dice.jl

if recursing, pass values of sibling *enumlikes*

Closed this issue · 1 comments

# TODO: if recursing, pass values of sibling *enumlikes*

function tocoq(v::Tuple)
    "($(join([tocoq(x) for x in v], ", ")))"
end




function generate(rs::RunState, p, track_return)
    to_visit = [p.root_ty]
    seen = Set([p.root_ty])
    while !isempty(to_visit)
        ty = pop!(to_visit)
        for (ctor, params) in variants(ty)
            for param in params
                if param  seen && hasmethod(variants, (Type{param},))
                    push!(seen, param)
                    push!(to_visit, param)
                end
            end
        end
    end

    type_ctor_to_id = Dict()
    for ty in seen
        for (ctor, _) in variants(ty)
            type_ctor_to_id[(ty, ctor)] = length(type_ctor_to_id)
        end
    end

    type_to_gen = Dict()
    for ty in seen
        type_to_gen[ty] = (size, stack_tail) -> begin
            dependents = (size, stack_tail)
            if size == 0
                frequency_for(rs, "0_$(ty)_variant", dependents, [
                    "$(ctor)" => Dice.construct(ty, ctor, [
                        if param  seen
                            type_to_gen[param](
                                size - 1,
                                update_stack_tail(p, stack_tail, type_ctor_to_id[(ty, ctor)])
                            )
                        elseif param == AnyBool
                            flip_for(rs, "0_$(ty)_$(ctor)_$(i)", dependents)
                        elseif param == DistUInt32
                             sum(
                                @dice_ite if flip_for(rs, "0_$(ty)_$(ctor)_$(i)_num$(n)", dependents)
                                    DistUInt32(n)
                                else
                                    DistUInt32(0)
                                end
                                for n in twopowers(p.intwidth)
                            )
                        else
                            error()
                        end
                        for (i, param) in enumerate(params)
                    ])
                    for (ctor, params) in variants(ty)
                    if all(param != ty for param in params) 
                ])
            else
                # TODO: if recursing, pass values of sibling *enumlikes*
                frequency_for(rs, "$(ty)_variant", dependents, [
                    "$(ctor)" => Dice.construct(ty, ctor, [
                        if param  seen
                            type_to_gen[param](
                                size - 1,
                                update_stack_tail(p, stack_tail, type_ctor_to_id[(ty, ctor)])
                            )
                        elseif param == AnyBool
                            flip_for(rs, "$(ty)_$(ctor)_$(i)", dependents)
                        elseif param == DistUInt32
                             sum(
                                @dice_ite if flip_for(rs, "$(ty)_$(ctor)_$(i)_num$(n)", dependents)
                                    DistUInt32(n)
                                else
                                    DistUInt32(0)
                                end
                                for n in twopowers(p.intwidth)
                            )
                        else
                            error()
                        end
                        for (i, param) in enumerate(params)
                    ])
                    for (ctor, params) in variants(ty)
                ])
            end
        end
    end

    type_to_gen[p.root_ty](p.init_size, empty_stack(p))
end