Variable cost to variant
AzizZayed opened this issue · 10 comments
How do we assign a variable cost to a variant? For example
(datatype Expr
(Var String)
(Num i64) ; I want to set the cost of a Num to be its input
(Abs Expr)
(Add Expr Expr)
(Mul Expr Expr))
(Num 10) ; I want this to have cost of 10
(Num 55) ; I want this to have cost of 55
It's not currently possible. We had two PRs with different ways of adding custom costs #355 #353 but punted on it, because we wanted a better idea first of what the space was of user needs here.
If you could give a larger example how you would use this, that would be helpful in designing a solution around it.
Sure, here is my use case: I want to optimize tensor multiplications according to the shape of the tensors. Assume you have 3 tensors
I guess in this case I would need a dimension analysis, then set the cost of each operation using that analysis. Do you have an example with dimension analysis?
Thanks!
Do you have an example with dimension analysis?
Sort of: https://github.com/egraphs-good/egglog/blob/main/tests/matrix.egg
Let's say I have the following datatype and rewrite rule:
(dataype MatrixOp
(Matrix String i64 i64) ; (<name> <nrows> <ncols>)
(MatMul MatrixOp MatrixOp)
)
(birewrite
(MatMul ?a (MatMul ?b ?c))
(MatMul (MatMul ?a ?b) ?c)
)
; below here I create some ops, run the rewrite rules and extract
...
Using the current egglog ways, how would I tell the extract command to return the MatMul
operation with the least number of multiplications? How to calculate the number of multiplications:
Assume you have 3 tensors X with shape a×b, Y with shape b×c and Z with shape c×d. The number of multiplications (I define as the cost) of (XY)Z is different from the cost of X(YZ), even though they are equivalent expressions. The cost of (XY)Z is abc+acd, and the cost of X(YZ) is bcd+abd.
If I can't do this with the extract function, how can I achieve this goal with the current egglog commands?
The only command that can be used to influence extraction at a per node basis is subsume
: #301 (or I guess delete
). But I am not sure if it's possible to do what you want with those commands, they weren't made for this kind of situation.
So am I understanding it correctly that you want to model the cost of a MatMul
expression of two matrices a x b and b x c as a * b * c?
So I think with #355 your case could maybe be supported like this?
(dataype MatrixOp
(Matrix String i64 i64) ; (<name> <nrows> <ncols>)
(MatMul MatrixOp MatrixOp)
)
(birewrite
(MatMul ?a (MatMul ?b ?c))
(MatMul (MatMul ?a ?b) ?c)
)
(function nrows (MatrixOp) i64)
(function ncols (MatrixOp) i64)
(rule ((= ?m (Matrix ?s ?r ?c)))
((set (nrows ?m) ?r)
(set (ncols ?m) ?c)))
(rule ((= ?m (MatMul ?a ?b))
(= ?r (nrows ?a))
(= ?m (ncols ?a) (nrows ?b))
(= ?c (ncols ?b)))
((set (nrows ?m) ?r)
(set (ncols ?m) ?c)
(cost (MathMul ?a ?b) (* (* r m) c))))
Does that seem right?
Does that seem right?
Yes, this looks right.
A way to do it with the current comands is maybe a conditional rewrite? so
(function nrows (MatrixOp) i64)
(function ncols (MatrixOp) i64)
(rule ((= ?m (MatMul ?x (MatMul ?y ?z)))
(= ?a (nrows ?x))
(= ?b (ncols ?x) (nrows y))
(= ?c (ncols ?y) (nrows ?z))
(= ?d (ncols ?z))
[> (* ?b ?d (+ ?a ?c)) (* ?a ?c (+ ?b ?d))])
((set/union (MatMul ?x (MatMul ?y ?z)) (MatMul (MatMul ?x ?y) ?z))))
(rule ((= ?m (MatMul ?x (MatMul ?y ?z)))
(= ?a (nrows ?x))
(= ?b (ncols ?x) (nrows y))
(= ?c (ncols ?y) (nrows ?z))
(= ?d (ncols ?z))
[> (* ?a ?c (+ ?b ?d)) (* ?b ?d (+ ?a ?c))])
((set/union (MatMul (MatMul ?x ?y) ?z)) (MatMul ?x (MatMul ?y ?z)) ))
Assuming
- cost of
$(XY)Z$ is$abc+acd$ - cost of
$X(YZ)$ is$bcd+abd$
Then
- rewrite
$(XY)Z$ to$X(YZ)$ if$abc+acd$ <$bcd+abd$ - rewrite
$X(YZ)$ to$(XY)Z$ if$bcd+abd$ <$abc+acd$
Oh yeah good idea! I think you can do something like that with the :subsume
keyword to rewrite
which will end up desugaring to something like the rule you wrote, plus "subsuming" the LHS (meaning that it can not be extracted or matched again, it's like a permanent delete) but is a bit more succinct to write:
(dataype MatrixOp
(Matrix String i64 i64) ; (<name> <nrows> <ncols>)
(MatMul MatrixOp MatrixOp)
)
(function nrows (MatrixOp) i64)
(function ncols (MatrixOp) i64)
(rule ((= ?m (Matrix ?s ?r ?c)))
((set (nrows ?m) ?r)
(set (ncols ?m) ?c)))
(rule ((= ?m (MatMul ?a ?b))
(= ?r (nrows ?a))
(= ?c (ncols ?b)))
((set (nrows ?m) ?r)
(set (ncols ?m) ?c)))
(rewrite
(MatMul ?x (MatMul ?y ?z))
(MatMul (MatMul ?x ?y) ?z)
:when (
(<
(+ (* (nrows ?x) (* (ncols ?x) (nrows ?y))) (* (ncols ?x) (* (nrows ?y) (ncols ?z))))
(+ (* (nrows ?y) (* (ncols ?y) (nrows ?z))) (* (ncols ?y) (* (nrows ?z) (ncols ?z))))
)
)
:subsume
)
(rewrite
(MatMul (MatMul ?x ?y) ?z)
(MatMul ?x (MatMul ?y ?z))
:when (
(<
(+ (* (nrows ?y) (* (ncols ?y) (nrows ?z))) (* (ncols ?y) (* (nrows ?z) (ncols ?z))))
(+ (* (nrows ?x) (* (ncols ?x) (nrows ?y))) (* (ncols ?x) (* (nrows ?y) (ncols ?z))))
)
)
:subsume
)
Closing as duplicate of #294