google-deepmind/tapnet

OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input.

Opened this issue · 10 comments

Hi everyone,

Thanks for the awesome work. I've been trying to export the pytorch model to ONNX for inference with torch.onnx.export but it yields this error : OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input.

Unfortunately, It seems 5D grid_sample is still unsupported by onnx / torch. Is there any alternative available ? Or any advice to make the model work with ONNX ?

Thanks

@Cyril9227, torch.onnx.export() fails for me too. It seems like the cause is described in pytorch/pytorch#100790 that will be addressed through pytorch/pytorch#114801 (ONNX opset 20 support).

In the meantime I was trying to convert to ONNX through Haiku (JAX) -> TensorFlow ->ONNX, using https://dm-haiku.readthedocs.io/en/latest/notebooks/jax2tf.html as a tutorial for Haiku -> TF:

import functools
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import mediapy as media
import numpy as np
from tqdm import tqdm
import tree
from tapnet import tapir_model
from tapnet.utils import transforms
from tapnet.utils import viz_utils
import tensorflow as tf
import sonnet as snt

checkpoint_path = 'tapnet/checkpoints/causal_tapir_checkpoint.npy'
ckpt_state = np.load(checkpoint_path, allow_pickle=True).item()
params, state = ckpt_state['params'], ckpt_state['state']
params_vars = tf.nest.map_structure(tf.Variable, params)

def build_online_model_init(frames, query_points):
  """Initialize query features for the query points."""
  model = tapir_model.TAPIR(use_causal_conv=True, bilinear_interp_with_depthwise_conv=False) 

  feature_grids = model.get_feature_grids(frames, is_training=False)
  query_features = model.get_query_features(
      frames,
      is_training=False,
      query_points=query_points,
      feature_grids=feature_grids,
  )
  return query_features

init_tf = hk.transform(build_online_model_init) 

class JaxModule(snt.Module):
  def __init__(self, params, apply_fn, name=None):
    super().__init__(name=name)
    self._params = params   
    self._apply = jax2tf.convert(lambda p, x: apply_fn(p, None, x), enable_xla=False)
    self._apply = tf.autograph.experimental.do_not_convert(self._apply)

  def __call__(self, inputs):
    return self._apply(self._params, inputs)

net = JaxModule(params_vars,  init_tf.apply)

# frames: [num_frames, height, width, 3], query_points: [num_points, 3] where 3 for the tuple (t, y, x)
@tf.function(autograph=False, input_signature=[{"frames" : tf.TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32), 
                                                "query_points": tf.TensorSpec(shape=(20,3), dtype=tf.float32)}]) 
def forward(x):
  return net(x)

to_save = tf.Module()
to_save.forward = forward
to_save.params = list(net.variables)
tf.saved_model.save(to_save, "TapirInit")  

but it fails with TypeError: build_online_model_init() missing 1 required positional argument: 'query_points'. Similar with build_online_model_predict(). Maybe the input_signature() is incorrect in tf.function(), but I cannot figure out how to fix it.
Have you tried the TF path?

Since tf2onnx only supports ONNX opset up to 18, the TF SavedModel to ONNX conversion is likely to have the same problem as with PyTorch :(

@Cyril9227 I have posted a solution here https://github.com/pytorch/pytorch/issues/100790. See if that works for you

@saikiran321, the solution you have posted does not produce the unsupported ONNX error related to opset 20 support.
Instead, torch.onnx.export fails with ValueError: only one element tensors can be converted to Python scalars.
A docker file and a Python code to reproduce the result are in the zip file attached torch2onnx.zip.
Do you know what could be the cause for this error? Thank you.

I'm no expert on ONNX, but if the problem is a 5D gather operation, then I suspect the source of the problem is extracting query features. It's possible to rewrite the vmap using a 4D gather; it wastes computation, but it's probably relatively small compared to the rest of the model. Try setting parallelize_query_extraction to True when contstructing the tapir model; it should produce exactly the same result given the same checkpoint, but hopefully it will avoid the 5D gather.

As a bit of an explanation, when extracting the query feature, we get a [t,y,x] coordinate and use bilinear interpolation to extract a feature from that location. The parallelize_query_extraction version instead extracts the feature at [y,x] from every frame (using a vmapped 4D gather), and then multiplies the resulting tensor by a 1-hot t vector to discard every query feature except the one on frame t.

Of course, this is only implemented the jax version; you'd have to re-implement the same algorithm in the torch model to export from torch.

hi! hi! I export opset16 -onnx,and use onnx_graphsurgeon to directly modify the opset to 20,then use trtexec --onnx xx—engine, meeting the same problem:Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addGridSample::1474, condition: input.getDimensions().nbDims == 4 @saikiran321 @SergeySandler @Cyril9227 @yotam

Hi

hi! hi! I export opset16 -onnx,and use onnx_graphsurgeon to directly modify the opset to 20,then use trtexec --onnx xx—engine, meeting the same problem:Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addGridSample::1474, condition: input.getDimensions().nbDims == 4 @saikiran321 @SergeySandler @Cyril9227 @yotam

Hi! Same error, did you succeed to solve this?

I modified the torch model for the case of t=1 and reduced all the 5D to 4D, among other changes: https://github.com/ibaiGorordo/Tapir-Pytorch-Inference

I also added a script to export the model but it is very slow when running in onnxruntime compared to Pytorch (RTX4080): ~700 ms without refinement and ~20s with 4 iterations (1000 points 640x640)

@ibaiGorordo,

it is very slow when running in onnxruntime compared to Pytorch (RTX4080)

Do you have the code for inference with ONNX? Do you use CUDA Execution Provider or CPU Execution Provider with ONNX?

@ibaiGorordo,

it is very slow when running in onnxruntime compared to Pytorch (RTX4080)

Do you have the code for inference with ONNX? Do you use CUDA Execution Provider or CPU Execution Provider with ONNX?

I added the inference time calculation on the onnx_export.py script.

CPU is faster:
tapir_onnx_cpu

Than CUDA:
tapir_onnx_cuda

The slow part seems to be with the convolutions in the pips mixer block

@ibaiGorordo, I have reproduced tapir.onnx and it is three times slower than Pytorch with CUDA device.
My results on Windows: PyTorch inference takes around 0.1 sec on CUDA, 3 sec on CPU; ONNX - 0.3 sec with DmlExecutionProvider, 3 sec with CPUExecutionProvider.

There are a couple of hints for Windows that might be useful, especially if your results with ONNX are worse than with CPU:

  1. Do not forget to add device_id:your_card_ID (that is 0 in my case) in
    predictor = onnxruntime.InferenceSession(f'{output_dir}/tapir.onnx', providers = ['DmlExecutionProvider'], provider_options=[{'device_id':0}]) , otherwise it might use integreated Intel graphics card instead of NVIDIA card,
  2. Without pip install onnxruntime-directml DmlExecutionProvider is not available in Windows.