facebookresearch/faiss

add a wrapper for search_preassigned in torch_utils

mdouze opened this issue · 4 comments

There are wrappers for search and search_and_reconstruct to accept torch arrays but not for search_preassigned (that was added later).

It would be useful to implement it, here:

https://github.com/facebookresearch/faiss/blob/main/contrib/torch_utils.py#L220

The following code works:

    def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.search_preassigned_numpy(x, k, Iq, Dq, D=D, I=I)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if D is None:
            D = torch.empty(n, k, device=x.device, dtype=torch.float32)
        else:
            assert type(D) is torch.Tensor
            assert D.shape == (n, k)
        D_ptr = swig_ptr_from_FloatTensor(D)

        if I is None:
            I = torch.empty(n, k, device=x.device, dtype=torch.int64)
        else:
            assert type(I) is torch.Tensor
            assert I.shape == (n, k)
        I_ptr = swig_ptr_from_IndicesTensor(I)

        assert Iq.shape == (n, self.nprobe)
        Iq = Iq.contiguous()
        Iq_ptr = swig_ptr_from_IndicesTensor(Iq)

        if Dq is not None:
            Dq = Dq.contiguous()
            assert Dq.shape == Iq.shape        
            Dq_ptr = swig_ptr_from_FloatTensor(Dq)
        else: 
            Dq_ptr = None

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)
        else:
            # CPU torch
            self.search_preassigned_c(n, x_ptr, k, Iq_ptr, Dq_ptr, D_ptr, I_ptr, False)

        return D, I

    torch_replace_method(the_class, 'search_preassigned', torch_replacement_search_preassigned)

implemented in #3916

implemented in #3916

So we can close this issue?