Denys88/rl_games

Export torchscript models to C++

Deams51 opened this issue · 10 comments

I've been testing the PPO implementation, and It doesn't seem like it is currently possible to export a model as a c++ compatible module.
Is it something you are planning?

If not, I could try to give it a go, though would appreciate it if you have any pointers.

Could you provide me a link with example what do you want?
I'll try to write some demo code for you.

From PyTorch doc, to use a model in c++, it first needs to be converted to torchscript: Example

As of now rl_games is simply dumping the pickle version of the model:

def save_scheckpoint(filename, state):

I've tried to trace or script a model:

script_module = torch.jit.script(model)
traced_module = torch.jit.trace(model, example_input)

Which fails due to torch not being able to infer properly all the types used and because incompatible types are used during forward (such as dicts): Torchscript supported types

Thanks,
I tested it some time ago and torch jit trace worked for me. But unfortunately I didn't see any performance improvements.
Ill make small demo for you tomorrow.
Btw I sometime ago I updated t pytorch to the 1.9

You are right dict doesn't work.
Looks like we need to use something like this:
https://github.com/facebookresearch/detectron2/blob/main/detectron2/export/flatten.py#L185
Will try it

Yeah I didn't see much performance gains from tracing simple models, though my experience with it is limited.
But I did see good gains on inference time when using traced pytorch models via the ONNX runtime. Seems like it is becoming the standard of runtime frameworks.

@Deams51 please take a look into the DM/test_trace branch:
I made test which works with cartpole:
python runner.py --train --file rl_games/configs/ppo_cartpole.yaml

I decided to use neural network directly because in model I return random actions and it didn't pass check_trace=True.
here is code example:

`
def save(self, fn):

    import rl_games.algos_torch.flatten as flatten
    inputs = {
        'obs' : torch.zeros((1,4)).to(self.device),
        'rnn_states' : None
    }
    with torch.no_grad():
        adapter = flatten.TracingAdapter(self.model.a2c_network, inputs,allow_non_tensor=True)
        traced = torch.jit.trace(adapter, adapter.flattened_inputs,check_trace=True)
        flattened_outputs = traced(*adapter.flattened_inputs)
        print(flattened_outputs)

    state = self.get_full_state_weights()
    torch_ext.save_checkpoint(fn, state)

`

Tried to run the test locally, but it failed during the check: error log

First diverging operator:
Node diff:
  - %2 : __torch__.rl_games.algos_torch.network_builder.Network = prim::GetAttr[name="model"](%self.1)
  + %2 : __torch__.rl_games.algos_torch.network_builder.___torch_mangle_26.Network = prim::GetAttr[name="model"](%self.1) 

I will reset my environment to root out a version issue with pytorch.

Same issue with a fresh conda environment, here is the dump: [conda_env_rl_games.yml](https://github.com/Denys88/rl_games/files/7244417/conda_env_rl_games.txt
Out of curiosity, ran it without the check to get a look at the model: cartpole_model

Graph looks strange. Ill doublecheck it with cartpole again. And will try to export continuous space network.
If export self.model directly it should not pass check trace because it randomly samples actions. But for the self.model.a2c_network it should be fine. No random operations inside of the graph.

@Deams51 unfortunately I didn't find a reason why it doesn't pass check_trace. Can try to use it with check_trace=False.
It works for me.