facebookresearch/theseus

Implicit backprop with batched matrix Lie group variables

holmesco opened this issue ยท 1 comments

๐Ÿ› Bug

When using implicit backpropagation and batched matrix lie group variables, the following error occurs in _merge_infos method in the NonlinearOptimizer class when calling forward():

image

Note: Batched dimension is 2.

The error is due to the dimension mismatch between idx_no_grad and the sol_no_grad/sol_grad tensors. The variable tensors have three dimensions because they are batched SO3 tensors, and the SO3 tensors are represented by 3x3 matrices.

Steps to Reproduce

Example Code:

import theseus as th
import torch
import numpy as np

torch.set_default_dtype(torch.float64)


def error(optim_vars, aux_vars):
    Cs = optim_vars[0]
    a = torch.tensor(np.array([[1.0, 1.0, 1.0]]).T)
    error = a - Cs.tensor @ a
    return error.squeeze(-1)


Cs = th.SO3(name="Cs")
objective = th.Objective()
optim_vars = [Cs]
aux_vars = []
cost_function = th.AutoDiffCostFunction(
    optim_vars=optim_vars,
    dim=3,
    err_fn=error,
    aux_vars=aux_vars,
    cost_weight=th.ScaleCostWeight(1.0),
    name="trace",
)
objective.add(cost_function)
layer = th.TheseusLayer(th.GaussNewton(objective, max_iterations=20))


inputs = {"Cs": torch.randn(2, 3, 3)}
with torch.no_grad():
    vars, info = layer.forward(
        inputs,
        optimizer_kwargs={
            "track_best_solution": True,
            "verbose": True,
            "backward_mode": "implicit",
        },
    )

Expected behavior

Should get the dimension right. A workaround would be to check the dimension of sol_grad and unsqueeze idx_no_grad accordingly, but I don't know if this is the best approach:

image

System Info

  • OS (e.g., Linux): Linux
  • Python version: 3.9.18
  • Using CPU version
  • pip/conda dependencies packages versions:
torch                     2.1.1+cpu
torchaudio                2.1.1+cpu
torchkin                  0.1.1
torchlie                  0.1.0
torchvision               0.16.1+cpu

Thanks for reporting, @holmesco! Added a fix in the PR linked above.