patrick-kidger/jaxtyping

Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'

jaanli opened this issue · 9 comments

I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:

joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.

File "/root/optimizer.py", line 89, in run_train_on_modal optimizer_state = optax.tree_utils.tree_set(optimizer_state, inner_state=cloudpickle.load(f)) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^ File "/usr/lib/python3.11/weakref.py", line 428, in __setitem__ self.data[ref(key, self._remove)] = value ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 321, in __hash__ return hash(cls._get_props()) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 306, in _get_props cls.index_variadic, ^^^^^^^^^^^^^^^^^^ AttributeError: type object 'Float[Array, '*shape']' has no attribute 'index_variadic'

Do you have a MWE?

Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state.

`lr_scheduler = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=learning_rate,
warmup_steps=warmup_iters if init_from == 'scratch' else 0,
decay_steps=lr_decay_iters - iter_num,
end_value=min_lr,
)
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)

optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))

checkpoint_params = {
"optimizer_state": optimizer_state
}

with open(checkpoint_params_file, "wb") as f:
cloudpickle.dump(checkpoint_params, f)`

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

I am also encountering this issue, but only with ray (just using cloudpickle on its own seems to work now). This MWE reproduces the issue:

pip install jaxtyping jax 'ray[default]'
import jax
import ray
from jax import numpy as jnp

from jaxtyping import Int


ray.init()


@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
    return x * 2


a = ray.put(jnp.arange(10))
ray.get(f.remote(a))

I tried implementing __setstate__ and __getstate__ as follows:

# jaxtyping/_array_types.py

@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
    class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
        # ...

        def __getstate__(cls):
            return cls._get_props()

        def __setstate__(cls, props):
            (
                cls.index_variadic,
                cls.dims,
                cls.array_type,
                cls.dtypes,
                cls.dim_str,
            ) = props
        
        # ...

But as best I can tell, neither one gets called at all.

It looks like ray.cloudpickle.cloudpickle. internally synthesises a class via types.new_class:

https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L536

and then immediately tries to hash it:

https://github.com/ray-project/ray/blob/200c54859dc87f02f7b40e003917b53e68356a60/python/ray/cloudpickle/cloudpickle.py#L124

which fails, as this class does not yet have our attributes set.

ray's approach seems a bit dodgy due to exactly the kind of failure we're seeing here! Anyway, I've worked around this in #237 by just always hashing to zero.

Thank you for the MWE, that was invaluable to figure this one out! :)

You guys, gals, and nonbinary pals rock!!

Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0:

ERROR:absl:Error occurred in child process with worker_index: 7
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 176, in _worker_loop
element_producer = _get_element_producer_from_queue(
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 148, in _get_element_producer_from_queue
element_producer_fn: GetElementProducerFn[Any] = cloudpickle.loads(
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 539, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 124, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/usr/local/lib/python3.10/weakref.py", line 429, in setitem
self.data[ref(key, self._remove)] = value
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 339, in hash
return hash(cls._get_props())
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 324, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Array, 'N C H W']' has no attribute 'index_variadic'
The above exception was the direct cause of the following exception:

Ah, this has already been fixed and I just haven't done a new release for it yet.

I've done a version bump + new release in #246