Improved ONNX support with dynamic shapes
xenova opened this issue ยท 9 comments
Hi there! ๐ Following the conversation in #103, I wanted to export the models so they (1) support dynamic shapes and (2) returned the normal information, mainly to run the models with Transformers.js. I got them working, and I've uploaded them to the Hugging Face Hub:
- https://huggingface.co/onnx-community/metric3d-vit-small
- https://huggingface.co/onnx-community/metric3d-vit-large
- https://huggingface.co/onnx-community/metric3d-vit-giant2
(you can find the .onnx weights - both fp32 and fp16) in the onnx
subfolder)
Feel free to use them yourself or add the links to the README for increased visibility! ๐ค PS: I'd also recommend uploading your original pytorch checkpoints to separate repos (instead of a single repo). Let me know if I can help with any of this!
Regarding the export, there were a few things to consider, mainly fixing the modelling code to avoid python type casts (ensuring the dynamic shapes work during tracing). I also made a few modifications to support CPU exports. Here's my conversion code:
import torch
import math
import torch.nn as nn
class NullContext:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
# Do not autocast to bf16 or cuda
torch.autocast = NullContext
class Metric3DExportModel(nn.Module):
"""
The model for exporting to ONNX format. Add custom preprocessing and postprocessing here.
"""
def __init__(self, meta_arch):
super().__init__()
self.meta_arch = meta_arch
self.register_buffer(
"rgb_mean", torch.tensor([123.675, 116.28, 103.53]).view(1, 3, 1, 1)
)
self.register_buffer(
"rgb_std", torch.tensor([58.395, 57.12, 57.375]).view(1, 3, 1, 1)
)
def normalize_image(self, image):
image = image - self.rgb_mean
image = image / self.rgb_std
return image
def forward(self, image):
image = self.normalize_image(image)
with torch.no_grad():
pred_depth, confidence, output_dict = self.meta_arch.inference(
{"input": image}
)
pred_depth = pred_depth.squeeze(1)
pred_normal = output_dict['prediction_normal'][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details
return pred_depth, pred_normal, normal_confidence
def patch_model(model):
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
# Comment out this code (so we always interpolate)
# if npatch == N and w == h:
# return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
if torch.jit.is_tracing():
sqrt_N = N ** 0.5
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, (sqrt_N).to(torch.int64), (sqrt_N).to(torch.int64), dim).permute(0, 3, 1, 2),
size=(w0, h0),
mode="bicubic",
antialias=self.interpolate_antialias,
)
else:
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
mode="bicubic",
antialias=self.interpolate_antialias,
)
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
model.depth_model.encoder.interpolate_pos_encoding = (
interpolate_pos_encoding.__get__(
model.depth_model.encoder, model.depth_model.encoder.__class__
)
)
def get_bins(self, bins_num):
depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num)
depth_bins_vec = torch.exp(depth_bins_vec)
return depth_bins_vec
model.depth_model.decoder.get_bins = (
get_bins.__get__(
model.depth_model.decoder, model.depth_model.decoder.__class__
)
)
return model
# Load model
model_name = "metric3d_vit_small" # or "metric3d_vit_large" or "metric3d_vit_giant2"
model = torch.hub.load("yvanyin/metric3d", model_name, pretrain=True)
model.eval()
# Patch model so we can export to ONNX
model = patch_model(model)
export_model = Metric3DExportModel(model)
export_model.eval()
# Export the model
dummy_image = torch.randn([2, 3, 280, 420])
onnx_output = f"{model_name}.onnx"
torch.onnx.export(
export_model,
(dummy_image, ),
onnx_output,
input_names=["pixel_values"],
output_names=["predicted_depth", "predicted_normal", "normal_confidence"],
opset_version=11,
dynamic_axes= {
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
"predicted_depth": {0: "batch_size", 1: "height", 2: "width"},
"predicted_normal": {0: "batch_size", 2: "height", 3: "width"},
"normal_confidence": {0: "batch_size", 1: "height", 2: "width"},
}
)
Example usage in python:
- Download the model:
wget https://huggingface.co/onnx-community/metric3d-vit-small/resolve/main/onnx/model.onnx
- Run model
import onnxruntime as ort
import requests
import numpy as np
from PIL import Image
# Load session
ort_session = ort.InferenceSession("./model.onnx", providers=['CPUExecutionProvider'])
# Load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# Predict depth
input = np.array(image).transpose(2, 0, 1)
input = np.expand_dims(input, 0) # Add batch dim
onnxruntime_input = {'pixel_values': input.astype(np.float32)}
pred_depth, pred_normal, normal_confidence = ort_session.run(None, onnxruntime_input)
- Visualize results
min_val = pred_depth.min()
max_val = pred_depth.max()
normalized = 255 * ((pred_depth - min_val)/(max_val-min_val))
Image.fromarray(normalized[0].astype(np.uint8)).save('depth.png')
Hi @xenova , thx for your support. Do you mind joining this project and updating your efforts to our README?
@xenova Great work!
I don't know if I'm missing something but the onnx-community/metric3d-vit-giant2 model seems to be incomplete looking at the size. (model.onnx is 1,64MB)
If I try to load the model_fp16 it throws an error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from weight/giant/model_fp16.onnx failed:Type Error: Type parameter (T) of Optype (Add) bound to different types (tensor(int64) and tensor(float16) in node (/depth_model/decoder/Add_6).
Using the large model it works great!
@xenova Great work! I don't know if I'm missing something but the onnx-community/metric3d-vit-giant2 model seems to be incomplete looking at the size. (model.onnx is 1,64MB) If I try to load the model_fp16 it throws an error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from weight/giant/model_fp16.onnx failed:Type Error: Type parameter (T) of Optype (Add) bound to different types (tensor(int64) and tensor(float16) in node (/depth_model/decoder/Add_6).
Using the large model it works great!
Hi! Great work on both models and export. Smaller ones work great.
I'm hitting the same issue for giant2 model for both versions (fp16 and fp32). Any update on this?
Tagging in case these comments are overlooked in a closed issue @xenova @YvanYin
same problem here, can anyone show the right inference script for onnx giant2 ? really appricate
Hello,
I'm using the code that you provided (here) with my own image. I see that the output size is not the same as the input size. Is this ok?
Input shape: (3, 960, 1280)
Pred depth shape: (1, 952, 1272)
Thank you in advance!