dynamic rule for egraph has weird behaviour of push! append! operation on vectors
overshiki opened this issue · 1 comments
Hi,
I tried to use the dynamic rule of egraph to merge several items into a vector, however, I just got a weird result. More specifically, it seems the equality saturation steps will repeatedly push the same items into the vector, so I got plenty of repeated items in one vector, which is not the result I want. The piece of code to reproduce the problem is as below:
using Metatheory
using Metatheory.Library: @right_associative, @left_associative
abstract type Param end
struct Posi <:Param end
struct Nega <: Param end
struct Model
param::Param
index::Symbol
end
struct ModelChain
models::Vector{Model}
index::Symbol
end
function is_to_ModelChain(x::Union{Model, ModelChain}, y::Union{Model, ModelChain})
return x.index==y.index && (x isa Model || y isa Model)
end
function to_ModelChain(x::Model, y::Model)
ms = Vector{Model}([x, y])
mc = ModelChain(ms, x.index)
return :($(mc))
end
function to_ModelChain(x::ModelChain, y::Model)
ms = x.models
push!(ms, y)
mc = ModelChain(ms, x.index)
return :($(mc))
end
function to_ModelChain(x::Model, y::ModelChain)
return to_ModelChain(y, x)
end
function egraph_rules()
v = AbstractRule[]
t = @theory x y begin
x::Union{Model, ModelChain} + y::Union{Model, ModelChain} => to_ModelChain(x, y) where is_to_ModelChain(x, y)
end
append!(v, t)
ra = @right_associative (+)
la = @left_associative (+)
push!(v, ra)
push!(v, la)
return v
end
function egraph_rewriter(circ, v)
g = EGraph(circ)
params = SaturationParams(timeout=100, eclasslimit=40000)
report = saturate!(g, v, params)
circ = extract!(g, astsize)
return circ
end
import Base.(+)
function (+)(a::Expr, b::Expr)
circ = :((+)())
append!(circ.args, a.args[2:end])
append!(circ.args, b.args[2:end])
return circ
end
function (+)(a::Expr, b::Model)
circ = :((+)())
append!(circ.args, a.args[2:end])
push!(circ.args, b)
return circ
end
posi1 = Model(Posi(), :one)
nega1 = Model(Nega(), :one)
posi2 = Model(Posi(), :two)
nega2 = Model(Nega(), :two)
expr = :((+)()) + posi1 + nega1 + nega1 + posi2
v = egraph_rules()
nexpr = egraph_rewriter(expr, v)
@show nexpr
As you can see, the ModelChain
has a vector that hold multiple Model
, and each time to_ModelChain
is called, it will rewrite model1+model2
into ModelChain([model1, model2], model1.index)
. I expected posi1+nega1+nega1
to be rewrite as a single ModelChain
containing 3 terms of Model
. The resulting term should contain at most 4 terms of Model
. However, the result is like this:
nexpr = :(ModelChain(Model[Model(Posi(), :one), Model(Nega(), :one), Model(Nega(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), Model(Posi(), :one), Model(Nega(), :one), ...
The terms in the ModelChain
just exploded!
So what's the cause of this problem, and how could I solve it?
Hi,
I just found out there is a simple solution to the problem: just make a copy of the vector and then apply push!
or append!
operation. Specifically, modify the to_ModelChain(x::ModelChain, y::Model)
function as below:
function to_ModelChain(x::ModelChain, y::Model)
ms = copy(x.models)
push!(ms, y)
mc = ModelChain(ms, x.index)
return :($(mc))
end
will give the correct result of
nexpr = :(ModelChain(Model[Model(Posi(), :one), Model(Nega(), :one), Model(Nega(), :one)], :one) + Model(Posi(), :two))