Visualize Semantic Segmentation with Legend
wkentaro opened this issue · 1 comments
wkentaro commented
def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
"""Draw pixel-wise label with colorization and label names.
label: ndarray, (H, W)
Pixel-wise labels to colorize.
img: ndarray, (H, W, 3), optional
Image on which the colorized label will be drawn.
label_names: iterable
List of label names.
"""
import matplotlib.pyplot as plt
backend_org = plt.rcParams['backend']
plt.switch_backend('agg')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
wspace=0, hspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
if label_names is None:
label_names = [str(l) for l in range(label.max() + 1)]
colormap = _validate_colormap(colormap, len(label_names))
label_viz = label2rgb(
label, img, n_labels=len(label_names), colormap=colormap, **kwargs
)
plt.imshow(label_viz)
plt.axis('off')
plt_handlers = []
plt_titles = []
for label_value, label_name in enumerate(label_names):
if label_value not in label:
continue
fc = colormap[label_value]
p = plt.Rectangle((0, 0), 1, 1, fc=fc)
plt_handlers.append(p)
plt_titles.append('{value}: {name}'
.format(value=label_value, name=label_name))
plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
f = io.BytesIO()
plt.savefig(f, bbox_inches='tight', pad_inches=0)
plt.cla()
plt.close()
plt.switch_backend(backend_org)
out_size = (label_viz.shape[1], label_viz.shape[0])
out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
out = np.asarray(out)
return out