onnx/onnx-tensorrt

`ScatterElements` with Reduction (opset 16) Not Fully Supported

anthony-correia opened this issue · 0 comments

Short Description

Conversion of an ONNX model to TensorRT using trtexec, which includes a scatterElements operation with a reduction like "sum" (opset 16), fails when the number of indices in the operation exceeds the output count.

Successful conversion requires n_indices <= n_outputs.

Long Description

Consider the following PyTorch model snippet:

import torch
import torch_scatter

n_indices: int = ...
dim_size: int = ...
n_outputs: int = ...

e_dummy = torch.randn(size=(n_indices, dim_size), device=device)
index_dummy = torch.randint(high=n_outputs, size=(n_indices,), device=device)

class ScatterModule(torch.nn.Module):
    def forward(self, e: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
        return torch_scatter.scatter(
            src=e,
            # broadcasting (should be done automatically anyway)
            index=index.unsqueeze(-1).expand(-1, e.shape[1]),
            dim=0,
            reduce="sum",
        )

Converting this corrresponding ONNX model using trtexec triggers an assertion error:

Assertion failed: indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!"

This error likely originates from this line of the ONNX-TensorRT code.

In the scenarios I've encountered within Graph Neural Networks, the number of indices (n_indices, corresponding to the edges in the graph) is significantly larger than the number of outputs (n_outputs, corresponding to the nodes in the graph).

Environment

TensorRT Version: 8.6.1.6-1+cuda11.8
GPU Type: NVIDIA RTX A2000 (laptop)
Nvidia Driver Version: 520.61.05
CUDA Version: 11.8.0-1
CUDNN Version: 8.7.0.84-1+cuda11.8
Operating System + Version: Ubuntu 22.04.1 LTS

Relevant Files

I've created a repository to reproduce the issue: anthony-correia/scatter_onnx2tensorrt.

The ONNX models are stored with the naming convention onnx/{n_indices}_{dim_size}_{n_outputs}_{seed}.onnx.

To replicate the issue, execute the following commands:

# This command fails when `n_outputs = 100` and `n_indices = 1000`.
trtexec --onnx="onnx/1000_3_100_0.onnx"

# This command succeeds when `n_outputs` equals `n_indices` (both are 100).
trtexec --onnx="onnx/100_3_100_0.onnx"