The implementation uses a function compiled using Numba; and is, in some cases, more than 50x faster than using numpy.argsort(...)!
- The cases where K is really small compared to the size of the array
import sys
try:
import numpy as np
except ImportError:
!{sys.executable} -m pip install numpy==1.17.4
try:
import numba as nb
except ImportError:
!{sys.executable} -m pip install numba==0.45.1
import numpy as np
import numba as nb
FLOAT_TYPE = np.float32
FLOAT_BUFFER = np.finfo(FLOAT_TYPE).resolution
K = 100
@nb.njit(nb.types.Array(nb.int64, 1, "A")(nb.float32[:]))
def fast_arg_top_k(array):
"""
Gets the indices of the top k values in an (1-D) array.
* NOTE: The returned indices are not sorted based on the top values.
"""
sorted_indices = np.zeros((K,), dtype=FLOAT_TYPE)
minimum_index = 0
minimum_index_value = 0
for value in array:
if value > minimum_index_value:
sorted_indices[minimum_index] = value
minimum_index = sorted_indices.argmin()
minimum_index_value = sorted_indices[minimum_index]
# FLOAT_BUFFER = np.finfo(FLOAT_TYPE).resolution
# In some situations, because of different resolution you get k-1 results - this is to avoid that!
minimum_index_value -= FLOAT_BUFFER
return (array >= minimum_index_value).nonzero()[0][::-1][:K]
def numpy_arg_top_k(array):
return (-array).argsort()[:K]
array = np.array(np.random.sample((1000000,)), dtype=FLOAT_TYPE)
time_fast = %timeit -n 100 -o fast_arg_top_k(array)
1.9 ms ± 88.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
time_numpy = %timeit -n 10 -o numpy_arg_top_k(array)
106 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
result_fast = sorted(fast_arg_top_k(array))
result_numpy = sorted(numpy_arg_top_k(array))
number_of_common = len(set(result_fast).intersection(result_numpy))
percentage_of_common = round((number_of_common / K) * 100)
# Could happen that there are a few exact same values in the top K
# In that case there could be a few differences
print(f'{percentage_of_common}% of the indices are same!')
100% of the indices are same!
print(f'"fast_arg_top_k" is around {round(time_numpy.best / time_fast.best)}x faster than "numpy_arg_top_k"!')
"fast_arg_top_k" is around 57x faster than "numpy_arg_top_k"!