Export torchscript models to C++
Deams51 opened this issue · 10 comments
I've been testing the PPO implementation, and It doesn't seem like it is currently possible to export a model as a c++ compatible module.
Is it something you are planning?
If not, I could try to give it a go, though would appreciate it if you have any pointers.
Could you provide me a link with example what do you want?
I'll try to write some demo code for you.
From PyTorch doc, to use a model in c++, it first needs to be converted to torchscript: Example
As of now rl_games is simply dumping the pickle version of the model:
I've tried to trace or script a model:
script_module = torch.jit.script(model)
traced_module = torch.jit.trace(model, example_input)
Which fails due to torch not being able to infer properly all the types used and because incompatible types are used during forward (such as dicts): Torchscript supported types
Thanks,
I tested it some time ago and torch jit trace worked for me. But unfortunately I didn't see any performance improvements.
Ill make small demo for you tomorrow.
Btw I sometime ago I updated t pytorch to the 1.9
You are right dict doesn't work.
Looks like we need to use something like this:
https://github.com/facebookresearch/detectron2/blob/main/detectron2/export/flatten.py#L185
Will try it
Yeah I didn't see much performance gains from tracing simple models, though my experience with it is limited.
But I did see good gains on inference time when using traced pytorch models via the ONNX runtime. Seems like it is becoming the standard of runtime frameworks.
@Deams51 please take a look into the DM/test_trace branch:
I made test which works with cartpole:
python runner.py --train --file rl_games/configs/ppo_cartpole.yaml
I decided to use neural network directly because in model I return random actions and it didn't pass check_trace=True.
here is code example:
`
def save(self, fn):
import rl_games.algos_torch.flatten as flatten
inputs = {
'obs' : torch.zeros((1,4)).to(self.device),
'rnn_states' : None
}
with torch.no_grad():
adapter = flatten.TracingAdapter(self.model.a2c_network, inputs,allow_non_tensor=True)
traced = torch.jit.trace(adapter, adapter.flattened_inputs,check_trace=True)
flattened_outputs = traced(*adapter.flattened_inputs)
print(flattened_outputs)
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
`
Tried to run the test locally, but it failed during the check: error log
First diverging operator:
Node diff:
- %2 : __torch__.rl_games.algos_torch.network_builder.Network = prim::GetAttr[name="model"](%self.1)
+ %2 : __torch__.rl_games.algos_torch.network_builder.___torch_mangle_26.Network = prim::GetAttr[name="model"](%self.1)
I will reset my environment to root out a version issue with pytorch.
Same issue with a fresh conda environment, here is the dump: [conda_env_rl_games.yml](https://github.com/Denys88/rl_games/files/7244417/conda_env_rl_games.txt
Out of curiosity, ran it without the check to get a look at the model:
Graph looks strange. Ill doublecheck it with cartpole again. And will try to export continuous space network.
If export self.model directly it should not pass check trace because it randomly samples actions. But for the self.model.a2c_network it should be fine. No random operations inside of the graph.