facebookresearch/segment-anything

Unable to export decoder in onnx format for GPU context

Greg7000 opened this issue · 0 comments

I am currently trying to execute notebook block 10 of this link https://github.com/AndreyGermanov/sam_onnx_full_export/blob/main/sam_onnx_export.ipynb}

Which is:

# Export mask decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
    f="vit_b_decoder.onnx",
    model=onnx_model,
    args=tuple(dummy_inputs.values()),
    input_names=list(dummy_inputs.keys()),
    output_names=output_names,
    dynamic_axes={
        "point_coords": {1: "num_points"},
        "point_labels": {1: "num_points"}
    },
    export_params=True,
    opset_version=17,
    do_constant_folding=True
)

This works perfectly fine for cpu context. However when trying to do it for a GPU context using:

from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel

# Load SAM model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")

# sam = sam.cuda()
sam.to(device="cuda")


# Export mask decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)


embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float).cuda(),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float).cuda(),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float).cuda(),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float).cuda(),
    "has_mask_input": torch.tensor([1], dtype=torch.float).cuda(),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float).cuda(),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
    f="bob/vit_h_decoder.onnx",
    model=onnx_model,
    args=tuple(dummy_inputs.values()),
    input_names=list(dummy_inputs.keys()),
    output_names=output_names,
    dynamic_axes={"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}},
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
)

I get:

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

Raised from segment_anything.modeling.mask_decoder.MaskDecoder.predict_mask line 126
(raised from torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0))

I tried many different things but the only way I managed to get it to work is by modifying a bit mask_decoder.py which is undesirable.

Anybody got a suggestion that could avoid any modifications to mask_decoder.py

I have a cpu remnant somewhere, maybe I need to convert embed_size and mask_input_size to torch.Size() but it does note seem to be enough