Pointcept/PointTransformerV2

How does your method handle empty voxels?

xiaobaitu123344 opened this issue · 4 comments

How does your method handle empty voxels?

PTv2 is a point-based method. What's empty voxels meaning in your question?

Isn't the partition-based pooling in your paper based on voxels?

No, the definition of partition-based pooling is: separating a point cloud into non-overlapping partitions, and fusion points share the same partition. For the implementation of grid pooling, we also compute which grid partitions each point belongs to and then fuse them. The implement code is attached below.

From my perspective, there is not much difference between voxel-based and point-based methods. Maybe voxels is just a kind of point after grid sampling. and you might have also found that current point-based methods also apply voxelization in data augmentation for downsampling points.

from torch_geometric.nn.pool import voxel_grid
from torch_scatter import segment_csr

class GridPool(nn.Module):
    """
    Partition-based Pooling (Grid Pooling)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 grid_size,
                 bias=False):
        super(GridPool, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grid_size = grid_size

        self.fc = nn.Linear(in_channels, out_channels, bias=bias)
        self.norm = PointBatchNorm(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, points, start=None):
        coord, feat, offset = points
        batch = offset2batch(offset)
        feat = self.act(self.norm(self.fc(feat)))
        start = segment_csr(coord, torch.cat([batch.new_zeros(1), torch.cumsum(batch.bincount(), dim=0)]),
                            reduce="min") if start is None else start
        cluster = voxel_grid(pos=coord - start[batch], size=self.grid_size, batch=batch, start=0)
        unique, cluster, counts = torch.unique(cluster, sorted=True, return_inverse=True, return_counts=True)
        _, sorted_cluster_indices = torch.sort(cluster)
        idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
        coord = segment_csr(coord[sorted_cluster_indices], idx_ptr, reduce="mean")
        feat = segment_csr(feat[sorted_cluster_indices], idx_ptr, reduce="max")
        batch = batch[idx_ptr[:-1]]
        offset = batch2offset(batch)
        return [coord, feat, offset], cluster

Thank you for your patience, I did not answer in time because of the epidemic this time, I would like to apologize for not answering questions in time