What is the shape of the encoder's boxes?
raoxinyu4977 opened this issue · 7 comments
I attempted to set the shape of the encoder input boxes as (4, 10, 4), representing (bs, num_boxes, 2 box corners). However, during the operation:
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
The boxes encoder is reshaped to (bsnum_boxes, 2, 2), outputting (bsnum_boxes, 2, 256). However, in the forward function, the sparse matrix sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) has a shape of (10, 0, 256). When concatenating with sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1), the dimensions don't align.
The expected shape for box inputs is: Bx4
And if you're using points & labels, the shape should be BxNx2 and BxN respectively, where B is the batch size and N is the number of points. So as-is, the model doesn't support having multiple box prompts for a single mask in the same way that it supports having multiple points (i.e. there is no 'N' component in the box shape).
In theory you could modify the code to support having 'N' boxes, by generating the corner_embedding
output for each of the 'N' (10 in the example you gave) boxes and concatenating them together into a single corner_embedding
. I think the result should have a shape of Bx(2N)x(embed_dim) to work with the existing code. Though it's unclear how well this would work since it's not part of the original model behavior/training (but worth trying maybe).
thanks, I know
Hello, I've encountered the same issue as you. Could you please share how you resolved it? I would greatly appreciate your assistance. Thank you.
There's some code on the SAMv2 issue board that provides support for having multiple boxes for a single prompt. That code references changes to the newer code base, but the equivalent code for SAMv1 can be found in the prompt_encoder script.
Though it seems both SAMv2 and SAMv1 perform poorly when using more than 1 box.
Thank you very much! I will try.
Hello,I've encountered a new issue: For example, in one training batch, I have 2 images; the first image has 2 bounding boxes, and the second image has 3 bounding boxes. In this case, how should I conduct the training? Can the program understand the correspondence between images and bounding boxes within a batch?
In general, if you have different shaped data, it would need to be processed in separate batches. In this case if you had multiple images with 2 bounding boxes you could batch all of them together and likewise for images with 3 bounding boxes.
Alternatively, the SAM model includes a not a point embedding that can be used to pad the prompts, so you could use this to make the 2-box prompt tensors the same shape as the 3-box prompts.
Each box prompt adds two 'points' to the prompt tensor, so I think to pad a 2-box prompt to match the shape of a 3 box prompt, you'd need to do something like:
# Pad 2-box prompt encoding to match 3-box encoding shape
pad_embed = predictor.model.prompt_encoder.not_a_point_embed
sparse_embeddings = torch.cat([sparse_embeddings, pad_embed, pad_embed], dim=1)
This would require modifying the sparse_embeddings
that are generated by the prompted encoder (which normally happens inside the predict function).