pytorch/extension-cpp

How does the layer of C++ extensions translate to TorchScript or onnx?

yanglinxiabuaaa opened this issue · 1 comments

How does the layer of C++ extensions translate to TorchScript or onnx?

test code is :

batch_size = 16
input_features = 32
state_size = 128

X = torch.randn(batch_size, input_features)
h = torch.randn(batch_size, state_size)
C = torch.randn(batch_size, state_size)

rnn = LLTM(input_features, state_size)

inputs = (X, (h, C))

traced = torch.jit.trace(rnn, inputs)
print(traced.graph)
torch.jit.save(traced, "lltm.pt")

graph(%self : torch.torch.nn.modules.module.Module,
%input : Float(16, 32),
%5 : (Float(16, 128), Float(16, 128))):
%39 : Tensor = prim::GetAttrname="bias"
%38 : Tensor = prim::GetAttrname="weights"
%old_h : Float(16, 128), %old_cell : Float(16, 128) = prim::TupleUnpack(%5)
%34 : (Tensor, Tensor) = ^LLTMFunction()(%input, %38, %39, %old_h, %old_cell) # /workspace/yanglinxia/CenterNet/torchscript/lltm-extension/LLTM.py:42:0
%35 : Float(16, 128), %36 : Float(16, 128) = prim::TupleUnpack(%34)
%37 : (Float(16, 128), Float(16, 128)) = prim::TupleConstruct(%35, %36)
return (%37)

Traceback (most recent call last):
File "script.py", line 22, in
torch.jit.save(traced, "lltm.pt")
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py", line 153, in save
m.save(f, _extra_files=_extra_files)
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py", line 1626, in save
return self._c.save(*args, **kwargs)
RuntimeError:
Could not export Python function call 'LLTMFunction'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants:
/workspace/yanglinxia/CenterNet/torchscript/lltm-extension/LLTM.py(42): forward
/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py(516): _slow_forward
/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py(530): call
/usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py(1034): trace_module
/usr/local/anaconda3/lib/python3.6/site-packages/torch/jit/init.py(882): trace
script.py(19):