[Enhancement] Add `optree` integration to `etils.etree`
XuehaiPan opened this issue · 1 comments
XuehaiPan commented
optree
is a standalone package (like dm-tree
) aimed to high-performance PyTree manipulation (like jax.tree_util
). It offers similar APIs to jax.tree_util
but better.
Some initial benchmark results:
Average Time Cost (↓) | OpTree (v0.9.0) | JAX XLA (v0.4.6) | PyTorch (v2.0.0) | TensorFlow Nest (v2.12.0) | DM-Tree (v0.1.8) |
---|---|---|---|---|---|
Tree Flatten | x1.00 | 2.33 | 22.05 | 1.38 | 1.12 |
Tree UnFlatten | x1.00 | 2.69 | 4.28 | 13.69 | 16.23 |
Tree Flatten with Path | x1.00 | 16.16 | Not Supported | 21.10 | 27.59 |
Tree Copy | x1.00 | 2.56 | 9.97 | 9.62 | 11.02 |
Tree Map | x1.00 | 2.56 | 9.58 | 9.16 | 10.62 |
Tree Map (nargs) | x1.00 | 2.89 | Not Supported | 74.26 | 31.33 |
Tree Map with Path | x1.00 | 7.23 | Not Supported | 40.78 | 19.66 |
Tree Map with Path (nargs) | x1.00 | 6.56 | Not Supported | 69.63 | 29.61 |
We have already seen some etils
folks get involved with optree
and jax.tree_util
discussions. I wonder if etils
maintainers have interest to add optree
to etils.etree
.
Ref:
Conchylicultor commented
Good idea. Let me try this