Tracking derivatives through wrapped texture sampling not working
daseyb opened this issue · 4 comments
Hi! I'm on the latest version of drjit and drjit-core (master, today) and I have the following problem (example code below):
I'm trying to take in a pytorch tensor, convert it to a drjit Texture, sample from it and then propagate gradients back to the pytorch tensor. This does not seem to work (the gradients of the tensor are always 0). Here is a script that reproduces this issue:
import drjit as dr
import torch
@dr.wrap(source="torch", target="drjit")
def wrap_texture_sample(tensor):
image = dr.llvm.ad.Texture2f(
tensor,
filter_mode=dr.FilterMode.Linear,
wrap_mode=dr.WrapMode.Repeat,
)
return dr.llvm.ad.Array1f(image.eval(dr.llvm.Array2f([0.5, 0.5])))
def no_wrap_texture_sample(tensor):
image = dr.llvm.ad.Texture2f(
tensor,
filter_mode=dr.FilterMode.Linear,
wrap_mode=dr.WrapMode.Repeat,
)
return dr.llvm.ad.Array1f(image.eval(dr.llvm.Array2f([0.5, 0.5])))
def test_texture():
tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)
tensor_drjit = dr.ones(dr.llvm.ad.TensorXf, shape=(1, 1, 1))
dr.enable_grad(tensor_drjit)
result_wrap = wrap_texture_sample(tensor_torch)
result_wrap.backward()
print("[Texture] Wrapped grad:\t\t", tensor_torch.grad)
result_no_wrap = no_wrap_texture_sample(tensor_drjit)
dr.backward(result_no_wrap)
print("[Texture] Not wrapped grad:\t\t", tensor_drjit.grad)
@dr.wrap(source="torch", target="drjit")
def wrap_tensor_read(tensor):
return tensor[0, 0, 0]
def no_wrap_tensor_read(tensor):
return tensor[0, 0, 0]
def test_tensor():
tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)
tensor_drjit = dr.ones(dr.llvm.ad.TensorXf, shape=(1, 1, 1))
dr.enable_grad(tensor_drjit)
result_wrap = wrap_tensor_read(tensor_torch)
result_wrap.backward()
print("[Tensor] Wrapped grad:\t\t", tensor_torch.grad)
result_no_wrap = no_wrap_tensor_read(tensor_drjit)
dr.backward(result_no_wrap)
print("[Tensor] Not wrapped grad:\t\t", tensor_drjit.grad)
if __name__ == "__main__":
test_texture()
test_tensor()
Expected output:
[Texture] Wrapped grad: tensor([[[1.]]])
[Texture] Not wrapped grad: [[[1]]]
[Tensor] Wrapped grad: tensor([[[1.]]])
[Tensor] Not wrapped grad: [[[1]]]
Actual output:
[Texture] Wrapped grad: tensor([[[0.]]])
[Texture] Not wrapped grad: [[[1]]]
[Tensor] Wrapped grad: tensor([[[1.]]])
[Tensor] Not wrapped grad: [[[1]]]
Is there an issue in my code or is this a bug? Thanks!
Interestingly, when I manually write a wrapper around the drjit operation, things seem to work:
class ManualTorchWrapper(Function):
@staticmethod
def forward(ctx, tensor):
inputs_drjit = dr.llvm.ad.TensorXf(tensor)
dr.enable_grad(inputs_drjit)
outputs_drjit = no_wrap_texture_sample(inputs_drjit)
ctx.inputs, ctx.output = inputs_drjit, outputs_drjit
return outputs_drjit.torch()
@staticmethod
def backward(ctx, grad_output):
grad_output = dr.llvm.ad.Array1f(grad_output)
dr.set_grad(ctx.output, grad_output)
grad_input = dr.backward_to(ctx.inputs)
return grad_input.torch()
def test_manual():
tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)
result_wrap = ManualTorchWrapper.apply(tensor_torch)
result_wrap.backward()
print("[Manual] Wrapped grad:\t\t", tensor_torch.grad)
This produces: [Manual] Wrapped grad: tensor([[[1.]]])
Hi @daseyb
Just to keep you updated, we've seen this and are still looking into it 🙇
Thank you!
The fix has been merged into master
. Thank you for reporting this !