tlc-pack/relax

[ARCH] Normal Form of SeqExpr

tqchen opened this issue · 14 comments

Right now we allow Call to appear in places that are not bound to Var or dataflow var.

The possible appearance include: function.body, if.then/else, SeqExpr.body.

While it is OK to have this normal form, it limits the pass writing style that we can have, since not every call corresponds to a VarBinding. Writing needs to happen to consider both Call and VarBinding. It is also a common source of bug when ppl choose to write passes assuming every call is bound to var(e.g. use Var as the key)

  • Most passes needs to override VisitExpr(Call)
  • When we try to write passes that in the form of VarBinding, and vars, there is usually missing cases.

We had a similar discussion before and the overall summary is that it boils down to the programming style of normal forms. Imperative style designs like MLIR enforces a strcutured construct(Region) to indicate a sub-control block, while enforcing every call binds to a Var. See previous post from @YuchenJin for an accurat summary #52 (comment)

We did not take action so we can wait to write more passes and get more lessons.

More recently we start to write more passes on relax, we have start to feel more lessons which indicates always having Call bound to var would be extremely helpful for pass writing. This is because the corresponding var can serve as key to maps, which enables richer pass writing styles.

On the Expr side, we have two kinds of expressions:

  • Those that must be bound to vars in ANF:
    • Call and TupleGetItem
  • Those expressions that can appear on the call arguments without bound to a var (LeafExpr,):
    • Right now they are Tuple, ShapeExpr, Const, Var

So this post we discuss structualization normal form of SeqExpr. Specifically:

  • Control flow sub-scopes, such as function.body, if.then always contain a SeqExpr
  • SeqExpr.body is always a LeadExpr that refers to its binding block for vars that bound to call.
  • Call can only appear in Binding block and bind to a var.

This can be enforced through well formness check, and further type restrictions.

SeqExpr.body is always a Var that refers to its binding block
Can SeqExpr.body be a Tuple?

And there is another question: Is Tuple a leaf node, which can be used in binding value (rhs), SeqExpr.body and If condition, etc?

@Hzfengsy this is a great question. I think your pt also generalizes to other possible constant nodes, such as ShapeExpr and Const.

The pt of consistency here is that for those Expr that requires to be normalized by ANF(right now we call them non-leaf expr), we cannot put them in body -- so they always have a VarBind. Updated to clarify this

T-ANF: TVM-Enhanced A-Normal Form

Principles. In short, the language is designed to mimic a dataflow graph as much as possible, including:

  • Convenient producer/consumer retrieval
  • Minimized alias if possible

Recap: in an NNVM-like dataflow graph, any intermediate tensor could be represented by a (o, i) where o is the operation that produces the tensor, and i means it’s the i-th output of o.

Design. Expressions are categorized as the following types:

  • Leaf nodes:
    • relax.Var
    • relax.Constant
  • RValues: temporary values that are usually used as intermediate parameter passing
    • Tuple => TupleMake(t0, …)
    • TupleGetItem(t, i)
    • ShapeExpr => ShapeMake
  • Non-leaf expressions:
    • Call
    • SeqExpr
    • IfExpr
    • Function

T-ANF. Slightly different than standard ANF, it is required that:

  • The body of Function/IfExpr must be be SeqExpr.
  • The body of SeqExpr must be leaf nodes or RValue.
  • Leaf nodes & RValues are not allowed to be bound to any leaf nodes;

Its implications are:

  • If a relax.Var is a Tuple, it must be the return value from Call or SeqExpr

One-hop producer-consumer retrieval.

  • producers: given a var binding v = OP(v0, v1, ...), where v0, v1, ... are all leaf nodes or RValues, by looking up what expressions they are bound to, we can find their producers without extra hops.
  • consumers: given a var binding v = OP(v0, v1, ...), i.e. var v is produced by OP, by looking up where v is used, all its consumers can be found without extra hops.

No-alias guarantee. If Call does not produce any alias to inputs, then there is no alias in the graph.

I agree that calls and TupleGetItem nodes should be considered non-atomic expressions in ANF and should be required to be bound to vars. However, I think that also applies to tuples and they should also be required to be bound to vars. This would ensure that applying ANF would mean that there is no nesting on the RHS of a var binding, which corresponds to the usual notion of ANF.

