cornell-zhang/hcl-dialect

`memref.reshape` causes segfault in JIT backend with `opt_level=3`

Opened this issue · 3 comments

Description

This thread documents an issue we met with memref.reshape. The generated IR is correct, it can be compiled with clang and executes correctly when mlir ExecutionEngine optimization level is set to 0, 1, 2. However, if ExecutionEngine optimization level is set to 3, this triggers a segfault.

Specifically, this step causes segfault:

execution_engine = ExecutionEngine(
            lowered, opt_level=3, shared_libs=shared_libs)

Current solution

This is likely an issue with MLIR JIT compiler. We bypass this issue by setting the optimization level lower than 3.

Sample IR to repeat this issue

module {
  memref.global "private" constant @const_0 : memref<3xi64> = dense<[5, 2, 4]>
  memref.global "private" constant @const_1 : memref<2xi64> = dense<[5, 8]>
  func.func @kernel(%arg0: memref<5x3x2xf32>, %arg1: memref<4x3xf32>, %arg2: memref<4xf32>) -> memref<5x8xf32> attributes {itypes = "___", otypes = "_"} {
    %c1_i32 = arith.constant 1 : i32
    %0 = arith.sitofp %c1_i32 : i32 to f32
    %1 = arith.negf %0 : f32
    %c2_i32 = arith.constant 2 : i32
    %2 = arith.sitofp %c2_i32 : i32 to f32
    %3 = arith.negf %2 : f32
    %alloc = memref.alloc() {name = "output1"} : memref<5x2x3xf32>
    %c0_i32 = arith.constant 0 : i32
    %4 = arith.sitofp %c0_i32 : i32 to f32
    linalg.fill {op_name = "transpose_init_zero_0"} ins(%4 : f32) outs(%alloc : memref<5x2x3xf32>)
    linalg.transpose ins(%arg0 : memref<5x3x2xf32>) outs(%alloc : memref<5x2x3xf32>) permutation = [0, 2, 1]  {op_name = "transpose_1"}
    %alloc_0 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_1 = arith.constant 0 : i32
    %5 = arith.sitofp %c0_i32_1 : i32 to f32
    linalg.fill {op_name = "linear_init_zero_2"} ins(%5 : f32) outs(%alloc_0 : memref<5x2x4xf32>)
    %alloc_2 = memref.alloc() : memref<3x4xf32>
    %c0_i32_3 = arith.constant 0 : i32
    %6 = arith.sitofp %c0_i32_3 : i32 to f32
    linalg.fill {op_name = "transpose_init_zero_3"} ins(%6 : f32) outs(%alloc_2 : memref<3x4xf32>)
    linalg.transpose ins(%arg1 : memref<4x3xf32>) outs(%alloc_2 : memref<3x4xf32>) permutation = [1, 0]  {op_name = "transpose_4"}
    %alloc_4 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_5 = arith.constant 0 : i32
    %7 = arith.sitofp %c0_i32_5 : i32 to f32
    linalg.fill {op_name = "matmul_init_zero_5"} ins(%7 : f32) outs(%alloc_4 : memref<5x2x4xf32>)
    %alloc_6 = memref.alloc() : memref<5x3x4xf32>
    linalg.broadcast ins(%alloc_2 : memref<3x4xf32>) outs(%alloc_6 : memref<5x3x4xf32>) dimensions = [0] 
    %alloc_7 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_8 = arith.constant 0 : i32
    %8 = arith.sitofp %c0_i32_8 : i32 to f32
    linalg.fill {op_name = "bmm_init_zero_6"} ins(%8 : f32) outs(%alloc_7 : memref<5x2x4xf32>)
    linalg.batch_matmul {op_name = "bmm_7"} ins(%alloc, %alloc_6 : memref<5x2x3xf32>, memref<5x3x4xf32>) outs(%alloc_7 : memref<5x2x4xf32>)
    %alloc_9 = memref.alloc() : memref<5x2x4xf32>
    %c0_i32_10 = arith.constant 0 : i32
    %9 = arith.sitofp %c0_i32_10 : i32 to f32
    linalg.fill {op_name = "view_init_zero_8"} ins(%9 : f32) outs(%alloc_9 : memref<5x2x4xf32>)
    %10 = memref.get_global @const_0 : memref<3xi64>
    %reshape = memref.reshape %alloc_7(%10) : (memref<5x2x4xf32>, memref<3xi64>) -> memref<5x2x4xf32>
    %alloc_11 = memref.alloc() : memref<5x2x4xf32>
    linalg.broadcast ins(%arg2 : memref<4xf32>) outs(%alloc_11 : memref<5x2x4xf32>) dimensions = [0, 1] 
    %alloc_12 = memref.alloc() {name = "output2"} : memref<5x2x4xf32>
    %c0_i32_13 = arith.constant 0 : i32
    %11 = arith.sitofp %c0_i32_13 : i32 to f32
    linalg.fill {op_name = "add_init_zero_9"} ins(%11 : f32) outs(%alloc_12 : memref<5x2x4xf32>)
    linalg.add {op_name = "add_10"} ins(%reshape, %alloc_11 : memref<5x2x4xf32>, memref<5x2x4xf32>) outs(%alloc_12 : memref<5x2x4xf32>)
    %alloc_14 = memref.alloc() : memref<5x8xf32>
    %c0_i32_15 = arith.constant 0 : i32
    %12 = arith.sitofp %c0_i32_15 : i32 to f32
    linalg.fill {op_name = "view_init_zero_11"} ins(%12 : f32) outs(%alloc_14 : memref<5x8xf32>)
    %13 = memref.get_global @const_1 : memref<2xi64>
    %reshape_16 = memref.reshape %alloc_12(%13) {name = "output"} : (memref<5x2x4xf32>, memref<2xi64>) -> memref<5x8xf32>
    return %reshape_16 : memref<5x8xf32>
  }


  func.func @main() {
	%arg0 = memref.alloc() : memref<5x3x2xf32>
	%arg1 = memref.alloc() : memref<4x3xf32>
	%arg2 = memref.alloc() : memref<4xf32>
	%arg3 = func.call @kernel(%arg0, %arg1, %arg2) : (memref<5x3x2xf32>, memref<4x3xf32>, memref<4xf32>) -> memref<5x8xf32>
	return
  }

}

