rusty1s/pytorch_cluster

Speed issues

de-gozaru opened this issue · 1 comments

Hi @rusty1s ,

I was comparing the speed of torch_cluster.knn_graph on the GPU with the function bellow:

def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

and found that torch_cluster.knn_graph is way slower, around x8 slower, is this something expected?

Thank you in advance!