naver/roma

`shortest_path` bug in `unitquat_to_rotvec()`

Closed this issue · 5 comments

roma/roma/utils.py

Lines 255 to 280 in 22806df

def unitquat_slerp(q0, q1, steps, shortest_path=False):
"""
Spherical linear interpolation between two unit quaternions.
Args:
q0, q1 (Ax4 tensor): batch of unit quaternions (A may contain multiple dimensions).
steps (tensor of shape B): interpolation steps, 0.0 corresponding to q0 and 1.0 to q1 (B may contain multiple dimensions).
shortest_path (boolean): if True, interpolation will be performed along the shortest path on SO(3).
Returns:
batch of interpolated quaternions (BxAx4 tensor).
Note:
When considering quaternions as rotation representations,
one should keep in mind that spherical interpolation is not necessarily performed along the shortest arc,
depending on the sign of ``torch.sum(q0*q1,dim=-1)``.
"""
if shortest_path:
# Flip some quaternions to ensure the shortest path interpolation
q1 = -torch.sign(torch.sum(q0*q1, dim=-1, keepdim=True)) * q1
# Relative rotation
rel_q = quat_product(quat_conjugation(q0), q1)
rel_rotvec = roma.mappings.unitquat_to_rotvec(rel_q)
# Relative rotations to apply
rel_rotvecs = steps.reshape(steps.shape + (1,) * rel_rotvec.dim()) * rel_rotvec.reshape((1,) * steps.dim() + rel_rotvec.shape)
rots = roma.mappings.rotvec_to_unitquat(rel_rotvecs.reshape(-1, 3)).reshape(*rel_rotvecs.shape[:-1], 4)
interpolated_q = quat_product(q0.reshape((1,) * steps.dim() + q0.shape).repeat(steps.shape + (1,) * q0.dim()), rots)
return interpolated_q

  1. On Line 272, q1 = -torch.sign(torch.sum(q0*q1, dim=-1, keepdim=True)) * q1 should be q1 = torch.sign(torch.sum(q0*q1, dim=-1, keepdim=True)) * q1, since one of the quaternions only need to be flipped when $cos \Omega < 0$, hence dot product $< 0$
  2. The shortest_path flag is actually useless, since unitquat_to_rotvec() (used on L275) always returns rotation vectors with angles within $[0, \pi]$

Thank you for opening this issue !
Indeed, it seems that the current implementation of this function is bugged:

  • As you mentioned, slerp is always performed along the shortest arc here.
  • Additionally, torch.sign could return 0 in some (rare) cases, leading to a wrong interpolation (i.e. ignoring q1 and returning q0).

These bugs must have appeared during the refactoring of the code, before open-sourcing the library. I will try to fix them in the next release.
Again, thanks for reporting the issue.

The issue should be solved with the commit c460d15.
I will try to push a new pip release next week, with some minor additional features.

Version 1.3.0 should have solved the issue. Thank you for the feedback, and feel free to reopen the issue if needed.

@rbregier Thanks for the fix! I realized there is an in-place operation here in unitquat_to_rotvec(), is it better if we use out-of-place operations instead?

Thanks for reviewing changes.
Actually, there is a copy just before this line, so it should not be a problem:

quat = quat.clone()