[Bug][Parsing] Passing a function call as an arg to an operator does not parse
slyubomirsky opened this issue · 4 comments
The following program results in a parse error:
from __future__ import annotations
import tvm
from tvm import relax as rx
from tvm.script import relax as R
@tvm.script.ir_module
class AddShape:
@R.function
def f(x: Tensor) -> Tensor:
return x
@R.function
def plus(x: Tensor) -> Tensor:
return R.add(f(x), f(x))
The stack trace is as follows:
Traceback (most recent call last):
File "/home/slyubomirsky/code/sandbox/add_shape.py", line 21, in <module>
class AddShape:
File "/home/slyubomirsky/code/relax/python/tvm/script/parser_v1/parser.py", line 1405, in ir_module
return _ir_module(input_module)
File "/home/slyubomirsky/code/relax/python/tvm/script/parser_v1/parser.py", line 1398, in _ir_module
mod = relax.transform.Normalize()(mod)
File "/home/slyubomirsky/code/relax/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/slyubomirsky/code/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
20: TVMFuncCall
19: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}>(tvm::transform::__mk_TVM9::{lambda(tvm::transform::Pass, tvm::IRModule)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
18: tvm::transform::Pass::operator()(tvm::IRModule) const
17: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
16: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
15: _ZN3tvm7runtime13PackedFuncObj
14: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::Normalize()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::Normalize()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
13: tvm::relax::Normalize(tvm::RelayExpr const&)
12: tvm::relax::ExprMutatorBase::VisitExpr(tvm::RelayExpr const&)
11: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
10: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
9: tvm::relax::ExprMutatorBase::VisitExpr_(tvm::relax::FunctionNode const*)
8: tvm::relax::NormalizeMutator::VisitExpr(tvm::RelayExpr const&)
7: tvm::relax::BlockBuilderNode::Normalize(tvm::RelayExpr const&)
6: tvm::relax::BlockBuilderNode::ExprNormalizer::VisitExpr(tvm::RelayExpr const&)
5: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
4: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_
3: tvm::relax::BlockBuilderNode::ExprNormalizer::VisitExpr_(tvm::relay::CallNode const*)
2: tvm::relax::BlockBuilderNode::ExprNormalizer::Bind(tvm::RelayExpr const&)
1: tvm::relax::BlockBuilderNode::Emit(tvm::RelayExpr const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)
0: tvm::relax::BlockBuilderNode::CurrentFrame()
File "/home/slyubomirsky/code/relax/src/relax/ir/block_builder.cc", line 700
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (!block_stack_.empty()) is false: no block is being built
The error is due to something about the plus
function. Additionally, if I change the arguments to constants (e.g., R.const(1)
), it parses just fine.
The same happens with other nested calls.
Examples that also result in this error:
@R.function
def plus(x: Tensor) -> Tensor:
return R.add(R.add(x, x), R.add(x, x))
@R.function
def plus(x: Tensor) -> Tensor:
return R.add(R.unique(x), R.unique(x))
Thanks @slyubomirsky for reporting the bug! It seems like an issue with Bind
in Normalizer. Could you try binding R.unique(x)
to a Var and run add on the Var?
Yeah it works if they're vars. I've done that many times before.
The root cause is the Normalize pass does not support normalize a relax function with non-SeqExpr as its body. Fixed in #268.
For example the following program has a CallNode as the function plus's body:
@R.function
def plus(x: Tensor) -> Tensor:
return R.add(R.unique(x), R.unique(x))
To normalize it, the Normalize pass needs to emit new bindings to bind R.unique(x)
, so the after normalized function needs to have a SeqExpr
as its body, which is handled by VisitWithNewScope.