ezyang/pytorch-unattached

Proposal: Modifying behavior of select nodes.

Opened this issue · 6 comments

This relates to and would address #231.

Current state:

  • Select nodes are created and inserted into the graph's node list manually. We expect everyone to maintain the "select-node invariant" that is, if a node returns N items, it always has N select nodes, and the order of its use list goes from output 0 to output N.
  • Anything that deals with modifying the graph has to handle both the single and multi-return case separately, and must maintain the invariant.
  • Viewing the graph is ok: we provide an outputs() functions that always returns the correct list of outputs, returning the node itself if it is single return.

Proposed state:

  • We modify the base version of create to take an optional argument that is the number of output nodes (like output= in the graph.op wrapper in python). If it is not 1 we initialize the select nodes when we create the op and mark the op as multiple return.
  • replaceAllUsesWith and destroy replace (or destroy) both the real node and all its selects.
  • Select nodes, like Param and Return, are not put in the nodes() list. Two reasons: (1) I audited all uses we have today of nodes and the passes either explicitly filter select, or would be semantically equivalent with no code changes if select were removed from the list. (2) it simplifies what happens when we create/delete a node with multiple returns since we don't have to add all those select nodes we create to the node list.

I like this change because it allows you to create, replace, and delete multi-return nodes the same way you would single return nodes, while not being a substantial change to make. It also brings the IR closer to toffee IR (what is in the nodes list is what we export to Toffee).

I agree that these changes make the graph manipulations easier, because we can deal with the multi-return nodes the same way as the single-return nodes. Manually maintain "invariants" may be error-prone.

I think this is a good direction; I never liked the select invariant. @houseroad, for context, the diff that added select nodes was 1325fa5. The original IR did multiple outputs in a functional/AST style. Multiple outputs for the mutable/LLVM style IR were considered, but @zdevito stated that he had previously worked with IRs with multiple outputs, and invariably refactored them to have single outputs with select nodes, so we went straight for the select node design.

I wonder if the proposed refactor goes far enough. First off, I think in the current IR there is a clear division between "single-return" syntax and "multi-return" syntax. Single-return things are Select and Param (which makes sense, since in the functional interpretation Select and Param are both binding constructs). Multi-return syntax is all operators; and in the ONNX-ified world, there are literally no operators which are not multi-return. So it seems to me that we should say that EVERYTHING (besides parameters and selects) are multi-return. If we buy into this argument, then I think it makes sense to statically distinguish parameters and selects from computation nodes.

