/fast-arg-top-k

Get the indices of the top K values in an array

Primary LanguageJupyter Notebook

Get the indices of the top K values in an (1-D) array

The implementation uses a function compiled using Numba; and is, in some cases, more than 50x faster than using numpy.argsort(...)!

  • The cases where K is really small compared to the size of the array
import sys

try:
    import numpy as np
except ImportError:
    !{sys.executable} -m pip install numpy==1.17.4
try:
    import numba as nb
except ImportError:
    !{sys.executable} -m pip install numba==0.45.1

import numpy as np
import numba as nb
FLOAT_TYPE = np.float32
FLOAT_BUFFER = np.finfo(FLOAT_TYPE).resolution

K = 100
@nb.njit(nb.types.Array(nb.int64, 1, "A")(nb.float32[:]))
def fast_arg_top_k(array):
    """
    Gets the indices of the top k values in an (1-D) array.
    * NOTE: The returned indices are not sorted based on the top values.
    """
    sorted_indices = np.zeros((K,), dtype=FLOAT_TYPE)
    minimum_index = 0
    minimum_index_value = 0
    for value in array:
        if value > minimum_index_value:
            sorted_indices[minimum_index] = value
            minimum_index = sorted_indices.argmin()
            minimum_index_value = sorted_indices[minimum_index]
    # FLOAT_BUFFER = np.finfo(FLOAT_TYPE).resolution
    # In some situations, because of different resolution you get k-1 results - this is to avoid that!
    minimum_index_value -= FLOAT_BUFFER
    return (array >= minimum_index_value).nonzero()[0][::-1][:K]
def numpy_arg_top_k(array):
    return (-array).argsort()[:K]
array = np.array(np.random.sample((1000000,)), dtype=FLOAT_TYPE)
time_fast = %timeit -n 100 -o fast_arg_top_k(array)
1.9 ms ± 88.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
time_numpy = %timeit -n 10 -o numpy_arg_top_k(array)
106 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
result_fast = sorted(fast_arg_top_k(array))
result_numpy = sorted(numpy_arg_top_k(array))

number_of_common = len(set(result_fast).intersection(result_numpy))
percentage_of_common = round((number_of_common / K) * 100)

# Could happen that there are a few exact same values in the top K
# In that case there could be a few differences
print(f'{percentage_of_common}% of the indices are same!')
100% of the indices are same!
print(f'"fast_arg_top_k" is around {round(time_numpy.best / time_fast.best)}x faster than "numpy_arg_top_k"!')
"fast_arg_top_k" is around 57x faster than "numpy_arg_top_k"!