Stack trace

 #3 0x00007f2186a5d950 llvm::isPotentiallyReachable(llvm::Instruction const*, llvm::Instruction const*, llvm::SmallPtrSetImpl<llvm::BasicBlock*> const*, llvm::DominatorTree const*, llvm::LoopInfo const*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a21950)
 #4 0x00007f2186a2d619 llvm::EarliestEscapeInfo::isNotCapturedBeforeOrAt(llvm::Value const*, llvm::Instruction const*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59f1619)
 #5 0x00007f2186a276bb llvm::BasicAAResult::getModRefInfo(llvm::CallBase const*, llvm::MemoryLocation const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59eb6bb)
 #6 0x00007f2186a0756b llvm::AAResults::getModRefInfo(llvm::CallBase const*, llvm::MemoryLocation const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59cb56b)
 #7 0x00007f2186a08e51 llvm::AAResults::getModRefInfo(llvm::Instruction const*, std::optional<llvm::MemoryLocation> const&, llvm::AAQueryInfo&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x59cce51)
 #8 0x00007f218638ebdf (anonymous namespace)::DSEState::isReadClobber(llvm::MemoryLocation const&, llvm::Instruction*) DeadStoreElimination.cpp:0:0
 #9 0x00007f2186398eae (anonymous namespace)::DSEState::getDomMemoryDef(llvm::MemoryDef*, llvm::MemoryAccess*, llvm::MemoryLocation const&, llvm::Value const*, unsigned int&, unsigned int&, bool, unsigned int&) DeadStoreElimination.cpp:0:0
