google-research/scenic

How to detect more than one predictions for target image?

DishantMewada opened this issue · 1 comments

In the minimal_example colab, we have block for 'Get predictions for target image with the query embedding', which detects the single closest match according to the source image. Is it possible to detect more than one object with a specific 'score'?

I am talking about the code below, present in the colab -

feature_map = image_embedder(target_image[None, ...])

b, h, w, d = feature_map.shape
target_boxes = box_predictor(
    image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map
)['pred_boxes']

target_class_predictions = class_predictor(
    image_features=feature_map.reshape(b, h * w, d),
    query_embeddings=query_embedding[None, None, ...],  # [batch, queries, d]
)


# Remove batch dimension and convert to numpy:
target_boxes = np.array(target_boxes[0])
target_logits = np.array(target_class_predictions['pred_logits'][0])

top_ind = np.argmax(target_logits[:, 0], axis=0)
score = sigmoid(target_logits[top_ind, 0])


fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(target_image, extent=(0, 1, 1, 0))
ax.set_axis_off()

cx, cy, w, h = target_boxes[top_ind]
ax.plot(
    [cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2],
    [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2],
    color='lime',
)

ax.text(
    cx - w / 2 + 0.015,
    cy + h / 2 - 0.015,
    f'Score: {score:1.2f}',
    ha='left',
    va='bottom',
    color='black',
    bbox={
        'facecolor': 'white',
        'edgecolor': 'lime',
        'boxstyle': 'square,pad=.3',
    },
)

ax.set_xlim(0, 1)
ax.set_ylim(1, 0)
ax.set_title(f'Closest match')

Thank you so much.

I have modified the code cell as follows to detect multiple objects.

Mainly,top_ind = np.argmax(target_logits[:, 0], axis=0)provides the index of closest match, which I have changed to
top_ind = np.argsort(target_logits[:, 0], axis=0)[-i] and iterating through the len(target_boxes).

DESIERED_SCORE = 0.97
NUMBER_OF_CLOSEST_OBJECTS = 5

feature_map = image_embedder(target_image[None, ...])

b, h, w, d = feature_map.shape
target_boxes = box_predictor(
    image_features=feature_map.reshape(b, h * w, d), feature_map=feature_map
)['pred_boxes']

target_class_predictions = class_predictor(
    image_features=feature_map.reshape(b, h * w, d),
    query_embeddings=query_embedding[None, None, ...],  # [batch, queries, d]
)

# Remove batch dimension and convert to numpy:
target_boxes = np.array(target_boxes[0])
target_logits = np.array(target_class_predictions['pred_logits'][1])

len_target_boxes = len(target_boxes)
# top_ind = np.argmax(target_logits[:, 0], axis=0)

dimension_list = []
objects_counter = 0
score_list = []

for i in range(len_target_boxes):
    
    top_ind = np.argsort(target_logits[:, 0], axis=0)[-i]
    
    score = sigmoid(target_logits[top_ind, 0])

    objects_counter = objects_counter + 1
    
    if score > DESIERED_SCORE and objects_counter <= NUMBER_OF_CLOSEST_OBJECTS:
        
        cx, cy, w, h = target_boxes[top_ind]
        dimension_list.append([cx, cy, w, h])

        score_list.append(score)
    
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.imshow(target_image, extent=(0, 1, 1, 0))
ax.set_axis_off()

for i in range(len(dimension_list)):
    cx = dimension_list[i][0]
    cy = dimension_list[i][1]
    w = dimension_list[i][2]
    h = dimension_list[i][3]
  
    ax.plot(
        [cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2],
        [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2],
        color='lime',
    )
    
    ax.text(
        cx - w / 2 + 0.015,
        cy + h / 2 - 0.015,
        f'Score: {score_list[i]:1.2f}',
        ha='left',
        va='bottom',
        color='black',
        bbox={
            'facecolor': 'white',
            'edgecolor': 'lime',
            'boxstyle': 'square,pad=.3',
        },
    )
    
ax.set_xlim(0, 1)
ax.set_ylim(1, 0)
ax.set_title(f'Closest match')

Let me know if you find something is wrong in the code, or have tips to write the code better.