vectorization
qiminchen opened this issue · 2 comments
Hi, thanks for the amazing work. Do you know how to vectorize the code for computing normal from depth, it is pretty slow using for loop. I tried to vectorize it but didn't work it out.
Normal-Assisted-Stereo/convert_normal.py
Lines 52 to 57 in 491c0d3
Yes. I agree it is very slow and we worked with the same too unfortunately. Currently I don't have a vectorized version. I can look at it later when I find some time
Hi @udaykusupati, I found a way to vectorize the code in Pytorch but without using val_mask
which might cause some information loss. It's pretty hard to vectorize the computation with val_mask
because each window would have a different number of valid points. (how to address this would be the next/final step of vectorization)
# XY: 1 x 2 x 240 x 320 (B x C x Height x Width)
# Z: 1 x 1 x 240 x 320 (B x C x Height x Width)
XYZ = torch.cat((XY, Z), dim=1) # 1 x 3 x 240 x 320
XYZ = F.pad(XYZ, (win_sz // 2, win_sz // 2, win_sz // 2, win_sz // 2), mode='reflect') # keep the Height and Width of the output the same as input
A = F.unfold(XYZ, kernel_size=win_sz).view(batch, 3, win_sz**2, height, width) # 1 x 3 x win_size**2, 240, 320
A = patches.permute(0, 3, 4, 1, 2) # 1 x 240 x 320 x 3 x win_sz**2
A_t = patches.permute(0, 1, 2, 4, 3) # transpose 1 x 240 x 320 x win_sz**2 x 3
A_At = torch.matmul(A, A_t) # 1 x 240 x 320 x 3 x 3
normal = torch.sum(torch.matmul(A_t, A_At.pinverse()), dim=-2) # 1 x 240 x 320 x 3
normal = normal.permute(0, 3, 1, 2) # 1 x 3 x 240 x 320
This would significantly speed up the computation, you can refer to view_as_windows for similar implementation in numpy
but since it doesn't filter out the invalid points, it's a bit less accurate than your original implementation. Do you have any ideas on how I should change the value of each window according to the value of the center pixel? I think your original implementation would discard the invalid points and only keep the valid ones.