google/orbax

About using CPU backend as mock and unifying using multihost_utils wrappers along repo

Closed this issue · 3 comments

Hi, currently I'm trying to benchmark various storages for checkpoint reads/writes using orbax in a large-scale setup.
To do so without allocating thousands of GPUs, I'm using XLA flag 'xla_force_host_platform_device_count=8' on 100+ CPU-only VMs to simulate 100+ 8GPU VMs.
Problem I'm facing rn is that it's looks like that XLA doesn't support any jitted computations for CPU backend.

  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 446, in __init__
    utils.sync_global_devices('CheckpointManager:create_directory')
  File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/utils.py", line 67, in sync_global_devices
    multihost_utils.sync_global_devices(name)
  File "/usr/local/lib/python3.10/dist-packages/jax/experimental/multihost_utils.py", line 88, in sync_global_devices
    assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
  File "/usr/local/lib/python3.10/dist-packages/jax/experimental/multihost_utils.py", line 156, in assert_equal
    expected = broadcast_one_to_all(in_tree)
  File "/usr/local/lib/python3.10/dist-packages/jax/experimental/multihost_utils.py", line 80, in broadcast_one_to_all
    out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Multiprocess computations aren't implemented on the CPU backend.

Is there anything i'm doing wrong? Maybe you can suggest me some best practices to work around this problem?

After digging into sources, the only jit usage i found in orbax is simple collective communication primitives (please correct me if i'm wrong) like sync_global_devices

def sync_global_devices(name: str):

and broadcast_one_to_all
def broadcast_one_to_all(pytree: PyTree) -> PyTree:

Anyway, as i simple workaround i'm trying to implement these primitives myself using shared filesystem and monkey patch these orbax functions like that:

import orbax.checkpoint as ocp
import self_implemented_sync

ocp.utils.sync_global_devices = self_implemented_sync.sync_global_devices
ocp.utils.broadcast_one_to_all = self_implemented_sync.broadcast_one_to_all

But i noticed that using wrappers is not unified across orbax, for example, here multihost_utils.broadcast_one_to_all is used

padded_step_list = multihost_utils.broadcast_one_to_all(padded_step_list)

instead of wrapped function defined in the same file
def broadcast_one_to_all(pytree: PyTree) -> PyTree:

With orbax patching I've actually implemented JIT-less checkpoint save/write in distributed CPU-backend setup.
Patching sync_global_devices and broadcast_one_to_all was enough.

Ah ok nice, so this is solved for you?

Thanks for noting the instance of multihost_utils.broadcast_one_to_all, I think the reason for the discrepancy is just that it's old code.

We might look into a way of doing this without patching as part of our efforts to improve our existing benchmarks.

Ah ok nice, so this is solved for you?

More or less, yes.

We might look into a way of doing this without patching as part of our efforts to improve our existing benchmarks.

That definitely would be great.