[BUG] Enable multiprocessing queues with mx.array
Opened this issue · 1 comments
This maybe a feature request or a bug.. not sure:
import mlx.core as mx
import multiprocessing as mp
queue = mp.Queue(5)
def worker(queue):
queue.put(mx.array(1.0))
if __name__ == "__main__":
process = mp.Process(target=worker, args=(queue,))
process.start()
print(queue.get())
process.join()This doesn't really work with CUDA or Metal backends. For CUDA it just hangs. For METAL it runs and prints the array but then exits uncleanly:
UserWarning: resource_tracker: There appear to be 3 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
CUDA
I believe the primary issue is that Python's multiprocessing defaults to fork (for Python < 3.14 on Linux), and the CUDA RT doesn't play well with that (eg: see PyTorch's note on this, lore).
Switching to spawn works (tested with mlx-cuda==0.28.0):
import mlx.core as mx
import multiprocessing as mp
def worker(queue):
queue.put(mx.array(1.0))
if __name__ == "__main__":
ctx = mp.get_context("spawn")
queue = ctx.Queue(5)
process = ctx.Process(target=worker, args=(queue,))
process.start()
print(queue.get())
process.join()Besides the fork/spawn issue, there's actually another potential issue here that manifests exactly the same (where the process appears hung). That's due to the implicit dependency on numpy (which is, I believe, only marked as a dev-dependency for MLX). Pickling mx.array (implicitly in the queue.put call) transforms to an ndarray (via __getstate__ -> mlx_to_np_array -> nanobind::detail::ndarray_export()) which can run into a silent ImportError in the worker and hang.
Metal
fwiw, on macOS + metal, I don't hit the resource tracker warnings. Those are pretty common in the PyTorch + mp world where folks have largely been ignoring `em.