add a wrapper for search_preassigned in torch_utils
mdouze opened this issue · 4 comments
mdouze commented
There are wrappers for search and search_and_reconstruct to accept torch arrays but not for search_preassigned (that was added later).
It would be useful to implement it, here:
https://github.com/facebookresearch/faiss/blob/main/contrib/torch_utils.py#L220
mdouze commented
mdouze commented
The following code works:
def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if D is None:
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
else:
assert type(D) is torch.Tensor
assert D.shape == (n, k)
D_ptr = swig_ptr_from_FloatTensor(D)
if I is None:
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(I) is torch.Tensor
assert I.shape == (n, k)
I_ptr = swig_ptr_from_IndicesTensor(I)
assert Iq.shape == (n, self.nprobe)
Iq = Iq.contiguous()
Iq_ptr = swig_ptr_from_IndicesTensor(Iq)
if Dq is not None:
Dq = Dq.contiguous()
assert Dq.shape == Iq.shape
Dq_ptr = swig_ptr_from_FloatTensor(Dq)
else:
Dq_ptr = None
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
else:
# CPU torch
self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
return D, I
torch_replace_method(the_class, 'search_preassigned', torch_replacement_search_preassigned)