Unable to export decoder in onnx format for GPU context
Greg7000 opened this issue · 0 comments
I am currently trying to execute notebook block 10 of this link https://github.com/AndreyGermanov/sam_onnx_full_export/blob/main/sam_onnx_export.ipynb}
Which is:
# Export mask decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
f="vit_b_decoder.onnx",
model=onnx_model,
args=tuple(dummy_inputs.values()),
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes={
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"}
},
export_params=True,
opset_version=17,
do_constant_folding=True
)
This works perfectly fine for cpu context. However when trying to do it for a GPU context using:
from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel
# Load SAM model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
# sam = sam.cuda()
sam.to(device="cuda")
# Export mask decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float).cuda(),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float).cuda(),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float).cuda(),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float).cuda(),
"has_mask_input": torch.tensor([1], dtype=torch.float).cuda(),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float).cuda(),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
f="bob/vit_h_decoder.onnx",
model=onnx_model,
args=tuple(dummy_inputs.values()),
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes={"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}},
export_params=True,
opset_version=17,
do_constant_folding=True,
)
I get:
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Raised from segment_anything.modeling.mask_decoder.MaskDecoder.predict_mask line 126
(raised from torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0))
I tried many different things but the only way I managed to get it to work is by modifying a bit mask_decoder.py which is undesirable.
Anybody got a suggestion that could avoid any modifications to mask_decoder.py
I have a cpu remnant somewhere, maybe I need to convert embed_size and mask_input_size to torch.Size() but it does note seem to be enough