One counter argument for keeping the type distinction here collapsed is that Node contains useful utility functions which apply in both cases. But I do not think this is true. Let us enumerate the operations nodes support, and their relevance in single/multiple return cases:

  • Intrusive linked list, and topological manipulation functions (multiple return only, esp. under the proposal)
  • Kind (both)
  • Type (single return only)
  • Debug name (both?)
  • Owning graph (both)
  • Unique (single return only; uniques for multiple return nodes are never rendered except in uses information; more on this later)
  • Stage (multiple return; an invariant we don't enforce but probably should is that the stages of all select nodes coincide with the multiple output node)
  • Inputs (multiple return; in single return the input is degenerate and is either empty, or always points to the multiple output node)
  • Usage info (single return only; the uses of a multiple return node is always degenerate, and is exactly the select nodes)

You can see that except for some very basic, trivial metadata (kind, debug name, owning graph), all other nontrivial properties are distinct in both cases.

If this argument is still not convincing enough, we can count the number of occurrences of JIT_ASSERT(node->hasMultipleOutputs()) or its inverse; it seems clear to me that the division here is statically justified.

So, I would suggest the following refactor:

  • Assert a new invariant that all computation nodes (Node) are always multireturn.
  • To do this and still support Param/Select, we must split Node into two classes, Node (representing computations) and Select (representing selects), and make Param a subclass of Select (where the multi-return node is left null). Divide the methods between these two classes as makes sense.
  • create no longer has a special case for the 1 output case. It always uniformly allocates select nodes.
  • The remaining changes happen automatically from the representation change.

Thanks for reading.

What I proposed was a tiny change that will mostly delete code. What is proposed here is a much larger change (new class hierarchies, refactor almost all uses of Node in the codebase!) that will be more complicated and require a larger adjustment in how we use the code. Why don't we do the simple one first and if we are still unhappy with it after trying it out we can make further changes.

Aside: I have dealt with API where everything is multi-return by default. All those hasMultipleOutputs() calls turn into hasSingleOutput() calls and move into places where you need to know you have a single output (e.g. for an optimization to be valid).

Why don't we do the simple one first and if we are still unhappy with it after trying it out we can make further changes.

For the record, I am fine with this.

Since it seems like my comment above seems like, "Someone suggested a small change, and someone else asks them to do a huge refactor instead," let me explain in more words why I wrote "I wonder if the proposed refactor goes far enough."

One thing that I strongly believe in is that core APIs (e.g., the compiler IR) should have a coherent "design". A design goes beyond how the code is actually implemented, because it dictates what further changes/enhancements are acceptable versus not acceptable. It also is less than how the code is implemented, because there may be multiple valid implementations of a design.

I don't have an objection to the implementation strategy, but what I observed is that the implementation seems to leave some of the design intent ambiguous. Is it possible that in the future we will add further node types which are omitted from the topological order? Is it possible that a node in the future will take a multi-return node directly as an input, rather than by going through the selects? Is it possible to create and use a select node beyond the automatic creation that occurs when we return multiple items?

Yes, I suggested an alternate implementation because I would like to see it be done (and I earnestly believe that it truly will not require an adjustment in how we use the code, because it is simply codifying what I understand to be the implicit design currently diffused through how we use the IR in the codebase.) But I have also suggested it because the implementation articulates a more opinionated design than the original implementation strategy. If this design vision is divergent from yours, that is a big deal, and we should talk about it! But if it matches, then it is fine to take a less specific implementation; we just have to remember to assess further changes by this design. Just because a change is convenient, doesn't mean it adds to coherent design.

Aside: I have dealt with API where everything is multi-return by default. All those hasMultipleOutputs() calls turn into hasSingleOutput() calls and move into places where you need to know you have a single output (e.g. for an optimization to be valid).

Yes, I don't think there is any way around this: sometimes you will have a node, and you will legitimately not know if it is single return or multi return. And then you will have to do the test. What I am advocating is that we avoid the test when the result is already known (because, e.g., we're dealing with a select node), because it is, in many situations in our code today.

And indeed, I was missing something.

Multi-return syntax is all operators; and in the ONNX-ified world, there are literally no operators which are not multi-return.

In fact, this is not true. This symbolic definition (and others like it) exercises single-return, non-Param/Select case:

    @staticmethod
    def symbolic(g, a, b, inplace=False):
        return g.appendNode(g.create("Add", [a, b]))

No select node is created here, and so the ONNX-ified Add node is in fact a single-return node.)

(How did I discover this? I was attempting to finish addressing comments in #229 when I noticed that tryToMoveChunk was relying on producer_for_chunk being a single-return node (rather than a select node).)

EDIT: Also, interestingly enough, the test for kFusionGroup here:

  bool isFusable(Node * node) {
    if (!node->hasType()) return false;
    if (node->kind() == kFusionGroup) return true;
    return isSimpleMap(node) && isCuda(node);
  }

Is dead as far as tryToMoveChunk is concerned, because any "fusion group" producer_for_chunk will have an intervening select node; but tryToMoveChunk is special cased for single return only and will thus fail to recognize it (even if the fusion group in question only has a single output.)

Ok, cool. I generally understand the design you want to express with the different class hierarchies (one for operators, one for values), and agree with it from a design perspective. And we can revisit how we encode this when new things pop up. For now, I also like the 'view' of this design that encodes a single-use case without an explicit select node. It makes single return the default, and multi-return the thing you have to check for. In the tryToMoveChunk case, the non-handling of multi-return nodes is actually intentional (it probably deserved a comment, in hindsight). Pushing a chunk into producer_for_chunk will change all of its outputs, and ensuring this doesn't mess with the multi-return case (where one of the outputs might not be consumed by the chunk) is often harder and I didn't want to deal with it until we see it is necessary. This happens a lot in pattern-based simplifications: it is easy to handle the single-return case, and harder to do multi-return. So rather than pepper our simple cases with checks that the cases are indeed single-return, I'd rather for now just push all the complex handling to cases where multi-returns exist.