#10 0x00007f218639af52 (anonymous namespace)::eliminateDeadStores(llvm::Function&, llvm::AAResults&, llvm::MemorySSA&, llvm::DominatorTree&, llvm::PostDominatorTree&, llvm::AssumptionCache&, llvm::TargetLibraryInfo const&, llvm::LoopInfo const&) DeadStoreElimination.cpp:0:0
#11 0x00007f218639d038 llvm::DSEPass::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5361038)
#12 0x00007f218554e35e llvm::detail::PassModel<llvm::Function, llvm::DSEPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x451235e)
#13 0x00007f218737d514 llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x6341514)
#14 0x00007f21855472ce llvm::detail::PassModel<llvm::Function, llvm::PassManager<llvm::Function, llvm::AnalysisManager<llvm::Function>>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Function>>::run(llvm::Function&, llvm::AnalysisManager<llvm::Function>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x450b2ce)
#15 0x00007f2186a7739f llvm::CGSCCToFunctionPassAdaptor::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a3b39f)
#16 0x00007f218554d40e llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::CGSCCToFunctionPassAdaptor, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x451140e)
#17 0x00007f2186a6ff5b llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a33f5b)
#18 0x00007f218554d3ce llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x45113ce)
#19 0x00007f2186a73a55 llvm::DevirtSCCRepeatedPass::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a37a55)
#20 0x00007f218554d3ee llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::DevirtSCCRepeatedPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x45113ee)
#21 0x00007f2186a71bf9 llvm::ModuleToPostOrderCGSCCPassAdaptor::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x5a35bf9)
#22 0x00007f21857956ef llvm::ModuleInlinerWrapperPass::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x47596ef)
#23 0x00007f218554cf8e llvm::detail::PassModel<llvm::Module, llvm::ModuleInlinerWrapperPass, llvm::PreservedAnalyses, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x4510f8e)
#24 0x00007f2185542f50 mlir::makeOptimizingTransformer(unsigned int, unsigned int, llvm::TargetMachine*)::'lambda'(llvm::Module*)::operator()(llvm::Module*) const OptUtils.cpp:0:0
#25 0x00007f2185543cad std::_Function_handler<llvm::Error (llvm::Module*), mlir::makeOptimizingTransformer(unsigned int, unsigned int, llvm::TargetMachine*)::'lambda'(llvm::Module*)>::_M_invoke(std::_Any_data const&, llvm::Module*&&) OptUtils.cpp:0:0
#26 0x00007f218264550d llvm::Error llvm::function_ref<llvm::Error (llvm::Module*)>::callback_fn<std::function<llvm::Error (llvm::Module*)>>(long, llvm::Module*) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x160950d)
#27 0x00007f218306e1e6 mlir::ExecutionEngine::create(mlir::Operation*, mlir::ExecutionEngineOptions const&, std::unique_ptr<llvm::TargetMachine, std::default_delete<llvm::TargetMachine>>) (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x20321e6)
#28 0x00007f2182646ab5 mlirExecutionEngineCreate (/work/shared/users/phd/nz264/mlir/hcl-dialect/build/tools/hcl/python_packages/hcl_core/hcl_mlir/_mlir_libs/libHCLMLIRAggregateCAPI.so.18git+0x160aab5)
#29 0x00007f21803de5e5 pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)::operator()(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool) const /work/shared/users/common/llvm-project-18.x/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp:82:77
#30 0x00007f21803df91a void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)::operator()(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool) const /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/detail/init.h:242:29
#31 0x00007f21803e37e3 pybind11::class_<(anonymous namespace)::PyExecutionEngine> pybind11::detail::argument_loader<pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool>::call_impl<void, void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)&, 0ul, 1ul, 2ul, 3ul, 4ul, pybind11::detail::void_type>(void pybind11::detail::initimpl::factory<pybind11_init__mlirExecutionEngine(pybind11::module_&)::'lambda'(MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type (*)(), (anonymous namespace)::PyExecutionEngine* (MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool), pybind11::detail::void_type ()>::execute<pybind11::class_<(anonymous namespace)::PyExecutionEngine>, pybind11::arg, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v, char [327]>(pybind11::class_<(anonymous namespace)::PyExecutionEngine>&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&, char const (&) [327]) &&::'lambda'(pybind11::detail::value_and_holder&, MlirModule, int, std::vector<std::string, std::allocator<std::string>> const&, bool)&, std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul>, pybind11::detail::void_type&&) && /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1205:91
#32 0x00007f21803e3439 _ZNO8pybind116detail15argument_loaderIJRNS0_16value_and_holderE10MlirModuleiRKSt6vectorISsSaISsEEbEE4callIvNS0_9void_typeERZNOS0_8initimpl7factoryIZL34pybind11_init__mlirExecutionEngineRNS_7module_EEUlS4_iS9_bE_PFSC_vEFPN12_GLOBAL__N_117PyExecutionEngineES4_iS9_bESI_E7executeINS_6class_ISL_JEEEJNS_3argENS_5arg_vEST_ST_A327_cEEEvRT_DpRKT0_EUlS3_S4_iS9_bE_EENSt9enable_ifIXsrSt7is_voidISV_E5valueESC_E4typeEOT1_ /home/nz264/anaconda3/envs/mlir/lib/python3.8/site-packages/pybind11/include/pybind11/cast.h:1183:26

