djukicn/loca

Density Map Visualization

Closed this issue · 1 comments

Hello, first of all thank you very much for your work, and is there any code to visualizied output density map of model? Thx!

Here's my solution:

import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader

from data.fsc import FSC147Dataset
from models.loca import LOCA


@torch.no_grad()
def main(_device):
    model = LOCA(
        image_size=512,
        num_encoder_layers=3,
        num_ope_iterative_steps=3,
        num_objects=3,
        emb_dim=256,
        num_heads=8,
        kernel_dim=3,
        backbone_name="resnet50",
        swav_backbone=True,
        train_backbone=False,
        reduction=8,
        dropout=0.1,
        layer_norm_eps=1e-5,
        mlp_factor=8,
        norm_first=True,
        activation="gelu",
        norm=True,
        zero_shot=False
    )

    val = FSC147Dataset(
        "datasets/FSC147",
        512,
        split="val",
        num_objects=3,
        tiling_p=0.5
    )

    val_loader = DataLoader(val, batch_size=4, shuffle=False, num_workers=0, drop_last=False)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    state_dict = torch.load("loca_few_shot.pt", map_location=torch.device(_device))["model"]
    model.load_state_dict(state_dict, False)
    model.eval().to(_device)

    for data in val_loader:
        img, bboxes, density_map = data
        img = img.to(_device)
        bboxes = bboxes.to(_device)
        density_map = density_map.to(_device)

        for i in range(density_map.size(0)):
            print(f"Ground truth density map {i} sum: {density_map[i].sum()}")

        _, out, _ = model(img, bboxes)

        for i in range(out.size(0)):
            print(f"Predicted density map {i} sum: {out[i].sum()}")
            output = out[i, ...].squeeze().cpu().numpy()
            output = (output - output.min()) / (output.max() - output.min())
            output_image = img[i]
            for j in range(3):
                output_image[j] = output_image[j] * std[j] + mean[j]
            output_image = cv2.cvtColor(
                (output_image * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), cv2.COLOR_RGB2BGR
            )
            colored_output = cv2.applyColorMap(
                (output * 255).astype(np.uint8), cv2.COLORMAP_JET
            )
            fuse_image = cv2.addWeighted(output_image, 1, colored_output, 0.7, 0)
            cv2.imwrite(f"out/loca_fuse_image_{i}.png", fuse_image)
            cv2.imwrite(f"out/loca_predict_density_map_{i}.png", colored_output)
            cv2.imwrite(f"out/loca_original_{i}.png", output_image)
        break


if __name__ == "__main__":
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")
    main(device)