Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
jaanli opened this issue · 9 comments
I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:
joblib.externals.loky.process_executor._RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
self.data[ref(key, self._remove)] = value
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
return hash(cls._get_props())
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
main(cfg)
File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
trainer.evaluate(
File "/home/ray/trainer.py", line 427, in evaluate
predictions, objective = model_and_objective(batch)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
predictions = Parallel(n_jobs=self.n_jobs)(
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
return output if self.return_generator else list(output)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
yield from self._retrieve()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
self._raise_error_fast()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
error_job.get_result(self.timeout)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
return self._return_or_raise()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
joblib.externals.loky.process_executor._RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
self.data[ref(key, self._remove)] = value
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
return hash(cls._get_props())
File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
main(cfg)
File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
trainer.evaluate(
File "/home/ray/trainer.py", line 427, in evaluate
predictions, objective = model_and_objective(batch)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
predictions = Parallel(n_jobs=self.n_jobs)(
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
return output if self.return_generator else list(output)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
yield from self._retrieve()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
self._raise_error_fast()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
error_job.get_result(self.timeout)
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
return self._return_or_raise()
File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
Looks like they're not getting de/serialised correctly, so the index_variadic
attribute doesn't make it across.
If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__
and __getstate__
?)
I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.
File "/root/optimizer.py", line 89, in run_train_on_modal optimizer_state = optax.tree_utils.tree_set(optimizer_state, inner_state=cloudpickle.load(f)) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^ File "/usr/lib/python3.11/weakref.py", line 428, in __setitem__ self.data[ref(key, self._remove)] = value ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 321, in __hash__ return hash(cls._get_props()) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 306, in _get_props cls.index_variadic, ^^^^^^^^^^^^^^^^^^ AttributeError: type object 'Float[Array, '*shape']' has no attribute 'index_variadic'
Do you have a MWE?
Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state.
`lr_scheduler = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=learning_rate,
warmup_steps=warmup_iters if init_from == 'scratch' else 0,
decay_steps=lr_decay_iters - iter_num,
end_value=min_lr,
)
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)
optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))
checkpoint_params = {
"optimizer_state": optimizer_state
}
with open(checkpoint_params_file, "wb") as f:
cloudpickle.dump(checkpoint_params, f)`
Looks like they're not getting de/serialised correctly, so the
index_variadic
attribute doesn't make it across.If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement
__setstate__
and__getstate__
?)
I am also encountering this issue, but only with ray
(just using cloudpickle
on its own seems to work now). This MWE reproduces the issue:
pip install jaxtyping jax 'ray[default]'
import jax
import ray
from jax import numpy as jnp
from jaxtyping import Int
ray.init()
@ray.remote(max_retries=0)
def f(x: Int[jax.Array, "one two"]):
return x * 2
a = ray.put(jnp.arange(10))
ray.get(f.remote(a))
I tried implementing __setstate__
and __getstate__
as follows:
# jaxtyping/_array_types.py
@ft.lru_cache(maxsize=None)
def _make_metaclass(base_metaclass):
class MetaAbstractArray(_MetaAbstractArray, base_metaclass):
# ...
def __getstate__(cls):
return cls._get_props()
def __setstate__(cls, props):
(
cls.index_variadic,
cls.dims,
cls.array_type,
cls.dtypes,
cls.dim_str,
) = props
# ...
But as best I can tell, neither one gets called at all.
It looks like ray.cloudpickle.cloudpickle.
internally synthesises a class via types.new_class
:
and then immediately tries to hash it:
which fails, as this class does not yet have our attributes set.
ray
's approach seems a bit dodgy due to exactly the kind of failure we're seeing here! Anyway, I've worked around this in #237 by just always hashing to zero.
Thank you for the MWE, that was invaluable to figure this one out! :)
You guys, gals, and nonbinary pals rock!!
Seem to have a related issue with Grain dataloader, which involve also cloudpickle and index_variadic. This error only happens when I set worker_count > 0:
ERROR:absl:Error occurred in child process with worker_index: 7
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 176, in _worker_loop
element_producer = _get_element_producer_from_queue(
File "/usr/local/lib/python3.10/site-packages/grain/_src/python/grain_pool.py", line 148, in _get_element_producer_from_queue
element_producer_fn: GetElementProducerFn[Any] = cloudpickle.loads(
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 539, in _make_skeleton_class
return _lookup_class_or_track(class_tracker_id, skeleton_class)
File "/usr/local/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 124, in _lookup_class_or_track
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
File "/usr/local/lib/python3.10/weakref.py", line 429, in setitem
self.data[ref(key, self._remove)] = value
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 339, in hash
return hash(cls._get_props())
File "/usr/local/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 324, in _get_props
cls.index_variadic,
AttributeError: type object 'Float[Array, 'N C H W']' has no attribute 'index_variadic'
The above exception was the direct cause of the following exception:
Ah, this has already been fixed and I just haven't done a new release for it yet.
I've done a version bump + new release in #246