rusty1s/pytorch_cluster

`radius_graph` bug: it might produce `max_num_neighbors + 1` neighbors for both CPU and GPU versions

stslxg-nv opened this issue · 1 comments

Hi Team, I found that radius_graph might produce max_num_neighbors + 1 neighbors for both CPU and GPU versions.
A minimal reproducer for CPU, adapted from the example in README:

import torch
from torch_cluster import radius_graph
x = torch.tensor([[-1., -1.], [-1., 1.], [1., -1.], [1., 1.]])
edge_index = radius_graph(x, r=100, max_num_neighbors=1)
print(edge_index)

We get

tensor([[1, 0, 0, 1, 0, 1],
        [0, 1, 2, 2, 3, 3]])

Note that Node 2 and 3 both get 2 neighbors instead of 1.

The minimal reproducer for GPU is similar:

import torch
from torch_cluster import radius_graph
x = torch.tensor([[-1., -1.], [-1., 1.], [1., -1.], [1., 1.]]).cuda()
edge_index = radius_graph(x, r=100, max_num_neighbors=1)
print(edge_index)

We also get

tensor([[1, 0, 0, 1, 0, 1],
        [0, 1, 2, 2, 3, 3]], device='cuda:0')

I believe this is due to this line.
In the case where we don't allow self-loop, the assumption here is that each node will find itself as one of its neighbors, and then it will get delete by these lines, therefore the need to find max_num_neighbors + 1 neighbors instead of just max_num_neighbors neighbors.

However, this assumption is not always true:

  • For the GPU version, for a node i, if there are already at least max_num_neighbors + 1 nodes before node i that have distance <= radius with node i, then node i won't include itself in the result. This is the case for Node 2 and 3 in the minimal reproducer.
  • For the CPU version, since we are using a KDTree, and then just copy the first max_num_neighbors + 1 neighbors found, there is no guarantee that the node will include itself in the result.

A potential solution for the CPU version is to set params.sorted to true on this line instead, so that sorting is enable and it is guaranteed that each node will find itself as the closest neighbor and thus include itself in the result. However, this will have some runtime overhead.

For the CPU version, sorting won't solve this problem, since there might be coincident points thus it is still not guaranteed to include itself in the result.