xarray-contrib/xoak

Re-implement scikit-learn's search trees with numba

Opened this issue · 2 comments

This could be done at a later stage, if we choose to go down this way.

The implementation approach used in scikit-learn is interesting in several aspects:

  • kd-tree and ball tree are built as thin layers on top of a common, binary tree implementation

  • all tree data is pre-allocated, which could make easier the re-implementation with numba and perhaps could facilitate experimenting with those structures and dask.

I think numba is now mature enough and supported in various distribution so that we can use it as a dependency. I'm not sure if numba's jitted classes are very mature and/or we could avoid using it here, though.

The biggest advantage of using numba is just-in-time compilation that allows very flexible metric functions.

Huite commented

Hey, since you mentioned xoak in: NOAA-ORR-ERD/gridded#55

I did some looking around before, and I came across this repository:
https://github.com/jackd/numba-neighbors

(With MIT license, so good to go)

Looks almost a perfect match with what you're proposing here?

It uses a jitclass, but very lightly, which is arguably the right approach in my opinion. You could pass the tree data more easily as a namedtuple, if you don't want to pass all the arguments.

Some query methods are still missing, but not that difficult to implement; although I'm not sure you can dynamically allocate as efficiently? (Numba could use something like C++'s std::Vector -- or is a typed List this already, it felt significantly slower to me.)

Also parallelisation is extremely simple using numba's prange.

JIT indeed provides very flexible metric functions, best way to introduce seems by using closures in numba to avoid the function call overhead, I believe: https://numba.pydata.org/numba-doc/latest/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function

I've also noticed that performance can benefit significantly by aggressively inlining (although this increases compile cost).
Since a tree will generally consist of float32 or float64 coordinates, and int32 or int64 indices, maybe it's a nice idea to ahead-of-time compile for the built-in metric functions.
https://numba.pydata.org/numba-doc/dev/user/pycc.html#compiling-code-ahead-of-time

Good to know about the numba-neighbors repository and numba tricks @Huite, thanks!