|
def instance_inference(self, mask_cls, mask_pred): |
|
# mask_pred is already processed to have the same shape as original input |
|
image_size = mask_pred.shape[-2:] |
|
|
|
# [Q, K] |
|
scores = F.softmax(mask_cls, dim=-1)[:, :-1] |
|
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) |
|
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False) |
|
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) |
|
labels_per_image = labels[topk_indices] |
|
|
|
topk_indices = topk_indices // self.sem_seg_head.num_classes |
|
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1) |
|
mask_pred = mask_pred[topk_indices] |
|
|
|
# if this is panoptic segmentation, we only keep the "thing" classes |
|
if self.panoptic_on: |
|
keep = torch.zeros_like(scores_per_image).bool() |
|
for i, lab in enumerate(labels_per_image): |
|
keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values() |
|
|
|
scores_per_image = scores_per_image[keep] |
|
labels_per_image = labels_per_image[keep] |
|
mask_pred = mask_pred[keep] |
|
|
|
result = Instances(image_size) |
|
# mask (before sigmoid) |
|
result.pred_masks = (mask_pred > 0).float() |
|
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) |
|
# Uncomment the following to get boxes from masks (this is slow) |
|
# result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() |
|
|
|
# calculate average mask prob |
|
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) |
|
result.scores = scores_per_image * mask_scores_per_image |
|
result.pred_classes = labels_per_image |
|
return result |