google/tree-math

How well does tree-math support computation on multiple devices?

connection-on-fiber-bundles opened this issue · 1 comments

Wondering how well tree-math supports computation on multiple devices?

Let's say we have a pytree of tensors of different dimensions and want to perform some operations on each of them with tree-math, can we distribute those tasks to multiple devices (GPU, for instance)?

This should not be a problem. Tree-Math is entirely agnostic to JAX's multi-device APIs. It's just syntactic sugar for jax.tree_util.tree_map.