udaykusupati/Normal-Assisted-Stereo

vectorization

qiminchen opened this issue · 2 comments

Hi, thanks for the amazing work. Do you know how to vectorize the code for computing normal from depth, it is pretty slow using for loop. I tried to vectorize it but didn't work it out.

h, w = Z.shape[:2]
normal = np.zeros((h-win_sz, w-win_sz, 3))
for i in range(h-win_sz):
for j in range(w-win_sz):
norm_val = cal_patch(i, j, win_sz)
normal[i, j] = norm_val

Yes. I agree it is very slow and we worked with the same too unfortunately. Currently I don't have a vectorized version. I can look at it later when I find some time

Hi @udaykusupati, I found a way to vectorize the code in Pytorch but without using val_mask which might cause some information loss. It's pretty hard to vectorize the computation with val_mask because each window would have a different number of valid points. (how to address this would be the next/final step of vectorization)

# XY: 1 x 2 x 240 x 320 (B x C x Height x Width)
# Z:  1 x 1 x 240 x 320 (B x C x Height x Width)

XYZ = torch.cat((XY, Z), dim=1)  # 1 x 3 x 240 x 320
XYZ = F.pad(XYZ, (win_sz // 2, win_sz // 2, win_sz // 2, win_sz // 2), mode='reflect')  # keep the Height and Width of the output the same as input
A = F.unfold(XYZ, kernel_size=win_sz).view(batch, 3, win_sz**2, height, width)  # 1 x 3 x win_size**2, 240, 320
A = patches.permute(0, 3, 4, 1, 2)       # 1 x 240 x 320 x 3 x win_sz**2
A_t = patches.permute(0, 1, 2, 4, 3)     # transpose 1 x 240 x 320 x win_sz**2 x 3
A_At = torch.matmul(A, A_t)  # 1 x 240 x 320 x 3 x 3
normal = torch.sum(torch.matmul(A_t, A_At.pinverse()), dim=-2)  # 1 x 240 x 320 x 3
normal = normal.permute(0, 3, 1, 2)  # 1 x 3 x 240 x 320

This would significantly speed up the computation, you can refer to view_as_windows for similar implementation in numpy but since it doesn't filter out the invalid points, it's a bit less accurate than your original implementation. Do you have any ideas on how I should change the value of each window according to the value of the center pixel? I think your original implementation would discard the invalid points and only keep the valid ones.