wkentaro/imgviz

Visualize Image Classification

wkentaro opened this issue · 0 comments

def draw_image_classification_top5(img, label_names, proba):
    assert len(label_names) == len(proba)
    square_size = min(img.shape[:2])
    img = centerize(img, dst_shape=(square_size, square_size))
    # draw bars
    bars = np.zeros((square_size // 2, square_size, 3), dtype=np.uint8)
    bars.fill(255)
    step = square_size // (2 * 5)
    for i in range(5):
        y1 = step * i
        y2 = y1 + step
        x1 = 0
        x2 = int(square_size * proba[i])
        color = np.array((proba[i], 0, 1 - proba[i]))
        bars[y1:y2, x1:x2] = (color * 255).astype(np.uint8)
        bars[y1:y1+1, :] = 0
        bars[y2-1:y2, :] = 0
        label_name = label_names[i].split(' ')[1].strip(',')
        if _OPENCV_AVAILABLE:
            cv2.putText(bars, label_name,
                        (x1 + square_size // 2, y1 + step // 2),
                        cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
        else:
            _warn_opencv_unavailable()
    bars = bars[:y2, :]
    img = np.vstack((img, bars))
    return img