mitsuba-renderer/drjit

Gradient flow from PyTorch to DrJiT

shhra opened this issue · 1 comments

I am currently working on a project where I need to pass gradients between PyTorch and Dr.Jit. While the forward pass works seamlessly, I am encountering difficulties during the backward pass. Unfortunately, the error messages provided are not particularly informative, making it challenging to diagnose the root cause.

To better illustrate the problem, I have attached a minimal code example that replicates the issue. I would greatly appreciate it if you could review the code and provide insights into what might be going wrong during the backward pass.
😄

import torch
import drjit as dr


@dr.wrap(source="drjit", target="torch")
def sum_a_b(a, b):
    return a + b


@dr.wrap(source="torch", target="drjit")
def prod(a, b):
    res = torch.tensor([0.0]).cuda()
    for i in range(int(b.numpy().item())):
        res = sum_a_b(res, a)
    return res


def one_up(a):
    return a + 1


if __name__ == '__main__':
    
    a = torch.tensor([1.0], requires_grad=True, device="cuda")
    b = torch.tensor([2.0]).cuda()
    
    optim = torch.optim.Adam(params=[a], lr=1e-4)
    target = torch.tensor([128], device="cuda")
    
    for i in range(200):
        optim.zero_grad()
        res = prod(a, b)
        res = one_up(res)
        loss = target - res
        print(f"Step: {i}; loss: {loss}, a: {a}")
        loss.backward()
        optim.step()

Hi @shhra
The only thing that is a bit weird is the initialisation of res in prod(). The code within that function should be drjit code as it's marked as the target framework. But I tried it myself, even with a drjit variable it crashes.

I think we never thought of nesting @wrap calls into each other. I'll have a deeper look into what's happening.