Possible error in sample_last_hop() function
Closed this issue · 1 comments
Thank you for open sourcing the project, and for adding a lot of comments in it!
I think there is a logical error in line number 397 in utils.py : nnz = A[nnz, new_sample].nonzero()[1]
. The issue is .nonzero()
gives the non-zero indices with respect to the sub-matrix (A[nnz, new_sample]
) and not with respect to the original matrix A.
This results in generating incorrect triplets. For example, for the following simple graph:
the triplet generated (with seed = 0) is:
0,1,3
1,0,2
2,3,4
3,2,0
4,3,2
3,2,4
3,0,4
Here the first column is the reference node, and the other columns are such that shortest_path(col0,col1) < shortest_path(col0, col2). The last triplet is incorrect because Node 4 is closer than Node 0 for Node 3. Changing line number 397 to: nnz = A[nodes, sampled].nonzero()[1]
seems to fix the error.
Thanks for spotting this and for the PR.