Tractables/Dice.jl

discuss `ADNode`s, `compute`

Closed this issue · 1 comments

https://github.com/Juice-jl/Dice.jl/blob/b2d0baa5f6133d102811e2a5743223d48d571731/examples/qc/tour_2_learning.jl#L30

################################################################################
# Maximizing expression probabilities
################################################################################
using Revise
using Dice

# What value for ? maximizes the probability of the following expression?
#   flip(?) & flip(?) & !flip(?)

# Let's check!
p = add_unit_interval_var!("p")
x = flip(p) & flip(p) & !flip(p)
train_vars!([x])
compute(p)  # ~ 2/3

# What just happened?
# - `add_unit_interval_var!()` registers a value in (0,1) to learn (init. 0.5)
# - `train_vars!(bools)` performs maximum likelihood estimation to train
#   the parameter to maximize the product of the probabilities of the bools

# We can also perform computation on params before using them for flip
# probabilities. For example, `x` could have been equivalently defined as:
#   x = flip(p) & flip(p) & flip(1 - p)

clear_vars!()  # call before adding the params of the next "program"

# (For the curious) What's happening under the hood?
# - TODO: discuss `ADNode`s, `compute`
# - TODO: discuss how `add_unit_interval_var!`` wraps a param in `sigmoid`

# If the flips above can have different groups, each can take on a
# different probability.
a = add_unit_interval_var!("a")
b = add_unit_interval_var!("b")
c = add_unit_interval_var!("c")
x = flip(a) & flip(b) & !flip(c)
train_vars!([x])
compute(a)  # 0.8419880024053406
compute(b)  # 0.8419880024053406
compute(c)  # 0.1580119975946594

# We can also keep training to get closer to 1, 1, 0.
train_vars!([x]; epochs=10000, learning_rate=3.0)
compute(a)  # 0.9999666784828445
compute(b)  # 0.9999666784828445
compute(c)  # 3.332151715556398e-5

clear_vars!()


################################################################################