Why is MERU interpolation done in tangent space as opposed to along a geodesic in hyperbolic space?
Opened this issue · 1 comments
As the title asks, why is MERU interpolation done in tangent space as opposed to along a geodesic in hyperbolic space?
It seems pretty reasonable to compute a geodesic between the root and point of interest on the hyperboloid. And then we could sample along this geodesic. Did you try this out?
I have several guesses as to why MERU interpolated in tangent space and not (along the geodesic) in hyperbolic space?
- Perhaps it is too complicated to sample along this geodesic? We know that in hyperbolic space distance grows exponentially so maybe moving along this geodesic would not be uniform.
- Maybe we wanted to interpolate in the tangent space (i.e. Euclidean space) so that this could be more easily compared to the CLIP model.
Any insight / clarification here would be appreciated. Thanks!
As a followup, I was taking a look at the interpolate
function:
def interpolate(model, feats: torch.Tensor, root_feat: torch.Tensor, steps: int):
"""
Interpolate between given feature vector and `[ROOT]` depending on model type.
"""
# Linear interpolation between root and image features. For MERU, this happens
# in the tangent space of the origin.
if isinstance(model, MERU):
feats = L.log_map0(feats, model.curv.exp())
interp_feats = [
torch.lerp(root_feat, feats, weight.item())
for weight in torch.linspace(0.0, 1.0, steps=steps)
]
interp_feats = torch.stack(interp_feats)
# Lift on the Hyperboloid (for MERU), or L2 normalize (for CLIP).
if isinstance(model, MERU):
feats = L.log_map0(feats, model.curv.exp())
interp_feats = L.exp_map0(interp_feats, model.curv.exp())
else:
interp_feats = torch.nn.functional.normalize(interp_feats, dim=-1)
# Reverse the traversal order: (image first, root last)
return interp_feats.flip(0)
Why do we apply the logarithmic map to feats
twice via the line feats = L.log_map0(feats, model.curv.exp())
? Perhaps the first time is meant to map the root, something like root_feat = L.log_map0(root_feat, model.curv.exp())
?
Any help clarifying this would be greatly appreciated!