Note that if we apply canonicalization of bindings like in this pass, this would ensure that all vars used point to the original definition, allowing for pass writers to easily look up binding sites for, e.g., tuples.

The case of Tuple is interesting. Because we have quite a few intrinsics that leverages tuple in args, such as call_tir. Enforcing Tuple to not bound to vars might be helpful in our cases, so we could do direct pattern matching without extra hop.

Could you give some examples of such usage?

For example, we have call_tir(tir_func, arg_tuple, shape) where arg_tuple is a tuple. As a matter of fact, we require the arg_tuple to be an explicit Tuple that can be unpacked, it enables us to do pattern match easily, so we can do

auto args = call->args[1].as<TupleNode>

It also brings clarity to printing.

def main_tuple_bound(x, y):
    lv0 = (x, y)
    lv1 = call_tir(tir_f0, lv0, shape)

def main_no_tuple_bound(x, y):
    lv1 = call_tir(tir_f0,  (x, y), shape)

Could we just make it a variadic function? It seems inconsistent with the rest of the design to require an argument to not only be of tuple type but also a tuple literal.

Having such structural grouping itself is actually helpful. Because we have arg_tuple being grouped separated as shape, making it beyond what a variadic function (because seperation of shape and arg_tuple here) could handle.

Here is another motivating intrinsic that we recently encountered in some of the catalyst efforts.

def call_with_explicit_rw(func, read_tensors, write_tensors):
    func(*read_tensors, *write_tensors)

call_with_explicit_rw(add, [x, y], [z])
call_with_explicit_rw(split, [x], [y, z])

Where effectively the semantics of the intrinsic is to call func by passing read_tensors, and write tensors that are grouped separately. These read/write information can be helpful for us to do parallelization(multi gpu stream) planning.

The tuple can be useful to inform such grouping. We might have a similar case for call_tir with tuple return values

t = call_tir(func, (x,), (shape0, shape1)) 

Where we are returning a tuple with tensor[shape0], tensor[shape1]. Here the tuple is used to group the shape arguments(to indicate the outputs)

I see, those are curious applications. I'm not sure it's much of an advantage to have special handling for these cases in ANF when you could just use BlockBuilder::LookupBinding to get the tuple definition (not much more verbose). If we want to require structurally that these must be tuple literals, maybe making this operation a separate AST node would be simpler, as that way, we would not need to have special cases for how we treat the call node depending on the op that's being called.

Thanks for discussions. Normally we would like to keep AST reasonably stable while maintaining some flexibility in introducing new intrinsics. Most of the rewriting themselves still follows the same convention of call, so common rewriting that pass through them still works fine. The intrinsics carry some extra information, and lowering/analysis can benefit from some op specific behavior.

Note that these cases we do not necessary require tuple literals, but indeed tuple literal are the common cases(they also show up as concise part of printing where all context are together). We can also turn a tuple into literals in lowering (which is likely the only case where we need to unpack them.

def before()
  t :Tuple[T, T] = func() 
  call_tir(func, t, shape)

def after()
   t :Tuple[T, T] = func() 
   v0 = t[0]
   v1= t[1]
   call_tir(func, (v0, v1), shape)

Coming back to complexity bought to pass writing. One thing to note is that every op usually have a fixed type signature, so if we rewrite by detecting ops, we usually know what to expect. e.g. if we look at add(x, y), we know that we are expecting no nesting in x or y. If we look at concact(arg_tuple), we know that we are expecting a tuple(one level of nesting). So most of the op-specific rewriting pass do not need to worry about handle un-nesting/nesting case.

I think having lots of op-specific conventions could actually become more complex for passes to deal with, especially if these become invariants that other passes rely on. For example, it might mean that passes would have to have lots of special cases when handling call nodes because of specific operators. However, if tuples in calls are the only such case, then it might be reasonable to leave only that one (we might have to figure out what to do about tuples nested in tuples, though).

For printing, we might want to have different rules/flags. I think, though, that having a simple and consistent representation would ultimately simplify pass-writing.

Agree about less special casing but instead state things as simple and consistent rules(eg tuple can appear in call args, even encourage folding tuple literal bindings) and general ways to deal with the normal form.

In the current passes(that do support tuple in calls), we have generic code that handles tuple rewriting recursively that are op independent, such as memory planning and constant folding.

While allowing quick local pattern matching for specific op that we know signature when specific rules are needed.

Addressed by @slyubomirsky in #288.