agrimgupta92/sgan

A problem in trajectories.py

Opened this issue · 1 comments

I think there is a problem in this code segment.

def poly_fit(traj, traj_len, threshold):
"""
Input:
- traj: Numpy array of shape (2, traj_len)
- traj_len: Len of trajectory
- threshold: Minimum error to be considered for non linear traj
Output:
- int: 1 -> Non Linear 0-> Linear
"""
t = np.linspace(0, traj_len - 1, traj_len)
res_x = np.polyfit(t, traj[0, -traj_len:], 2, full=True)[1]
res_y = np.polyfit(t, traj[1, -traj_len:], 2, full=True)[1]
if res_x + res_y >= threshold:
return 1.0
else:
return 0.0

res_x and res_y are residuals of quadratic terms, so if the value of (res_x + res_y) is bigger than threshold, indicating that the trajectories don't conform to the quadratic distribution. But the code indicates Non Linear. Is not contradictory?

I think the code should be changed to the following:
res_x = np.polyfit(t, traj[0, -traj_len:], 1, full=True)[1]
res_y = np.polyfit(t, traj[1, -traj_len:], 1, full=True)[1]

Is my idea correct?