Density Map Visualization
Closed this issue · 1 comments
Sunhill666 commented
Hello, first of all thank you very much for your work, and is there any code to visualizied output density map of model? Thx!
Sunhill666 commented
Here's my solution:
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from data.fsc import FSC147Dataset
from models.loca import LOCA
@torch.no_grad()
def main(_device):
model = LOCA(
image_size=512,
num_encoder_layers=3,
num_ope_iterative_steps=3,
num_objects=3,
emb_dim=256,
num_heads=8,
kernel_dim=3,
backbone_name="resnet50",
swav_backbone=True,
train_backbone=False,
reduction=8,
dropout=0.1,
layer_norm_eps=1e-5,
mlp_factor=8,
norm_first=True,
activation="gelu",
norm=True,
zero_shot=False
)
val = FSC147Dataset(
"datasets/FSC147",
512,
split="val",
num_objects=3,
tiling_p=0.5
)
val_loader = DataLoader(val, batch_size=4, shuffle=False, num_workers=0, drop_last=False)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
state_dict = torch.load("loca_few_shot.pt", map_location=torch.device(_device))["model"]
model.load_state_dict(state_dict, False)
model.eval().to(_device)
for data in val_loader:
img, bboxes, density_map = data
img = img.to(_device)
bboxes = bboxes.to(_device)
density_map = density_map.to(_device)
for i in range(density_map.size(0)):
print(f"Ground truth density map {i} sum: {density_map[i].sum()}")
_, out, _ = model(img, bboxes)
for i in range(out.size(0)):
print(f"Predicted density map {i} sum: {out[i].sum()}")
output = out[i, ...].squeeze().cpu().numpy()
output = (output - output.min()) / (output.max() - output.min())
output_image = img[i]
for j in range(3):
output_image[j] = output_image[j] * std[j] + mean[j]
output_image = cv2.cvtColor(
(output_image * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8), cv2.COLOR_RGB2BGR
)
colored_output = cv2.applyColorMap(
(output * 255).astype(np.uint8), cv2.COLORMAP_JET
)
fuse_image = cv2.addWeighted(output_image, 1, colored_output, 0.7, 0)
cv2.imwrite(f"out/loca_fuse_image_{i}.png", fuse_image)
cv2.imwrite(f"out/loca_predict_density_map_{i}.png", colored_output)
cv2.imwrite(f"out/loca_original_{i}.png", output_image)
break
if __name__ == "__main__":
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
main(device)