Gradient flow from PyTorch to DrJiT
shhra opened this issue · 1 comments
I am currently working on a project where I need to pass gradients between PyTorch and Dr.Jit. While the forward pass works seamlessly, I am encountering difficulties during the backward pass. Unfortunately, the error messages provided are not particularly informative, making it challenging to diagnose the root cause.
To better illustrate the problem, I have attached a minimal code example that replicates the issue. I would greatly appreciate it if you could review the code and provide insights into what might be going wrong during the backward pass.
😄
import torch
import drjit as dr
@dr.wrap(source="drjit", target="torch")
def sum_a_b(a, b):
return a + b
@dr.wrap(source="torch", target="drjit")
def prod(a, b):
res = torch.tensor([0.0]).cuda()
for i in range(int(b.numpy().item())):
res = sum_a_b(res, a)
return res
def one_up(a):
return a + 1
if __name__ == '__main__':
a = torch.tensor([1.0], requires_grad=True, device="cuda")
b = torch.tensor([2.0]).cuda()
optim = torch.optim.Adam(params=[a], lr=1e-4)
target = torch.tensor([128], device="cuda")
for i in range(200):
optim.zero_grad()
res = prod(a, b)
res = one_up(res)
loss = target - res
print(f"Step: {i}; loss: {loss}, a: {a}")
loss.backward()
optim.step()
Hi @shhra
The only thing that is a bit weird is the initialisation of res
in prod()
. The code within that function should be drjit
code as it's marked as the target
framework. But I tried it myself, even with a drjit
variable it crashes.
I think we never thought of nesting @wrap
calls into each other. I'll have a deeper look into what's happening.