rusty1s/pytorch_cluster

'nearest' slows down after one run

dddraxxx opened this issue · 3 comments

Hi, I am using nearest to do my network training but I found the speed just slows down quickly after one run and also, the batch run is even slower. Usually, the first run just costs around 10 milliseconds but later runs need tens of seconds. Below is the reproduce code:

import time
import torch

torch.manual_seed(12345)

A = torch.randn(4*100*100*100, 3)
B = torch.randn(4*100*100*100, 3)

A_cu = A.cuda()
B_cu = B.cuda()
from torch_cluster import nearest

st3 = time.time()
for i in range(4):
    nn_torch = nearest(B_cu[i*100*100*100:(i+1)*100*100*100], A_cu[i*100*100*100:(i+1)*100*100*100])
    torch.cuda.empty_cache()
print("no_batch time:", time.time()-st3)
del nn_torch
torch.cuda.empty_cache()

st2 = time.time()
batch_x = torch.arange(4).repeat_interleave(100*100*100).cuda()
batch_y = torch.arange(4).repeat_interleave(100*100*100).cuda()
nn_torch = nearest(B_cu, A_cu,batch_x=batch_x, batch_y=batch_y)
print("batch time:", time.time()-st2)

I really like the repository and I want to know if there is any solution for it. Thanks!

Thanks for reporting. I also tried out torch_cluster.knn with k=1 and it seems to be indeed faster (with N=100*100*10) - for anything larger it indeed gets increasingly slow. In the current implementation, we do not construct any kD-tree but instead do the neighbor search in a brute-force fashion with O(N^2) comparisons. You might be able to speed this up more via some dedicated knn package like faiss-gpu.

Thanks for reporting. I also tried out torch_cluster.knn with k=1 and it seems to be indeed faster (with N=100*100*10) - for anything larger it indeed gets increasingly slow. In the current implementation, we do not construct any kD-tree but instead do the neighbor search in a brute-force fashion with O(N^2) comparisons. You might be able to speed this up more via some dedicated knn package like faiss-gpu.

Thanks for your answer! I tried faiss-gpu and FRNN, and they both worked well.