Benchmarking Batch Inference
merveenoyan opened this issue ยท 2 comments
Hello ๐๐ผ I'm running some benchmarks on TinySAM, trying to benchmark batch inference. However, I hit a wall during batch inference. All inputs are torch tensors with the shapes expected by docstrings in inference code:
point_prompt.shape # torch.Size([4, 1, 2]) BXNX2
input_label.shape # torch.Size([4, 1]) BXN
batched_image.shape # torch.Size([4, 3, 1024, 1024]), BCHW
predictor.set_torch_image(batched_image, original_image_size=batched_image[0, 0, :, :].shape) # goes well
# this fails
with torch.no_grad():
_, _, _ = predictor.predict_torch(
point_coords=point_prompt,
point_labels=input_label)
I don't really have a lot of time to debug this as I already did couple of steps, I feel like I'm missing a step, can you let me know if so? I can post a full trace if you want but I really feel like I'm missing a step and hence it errors out.
Hi merveenoyan,
Thanks for your attempt to benchmark TinySAM. The original interface of SAM does not support batch inference on multiple images, and weโve followed this design. From the note of set_torch_image
and predict_torch
, we can find that,
def set_torch_image(
self,
transformed_image: torch.Tensor,
original_image_size: Tuple[int, ...],
) -> None:
"""
Arguments:
transformed_image (torch.Tensor): The input image, with shape
1x3xHxW, which has been transformed with ResizeLongestSide.
"""
in which the shape of image is set as 1x3xHxW
. And in predict_torch
def predict_torch(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using ResizeLongestSide.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
"""
There is batch dimension for point_coords
and point_labels
, which means multiple points in one image. So the batch inference of SAM/TinySAM only supports multiple prompts for one image, not for multiple images.
As for benchmark, I think it is possible to eval SAM/TinySAM under the same batch settings since we have the same interface. So there is no necessity to implement batch inference for multiple images.
@Gaffey thanks a lot for the swift response! I'll keep this in mind ๐๐ผ