Visualize Image Classification
wkentaro opened this issue · 0 comments
wkentaro commented
def draw_image_classification_top5(img, label_names, proba):
assert len(label_names) == len(proba)
square_size = min(img.shape[:2])
img = centerize(img, dst_shape=(square_size, square_size))
# draw bars
bars = np.zeros((square_size // 2, square_size, 3), dtype=np.uint8)
bars.fill(255)
step = square_size // (2 * 5)
for i in range(5):
y1 = step * i
y2 = y1 + step
x1 = 0
x2 = int(square_size * proba[i])
color = np.array((proba[i], 0, 1 - proba[i]))
bars[y1:y2, x1:x2] = (color * 255).astype(np.uint8)
bars[y1:y1+1, :] = 0
bars[y2-1:y2, :] = 0
label_name = label_names[i].split(' ')[1].strip(',')
if _OPENCV_AVAILABLE:
cv2.putText(bars, label_name,
(x1 + square_size // 2, y1 + step // 2),
cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0))
else:
_warn_opencv_unavailable()
bars = bars[:y2, :]
img = np.vstack((img, bars))
return img