To compile any MLIR IR with gcc/clang, we can do this:

mlir-opt example.mlir \
	--convert-linalg-to-affine-loops \
	--one-shot-bufferize \
	--lower-affine \
	--convert-scf-to-cf \
	--convert-cf-to-llvm \
	--convert-func-to-llvm \
	--convert-arith-to-llvm \
	--finalize-memref-to-llvm \
	--reconcile-unrealized-casts \
	-o example.llvm.mlir


mlir-translate example.llvm.mlir \
	--mlir-to-llvmir \
	-o example.ll

llc example.ll -o example.s
as example.s -o example.o
gcc example.o -o example.exe

Associated allo program to this sample:

import allo
from allo.ir.types import int32, float32
import numpy as np

def test_library_higher_dimension_ops(enable_tensor):
    M = 5
    N = 4
    K = 3
    L = 2
    A = np.random.uniform(size=(M, K, L)).astype(np.float32)
    B = np.random.uniform(size=(N, K)).astype(np.float32)
    C = np.random.uniform(size=(N,)).astype(np.float32)

    def kernel(
        A: float32[M, K, L], B: float32[N, K], C: float32[N]
    ) -> float32[M, L * N]:
        output1 = allo.transpose(A, (-1, -2))
        output2 = allo.linear(output1, B, C)
        output = allo.view(output2, (5, 8))
        return output

    s = allo.customize(kernel, enable_tensor=enable_tensor)
    mod = s.build()
    outp = mod(A, B, C)
    np_outp = kernel(A, B, C)
    np.testing.assert_allclose(outp, np_outp, rtol=1e-5)
    
if __name__ == "__main__":
	test_library_higher_dimension_ops(False)

Thank you! This works for me. But when opt-level is set to 0 or 1, type test case can not pass.

def test_compare_int_float():
        Ty = Int(5)
    
        def kernel(A: Ty) -> Ty:
            B: Ty = 0
            if A > B or A + 1 < 0.0:
                B = A
            return B
    
        s = allo.customize(kernel)
        mod = s.build()
        assert mod(2) == kernel(2)
>       assert mod(-3) == kernel(-3)
E       assert 29 == -3
E        +  where 29 = <allo.backend.llvm.LLVMModule object at 0x7f8d21573c70>(-3)
E        +  and   -3 = <function test_compare_int_float.<locals>.kernel at 0x7f8d21d588b0>(-3)

tests/test_types.py:165: AssertionError