`ScatterElements` with Reduction (opset 16) Not Fully Supported
anthony-correia opened this issue · 0 comments
Short Description
Conversion of an ONNX model to TensorRT using trtexec
, which includes a scatterElements
operation with a reduction like "sum"
(opset 16), fails when the number of indices in the operation exceeds the output count.
Successful conversion requires n_indices <= n_outputs
.
Long Description
Consider the following PyTorch model snippet:
import torch
import torch_scatter
n_indices: int = ...
dim_size: int = ...
n_outputs: int = ...
e_dummy = torch.randn(size=(n_indices, dim_size), device=device)
index_dummy = torch.randint(high=n_outputs, size=(n_indices,), device=device)
class ScatterModule(torch.nn.Module):
def forward(self, e: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
return torch_scatter.scatter(
src=e,
# broadcasting (should be done automatically anyway)
index=index.unsqueeze(-1).expand(-1, e.shape[1]),
dim=0,
reduce="sum",
)
Converting this corrresponding ONNX model using trtexec
triggers an assertion error:
Assertion failed: indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!"
This error likely originates from this line of the ONNX-TensorRT code.
In the scenarios I've encountered within Graph Neural Networks, the number of indices (n_indices
, corresponding to the edges in the graph) is significantly larger than the number of outputs (n_outputs
, corresponding to the nodes in the graph).
Environment
TensorRT Version: 8.6.1.6-1+cuda11.8
GPU Type: NVIDIA RTX A2000 (laptop)
Nvidia Driver Version: 520.61.05
CUDA Version: 11.8.0-1
CUDNN Version: 8.7.0.84-1+cuda11.8
Operating System + Version: Ubuntu 22.04.1 LTS
Relevant Files
I've created a repository to reproduce the issue: anthony-correia/scatter_onnx2tensorrt.
The ONNX models are stored with the naming convention onnx/{n_indices}_{dim_size}_{n_outputs}_{seed}.onnx
.
To replicate the issue, execute the following commands:
# This command fails when `n_outputs = 100` and `n_indices = 1000`.
trtexec --onnx="onnx/1000_3_100_0.onnx"
# This command succeeds when `n_outputs` equals `n_indices` (both are 100).
trtexec --onnx="onnx/100_3_100_0.onnx"