[Question] TypeError when exporting a model to PyTorch in SBX
LennertEvens opened this issue ยท 3 comments
๐ Bug
When using PyTorch JIT to trace and save a trained model with SBX an exception occurs.
To Reproduce
The following code works fine for a model trained with TD3 with SB3. However, a TypeError occurs when trying to save a model trained with SBX.
import torch as th
from stable_baselines3.common.policies import BasePolicy
from sbx import TD3
from typing import Tuple
import torch as th
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.policy = policy
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
return self.policy(observation, deterministic=True)
jit_path = "model.pt"
cuda_id = th.cuda.current_device()
model = TD3.load("model", device=cuda_id)
onnxable_model = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size).to(device=cuda_id)
# Trace and optimize the module
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
frozen_module = th.jit.freeze(traced_module)
frozen_module = th.jit.optimize_for_inference(frozen_module)
th.jit.save(frozen_module, jit_path)
Traceback (most recent call last):
File "/home/.venv/lib/python3.12/site-packages/jax/_src/api_util.py", line 584, in shaped_abstractify
return _shaped_abstractify_handlers[type(x)](x)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^
KeyError: <class 'torch.Tensor'>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/export_to_pt.py", line 33, in <module>
traced_module = th.jit.trace(onnxable_model.eval(), dummy_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/jit/_trace.py", line 806, in trace
return trace_module(
^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/jit/_trace.py", line 1074, in trace_module
module._c._create_method_from_trace(
File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/export_to_pt.py", line 21, in forward
return self.policy(observation, deterministic=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/sbx/td3/policies.py", line 178, in forward
return self._predict(obs, deterministic=deterministic)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/.venv/lib/python3.12/site-packages/sbx/td3/policies.py", line 187, in _predict
return TD3Policy.select_action(self.actor_state, observation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot interpret 'torch.float32' as a data type
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
### System Info
- OS: Linux-4.18.0-513.11.1.el8_9.0.1.x86_64-x86_64-with-glibc2.28 # 1 SMP Sun Feb 11 10:42:18 UTC 2024
- Python: 3.12.1
- Stable-Baselines3: 2.3.0a2
- PyTorch: 2.2.1+cu121
- GPU Enabled: True
- Numpy: 1.26.4
- Cloudpickle: 3.0.0
- Gymnasium: 0.29.1
- OpenAI Gym: 0.26.2
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Hello,
Why would you trace a Jax module with PyTorch?
Hello, Why would you trace a Jax module with PyTorch?
The goal is to use the traced model for inference in C/C++ applications. The significant speedup during training is a huge advantage of SBX over SB3.
The goal is to use the traced model for inference in C/C++ applications. The significant speedup during training is a huge advantage of SBX over SB3.
Then you need to use ONNX with Jax.
Apparently, you need to convert it first to TF:
google/jax#7629 (comment)
Otherwise, you need to manually re-create the policy architecture in PyTorch and load the exported weights into it.