bisection_timestamp_search returns the index closest and SMALLER than the query
georgegu1997 opened this issue · 2 comments
georgegu1997 commented
The bisection search used in this line returns the index closest and smaller than the query rather than the one closest to it.
See example below
timed_data = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
query_timestamp_ns = 2.9
# If this is safe we perform the Bisection search
start = 0
end = len(timed_data) - 1
while start < end:
mid = (start + end) // 2
mid_timestamp = timed_data[mid]
if mid_timestamp == query_timestamp_ns:
start = mid
break
if mid_timestamp < query_timestamp_ns:
start = mid + 1
else:
end = mid - 1
print(start, timed_data[start]) # it prints 0 1
timed_data = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
query_timestamp_ns = 1.1
# If this is safe we perform the Bisection search
start = 0
end = len(timed_data) - 1
while start < end:
mid = (start + end) // 2
mid_timestamp = timed_data[mid]
if mid_timestamp == query_timestamp_ns:
start = mid
break
if mid_timestamp < query_timestamp_ns:
start = mid + 1
else:
end = mid - 1
print(start, timed_data[start]) # This also prints 0 1
georgegu1997 commented
This is still buggy - will push another update later.
A proposed fix for the problem:
def bisection_timestamp_search(timed_data, query_timestamp_ns: int) -> int:
"""
Binary search helper function, assuming that timed_data is sorted by the field names 'tracking_timestamp'
Returns index of the element closest to the query timestamp else returns None if not found (out of time range)
"""
# Deal with border case
if timed_data and len(timed_data) > 1:
first_timestamp = timed_data[0].tracking_timestamp.total_seconds() * 1e9
last_timestamp = timed_data[-1].tracking_timestamp.total_seconds() * 1e9
if query_timestamp_ns <= first_timestamp:
return None
elif query_timestamp_ns >= last_timestamp:
return None
# If this is safe we perform the Bisection search
start = 0
end = len(timed_data) - 1
while start < end:
mid = (start + end) // 2
mid_timestamp = timed_data[mid].tracking_timestamp.total_seconds() * 1e9
if mid_timestamp == query_timestamp_ns:
return mid
if mid_timestamp < query_timestamp_ns:
start = mid + 1
else:
end = mid - 1
# Get the value that is actually closest to the query timestamp
if start + 1 < len(timed_data):
start_timestamp = timed_data[start].tracking_timestamp.total_seconds() * 1e9
start_plus_1_timestamp = timed_data[start + 1].tracking_timestamp.total_seconds() * 1e9
if abs(start_timestamp - query_timestamp_ns) > abs(start_plus_1_timestamp - query_timestamp_ns):
start = start + 1
return start
SeaOtocinclus commented
Thank you, we will take a look to your code, or please do a PR for integration
We are also recommending to use the mpsDataProvider
projectaria_tools/core/python/MpsPyBind.h
Line 530 in 85ce967