YvanYin/Metric3D

Improved ONNX support with dynamic shapes

xenova opened this issue ยท 9 comments

Hi there! ๐Ÿ‘‹ Following the conversation in #103, I wanted to export the models so they (1) support dynamic shapes and (2) returned the normal information, mainly to run the models with Transformers.js. I got them working, and I've uploaded them to the Hugging Face Hub:

(you can find the .onnx weights - both fp32 and fp16) in the onnx subfolder)

Feel free to use them yourself or add the links to the README for increased visibility! ๐Ÿค— PS: I'd also recommend uploading your original pytorch checkpoints to separate repos (instead of a single repo). Let me know if I can help with any of this!

Regarding the export, there were a few things to consider, mainly fixing the modelling code to avoid python type casts (ensuring the dynamic shapes work during tracing). I also made a few modifications to support CPU exports. Here's my conversion code:

import torch
import math
import torch.nn as nn

class NullContext:
  def __init__(self, *args, **kwargs):
    pass

  def __enter__(self):
    pass

  def __exit__(self, exc_type, exc_value, traceback):
    pass

# Do not autocast to bf16 or cuda
torch.autocast = NullContext

class Metric3DExportModel(nn.Module):
    """
    The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
    """

    def __init__(self, meta_arch):
        super().__init__()
        self.meta_arch = meta_arch
        self.register_buffer(
            "rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1)
        )
        self.register_buffer(
            "rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1)
        )

    def normalize_image(self, image):
        image = image - self.rgb_mean
        image = image / self.rgb_std
        return image

    def forward(self, image):
        image = self.normalize_image(image)
        with torch.no_grad():
            pred_depth, confidence, output_dict = self.meta_arch.inference(
                {"input": image}
            )

        pred_depth = pred_depth.squeeze(1)
        pred_normal = output_dict['prediction_normal'][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
        normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details

        return pred_depth, pred_normal, normal_confidence


def patch_model(model):

    def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        # Comment out this code (so we always interpolate)
        # if npatch == N and w == h:
        #     return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset

        if torch.jit.is_tracing():
          sqrt_N = N ** 0.5
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
              size=(w0, h0),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )
        else:
          sqrt_N = math.sqrt(N)
          sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
          patch_pos_embed = nn.functional.interpolate(
              patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
              scale_factor=(sx, sy),
              mode="bicubic",
              antialias=self.interpolate_antialias,
          )

        assert int(w0) == patch_pos_embed.shape[-2]
        assert int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

    model.depth_model.encoder.interpolate_pos_encoding = (
        interpolate_pos_encoding.__get__(
            model.depth_model.encoder, model.depth_model.encoder.__class__
        )
    )

    def get_bins(self, bins_num):
        depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num)
        depth_bins_vec = torch.exp(depth_bins_vec)
        return depth_bins_vec

    model.depth_model.decoder.get_bins = (
        get_bins.__get__(
            model.depth_model.decoder, model.depth_model.decoder.__class__
        )
    )

    return model

# Load model
model_name = "metric3d_vit_small" # or "metric3d_vit_large" or "metric3d_vit_giant2"
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.eval()

# Patch model so we can export to ONNX
model = patch_model(model)
export_model = Metric3DExportModel(model)
export_model.eval()

# Export the model
dummy_image = torch.randn([2, 3, 280, 420])
onnx_output = f"{model_name}.onnx"
torch.onnx.export(
    export_model,
    (dummy_image, ),
    onnx_output,
    input_names=["pixel_values"],
    output_names=["predicted_depth", "predicted_normal", "normal_confidence"],
    opset_version=11,

    dynamic_axes= {
      "pixel_values": {0: "batch_size", 2: "height", 3: "width"},
      "predicted_depth": {0: "batch_size", 1: "height", 2: "width"},
      "predicted_normal": {0: "batch_size", 2: "height", 3: "width"},
      "normal_confidence": {0: "batch_size", 1: "height", 2: "width"},
    }
)

There are minor differences in output, but this can be attributed to (1) implementation differences between ORT and pytorch, and (2) default dtypes.

Diff between normalized images:

PyTorch ONNX Diff
image image image

Example usage in python:

  1. Download the model:
wget https://huggingface.co/onnx-community/metric3d-vit-small/resolve/main/onnx/model.onnx
  1. Run model
import onnxruntime as ort
import requests
import numpy as np
from PIL import Image

# Load session
ort_session = ort.InferenceSession("./model.onnx", providers=['CPUExecutionProvider'])

# Load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# Predict depth
input = np.array(image).transpose(2, 0, 1)
input = np.expand_dims(input, 0) # Add batch dim
onnxruntime_input = {'pixel_values': input.astype(np.float32)}
pred_depth, pred_normal, normal_confidence = ort_session.run(None, onnxruntime_input)
  1. Visualize results
min_val = pred_depth.min()
max_val = pred_depth.max()
normalized = 255 * ((pred_depth - min_val)/(max_val-min_val))
Image.fromarray(normalized[0].astype(np.uint8)).save('depth.png')

image

Hi @xenova , thx for your support. Do you mind joining this project and updating your efforts to our README?

@YvanYin you're welcome! :) Do you mean submitting a PR? If so, then sure!

@xenova I invited you.

@xenova Great work!
I don't know if I'm missing something but the onnx-community/metric3d-vit-giant2 model seems to be incomplete looking at the size. (model.onnx is 1,64MB)
If I try to load the model_fp16 it throws an error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from weight/giant/model_fp16.onnx failed:Type Error: Type parameter (T) of Optype (Add) bound to different types (tensor(int64) and tensor(float16) in node (/depth_model/decoder/Add_6).

Using the large model it works great!

@xenova Great work! I don't know if I'm missing something but the onnx-community/metric3d-vit-giant2 model seems to be incomplete looking at the size. (model.onnx is 1,64MB) If I try to load the model_fp16 it throws an error: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from weight/giant/model_fp16.onnx failed:Type Error: Type parameter (T) of Optype (Add) bound to different types (tensor(int64) and tensor(float16) in node (/depth_model/decoder/Add_6).

Using the large model it works great!

Hi! Great work on both models and export. Smaller ones work great.
I'm hitting the same issue for giant2 model for both versions (fp16 and fp32). Any update on this?
Tagging in case these comments are overlooked in a closed issue @xenova @YvanYin

same problem here, can anyone show the right inference script for onnx giant2 ? really appricate

Hello,

I'm using the code that you provided (here) with my own image. I see that the output size is not the same as the input size. Is this ok?

Input shape:  (3, 960, 1280)
Pred depth shape:  (1, 952, 1272)

Thank you in advance!