facebookresearch/meru

Why is MERU interpolation done in tangent space as opposed to along a geodesic in hyperbolic space?

Opened this issue · 1 comments

ez2rok commented

As the title asks, why is MERU interpolation done in tangent space as opposed to along a geodesic in hyperbolic space?

It seems pretty reasonable to compute a geodesic between the root and point of interest on the hyperboloid. And then we could sample along this geodesic. Did you try this out?

I have several guesses as to why MERU interpolated in tangent space and not (along the geodesic) in hyperbolic space?

  1. Perhaps it is too complicated to sample along this geodesic? We know that in hyperbolic space distance grows exponentially so maybe moving along this geodesic would not be uniform.
  2. Maybe we wanted to interpolate in the tangent space (i.e. Euclidean space) so that this could be more easily compared to the CLIP model.

Any insight / clarification here would be appreciated. Thanks!

ez2rok commented

As a followup, I was taking a look at the interpolate function:

def interpolate(model, feats: torch.Tensor, root_feat: torch.Tensor, steps: int):
    """
    Interpolate between given feature vector and `[ROOT]` depending on model type.
    """

    # Linear interpolation between root and image features. For MERU, this happens
    # in the tangent space of the origin.
    if isinstance(model, MERU):
        feats = L.log_map0(feats, model.curv.exp())

    interp_feats = [
        torch.lerp(root_feat, feats, weight.item())
        for weight in torch.linspace(0.0, 1.0, steps=steps)
    ]
    interp_feats = torch.stack(interp_feats)

    # Lift on the Hyperboloid (for MERU), or L2 normalize (for CLIP).
    if isinstance(model, MERU):
        feats = L.log_map0(feats, model.curv.exp())
        interp_feats = L.exp_map0(interp_feats, model.curv.exp())
    else:
        interp_feats = torch.nn.functional.normalize(interp_feats, dim=-1)

    # Reverse the traversal order: (image first, root last)
    return interp_feats.flip(0)

Why do we apply the logarithmic map to feats twice via the line feats = L.log_map0(feats, model.curv.exp())? Perhaps the first time is meant to map the root, something like root_feat = L.log_map0(root_feat, model.curv.exp())?

Any help clarifying this would be greatly appreciated!