Tractables/Dice.jl

make the following function use unit_exponential

Opened this issue · 0 comments

beta_vec[i] = 0.0

end

@show sympy.Poly(simplify(v2integrate(varint^(α-1) * (varint - v2) exp(βvarint), (varint, v2, v2 + ϵ))/exp(βv2)), v2)

https://github.com/Juice-jl/Dice.jl/blob/593e1ab860ef663637f27c79123fcf33cb1469fc/src/dist/number/fix.jl#L318

        slope_flip = flip_prob!(slope_flip, firstprob/avgprob)
        firstinterval + ifelse(slope_flip, unif, tria)
    end  
end

###################################################
# sound bit blasting for mixed gamma distributions
###################################################

# bit blasted exponential distribution on unit interval
# reverse=true refers to LSB to MSB order of flips 
function unit_exponential(t::Type{DistFix{W, F}}, beta::Float64; reverse=false) where W where F
    if !reverse
        DistFix{W, F}(vcat([false for i in 1:W-F], [flip(exp(beta/2^i)/(1+exp(beta/2^i))) for i in 1:F]))
    else
        DistFix{W, F}(vcat([false for i in 1:W-F], reverse([flip(exp(beta/2^i)/(1+exp(beta/2^i))) for i in F:-1:1])))
    end
end

# bit blasted exponential distribution on arbitrary interval
# reverse=true refers to LSB to MSB order of flips
# TODO: make the following function use unit_exponential
function exponential(t::Type{DistFix{W, F}}, beta::Float64, start::Float64, stop::Float64; reverse=false) where W where F   
    range = stop - start
    @assert ispow2(range)

    new_beta = beta*range
    bits = Int(log2(range)) + F
    
    if !reverse
        bit_vector = vcat([false for i in 1:W - bits], [flip(exp(new_beta/2^i)/(1+exp(new_beta/2^i))) for i in 1:bits])
    else 
        bit_vector = vcat([false for i in 1:W - bits], reverse([flip(exp(new_beta/2^i)/(1+exp(new_beta/2^i))) for i in bits:-1:1]))
    end
    DistFix{W, F}(bit_vector) + DistFix{W, F}(start)
end

function beta(d::ContinuousUnivariateDistribution, start::Float64, stop::Float64, interval_sz::Float64)
    prob_start = cdf(d, start + interval_sz) - cdf(d, start)
    if prob_start == 0.0
        prob_start = eps(0.0)
    end
    prob_end = cdf(d, stop) - cdf(d, stop - interval_sz)
    result = (log(prob_end) - log(prob_start)) / (stop - start - interval_sz)

    if result  Inf
        @show prob_start
        @show prob_end
        @show start, stop, interval_sz
    end
    result
end

function bitblast_exponential(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution, 
    numpieces::Int, start::Float64, stop::Float64, strategy=:exponential) where {W,F}

    # basic checks
    @assert start >= -(2^(W - F - 1))
    @assert stop <= (2^(W - F - 1))
    @assert start < stop
    a = Int(log2((stop - start)*2^F))
    @assert a isa Int 
    @assert ispow2(numpieces) "Number of pieces must be a power of two"
    piece_bits = Int(log2(numpieces))
    if piece_bits == 0
        piece_bits = 1
    end
    @assert typeof(piece_bits) == Int

    # preliminaries
    d = truncated(dist, start, stop)
    whole_bits = a
    point = F
    interval_sz = (2^whole_bits/numpieces)
    bits = Int(log2(interval_sz))
    areas = Vector(undef, numpieces)
    total_area = 0

    beta_vec = Vector(undef, numpieces)
    start_pts = Vector(undef, numpieces)
    stop_pts = Vector(undef, numpieces)

    

    # Figuring out end points
    for i=1:numpieces
        p1 = start + (i-1)*interval_sz/2^point 
        p3 = start + (i)*interval_sz/2^point 

        beta_vec[i] = beta(d, p1, p3, 2.0^(-F))
        # if (beta_vec[i] in [Inf, -Inf]) | isnan(beta_vec[i])
        #     beta_vec[i] = 0.0
        # end

        areas[i] = (cdf.(d, p3) - cdf.(d, p1))
        # @show p1, p2, p3, p4, areas[i]
        start_pts[i] = p1
        stop_pts[i] = p3

        total_area += areas[i]
    end


    # @show beta_vec

    rel_prob = areas/total_area

    b = discrete(DistUInt{piece_bits}, rel_prob)

    ans = DistFix{W, F}((2^(W-1)-1)/2^F)

    for i=numpieces:-1:1
        ans = ifelse( prob_equals(b, DistUInt{piece_bits}(i-1)), 
                exponential(DistFix{W, F}, beta_vec[i], start_pts[i], stop_pts[i]),
                ans)  
    end
    return ans
