mit-han-lab/pvcnn

The code can be speedup with pytorch grid_sample

SteveJunGao opened this issue · 5 comments

Many thanks for releasing the code!

When I was trying the code I find one key function below can be effectively accelerated using pytorch functions

voxel_features = trilinear_devoxelize(voxel_features, voxel_coords, resolution)

import torch.nn.functional as F
def pytorch_grid_sample(self, voxel_features, coords, r):
    coords = (coords * 2 + 1.0) / r - 1.0
    coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3)
    coords = torch.flip(coords, dims=[-1])
    f = F.grid_sample(input=voxel_features, grid=coords,  padding_mode='border', align_corners=False)
    f = f.squeeze(dim=2).squeeze(dim=2)
    return f

pytorch_voxel_features = pytorch_grid_sample(voxel_features, voxel_coords, resolution)
pvcnn_voxel_features = trilinear_devoxelize(voxel_features, voxel_coords, resolution)

The maximum difference between pytorch_voxel_features and pvcnn_voxel_features is less than 1e-6, but the running time of pytorch_grid_sample is significantly reduced.

any way to speed up the voxelization? for ex. avg voxelization using native pytorch?

It might be possible to use torch.sparse.FloatTensor().to_dense(). However, the way it deals with duplicates is to sum them up rather them doing the averaging. Therefore, one possible way is to also use to_dense to count the occurrence in each voxel. I'm not sure how fast it will be. It would be great if you can try it out and let us know how it works. Thanks!

@sinAshish You can use PyTorch Scatter to emulate

return F.avg_voxelize(features, vox_coords, self.r), norm_coords
like so:

from torch_scatter import scatter_mean

index = vox_coords[:, :, 0] * self.r ** 2 + vox_coords[:, :, 1] * self.r + vox_coords[:, :, 2]
voxel_features = scatter_mean(src=features, index=index.long(), dim_size=self.r ** 3)
voxel_features = voxel_features.view(features.size(0), -1, self.r, self.r, self.r)
HLJT commented

@sinAshish You can use PyTorch Scatter to emulate

return F.avg_voxelize(features, vox_coords, self.r), norm_coords

like so:

from torch_scatter import scatter_mean

index = vox_coords[:, :, 0] * self.r ** 2 + vox_coords[:, :, 1] * self.r + vox_coords[:, :, 2]
voxel_features = scatter_mean(src=features, index=index.long(), dim_size=self.r ** 3)
voxel_features = voxel_features.view(features.size(0), -1, self.r, self.r, self.r)

Could you explain your voxelization code more ? In fact, I tried your voxelization method in the code, but unfortunately the program reported an error about the dimension. I can't understand the code you wrote. , I would be grateful if you could explain

hummat commented

Hi @HLJT, sorry for the late reply. Did you figure it out already? Otherwise, could you provide a minimal example that failed and the error you are getting?