facebookresearch/projectaria_tools

bisection_timestamp_search returns the index closest and SMALLER than the query

georgegu1997 opened this issue · 2 comments

The bisection search used in this line returns the index closest and smaller than the query rather than the one closest to it.

See example below

timed_data = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
query_timestamp_ns = 2.9

# If this is safe we perform the Bisection search
start = 0
end = len(timed_data) - 1
while start < end:
    mid = (start + end) // 2
    mid_timestamp = timed_data[mid]
    if mid_timestamp == query_timestamp_ns:
        start = mid
        break
    if mid_timestamp < query_timestamp_ns:
        start = mid + 1
    else:
        end = mid - 1
        
print(start, timed_data[start])  # it prints 0 1
timed_data = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
query_timestamp_ns = 1.1

# If this is safe we perform the Bisection search
start = 0
end = len(timed_data) - 1
while start < end:
    mid = (start + end) // 2
    mid_timestamp = timed_data[mid]
    if mid_timestamp == query_timestamp_ns:
        start = mid
        break
    if mid_timestamp < query_timestamp_ns:
        start = mid + 1
    else:
        end = mid - 1
        
print(start, timed_data[start]) # This also prints 0 1

This is still buggy - will push another update later.

A proposed fix for the problem:

def bisection_timestamp_search(timed_data, query_timestamp_ns: int) -> int:
    """
    Binary search helper function, assuming that timed_data is sorted by the field names 'tracking_timestamp'
    Returns index of the element closest to the query timestamp else returns None if not found (out of time range)
    """
    # Deal with border case
    if timed_data and len(timed_data) > 1:
        first_timestamp = timed_data[0].tracking_timestamp.total_seconds() * 1e9
        last_timestamp = timed_data[-1].tracking_timestamp.total_seconds() * 1e9
        if query_timestamp_ns <= first_timestamp:
            return None
        elif query_timestamp_ns >= last_timestamp:
            return None
    # If this is safe we perform the Bisection search
    start = 0
    end = len(timed_data) - 1
    while start < end:
        mid = (start + end) // 2
        mid_timestamp = timed_data[mid].tracking_timestamp.total_seconds() * 1e9
        if mid_timestamp == query_timestamp_ns:
            return mid
        if mid_timestamp < query_timestamp_ns:
            start = mid + 1
        else:
            end = mid - 1
    
    # Get the value that is actually closest to the query timestamp
    if start + 1 < len(timed_data):
        start_timestamp = timed_data[start].tracking_timestamp.total_seconds() * 1e9
        start_plus_1_timestamp = timed_data[start + 1].tracking_timestamp.total_seconds() * 1e9
        if abs(start_timestamp - query_timestamp_ns) > abs(start_plus_1_timestamp - query_timestamp_ns):
            start = start + 1
            
    return start

Thank you, we will take a look to your code, or please do a PR for integration

We are also recommending to use the mpsDataProvider

"MpsDataProvider",