NVIDIA/TensorRT-Incubator

`tp.mean` failure when `dim` is multi dimensional with skipped dimensions

Opened this issue · 0 comments

tp.mean fails if we skip a dimension

Working examples

x = tp.reshape(tp.arange(12), (2,3,2))

then if we do

>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
    [[[3.0000, 4.0000],
      [5.0000, 6.0000],
      [7.0000, 8.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 3, 2))
>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
    [[[5.0000, 6.0000]]], 
    dtype=float32, loc=gpu:0, shape=(1, 1, 2))
>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))

Failed example 1

but if you skip a dim

>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
    data_list = self.tolist()
  File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
    data_memref = self.eval()
  File "/tripy/tripy/frontend/tensor.py", line 180, in eval
    executable = compiler.compile(mlir, flat_ir=flat_ir)
  File "/tripy/tripy/utils/utils.py", line 74, in wrapper
    result = func(*args, **kwargs)
  File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
    map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
  File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
    raise_error(
  File "/tripy/tripy/common/exception.py", line 195, in raise_error
    raise TripyException(msg) from None
tripy.common.exception.TripyException: 

--> <stdin>:1 in <module>()

MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12

Additional context:
Traceback (most recent call last):
  File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
    executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
    (t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
    (t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
    (t1926)): error: op: 
    %2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
    ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
      %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      "stablehlo.return"(%7) : (tensor<f32>) -> ()
    }) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.

    This error occured while trying to compile the following FlatIR expression:
          |
          | t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
          | 

    This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.

    Note: This originated from the following expression:

    --> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
          |
      174 |     return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    --> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
          |
      318 |     sum_val = sum(tensor, dim=dim, keepdim=keepdim)
          |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    --> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
          |
      361 |     return mean_impl(input, dim, keepdim)
          |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here

    Input 0:

    --> /tripy/tripy/frontend/utils.py:455 in wrapper()
          |
      455 |             return func(*new_args, **new_kwargs)
          | 

    --> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
          |
      145 |     return reshape_impl(input, shape, len(shape), output_len)
          |