tlc-pack/relax

[Bug][Parsing] Passing a function call as an arg to an operator does not parse

slyubomirsky opened this issue · 4 comments

The following program results in a parse error:

from __future__ import annotations

import tvm
from tvm import relax as rx
from tvm.script import relax as R

@tvm.script.ir_module
class AddShape:
    @R.function
    def f(x: Tensor) -> Tensor:
        return x

    @R.function
    def plus(x: Tensor) -> Tensor:
	return R.add(f(x), f(x))

The stack trace is as follows:

Traceback (most recent call last):
  File "/home/slyubomirsky/code/sandbox/add_shape.py", line 21, in <module>
    class AddShape:
  File "/home/slyubomirsky/code/relax/python/tvm/script/parser_v1/parser.py", line 1405, in ir_module
    return _ir_module(input_module)
  File "/home/slyubomirsky/code/relax/python/tvm/script/parser_v1/parser.py", line 1398, in _ir_module
    mod = relax.transform.Normalize()(mod)
  File "/home/slyubomirsky/code/relax/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/slyubomirsky/code/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  20: TVMFuncCall
  19: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  18: tvm::transform::Pass::operator()(tvm::IRModule) const
  17: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  16: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: _ZN3tvm7runtime13PackedFuncObj
  14: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::Normalize()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::Normalize()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  13: tvm::relax::Normalize(tvm::RelayExpr const&)
  12: tvm::relax::ExprMutatorBase::VisitExpr(tvm::RelayExpr const&)
  11: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  10: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  9: tvm::relax::ExprMutatorBase::VisitExpr_(tvm::relax::FunctionNode const*)
  8: tvm::relax::NormalizeMutator::VisitExpr(tvm::RelayExpr const&)
  7: tvm::relax::BlockBuilderNode::Normalize(tvm::RelayExpr const&)
  6: tvm::relax::BlockBuilderNode::ExprNormalizer::VisitExpr(tvm::RelayExpr const&)
  5: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  4: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
  3: tvm::relax::BlockBuilderNode::ExprNormalizer::VisitExpr_(tvm::relay::CallNode const*)
  2: tvm::relax::BlockBuilderNode::ExprNormalizer::Bind(tvm::RelayExpr const&)
  1: tvm::relax::BlockBuilderNode::Emit(tvm::RelayExpr const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
  0: tvm::relax::BlockBuilderNode::CurrentFrame()
  File "/home/slyubomirsky/code/relax/src/relax/ir/block_builder.cc", line 700
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (!block_stack_.empty()) is false: no block is being built

The error is due to something about the plus function. Additionally, if I change the arguments to constants (e.g., R.const(1)), it parses just fine.

The same happens with other nested calls.

Examples that also result in this error:

@R.function
def plus(x: Tensor) -> Tensor:
     return R.add(R.add(x, x), R.add(x, x))
@R.function
def plus(x: Tensor) -> Tensor:
     return R.add(R.unique(x), R.unique(x))

Thanks @slyubomirsky for reporting the bug! It seems like an issue with Bind in Normalizer. Could you try binding R.unique(x) to a Var and run add on the Var?

Yeah it works if they're vars. I've done that many times before.

The root cause is the Normalize pass does not support normalize a relax function with non-SeqExpr as its body. Fixed in #268.

For example the following program has a CallNode as the function plus's body:

@R.function
def plus(x: Tensor) -> Tensor:
     return R.add(R.unique(x), R.unique(x))

To normalize it, the Normalize pass needs to emit new bindings to bind R.unique(x), so the after normalized function needs to have a SeqExpr as its body, which is handled by VisitWithNewScope.