About using CPU backend as mock and unifying using multihost_utils wrappers along repo
qGentry opened this issue · 3 comments
Hi, currently I'm trying to benchmark various storages for checkpoint reads/writes using orbax in a large-scale setup.
To do so without allocating thousands of GPUs, I'm using XLA flag 'xla_force_host_platform_device_count=8' on 100+ CPU-only VMs to simulate 100+ 8GPU VMs.
Problem I'm facing rn is that it's looks like that XLA doesn't support any jitted computations for CPU backend.
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 446, in __init__
utils.sync_global_devices('CheckpointManager:create_directory')
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/utils.py", line 67, in sync_global_devices
multihost_utils.sync_global_devices(name)
File "/usr/local/lib/python3.10/dist-packages/jax/experimental/multihost_utils.py", line 88, in sync_global_devices
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
File "/usr/local/lib/python3.10/dist-packages/jax/experimental/multihost_utils.py", line 156, in assert_equal
expected = broadcast_one_to_all(in_tree)
File "/usr/local/lib/python3.10/dist-packages/jax/experimental/multihost_utils.py", line 80, in broadcast_one_to_all
out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Multiprocess computations aren't implemented on the CPU backend.
Is there anything i'm doing wrong? Maybe you can suggest me some best practices to work around this problem?
After digging into sources, the only jit usage i found in orbax is simple collective communication primitives (please correct me if i'm wrong) like sync_global_devices
and
broadcast_one_to_all
Anyway, as i simple workaround i'm trying to implement these primitives myself using shared filesystem and monkey patch these orbax functions like that:
import orbax.checkpoint as ocp
import self_implemented_sync
ocp.utils.sync_global_devices = self_implemented_sync.sync_global_devices
ocp.utils.broadcast_one_to_all = self_implemented_sync.broadcast_one_to_all
But i noticed that using wrappers is not unified across orbax, for example, here multihost_utils.broadcast_one_to_all
is used
orbax/checkpoint/orbax/checkpoint/utils.py
Line 697 in 6639066
instead of wrapped function defined in the same file
With orbax patching I've actually implemented JIT-less checkpoint save/write in distributed CPU-backend setup.
Patching sync_global_devices
and broadcast_one_to_all
was enough.
Ah ok nice, so this is solved for you?
Thanks for noting the instance of multihost_utils.broadcast_one_to_all
, I think the reason for the discrepancy is just that it's old code.
We might look into a way of doing this without patching as part of our efforts to improve our existing benchmarks.
Ah ok nice, so this is solved for you?
More or less, yes.
We might look into a way of doing this without patching as part of our efforts to improve our existing benchmarks.
That definitely would be great.