tensorflow/mlir-hlo

while loop simplifier missing

lipracer opened this issue · 11 comments

xla service has while_loop_simplifier.
Invariant's code Motion function MHLO already exists.
But the following optimizations are missing.

  • loop with zero trip count elininate
while (operands) {
 cond(operands) {
   return fasle
 }
 body(operands){
 }
}
replace with: operands
  • loop with one trip count
    inline while's body.

Currently MHLIR ::WhileOp just lower to SCF ::WhileOp, I'm not sure if SCF ::WhileOp has a similar simplification.
If MLIR ::WhileOp needs to support this functionality, it may need to simulator excute 'whileOp''s body, or use data flow analysis to implement constant propagation to simulate body execution of 'whileOp.

I recently changed scf.while to be simplifiable by SCCP, after fixing a few bugs with getRegionSuccessors and how that function was getting used. The same can be done for mhlo.while.

Thanks a lot, I will make a reference and then implement the simplification of mhlo::while.

Isn't a folder for the WhileOp when the return of the condition is false enough here?

Yeah.First we need to propagate the constant to cond's body, then try to fold it's.

while(%cst0, %cst1) {
  cond (%arg0, %arg1) {
    return compare(%arg0, %arg1)  
  }
}

and the mhlo-sink-constants-to-control-flow just clone constantOp into body, didn't clone arguments.

I have a patch for the folder, but I also looked into the RegionBranchOpInterface and it seems like there is a limitation to the interface that prevents us from using it in MHLO.

In particular the getMutableSuccessorOperands for RegionBranchTerminatorOpInterface can't be used with our while op because the condition terminator does not take as operands the values to pass to the body...

Can we apply the framework of data flow analysis, implemented by constant propagation? Including other control flow simplifications, I'm not sure if scf dailect does these mhlo's as well, but we need these for completeness.

This relies on RegionBranchOpInterface, which is difficult to apply here (what I'm describing above)

In particular the getMutableSuccessorOperands for RegionBranchTerminatorOpInterface can't be used with our while op because the condition terminator does not take as operands the values to pass to the body...

I've run into this limitation before but just added as operands the forwarded values. I will check whether the function really needs to return a MutableOperandRange or whether it can just return a ValueRange. Then, the terminator could just return the block arguments.

Thanks!

Thanks.I need to know about RegionBranchOpInterface. But xla service already has while_loop_simplifier.

I will check whether the function really needs to return a MutableOperandRange or whether it can just return a ValueRange

The only upstream use of the MutableOperandRange return type is in buffer deallocation.

This signature matches with the interface for BranchOpInterface in that it allows transformations to reason and modify control-flow constructs. I'm not sure it can go away.