Tractables/Dice.jl

share an expander?

Opened this issue · 0 comments

# TODO: share an expander?

    println_flush(rs.io)
end

struct SimpleLossMgr <: LossMgr
    loss::ADNode
    function SimpleLossMgr(loss::ADNode)
        # TODO: share an expander?
        l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([loss]))))
        loss = Dice.expand_logprs(l, loss)
        new(loss)
    end
end
produce_loss(rs::RunState, m::SimpleLossMgr, epoch::Integer) = m.loss

struct SamplingEntropy{T} <: LossConfig{T}
    resampling_frequency::Integer
    samples_per_batch::Integer
end

mutable struct SamplingEntropyLossMgr <: LossMgr
    p::SamplingEntropy
    val::Dist
    consider
    ignore
    current_loss::Union{Nothing,ADNode}
    SamplingEntropyLossMgr(p, val, consider, ignore) = new(p, val, consider, ignore, nothing)
end
function produce_loss(rs::RunState, m::SamplingEntropyLossMgr, epoch::Integer)
    if (epoch - 1) % m.p.resampling_frequency == 0
        println_flush(rs.io, "Sampling...")
        time_sample = @elapsed samples = with_concrete_ad_flips(rs.var_vals, m.val) do
            [sample_as_dist(rs.rng, Valuation(), m.val) for _ in 1:m.p.samples_per_batch]
        end
        println(rs.io, "  $(time_sample) seconds")

        loss = sum(
            LogPr(prob_equals(m.val, sample))
            for sample in samples
            if m.consider(sample)
        )
        for sample in samples
            @assert m.consider(sample) ^ m.ignore(sample)
        end
        l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([loss]))))
        loss = Dice.expand_logprs(l, loss) / m.p.samples_per_batch
        m.current_loss = loss
    end

    @assert !isnothing(m.current_loss)
    m.current_loss
end

function save_learning_curve(out_dir, learning_curve, name)
    open(joinpath(out_dir, "$(name).csv"), "w") do file
        xs = 0:length(learning_curve)-1
        for (epoch, logpr) in zip(xs, learning_curve)
            println(file, "$(epoch)\t$(logpr)")
        end
        plot(xs, learning_curve)
        savefig(joinpath(out_dir, "$(name).svg"))
    end
end

##################################