google/etils

[Enhancement] Add `optree` integration to `etils.etree`

XuehaiPan opened this issue · 1 comments

optree is a standalone package (like dm-tree) aimed to high-performance PyTree manipulation (like jax.tree_util). It offers similar APIs to jax.tree_util but better.

Some initial benchmark results:

Average Time Cost (↓) OpTree (v0.9.0) JAX XLA (v0.4.6) PyTorch (v2.0.0) TensorFlow Nest (v2.12.0) DM-Tree (v0.1.8)
Tree Flatten x1.00 2.33 22.05 1.38 1.12
Tree UnFlatten x1.00 2.69 4.28 13.69 16.23
Tree Flatten with Path x1.00 16.16 Not Supported 21.10 27.59
Tree Copy x1.00 2.56 9.97 9.62 11.02
Tree Map x1.00 2.56 9.58 9.16 10.62
Tree Map (nargs) x1.00 2.89 Not Supported 74.26 31.33
Tree Map with Path x1.00 7.23 Not Supported 40.78 19.66
Tree Map with Path (nargs) x1.00 6.56 Not Supported 69.63 29.61

We have already seen some etils folks get involved with optree and jax.tree_util discussions. I wonder if etils maintainers have interest to add optree to etils.etree.

Ref:

Good idea. Let me try this