mihaidusmanu/d2-net

Any suggestion for parallelization

alexlopezcifuentes opened this issue · 3 comments

Hello,

I have tested your code on my own images and as far I'm getting nice results. However, I want to speed up the process of detecting and describing keypoints. I have been looking your code and it does not seem to support batch-sizes >1 computation, or at least I haven't been able to run it that way.

Is there any possibility to do the process with batches higher than 1? If not, do you have any suggestions on how to parallelize the process to obtain keypoints for many images at the same time (assuming we have the GPU and CPU computational power)?

Thanks.

Hello. All the modules implemented in https://github.com/mihaidusmanu/d2-net/blob/master/lib/model_test.py are already batch-friendly. However, the multi-scale pyramid is sadly not adapted, nor is the bilinear feature interpolation that we are currently using (but it can be replaced by torch.nn.functional.grid_sample):

d2-net/lib/pyramid.py

Lines 80 to 83 in 2a4d88f

raw_descriptors, _, ids = interpolate_dense_features(
fmap_keypoints.to(device),
dense_features[0]
)

The main reason why we didn't make the feature extraction script batch-friendly is because the images we evaluated on have different resolutions.

If you know that your images have the same resolution and want to extract single-scale features only, then you can run something along the lines (WARNING: untested):

# Dense feature extraction.
dense_features = model.dense_feature_extraction(images)

# Recover detections.
detections = model.detection(dense_features)
fmap_pos = torch.nonzero(detections).t()

# Recover displacements.
displacements = model.localization(dense_features)
displacements_i = displacements[
    fmap_pos[0, :], 0, fmap_pos[1, :], fmap_pos[2, :], fmap_pos[3, :]
]
displacements_j = displacements[
    fmap_pos[0, :], 1, fmap_pos[1, :], fmap_pos[2, :], fmap_pos[3, :]
]
mask = torch.min(
    torch.abs(displacements_i) < 0.5,
    torch.abs(displacements_j) < 0.5
)
fmap_pos = fmap_pos[:, mask]
valid_displacements = torch.stack([
    displacements_i[mask], displacements_j[mask]
], dim=0)

# Add displacements to detected points.
fmap_keypoints = fmap_pos[[0, 2, 3]].float()  # Remove the feature map idx.
fmap_keypoints[1 : 3] += valid_displacements

# TODO: Interpolate descriptors using F.grid_sample.
# This will require some tricks in the tensor representation.

# Upscale keypoints to the resolution of images.
keypoints = fmap_keypoints
keypoints[1 : 3] = (keypoints[1 : 3] * 2 + 0.5) * 2 + 0.5

From now, I'm working using a single scale, so it is not a problem that the multi-scale pyramid does not work with batches. I'll try your code and share the results.

Thanks!

Closed! Feel free to reopen / create a new issue if you run into any new issues!