araffin/sbx

[Question] TypeError when exporting a model to PyTorch in SBX

LennertEvens opened this issue ยท 3 comments

๐Ÿ› Bug

When using PyTorch JIT to trace and save a trained model with SBX an exception occurs.

To Reproduce

The following code works fine for a model trained with TD3 with SB3. However, a TypeError occurs when trying to save a model trained with SBX.

import torch as th
from stable_baselines3.common.policies import BasePolicy
from sbx import TD3
from typing import Tuple
import torch as th

class OnnxableSB3Policy(th.nn.Module):
    def __init__(self, policy: BasePolicy):
        super().__init__()
        self.policy = policy

    def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        return self.policy(observation, deterministic=True)
    
jit_path = "model.pt"

cuda_id = th.cuda.current_device()
model = TD3.load("model", device=cuda_id)
onnxable_model = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size).to(device=cuda_id)

# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
Traceback (most recent call last):
  File "/home/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 584, in shaped_abstractify
    return _shaped_abstractify_handlers[type(x)](x)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^
KeyError: <class 'torch.Tensor'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/export_to_pt.py", line 33, in <module>
    traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/jit/_trace.py", line 806, in trace
    return trace_module(
           ^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/jit/_trace.py", line 1074, in trace_module
    module._c._create_method_from_trace(
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/export_to_pt.py", line 21, in forward
    return self.policy(observation, deterministic=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/sbx/td3/policies.py", line 178, in forward
    return self._predict(obs, deterministic=deterministic)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.venv/lib/python3.12/site-packages/sbx/td3/policies.py", line 187, in _predict
    return TD3Policy.select_action(self.actor_state, observation)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot interpret 'torch.float32' as a data type
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

### System Info

- OS: Linux-4.18.0-513.11.1.el8_9.0.1.x86_64-x86_64-with-glibc2.28 # 1 SMP Sun Feb 11 10:42:18 UTC 2024
- Python: 3.12.1
- Stable-Baselines3: 2.3.0a2
- PyTorch: 2.2.1+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
- OpenAI Gym: 0.26.2

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Hello,
Why would you trace a Jax module with PyTorch?

Hello, Why would you trace a Jax module with PyTorch?

The goal is to use the traced model for inference in C/C++ applications. The significant speedup during training is a huge advantage of SBX over SB3.

The goal is to use the traced model for inference in C/C++ applications. The significant speedup during training is a huge advantage of SBX over SB3.

Then you need to use ONNX with Jax.
Apparently, you need to convert it first to TF:
google/jax#7629 (comment)

Otherwise, you need to manually re-create the policy architecture in PyTorch and load the exported weights into it.