end

function continuous(t::Type{DistFix{W, F}}, d::ContinuousUnivariateDistribution, pieces::Int, start::Float64, stop::Float64, exp::Bool=false) where {W, F}
    c = if exp
        continuous_exp(DistFix{W, F}, d, pieces, start, stop)
    else 
        continuous_linear(DistFix{W, F}, d, pieces, start, stop)
    end
    return c
end

# https://en.wikipedia.org/wiki/Laplace_distribution
function laplace(t::Type{DistFix{W, F}}, mean::Float64, scale::Float64, start::Float64, stop::Float64) where {W, F}
    @assert scale > 0

    beta1 = -1/scale
    e1 = exponential(DistFix{W, F}, beta1, mean, stop)

    beta2 = 1/scale
    e2 = exponential(DistFix{W, F}, beta2, start, mean)

    ifelse(flip(0.5), e1, e2)
end

#Helper function that returns exponentials 

function shift_point_gamma(::Type{DistFix{W, F}}, alpha::Int, beta::Float64) where {W, F}
    DFiP = DistFix{W, F}
    if alpha == 0
        unit_exponential(DFiP, beta)
    else
        x1 = shift_point_gamma(DFiP, alpha - 1, beta)
        x2 = uniform(DFiP, 0.0, 1.0)
        observe(ifelse(flip(1/(1 + 2.0^(-F))), x2 < x1, true))
        x1
    end
end

#https://www.wolframalpha.com/input?i=sum+%28a*epsilon%29%5E2+e%5E%28beta+*+epsilon+*+a%29+from+a%3D0+to+a%3D2%5Eb-1
function sum_qgp::Float64, ϵ::Float64)
    ans = (1/ϵ - 1)^2 * exp*ϵ*(2 + 1/ϵ))
    ans += (1/ϵ^2)*exp(β)
    ans += (2/ϵ - 2/ϵ^2 + 1)*exp* (1 + ϵ))
    ans -= exp*ϵ)
    ans -= exp(2*β*ϵ)
    ans *= ϵ^2
    ans /= (exp* ϵ) - 1)^3
    ans
end

#https://www.wolframalpha.com/input?i=sum+%28a*epsilon%29+*+e%5E%28beta+*+epsilon+*+a%29+from+a%3D0+to+a%3D2%5Eb-1
function sum_agp::Float64, ϵ::Float64)
    ans = (1/ϵ - 1)*exp*(1 + ϵ))
    ans -= exp(β)/ϵ
    ans += exp*ϵ)
    ans *= ϵ
    ans /= (exp*ϵ) - 1)^2
    ans
end

#https://www.wolframalpha.com/input?i=sum+e%5E%28beta+*+epsilon+*+a%29+from+a%3D0+to+a%3D2%5Eb-1
function sum_gp::Float64, ϵ::Float64)
    ans = (exp(β) - 1) / (exp*ϵ) - 1)
    ans
end

function sum_pgp::Float64, ϵ::Float64, p::Int)
    if p == 0
        sum_gp(β, ϵ)
    elseif p == 1
        sum_agp(β, ϵ)
    elseif p == 2
        sum_qgp(β, ϵ)
    else
        sum = 0
        for i = 0:ϵ:1-ϵ
            sum += i^p * exp*i)
        end
        sum
    end
end

function n_unit_exponentials(::Type{DistFix{W, F}}, betas::Vector{Float64}) where {W, F}
    DFiP = DistFix{W, F}
    l = length(betas)
    ans = Vector(undef, l)
    for i in 1:l
        ans[i] = Vector(undef, W)
    end
    for i in 1:W-F
        for j in 1:l
            ans[j][i] = false
        end
    end
    for i in 1:F
        for j in 1:l
            ans[j][i + W - F] = flip(exp(betas[j]/2^i)/(1+exp(betas[j]/2^i)))
        end
    end
    [DFiP(i) for i in ans] 
end

function exponential_for_gamma::Int, β::Float64)::Vector{Float64}
    if α == 0
        []
    elseif α == 1
        [β, β, 0.0]
    else
        v = []
        for i in 1:α
            v = vcat(vcat([β], zeros(i-1)), v)
        end

        vcat(vcat(exponential_for_gamma-1, β), [0.0]), v)
    end
end

