`tp.mean` failure when `dim` is multi dimensional with skipped dimensions
Opened this issue · 0 comments
farazkh80 commented
tp.mean
fails if we skip a dimension
Working examples
x = tp.reshape(tp.arange(12), (2,3,2))
then if we do
>>> tp.mean(x, dim=[0], keepdim=True)
tensor(
[[[3.0000, 4.0000],
[5.0000, 6.0000],
[7.0000, 8.0000]]],
dtype=float32, loc=gpu:0, shape=(1, 3, 2))
>>> tp.mean(x, dim=[0,1], keepdim=True)
tensor(
[[[5.0000, 6.0000]]],
dtype=float32, loc=gpu:0, shape=(1, 1, 2))
>>> tp.mean(x, dim=[0,1,2], keepdim=True)
tensor([[[5.5000]]], dtype=float32, loc=gpu:0, shape=(1, 1, 1))
Failed example 1
but if you skip a dim
>>> tp.mean(x, dim=[0,2], keepdim=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/tripy/tripy/frontend/tensor.py", line 214, in __repr__
data_list = self.tolist()
File "/tripy/tripy/frontend/tensor.py", line 195, in tolist
data_memref = self.eval()
File "/tripy/tripy/frontend/tensor.py", line 180, in eval
executable = compiler.compile(mlir, flat_ir=flat_ir)
File "/tripy/tripy/utils/utils.py", line 74, in wrapper
result = func(*args, **kwargs)
File "/tripy/tripy/backend/mlir/compiler.py", line 109, in compile
map_error_to_user_code_and_raise(flat_ir, exc, stderr.decode())
File "/tripy/tripy/backend/mlir/utils.py", line 513, in map_error_to_user_code_and_raise
raise_error(
File "/tripy/tripy/common/exception.py", line 195, in raise_error
raise TripyException(msg) from None
tripy.common.exception.TripyException:
--> <stdin>:1 in <module>()
MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
Additional context:
Traceback (most recent call last):
File "/tripy/tripy/backend/mlir/compiler.py", line 102, in compile
executable = compiler.compiler_stablehlo_to_executable(
mlir_tensorrt.runtime._mlir_libs._api.MTRTException: InternalError: failed to run compilation on module with symbol name: outs_t2231_12
.
(t1926)): error: op: %7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32> from function main is invalid, post clustering.
(t1926)): error: op: "stablehlo.return"(%7) : (tensor<f32>) -> () from function main is invalid, post clustering.
(t1926)): error: op:
%2 = "stablehlo.reduce"(%1, %0) <{dimensions = array<i64: 0, 2>}> ({
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "stablehlo.add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%7) : (tensor<f32>) -> ()
}) : (tensor<2x3x2xf32>, tensor<f32>) -> tensor<3xf32> from function main is invalid, post clustering.
This error occured while trying to compile the following FlatIR expression:
|
| t_inter4: [rank=(1), dtype=(float32), loc=(gpu:0)] = ReduceOp(t_inter3, t_inter5, reduce_mode='sum', reduce_dims=[0, 2])
|
This operation was introduced to Cloning tensor t1926: [rank=(1), dtype=(float32), loc=(gpu:0)] for function input/output.
Note: This originated from the following expression:
--> /tripy/tripy/frontend/trace/ops/reduce.py:174 in sum()
|
174 | return _reduce_impl(input, Reduce.Kind.SUM, dim, keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
--> /tripy/tripy/frontend/trace/ops/reduce.py:318 in mean_impl()
|
318 | sum_val = sum(tensor, dim=dim, keepdim=keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
--> /tripy/tripy/frontend/trace/ops/reduce.py:361 in mean()
|
361 | return mean_impl(input, dim, keepdim)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ --- required from here
Input 0:
--> /tripy/tripy/frontend/utils.py:455 in wrapper()
|
455 | return func(*new_args, **new_kwargs)
|
--> /tripy/tripy/frontend/trace/ops/reshape.py:145 in reshape()
|
145 | return reshape_impl(input, shape, len(shape), output_len)
|