drugilsberg/interact

Small nntree addition- return_dist=True

Closed this issue · 0 comments

class NearestNeighborsTree(object):
"""Nearest neighbor tree support."""

tree = None
word_series = None
cluster_series = None

def __init__(self, embedding, algorithm='kd_tree', metric='minkowski',
             k=500, n_init=10, n_jobs=-1):
    """Build from embedding pd.DataFrame (index: words)."""
    self.word_series = pd.Series(dict(enumerate(embedding.index.values)))
    self.cluster_series = pd.Series(dict(enumerate(
        cluster_vectors(
            embedding.values, k=k, n_init=n_init, n_jobs=n_jobs
        ).labels_
    )))
    self.tree = build_tree(
        embedding.values, algorithm=algorithm,
        metric=metric, n_jobs=n_jobs
    )

def kneighbors(self, X=None, k=5, mode=NeighborsMode.BOTH):
    """Get k neighbors from query points."""     
    if not isinstance(mode, NeighborsMode):
        raise RuntimeError(
            'mode as to be a value from enum NeighborsMode'
        )
    neighbors_dist,neighbors_indices = self.tree.kneighbors(
        X=X, n_neighbors=k, return_distance=True
    )
    neighbors_simil = 1 / (1 + neighbors_dist)

    if mode == NeighborsMode.WORDS:
        return(
            np.array([
                _map_indices_with_series(indices, self.word_series)
                for indices in neighbors_indices
            ]),
            neighbors_dist,
            neighbors_simil
        )
    elif mode == NeighborsMode.CLUSTERS:
        return(
            np.array([
                _map_indices_with_series(indices, self.cluster_series)
                for indices in neighbors_indices
            ]),
            neighbors_dist,
            neighbors_simil
        )
    elif mode == NeighborsMode.BOTH:
        return (
            np.array([
                _map_indices_with_series(indices, self.word_series)
                for indices in neighbors_indices
            ]),
            np.array([
                _map_indices_with_series(indices, self.cluster_series)
                for indices in neighbors_indices
            ]),
            neighbors_dist,
            neighbors_simil
        )
    else:
        raise RuntimeError('invalid return mode')