NazirNayal8/RbA

Batch inference?

Opened this issue · 1 comments

Hey Nazir, thank you for your exciting research and work on RbA!
I'm trying to potentially run the model for multiple images/frames in a directory in batches.
I noticed that in here you mention only batch_size of 1 is supported.
Is that a limitation related to Mask2Former?
What are the difficulties in adapting the model/code to accept more images in parallel?

Thanks again for your contributions to the community!

Hi @pieris98 , thank you for your interest in our work!

In the script that you have linked, we only use a single batch for simplicity. It can be of course extended to any arbitrary batch size you need. The only important thing to be careful about is satisfying the format expected by mask2former, which requires that a batch is passed as a list, where each image is a dictionary inside that list. For example, the modified version of the function that runs a batch of size 2 would look like this:

def get_logits(model, img1, img2, **kwargs):
  
    with torch.no_grad():
        out = model([{"image": img1.to(DEVICE)}, {"image": img2.to(DEVICE)}])

    return out

note that img1 and img2 are each of shape (3,H,W) (that is the batch dimension is collapsed. And you can of course extend this the same way for any batch size you need.

I hope this was helpful, if you have any further question please do not hesitate to ask.