How to restore a variable from checkpoint saved in cpu back in cpu when you have both gpu and cpu?
PriyeshV opened this issue · 5 comments
I get the following error when I try to restore,
ValueError: SingleDeviceSharding with Device=TFRT_CPU_0 was not found in jax.local_devices().
Despite enclosing the statements within a CPU device scope, like below, the visible device is only cuda and not CPU.
with jax.default_device(jax.devices('cpu')[0]):
print(jax.devices())
print(jax.local_devices())
variables = Mngr.restore(start_step)
Could you give me some pointers on how to handle this?
PS: This is for a DQN code where I'm trying to save the replay buffer (FlashBAX) from the CPU and network parameters from the GPU. I saved the buffer and parameters, but restoration has been an issue.
Thank You
You might have saved an array with a sharding that is incompatible with your current device setup inside the CPU scope. You'll need to specify restore_args
to communicate the sharding that you want for each array in the tree.
Hi,
I don't have any sharding settings specified for the variable.
Below is the entirety of the code.
PS: I'm running on a machine with a GPU
import jax
import jax.numpy as jnp
import chex
import flashbax as fbx
import orbax.checkpoint as ocp
@chex.dataclass(frozen=True)
class TimeStep:
observation: chex.Array
action: chex.Array
reward: chex.Array
done: chex.Array
with jax.default_device(jax.devices('cpu')[0]):
rb = fbx.make_flat_buffer(max_length=10000, min_length=1000,
sample_batch_size=512, add_sequences=False, add_batch_size=None)
rb = rb.replace(init=jax.jit(rb.init), add=jax.jit(rb.add, donate_argnums=0), sample=jax.jit(rb.sample),
can_sample=jax.jit(rb.can_sample))
dummy_timestep = TimeStep(observation=jnp.ones((84, 84, 4), dtype=jnp.uint8), action=jnp.int32(0),
reward=jnp.float32(0.0), done=jnp.bool_(True))
rb_state = rb.init(dummy_timestep)
mngr_options = ocp.CheckpointManagerOptions(max_to_keep=1, save_interval_steps=1)
Mngr = ocp.CheckpointManager('/home/mila/v/vijayanp/Test', {'rb_state': ocp.PyTreeCheckpointer()}, mngr_options)
Mngr.save(0, {'rb_state': rb_state})
Mngr.wait_until_finished()
rb_variables = Mngr.restore(Mngr.latest_step())
Example if you're trying to restore on device.
def make_restore_arg(arr):
return ocp.ArrayRestoreArgs(sharding=...)
restore_args = jax.tree_util.tree_map(make_restore_arg, rb_state)
Mngr.restore(Mngr.latest_step(), restore_kwargs={'rb_state': {'restore_args': restore_args})
I'm unclear on what you want to do exactly. Maybe you want to restore in CPU memory (as numpy arrays).
def make_restore_arg(arr):
return ocp.RestoreArgs(restore_type=np.ndarray)
restore_args = jax.tree_util.tree_map(make_restore_arg, rb_state)
Mngr.restore(Mngr.latest_step(), restore_kwargs={'rb_state': {'restore_args': restore_args})
I'm sorry for being unclear. I'll try to explain again, if you don't mind.
Objective: Load a variable (originally in CPU) into CPU memory.
Issue: When I call restore, it tries to load the CPU object but in vain. It throws the following error,
ValueError: SingleDeviceSharding with Device=TFRT_CPU_0 was not found in jax.local_devices().
My understanding:
- Orbax looks for CPU in jax.local_devices() instead of jax.devices() to restore, but CPU is unavailable there.
I think you do want to do this then:
def make_restore_arg(arr):
return ocp.RestoreArgs(restore_type=np.ndarray)
restore_args = jax.tree_util.tree_map(make_restore_arg, rb_state)
Mngr.restore(Mngr.latest_step(), restore_kwargs={'rb_state': {'restore_args': restore_args})
I think what's happening is that with jax.default_device(jax.devices('cpu')[0])
creates arrays with SingleDeviceSharding(device=TFRT_CPU_0)
. This gets recorded in the sharding metadata in the checkpoint. When you don't provide restore_args
with sharding
property specified, it tries to use the sharding metadata to restore the arrays. But I guess because you're not including the with
again, it's not able to reconstruct the original sharding as recorded in the metadata.
To restore, you would need to either:
- Include the
with
(not completely sure if this would work or not) - Provide
restore_args
such that the same sharding as was used to save is used to restore - Provide restore_type=np.ndarray to restore as numpy array in memory.