NVlabs/PWC-Net

Correlation - Python Implementation

ndimitriou opened this issue · 2 comments

Why is there a external dependency for computing the cross-correlation between feature pyramids? If we assume that f1 and f2 are the features for image1 and image2, I believe more or less a code like below would do the job and is much simpler,

cost_vol_lev = torch.empty((B, 81, H, W), device=self.device) # the cost volume for a single level
k = 0
for i in range(-4, 5): #assuming a 9x9 window
for j in range(-4, 5):
f2_rolled = torch.roll(f2, shifts=(i, j), dims=(2, 3)) # shifting the second tensor
product = f1 * f2_rolled
f1_norm = torch.sqrt(torch.sum(f1 ** 2, 1) + 1e-10) # adding small constant to avoid division by zero
f2_rolled_norm = torch.sqrt(torch.sum(f2_rolled ** 2, 1) + 1e-10)
corr = torch.mean(product, 1)
norm_fac = f1_norm * f2_rolled_norm
corr = corr / norm_fac # normalizing
cost_vol_lev[:, k, :, :] = corr
k = k + 1

Am I missing something (on the backpropagation step perhaps)?

I think with the current PyTorch library, it would be even simpler to use torch.nn.functional.unfold to implement the correlation function.

def corr(f1, f2, md=4):
b, c, h, w = f1.shape
# 1.normalize feature
f1 = f1 / torch.norm(f1, dim=1, keepdim=True)
f2 = f2 / torch.norm(f2, dim=1, keepdim=True)
# 2.compute correlation matrix
f1 = F.unfold(f1, kernel_size=(md2+1, md2+1), padding=(md, md), stride=(1, 1))
f1 = f1.view([b, c, -1, h, w])
f2 = f2.view([b, c, 1, h, w])
w = torch.sum(f1 * f2, dim=1)
return w

is it possible to implement this operation like this?