function gamma_constants::Int, β::Float64, ϵ::Float64)
    @vars varint
    @vars v2
    if α == 0
        []
    else
        c1 = Float64(sympy.Poly(integrate(varint^α*exp*varint), (varint, 0, 1)), varint).coeffs().evalf()[1])
        c2 = [Float64(i) for i in sympy.Poly(simplify(v2*integrate(varint^-1)*exp*varint), (varint, v2, v2 + ϵ))/exp*v2)), v2).coeffs()]
        p1 = 0
        for i in eachindex(c2)
            p1 += sum_pgp(β, ϵ, length(c2) + 1 - i) * c2[i]
        end
        p1 /= c1

        c2 = [Float64(i) for i in sympy.Poly(simplify(v2*integrate(varint^-1) * (varint - v2) *exp*varint), (varint, v2, v2 + ϵ))/exp*v2)), v2).coeffs()]
        # @show c2
        # @show sympy.Poly(simplify(v2*integrate(varint^(α-1) * (varint - v2) *exp(β*varint), (varint, v2, v2 + ϵ))/exp(β*v2)), v2)
        p2 = Vector(undef, α)
        for i in eachindex(c2)
            p2[i] = sum_pgp(β, ϵ, length(c2) - i) * c2[i]
        end
        # @show p2
        
        vcat([p1], p2, gamma_constants-1, β, ϵ))
    end
end




function unit_gamma(t::Type{DistFix{W, F}}, alpha::Int, beta::Float64; vec_arg=[], constants = [], discrete_bdd=[], constant_flips=[], f=[]) where {W, F}
    DFiP = DistFix{W, F}
    if alpha == 0
        unit_exponential(DFiP, beta)
    elseif alpha == 1
        
        t = (exp(beta*2.0^(-F))*(beta*2.0^(-F) - 1) + 1)*(1 - exp(beta)) / ((1 - exp(beta*2.0^(-F)))*(exp(beta) * (beta - 1) + 1))
        
        if f == []
            coinflip = flip(t)
        else
            coinflip = f[1]
        end

        if (length(vec_arg) != 0)
            (Y, Z, U) = vec_arg
        else
            (Y, Z, U) = n_unit_exponentials(DFiP, [beta, beta, 0.0])
        end
        observe(U < Y)

        
        
        final = ifelse(coinflip, Z, Y)
        final
    else 
        α = alpha
        β = beta
        if (length(vec_arg) != 0)
            vec_expo = vec_arg
            
        else
            discrete_bdd = Vector(undef, α)
            constants = gamma_constants(alpha, beta, 1/2^F)
            constant_flips = [flip(i) for i in constants]

            t = (exp(beta*2.0^(-F))*(beta*2.0^(-F) - 1) + 1)*(1 - exp(beta)) / ((1 - exp(beta*2.0^(-F)))*(exp(beta) * (beta - 1) + 1))
            f = flip(t)

            count = 0
            for i in α:-1:1
                # @show constants
                l = discrete(DistUInt{max(Int(ceil(log(i))), 1)}, normalize(constants[count + 2:count+i+1]))
                count = count+i+1

                discrete_bdd[α - i + 1] = l
            end
            vec_expo = n_unit_exponentials(DFiP, exponential_for_gamma(α, β))
            
        end

        seq = Int*^2 + 5)/6)
        @show constant_flips
        x1 = unit_gamma(DFiP, alpha-1, beta, vec_arg=vec_expo[1:seq], constants=constants[α + 2:length(constants)], discrete_bdd=discrete_bdd[2:α], constant_flips=constant_flips[α + 2:length(constants)], f=[f])
        x2 = vec_expo[seq + 1]
        observe(x2 < x1)

        discrete_dist_vec = Vector(undef, α)
        count = seq+2
        for i in 1:α 
            x = vec_expo[count]
            count+=1
            for j in 1:α - i
                observe(vec_expo[count] < x)
                count+=1
            end
            discrete_dist_vec[i] = x
        end

        # l = discrete(DistUInt{Int(ceil(log(α)))}, normalize(constants[2:α+1]))
        l = discrete_bdd[1]
        t = DFiP(0.0)
        for i in 1:α
            t = ifelse(prob_equals(l, DistUInt{Int(ceil(log(α)))}(i-1)), discrete_dist_vec[i], t)
        end
        
        ifelse(constant_flips[1], x1, t)
    end

end

function normalize(v)
    l = sum(v)
    [i/l for i in v]    
end
# # TODO: Write tests for the following function
# function unit_concave(t::Type{DistFix{W, F}}, beta::Float64) where {W, F}
#     @assert beta <= 0
#     DFiP = DistFix{W, F}
#     Y = uniform(DFiP, 0.0, 1.0)
#     X = unit_exponential(DFiP, beta)
#     observe((X < Y)| prob_equals(X, Y))
#     Y
# end