mitsuba-renderer/drjit

Tracking derivatives through wrapped texture sampling not working

daseyb opened this issue · 4 comments

Hi! I'm on the latest version of drjit and drjit-core (master, today) and I have the following problem (example code below):

I'm trying to take in a pytorch tensor, convert it to a drjit Texture, sample from it and then propagate gradients back to the pytorch tensor. This does not seem to work (the gradients of the tensor are always 0). Here is a script that reproduces this issue:

import drjit as dr
import torch


@dr.wrap(source="torch", target="drjit")
def wrap_texture_sample(tensor):
    image = dr.llvm.ad.Texture2f(
        tensor,
        filter_mode=dr.FilterMode.Linear,
        wrap_mode=dr.WrapMode.Repeat,
    )

    return dr.llvm.ad.Array1f(image.eval(dr.llvm.Array2f([0.5, 0.5])))


def no_wrap_texture_sample(tensor):
    image = dr.llvm.ad.Texture2f(
        tensor,
        filter_mode=dr.FilterMode.Linear,
        wrap_mode=dr.WrapMode.Repeat,
    )

    return dr.llvm.ad.Array1f(image.eval(dr.llvm.Array2f([0.5, 0.5])))


def test_texture():
    tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)
    tensor_drjit = dr.ones(dr.llvm.ad.TensorXf, shape=(1, 1, 1))
    dr.enable_grad(tensor_drjit)

    result_wrap = wrap_texture_sample(tensor_torch)
    result_wrap.backward()
    print("[Texture] Wrapped grad:\t\t", tensor_torch.grad)

    result_no_wrap = no_wrap_texture_sample(tensor_drjit)
    dr.backward(result_no_wrap)
    print("[Texture] Not wrapped grad:\t\t", tensor_drjit.grad)


@dr.wrap(source="torch", target="drjit")
def wrap_tensor_read(tensor):
    return tensor[0, 0, 0]


def no_wrap_tensor_read(tensor):
    return tensor[0, 0, 0]


def test_tensor():
    tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)
    tensor_drjit = dr.ones(dr.llvm.ad.TensorXf, shape=(1, 1, 1))
    dr.enable_grad(tensor_drjit)

    result_wrap = wrap_tensor_read(tensor_torch)
    result_wrap.backward()
    print("[Tensor] Wrapped grad:\t\t", tensor_torch.grad)

    result_no_wrap = no_wrap_tensor_read(tensor_drjit)
    dr.backward(result_no_wrap)
    print("[Tensor] Not wrapped grad:\t\t", tensor_drjit.grad)


if __name__ == "__main__":
    test_texture()
    test_tensor()

Expected output:

[Texture] Wrapped grad:          tensor([[[1.]]])
[Texture] Not wrapped grad:              [[[1]]]
[Tensor] Wrapped grad:           tensor([[[1.]]])
[Tensor] Not wrapped grad:               [[[1]]]

Actual output:

[Texture] Wrapped grad:          tensor([[[0.]]])
[Texture] Not wrapped grad:              [[[1]]]
[Tensor] Wrapped grad:           tensor([[[1.]]])
[Tensor] Not wrapped grad:               [[[1]]]

Is there an issue in my code or is this a bug? Thanks!

Interestingly, when I manually write a wrapper around the drjit operation, things seem to work:

class ManualTorchWrapper(Function):
    @staticmethod
    def forward(ctx, tensor):
        inputs_drjit = dr.llvm.ad.TensorXf(tensor)
        dr.enable_grad(inputs_drjit)

        outputs_drjit = no_wrap_texture_sample(inputs_drjit)

        ctx.inputs, ctx.output = inputs_drjit, outputs_drjit

        return outputs_drjit.torch()

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = dr.llvm.ad.Array1f(grad_output)
        dr.set_grad(ctx.output, grad_output)

        grad_input = dr.backward_to(ctx.inputs)

        return grad_input.torch()
        
def test_manual():
    tensor_torch = torch.ones((1, 1, 1), dtype=torch.float32, requires_grad=True)

    result_wrap = ManualTorchWrapper.apply(tensor_torch)
    result_wrap.backward()
    print("[Manual] Wrapped grad:\t\t", tensor_torch.grad)

This produces: [Manual] Wrapped grad: tensor([[[1.]]])

Hi @daseyb

Just to keep you updated, we've seen this and are still looking into it 🙇

Thank you!

The fix has been merged into master. Thank you for reporting this !