seanzhuh/SeqTR

Visualization

Closed this issue · 2 comments

Hi,

Congratulation!

I want to visualize the attention weights of segmentation points similar to Fig. 5.

According to the paper: "We visualize the cross attention map averaged over decoder layers and attention heads in Fig. 5.", but I am not sure how to incorporate these weights into the original image.

Would you like to share the script or provide a workable idea?

Thanks~

def imshow_attention(img, attn_weights, out_file):
    img = numpy.ascontiguousarray(img)[:, :, ::-1]
    h, w = img.shape[:2]

    attn_weights = torch.cat(list(map(lambda weights: torch.mean(
        weights, dim=0, keepdim=True), torch.split(attn_weights, [3, 3, 3, 3]))), dim=0)
    assert attn_weights.size(0) == 4 and attn_weights.ndim == 2
    attn_weights = attn_weights.reshape(4, 20, 20)
    attn_weights = attn_weights[:, :h // 32 + 1, :w // 32 + 1]
    attn_weights = attn_weights.cpu().numpy()

    for i in range(4):
        plt.clf()
        plt.axis('off')
        plt.imshow(img, alpha=0.7)
        attn_mask = cv2.resize(attn_weights[i], (w, h))
        attn_mask = (attn_mask * 255).astype(numpy.uint8)
        plt.imshow(attn_mask, alpha=0.3,
                   interpolation="bilinear", cmap="jet")
        plt.savefig(out_file.replace(".jpg", f"_{i}th_step.jpg"), dpi=300)
from typing import Sequence
import mmcv
import torch
import numpy
import argparse
import os.path as osp
import torch.nn.functional as f
from mmcv import Config, DictAction
from mmcv.utils import mkdir_or_exist
from seqtr.models import build_model
from seqtr.core import imshow_attention
from seqtr.utils import load_checkpoint, get_root_logger
from seqtr.datasets import extract_data, build_dataset, build_dataloader
try:
    import apex
except:
    pass


def parse_args():
    parser = argparse.ArgumentParser(description="SeqTR")
    parser.add_argument('config', help='visualize config file path')
    parser.add_argument(
        'checkpoint', help='the checkpoint file to load from.')
    parser.add_argument(
        '--output-dir', help='directory where visualized results will be saved.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--visualize',
        type=str,
        nargs='+',
        default='testA',
        help="evaluation set, which depends on the dataset, e.g., \
        'val', 'testA', 'testB' for RefCOCO(Plus)UNC, and 'val', 'test' for RefCOCOgUMD.")
    args = parser.parse_args()
    return args


def main(cfg):
    datasets_cfg = [cfg.data.train]
    for vis_set in cfg.visualize:
        datasets_cfg.append(eval(f"cfg.data.{vis_set}"))

    datasets = list(map(build_dataset, datasets_cfg))
    dataloaders = list(
        map(lambda dataset: build_dataloader(cfg, dataset), datasets))

    model = build_model(cfg,
                        word_emb=datasets[0].word_emb,
                        num_token=datasets[0].num_token)
    model = model.cuda()
    if cfg.use_fp16:
        model = apex.amp.initialize(model, opt_level="O1")
        for m in model.modules():
            if hasattr(m, "fp16_enabled"):
                m.fp16_enabled = True
    load_checkpoint(model, None, None, cfg.checkpoint)

    model.eval()
    model.head.transformer.need_weights = True
    model.head.transformer.decoder.need_weights = True
    decoder_layers = model.head.transformer.decoder.num_layers
    logger = get_root_logger()
    for i, vis_set in enumerate(cfg.visualize):
        logger.info(f"visualizing attention on set {vis_set}")
        output_dir = osp.join(cfg.output_dir, cfg.dataset, vis_set)
        mkdir_or_exist(output_dir)
        with torch.no_grad():
            prog_bar = mmcv.ProgressBar(len(datasets[i+1]))
            for batch, inputs in enumerate(dataloaders[i+1]):
                inputs = extract_data(inputs)
                """
                    inputs (Dict): {
                        'img_metas' (List[Dict]): {
                            'filename' (str): './data/images/train2014/COCO_train2014_000000580957.jpg',
                            'expression' (str): 'bowl behind the others can only see part',
                            'ori_shape' (tuple): (h_ori, w_ori, 3),
                            'img_shape' (tuple): (h_img, w_img, 3),
                            'pad_shape' (tuple): (h_pad, w_pad, 3),
                            'scale_factor' (Array): (w_scale, h_scale, w_scale, h_scale),
                            'img_norm_cfg' (dict): {
                                'mean' (Array): [0., 0., 0.] 
                                'std' (Array): [1., 1., 1.]
                            }
                            'to_rgb': True
                        },
                        'img' (Tensor): [batch_size, 3, h_batch, w_batch].
                        'ref_expr_inds' (Tensor): [batch_size, max_token].
                        'gt_bbox' (List[Tensor]): [
                            [tl_x, tl_y, br_x, br_y], in (h_img, w_img) coordinate system.
                        ]
                    }
                """
                img, ref_expr_inds, img_metas = inputs['img'], inputs['ref_expr_inds'], inputs['img_metas']
                batch_size = img.size(0)
                batch_input_shape = tuple(img.size()[-2:])
                for img_meta in img_metas:
                    img_meta['batch_input_shape'] = batch_input_shape

                x, y, y_word, y_mask = model.extract_visual_language(
                    img, ref_expr_inds)

                if model.with_neck:
                    x, y = model.neck(x, y, y_word, y_mask)

                x_mask, x_pos_embeds = model.head.transformer.x_mask_pos_enc(
                    x, img_metas)
                if model.with_fusion:
                    y = model.fusion(x, y_word, x_mask, y_mask)
                memory = model.head.transformer.forward_encoder(
                    x, x_mask, x_pos_embeds)

                attn_weights_all_coordinates = []
                seq_in_embeds = y
                for step in range(4):
                    out, attn_weights_all_layers = model.head.transformer.forward_decoder(
                        seq_in_embeds, memory, x_pos_embeds, x_mask)
                    attn_weights_all_coordinates.append(
                        attn_weights_all_layers)

                    logits = out[:, -1, :]
                    logits = model.head.predictor(logits)
                    logits = logits[:, :-1]  # [batch_size, num_bin]
                    probability = f.softmax(logits, dim=-1)
                    probability, next_token = probability.topk(
                        dim=-1, k=1, largest=True, sorted=True)
                    seq_in_embeds = torch.cat(
                        [seq_in_embeds, model.head.transformer.query_embedding(next_token)], dim=1)

                attn_weights_all_images = []
                for i in range(batch_size):
                    attn_weights_this_img = []
                    for j in range(4):
                        for k in range(decoder_layers):
                            attn_weights_this_img.append(
                                attn_weights_all_coordinates[j][k][i, -1, :])
                    # [4*decoder_layers, 400]
                    attn_weights_this_img = torch.vstack(attn_weights_this_img)
                    attn_weights_all_images.append(attn_weights_this_img)

                for img_meta, attn_weights in zip(img_metas, attn_weights_all_images):
                    filename = img_meta['filename']
                    expression = img_meta['expression'].replace(" ", "")
                    out_file = osp.join(
                        args.output_dir, expression + "_" + osp.basename(filename))
                    img = mmcv.imread(filename).astype(numpy.uint8)

                    imshow_attention(img, attn_weights, out_file)

                    prog_bar.update()


if __name__ == "__main__":
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    if isinstance(args.visualize, str):
        cfg.visualize = [args.visualize]
    elif isinstance(args.visualize, Sequence):
        cfg.visualize = args.visualize
    cfg.checkpoint = args.checkpoint
    cfg.output_dir = args.output_dir

    main(cfg)

Hi, you can reference the above code for attention visualization, the code may be buggy cause I renamed several apis and possibly changed the attribute of transformer decoder (whether return weights, etc) during open-sourcing. Nevertheless, this is the code used to visualize the attention map.

def imshow_attention(img, attn_weights, out_file):
    img = numpy.ascontiguousarray(img)[:, :, ::-1]
    h, w = img.shape[:2]

    attn_weights = torch.cat(list(map(lambda weights: torch.mean(
        weights, dim=0, keepdim=True), torch.split(attn_weights, [3, 3, 3, 3]))), dim=0)
    assert attn_weights.size(0) == 4 and attn_weights.ndim == 2
    attn_weights = attn_weights.reshape(4, 20, 20)
    attn_weights = attn_weights[:, :h // 32 + 1, :w // 32 + 1]
    attn_weights = attn_weights.cpu().numpy()

    for i in range(4):
        plt.clf()
        plt.axis('off')
        plt.imshow(img, alpha=0.7)
        attn_mask = cv2.resize(attn_weights[i], (w, h))
        attn_mask = (attn_mask * 255).astype(numpy.uint8)
        plt.imshow(attn_mask, alpha=0.3,
                   interpolation="bilinear", cmap="jet")
        plt.savefig(out_file.replace(".jpg", f"_{i}th_step.jpg"), dpi=300)
from typing import Sequence
import mmcv
import torch
import numpy
import argparse
import os.path as osp
import torch.nn.functional as f
from mmcv import Config, DictAction
from mmcv.utils import mkdir_or_exist
from seqtr.models import build_model
from seqtr.core import imshow_attention
from seqtr.utils import load_checkpoint, get_root_logger
from seqtr.datasets import extract_data, build_dataset, build_dataloader
try:
    import apex
except:
    pass


def parse_args():
    parser = argparse.ArgumentParser(description="SeqTR")
    parser.add_argument('config', help='visualize config file path')
    parser.add_argument(
        'checkpoint', help='the checkpoint file to load from.')
    parser.add_argument(
        '--output-dir', help='directory where visualized results will be saved.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--visualize',
        type=str,
        nargs='+',
        default='testA',
        help="evaluation set, which depends on the dataset, e.g., \
        'val', 'testA', 'testB' for RefCOCO(Plus)UNC, and 'val', 'test' for RefCOCOgUMD.")
    args = parser.parse_args()
    return args


def main(cfg):
    datasets_cfg = [cfg.data.train]
    for vis_set in cfg.visualize:
        datasets_cfg.append(eval(f"cfg.data.{vis_set}"))

    datasets = list(map(build_dataset, datasets_cfg))
    dataloaders = list(
        map(lambda dataset: build_dataloader(cfg, dataset), datasets))

    model = build_model(cfg,
                        word_emb=datasets[0].word_emb,
                        num_token=datasets[0].num_token)
    model = model.cuda()
    if cfg.use_fp16:
        model = apex.amp.initialize(model, opt_level="O1")
        for m in model.modules():
            if hasattr(m, "fp16_enabled"):
                m.fp16_enabled = True
    load_checkpoint(model, None, None, cfg.checkpoint)

    model.eval()
    model.head.transformer.need_weights = True
    model.head.transformer.decoder.need_weights = True
    decoder_layers = model.head.transformer.decoder.num_layers
    logger = get_root_logger()
    for i, vis_set in enumerate(cfg.visualize):
        logger.info(f"visualizing attention on set {vis_set}")
        output_dir = osp.join(cfg.output_dir, cfg.dataset, vis_set)
        mkdir_or_exist(output_dir)
        with torch.no_grad():
            prog_bar = mmcv.ProgressBar(len(datasets[i+1]))
            for batch, inputs in enumerate(dataloaders[i+1]):
                inputs = extract_data(inputs)
                """
                    inputs (Dict): {
                        'img_metas' (List[Dict]): {
                            'filename' (str): './data/images/train2014/COCO_train2014_000000580957.jpg',
                            'expression' (str): 'bowl behind the others can only see part',
                            'ori_shape' (tuple): (h_ori, w_ori, 3),
                            'img_shape' (tuple): (h_img, w_img, 3),
                            'pad_shape' (tuple): (h_pad, w_pad, 3),
                            'scale_factor' (Array): (w_scale, h_scale, w_scale, h_scale),
                            'img_norm_cfg' (dict): {
                                'mean' (Array): [0., 0., 0.] 
                                'std' (Array): [1., 1., 1.]
                            }
                            'to_rgb': True
                        },
                        'img' (Tensor): [batch_size, 3, h_batch, w_batch].
                        'ref_expr_inds' (Tensor): [batch_size, max_token].
                        'gt_bbox' (List[Tensor]): [
                            [tl_x, tl_y, br_x, br_y], in (h_img, w_img) coordinate system.
                        ]
                    }
                """
                img, ref_expr_inds, img_metas = inputs['img'], inputs['ref_expr_inds'], inputs['img_metas']
                batch_size = img.size(0)
                batch_input_shape = tuple(img.size()[-2:])
                for img_meta in img_metas:
                    img_meta['batch_input_shape'] = batch_input_shape

                x, y, y_word, y_mask = model.extract_visual_language(
                    img, ref_expr_inds)

                if model.with_neck:
                    x, y = model.neck(x, y, y_word, y_mask)

                x_mask, x_pos_embeds = model.head.transformer.x_mask_pos_enc(
                    x, img_metas)
                if model.with_fusion:
                    y = model.fusion(x, y_word, x_mask, y_mask)
                memory = model.head.transformer.forward_encoder(
                    x, x_mask, x_pos_embeds)

                attn_weights_all_coordinates = []
                seq_in_embeds = y
                for step in range(4):
                    out, attn_weights_all_layers = model.head.transformer.forward_decoder(
                        seq_in_embeds, memory, x_pos_embeds, x_mask)
                    attn_weights_all_coordinates.append(
                        attn_weights_all_layers)

                    logits = out[:, -1, :]
                    logits = model.head.predictor(logits)
                    logits = logits[:, :-1]  # [batch_size, num_bin]
                    probability = f.softmax(logits, dim=-1)
                    probability, next_token = probability.topk(
                        dim=-1, k=1, largest=True, sorted=True)
                    seq_in_embeds = torch.cat(
                        [seq_in_embeds, model.head.transformer.query_embedding(next_token)], dim=1)

                attn_weights_all_images = []
                for i in range(batch_size):
                    attn_weights_this_img = []
                    for j in range(4):
                        for k in range(decoder_layers):
                            attn_weights_this_img.append(
                                attn_weights_all_coordinates[j][k][i, -1, :])
                    # [4*decoder_layers, 400]
                    attn_weights_this_img = torch.vstack(attn_weights_this_img)
                    attn_weights_all_images.append(attn_weights_this_img)

                for img_meta, attn_weights in zip(img_metas, attn_weights_all_images):
                    filename = img_meta['filename']
                    expression = img_meta['expression'].replace(" ", "")
                    out_file = osp.join(
                        args.output_dir, expression + "_" + osp.basename(filename))
                    img = mmcv.imread(filename).astype(numpy.uint8)

                    imshow_attention(img, attn_weights, out_file)

                    prog_bar.update()


if __name__ == "__main__":
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    if isinstance(args.visualize, str):
        cfg.visualize = [args.visualize]
    elif isinstance(args.visualize, Sequence):
        cfg.visualize = args.visualize
    cfg.checkpoint = args.checkpoint
    cfg.output_dir = args.output_dir

    main(cfg)

Hi, you can reference the above code for attention visualization, the code may be buggy cause I renamed several apis and possibly changed the attribute of transformer decoder (whether return weights, etc) during open-sourcing. Nevertheless, this is the code used to visualize the attention map.

Well, thank you very much for your generous sharing. I think many of the latter will learn a lot about visualization skills from this, eg. me, haha. Thanks again for your quick reply!