JuliaSymbolics/Metatheory.jl

dynamic rule for egraph has weird behaviour of push! append! operation on vectors

overshiki opened this issue · 1 comments

Hi,
I tried to use the dynamic rule of egraph to merge several items into a vector, however, I just got a weird result. More specifically, it seems the equality saturation steps will repeatedly push the same items into the vector, so I got plenty of repeated items in one vector, which is not the result I want. The piece of code to reproduce the problem is as below:

using Metatheory
using Metatheory.Library: @right_associative, @left_associative

abstract type Param end

struct Posi <:Param end
struct Nega <: Param end

struct Model
    param::Param
    index::Symbol
end

struct ModelChain 
    models::Vector{Model}
    index::Symbol
end

function is_to_ModelChain(x::Union{Model, ModelChain}, y::Union{Model, ModelChain})
    return x.index==y.index && (x isa Model || y isa Model)
end

function to_ModelChain(x::Model, y::Model)
    ms = Vector{Model}([x, y])
    mc = ModelChain(ms, x.index)
    return :($(mc))
end

function to_ModelChain(x::ModelChain, y::Model)
    ms = x.models
    push!(ms, y)
    mc = ModelChain(ms, x.index)
    return :($(mc))
end

function to_ModelChain(x::Model, y::ModelChain)
    return to_ModelChain(y, x)
end


function egraph_rules()
    v = AbstractRule[]
    t = @theory x y begin
        x::Union{Model, ModelChain} + y::Union{Model, ModelChain} => to_ModelChain(x, y) where is_to_ModelChain(x, y)
    end
    append!(v, t)
    ra = @right_associative (+)
    la = @left_associative (+)
    push!(v, ra)
    push!(v, la)
    return v 
end

function egraph_rewriter(circ, v)
    g = EGraph(circ)
    params = SaturationParams(timeout=100, eclasslimit=40000)
    report = saturate!(g, v, params)
    circ = extract!(g, astsize)
    return circ
end 

import Base.(+)
function (+)(a::Expr, b::Expr)
    circ = :((+)())
    append!(circ.args, a.args[2:end])
    append!(circ.args, b.args[2:end])
    return circ
end

function (+)(a::Expr, b::Model)
    circ = :((+)())
    append!(circ.args, a.args[2:end])
    push!(circ.args, b)
    return circ
end

posi1 = Model(Posi(), :one)
nega1 = Model(Nega(), :one)
posi2 = Model(Posi(), :two)
nega2 = Model(Nega(), :two)

expr = :((+)()) + posi1 + nega1 + nega1 + posi2
v = egraph_rules()
nexpr = egraph_rewriter(expr, v)
@show nexpr

As you can see, the ModelChain has a vector that hold multiple Model, and each time to_ModelChain is called, it will rewrite model1+model2 into ModelChain([model1, model2], model1.index). I expected posi1+nega1+nega1 to be rewrite as a single ModelChain containing 3 terms of Model. The resulting term should contain at most 4 terms of Model. However, the result is like this:

nexpr = :(ModelChain(Model[Model(Posi(), :one), Model(Nega(), :one), Model(Nega(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), ...

The terms in the ModelChain just exploded!
So what's the cause of this problem, and how could I solve it?

Hi,
I just found out there is a simple solution to the problem: just make a copy of the vector and then apply push! or append! operation. Specifically, modify the to_ModelChain(x::ModelChain, y::Model) function as below:

function to_ModelChain(x::ModelChain, y::Model)
    ms = copy(x.models)
    push!(ms, y)
    mc = ModelChain(ms, x.index)
    return :($(mc))
end

will give the correct result of

nexpr = :(ModelChain(Model[Model(Posi(), :one), Model(Nega(), :one), Model(Nega(), :one)], :one) + Model(Posi(), :two))