DLR-RM/stable-baselines3

[Feature Request] torch compile / integrating intel extension for pytorch

george-adams1 opened this issue ยท 6 comments

๐Ÿš€ Feature

Request to integrate the Intel extension for PyTorch into sb3. The Intel extension for PyTorch optimizes the PyTorch library to better utilize the computational capabilities of Intel processors. By integrating this extension, SB3 users that utilize Intel processors could potentially experience significant performance improvements.

Motivation

Maximizing the performance of Intel processors. This can lead to faster training times and more efficient resource utilization, improving the overall experience for users of SB3 who use Intel processors.

Pitch

The integration of the Intel extension for PyTorch into SB3 would involve modifying the library to utilize the extension when it detects that it is running on an Intel processor.

Alternatives

An alternative to this would be to provide guidelines on how users can manually integrate the Intel extension into their SB3 setups.

Additional context

No response

Checklist

  • I have checked that there is no similar issue in the repo

Hello,
could you give some pointer/example on how to integrate that extension?

will you be willing to contribute such extension?

I believe it was a reference to this project: https://github.com/intel/intel-extension-for-pytorch

The latest version is compatible with torch.compile. This means that it is now much easier to leverage their optimizations without a complete rework of the libraries relying on pytorch. Here is the example script. Yet, it is describe here as an inference only backend.

Thanks for the pointers @duburcqa =)!

I see. In that case, it is already supported as SB3 exposes its PyTorch policy via the .policy attribute (this is a nn.Module, see doc).

Related: #1391 and #1439

You should also know that the bottleneck at train time is the gradient update, so it won't help you much to optimize inference time (unless you are at test time), we do have an experimental Jax version (SBX) if you need significant speed up (see README).

I would actually welcome a PR that shows how to use th.compile in our doc =)