Details for cross datasets validation on the CARPK dataset
Closed this issue · 2 comments
Hi @djukicn, thanks for your excellent work! While trying to follow your experiments, I found the code for cross-dataset validation missing. Could you please provide details about how to perform cross-dataset validation on the CARPK dataset?
You mentioned you used 12 exemplars (pre-cropped images) for all the test images in CARPK, which I guess are the ones provided by the authors of BMNet, I wonder how I could pass the pre-cropped exemplars into your current framework.
I would appreciate it if you could provide the code for it, thanks!
Hi @5k5000, thank you for your question. As the BMNet paper did not provide information on which exemplars they exactly used, we randomly sampled 12 objects from the training set and used them in our experiments.
As you noticed, our framework does not easily support the use of pre-cropped exemplars that are not in the input image. This is because such approach deviates from the definition of few-shot counting: input image and few bounding boxes around exemplar objects.
To accommodate for this experiment, I wrote a script that first extracts shape and appearance queries for the 12 exemplars and uses them in each iterative adaptation forward pass during testing. The code snippet below should be enough to reproduce results from the paper. You can find the pretrained model here (it was trained without car images in the FSC147 dataset).
You can run the script in the same way as evaluate.py
, making sure that you set --model_name loca_carpk
and that your data_path
points to the CARPK dataset.
from models.loca import build_model
from utils.arg_parser import get_argparser
import argparse
import os
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torchvision.ops import roi_align
from PIL import Image
class CARPKDataset(Dataset):
def __init__(
self, data_path, img_size, split='train'
):
self.split = split
self.data_path = data_path
self.img_size = img_size
self.resize = T.Resize((img_size, img_size))
with open(os.path.join(
self.data_path,
'CARPK_devkit',
'data',
'ImageSets',
'train.txt' if split == 'train' else 'test.txt'
)) as file:
self.image_names = [line.strip() for line in file.readlines()]
def __getitem__(self, idx):
img = Image.open(os.path.join(
self.data_path,
'CARPK_devkit',
'data',
'Images',
self.image_names[idx] + '.png'
))
w, h = img.size
img = T.Compose([
T.ToTensor(),
self.resize,
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])(img)
with open(os.path.join(
self.data_path,
'CARPK_devkit',
'data',
'Annotations',
self.image_names[idx] + '.txt'
)) as file:
annotations = [list(map(int, line.strip().split())) for line in file.readlines()]
bboxes = torch.tensor(annotations)[:, :-1]
bboxes = bboxes / torch.tensor([w, h, w, h]) * self.img_size
return img, bboxes
def __len__(self):
return len(self.image_names)
def extract_queries(model, x, bboxes):
num_objects = bboxes.size(1) if not model.zero_shot else model.num_objects
# backbone
backbone_features = model.backbone(x)
# prepare the encoder input
src = model.input_proj(backbone_features)
bs, c, h, w = src.size()
pos_emb = model.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
src = src.flatten(2).permute(2, 0, 1)
# push through the encoder
if model.num_encoder_layers > 0:
image_features = model.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
else:
image_features = src
# prepare OPE input
f_e = image_features.permute(1, 2, 0).reshape(-1, model.emb_dim, h, w)
if not model.zero_shot:
box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
shape_or_objectness = model.ope.shape_or_objectness(box_hw).reshape(
bs, -1, model.kernel_dim ** 2, model.emb_dim
).flatten(1, 2).transpose(0, 1)
else:
shape_or_objectness = model.ope.shape_or_objectness.expand(
bs, -1, -1, -1
).flatten(1, 2).transpose(0, 1)
# if not zero shot add appearance
if not model.zero_shot:
# reshape bboxes into the format suitable for roi_align
bboxes = torch.cat([
torch.arange(
bs, requires_grad=False
).to(bboxes.device).repeat_interleave(num_objects).reshape(-1, 1),
bboxes.flatten(0, 1),
], dim=1)
appearance = roi_align(
f_e,
boxes=bboxes, output_size=model.kernel_dim,
spatial_scale=1.0 / model.reduction, aligned=True
).permute(0, 2, 3, 1).reshape(
bs, num_objects * model.kernel_dim ** 2, -1
).transpose(0, 1)
else:
appearance = None
return shape_or_objectness, appearance
def predict(model, x, shape_or_objectness, appearance, num_objects):
# backbone
backbone_features = model.backbone(x)
# prepare the encoder input
src = model.input_proj(backbone_features)
bs, c, h, w = src.size()
pos_emb = model.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
src = src.flatten(2).permute(2, 0, 1)
# push through the encoder
if model.num_encoder_layers > 0:
image_features = model.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
else:
image_features = src
# prepare OPE input
f_e = image_features.permute(1, 2, 0).reshape(-1, model.emb_dim, h, w)
query_pos_emb = model.ope.pos_emb(
bs, model.kernel_dim, model.kernel_dim, f_e.device
).flatten(2).permute(2, 0, 1).repeat(num_objects, 1, 1)
if model.ope.num_iterative_steps > 0:
memory = f_e.flatten(2).permute(2, 0, 1)
all_prototypes = model.ope.iterative_adaptation(
shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
)
else:
if shape_or_objectness is not None and appearance is not None:
all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
else:
all_prototypes = (
shape_or_objectness if shape_or_objectness is not None else appearance
).unsqueeze(0)
outputs = list()
for i in range(all_prototypes.size(0)):
prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
bs, num_objects, model.kernel_dim, model.kernel_dim, -1
).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...]
response_maps = F.conv2d(
torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0),
prototypes,
bias=None,
padding=model.kernel_dim // 2,
groups=prototypes.size(0)
).view(
bs, num_objects, model.emb_dim, h, w
).max(dim=1)[0]
# send through regression heads
if i == all_prototypes.size(0) - 1:
predicted_dmaps = model.regression_head(response_maps)
else:
predicted_dmaps = model.aux_heads[i](response_maps)
outputs.append(predicted_dmaps)
return outputs[-1]
@torch.no_grad()
def eval_carpk(model, args, num_objects=12):
train = CARPKDataset(
args.data_path, 512, split='train'
)
# randomly sample the exemplars from the training set
train_objects = list()
for i in range(len(train)):
_, bboxes = train[i]
train_objects.extend((i, j) for j in range(bboxes.size(0)))
device = torch.device('cuda:0')
model = model.to(device)
model.eval()
torch.manual_seed(1)
bbox_idx = torch.randperm(len(train_objects))[:num_objects]
bbox_idx = [train_objects[t] for t in bbox_idx]
shape, appearance = list(), list()
for i, (img_idx, object_idx) in enumerate(bbox_idx):
img, bboxes = train[img_idx]
img, bboxes = img.to(device), bboxes.to(device)
bboxes = bboxes[[object_idx], :].unsqueeze(0)
sh, app = extract_queries(model, img.unsqueeze(0), bboxes)
shape.append(sh[:args.kernel_dim ** 2])
appearance.append(app[:args.kernel_dim ** 2])
shape, appearance = torch.cat(shape, dim=0), torch.cat(appearance, dim=0)
mae, mse = 0, 0
test = CARPKDataset(
args.data_path, 512, split='test'
)
loader = DataLoader(
test,
batch_size=1,
shuffle=False,
)
for img, bboxes in loader:
img = img.to(device)
target = torch.tensor(bboxes.size(1)).to(device)
output = predict(model, img, shape, appearance, num_objects)
mae += (output.sum() - target).abs().item() / len(test)
mse += (output.sum() - target).pow(2).item() / len(test)
print("MAE", mae, "RMSE", torch.sqrt(torch.tensor(mse)).item())
if __name__ == '__main__':
parser = argparse.ArgumentParser('LOCA', parents=[get_argparser()])
args = parser.parse_args()
model = build_model(args)
state_dict = torch.load(os.path.join(args.model_path, f'{args.model_name}.pt'))['model']
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
eval_carpk(model, args, num_objects=12)
Thanks! The result is reproducible.