chaimleib/intervaltree

Range Query Much Slower than List Scan

william-silversmith opened this issue · 3 comments

Hi!

Thank you for this great project with an interesting data structure! I'm working with a dataset that has ~12k items in it and I am finding overlap queries are taking 300ms to 400ms versus a list filter operation that is < 1ms. I did some profiling and saw the following results. I haven't studied the data structure or implementation carefully enough yet to suggest a good fix though.

I've attached a pickle file of the IntervalTree in case you are interested in examining it. The available query range is 0 to 63 inclusive (I'm classifying columns of a 3D image that is 64 voxels deep). I'll in the end write something in C++, but I was hoping to explore some possibilities in Python first. I am working on an NP-hard set cover problem and was hoping to reduce the cost of a greedy algorithm by reducing the number of cases to examine.

Thanks so much for this library and any attention you give this issue!

File: .../intervaltree/intervaltree/intervaltree.py
Function: overlap at line 913


Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   913                                               @profile
   914                                               def overlap(self, begin, end=None):
   915                                                   """
   916                                                   Returns a set of all intervals overlapping the given range.
   917                                           
   918                                                   Completes in O(m + k*log n) time, where:
   919                                                     * n = size of the tree
   920                                                     * m = number of matches
   921                                                     * k = size of the search range
   922                                                   :rtype: set of Interval
   923                                                   """
   924       100         49.0      0.5      0.0          root = self.top_node
   925       100         33.0      0.3      0.0          if not root:
   926                                                       return set()
   927       100         22.0      0.2      0.0          if end is None:
   928                                                       iv = begin
   929                                                       return self.overlap(iv.begin, iv.end)
   930       100         37.0      0.4      0.0          elif begin >= end:
   931                                                       return set()
   932       100    2438798.0  24388.0    100.0          result = root.search_point(begin, set())  # bound_begin might be greater
   933       100         35.0      0.3      0.0          boundary_table = self.boundary_table
   934       100        416.0      4.2      0.0          bound_begin = boundary_table.bisect_left(begin)
   935       100        118.0      1.2      0.0          bound_end = boundary_table.bisect_left(end)  # up to, but not including end
   936       100        111.0      1.1      0.0          result.update(root.search_overlap(
   937                                                       # slice notation is slightly slower
   938       100         93.0      0.9      0.0              boundary_table.keys()[index] for index in xrange(bound_begin, bound_end)
   939                                                   ))
   940        99         69.0      0.7      0.0          return result

Total time: 53.4637 s
File: .../intervaltree/intervaltree/node.py
Function: search_point at line 309

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   309                                               @profile
   310                                               def search_point(self, point, result):
   311                                                   """
   312                                                   Returns all intervals that contain point.
   313                                                   """
   314   8628853    1808999.0      0.2      3.4          for k in self.s_center:
   315   4971359    1133030.0      0.2      2.1              if k.begin <= point < k.end:
   316   3657493   50511584.0     13.8     94.5                  result.add(k)
   317      4960       1064.0      0.2      0.0          if point < self.x_center and self[0]:
   318      2951       3460.0      1.2      0.0              return self[0].search_point(point, result)
   319      3077       1670.0      0.5      0.0          elif point > self.x_center and self[1]:
   320      3077       3588.0      1.2      0.0              return self[1].search_point(point, result)
   321      1883        317.0      0.2      0.0          return result

The algorithm that I'm comparing this to is not exactly the same (I'm not checking the z range, just hitting the entire list). Here's the timings for that:

Total time: 8.06808 s
File: REDACTED
Function: find_optimal_pins at line 62

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    62                                           @profile
    63                                           def find_optimal_pins(pinset):
    64       138      58739.0    425.6      0.7    sets = [ x[2] for x in pinset ]
    65       138     132850.0    962.7      1.6    isets = [ [i,x] for i,x in enumerate(sets) ]
    66                                           
    67       138      50108.0    363.1      0.6    universe = set.union(*sets)
    68                                           
    69       138         29.0      0.2      0.0    final_pins = []
    70      4463       1281.0      0.3      0.0    while len(universe):
    71      4463    1420992.0    318.4     17.6      sizes = [ len(x[1]) for x in isets ]
    72      4463     174341.0     39.1      2.2      idx = sizes.index(max(sizes))
    73      4463       3307.0      0.7      0.0      i, cur = isets.pop(idx)
    74      4463       3264.0      0.7      0.0      universe -= cur
    75  12131825    2229706.0      0.2     27.6      for j, otherset in isets:
    76  12131825    2125046.0      0.2     26.3        otherset -= cur
    77      4463    1863452.0    417.5     23.1      isets = [ x for x in isets if len(x[1]) > 0 ]
    78      4463       4939.0      1.1      0.1      final_pins.append(pinset[i][:2])
    79                                           
    80       138         25.0      0.2      0.0    return final_pins

tree.pkl.zip

what is a list filter operation, could you explain it for me?

Hi! It's been a while, and I ended up writing something in C++, but all I meant was a list comprehension with a filter clause. Something like:

results = [ x for x in lst if 10 < x < 20 ]

Hi! It's been a while, and I ended up writing something in C++, but all I meant was a list comprehension with a filter clause. Something like:

results = [ x for x in lst if 10 < x < 20 ]

Thank you