Is is possible to call Torch's functions (or model) from Dr.Jit's Loop?
annadapb opened this issue · 1 comments
annadapb commented
I am trying to call Torch's function from Dr.Jit loop and it seems that the function call is not happening. Here is the MWE,
import drjit
import torch
from drjit.llvm.ad import Float, Array3f, Loop, UInt32
from torch import nn
# drjit.set_log_level(drjit.LogLevel.Info)
class mlp(nn.Module):
def __init__(self,): super().__init__()
def forward(self, x): return x.reshape(1, -1).squeeze(dim=0)
def wrap(fn):
def wrapper(*argv):
torch_tensor = argv[0].torch()
torch_return = model(torch_tensor)
drjit_return = Float(torch_return)
return drjit_return
return wrapper
model = mlp()
x = torch.arange(12, dtype=torch.float32).reshape(4, 3)
y = 0.
for i in range(10):
y += model(x)
print(y)
p = drjit.arange(Float, 12)
q = drjit.unravel(Array3f, p)
g = Float(0.)
i = UInt32(10)
loop = Loop('Loop test', lambda: (i))
while loop(i<10):
g += wrap(model)(q)
i += 1
print(g.torch())
g = Float(0.)
for i in range(10):
g += wrap(model)(q)
print(g.torch())
pass
which outputs,
tensor([ 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., 110.])
tensor([5.4735e-03, 4.5748e-41, 5.4735e-03, 4.5748e-41, 0.0000e+00, 9.1834e-41,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
tensor([ 0., 10., 20., 30., 40., 50., 60., 70., 80., 90., 100., 110.])
From python's (unrolled?) loop, I am getting the expected results but from Dr.Jit's recorded loop I am getting a zero array. What would be the correct way to do it?
njroussel commented
Hi @annadapb
Indeed this is not possible. The Dr.Jit recorded loop are only meant to record Dr.Jit operations. You will need to unroll the loop if you want to weave-in operations from another library.