The code can be speedup with pytorch grid_sample
SteveJunGao opened this issue · 5 comments
Many thanks for releasing the code!
When I was trying the code I find one key function below can be effectively accelerated using pytorch functions
voxel_features = trilinear_devoxelize(voxel_features, voxel_coords, resolution)
import torch.nn.functional as F
def pytorch_grid_sample(self, voxel_features, coords, r):
coords = (coords * 2 + 1.0) / r - 1.0
coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3)
coords = torch.flip(coords, dims=[-1])
f = F.grid_sample(input=voxel_features, grid=coords, padding_mode='border', align_corners=False)
f = f.squeeze(dim=2).squeeze(dim=2)
return f
pytorch_voxel_features = pytorch_grid_sample(voxel_features, voxel_coords, resolution)
pvcnn_voxel_features = trilinear_devoxelize(voxel_features, voxel_coords, resolution)
The maximum difference between pytorch_voxel_features
and pvcnn_voxel_features
is less than 1e-6
, but the running time of pytorch_grid_sample
is significantly reduced.
any way to speed up the voxelization? for ex. avg voxelization using native pytorch?
It might be possible to use torch.sparse.FloatTensor().to_dense()
. However, the way it deals with duplicates is to sum them up rather them doing the averaging. Therefore, one possible way is to also use to_dense
to count the occurrence in each voxel. I'm not sure how fast it will be. It would be great if you can try it out and let us know how it works. Thanks!
@sinAshish You can use PyTorch Scatter to emulate
Line 25 in 476715b
from torch_scatter import scatter_mean
index = vox_coords[:, :, 0] * self.r ** 2 + vox_coords[:, :, 1] * self.r + vox_coords[:, :, 2]
voxel_features = scatter_mean(src=features, index=index.long(), dim_size=self.r ** 3)
voxel_features = voxel_features.view(features.size(0), -1, self.r, self.r, self.r)
@sinAshish You can use PyTorch Scatter to emulate
Line 25 in 476715b
like so:
from torch_scatter import scatter_mean index = vox_coords[:, :, 0] * self.r ** 2 + vox_coords[:, :, 1] * self.r + vox_coords[:, :, 2] voxel_features = scatter_mean(src=features, index=index.long(), dim_size=self.r ** 3) voxel_features = voxel_features.view(features.size(0), -1, self.r, self.r, self.r)
Could you explain your voxelization code more ? In fact, I tried your voxelization method in the code, but unfortunately the program reported an error about the dimension. I can't understand the code you wrote. , I would be grateful if you could explain