mitsuba-renderer/drjit

Is is possible to call Torch's functions (or model) from Dr.Jit's Loop?

annadapb opened this issue · 1 comments

I am trying to call Torch's function from Dr.Jit loop and it seems that the function call is not happening. Here is the MWE,

import drjit
import torch
from drjit.llvm.ad import Float, Array3f, Loop, UInt32
from torch import nn

# drjit.set_log_level(drjit.LogLevel.Info)

class mlp(nn.Module):
    def __init__(self,): super().__init__()
    def forward(self, x): return x.reshape(1, -1).squeeze(dim=0)

def wrap(fn):
    def wrapper(*argv):
        torch_tensor = argv[0].torch()
        torch_return = model(torch_tensor)
        drjit_return = Float(torch_return)
        return drjit_return
    return wrapper

model = mlp()
x = torch.arange(12, dtype=torch.float32).reshape(4, 3)
y = 0.

for i in range(10):
     y += model(x)
print(y)

p = drjit.arange(Float, 12)
q = drjit.unravel(Array3f, p)
g = Float(0.)

i = UInt32(10)
loop = Loop('Loop test', lambda: (i))
while loop(i<10):
    g += wrap(model)(q)
    i += 1
print(g.torch())

g = Float(0.)
for i in range(10):
    g += wrap(model)(q)
print(g.torch())

pass

which outputs,

tensor([  0.,  10.,  20.,  30.,  40.,  50.,  60.,  70.,  80.,  90., 100., 110.])
tensor([5.4735e-03, 4.5748e-41, 5.4735e-03, 4.5748e-41, 0.0000e+00, 9.1834e-41,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
tensor([  0.,  10.,  20.,  30.,  40.,  50.,  60.,  70.,  80.,  90., 100., 110.])

From python's (unrolled?) loop, I am getting the expected results but from Dr.Jit's recorded loop I am getting a zero array. What would be the correct way to do it?

Hi @annadapb

Indeed this is not possible. The Dr.Jit recorded loop are only meant to record Dr.Jit operations. You will need to unroll the loop if you want to weave-in operations from another library.