ml-explore/mlx

[BUG] Enable multiprocessing queues with mx.array

Opened this issue · 1 comments

awni commented

This maybe a feature request or a bug.. not sure:

import mlx.core as mx
import multiprocessing as mp

queue = mp.Queue(5)

def worker(queue):
    queue.put(mx.array(1.0))

if __name__ == "__main__":
    process = mp.Process(target=worker, args=(queue,))
    process.start()
    print(queue.get())
    process.join()

This doesn't really work with CUDA or Metal backends. For CUDA it just hangs. For METAL it runs and prints the array but then exits uncleanly:

UserWarning: resource_tracker: There appear to be 3 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

CUDA

I believe the primary issue is that Python's multiprocessing defaults to fork (for Python < 3.14 on Linux), and the CUDA RT doesn't play well with that (eg: see PyTorch's note on this, lore).

Switching to spawn works (tested with mlx-cuda==0.28.0):

import mlx.core as mx
import multiprocessing as mp

def worker(queue):
    queue.put(mx.array(1.0))

if __name__ == "__main__":
    ctx = mp.get_context("spawn")
    queue = ctx.Queue(5)
    process = ctx.Process(target=worker, args=(queue,))
    process.start()
    print(queue.get())
    process.join()

Besides the fork/spawn issue, there's actually another potential issue here that manifests exactly the same (where the process appears hung). That's due to the implicit dependency on numpy (which is, I believe, only marked as a dev-dependency for MLX). Pickling mx.array (implicitly in the queue.put call) transforms to an ndarray (via __getstate__ -> mlx_to_np_array -> nanobind::detail::ndarray_export()) which can run into a silent ImportError in the worker and hang.

Metal

fwiw, on macOS + metal, I don't hit the resource tracker warnings. Those are pretty common in the PyTorch + mp world where folks have largely been ignoring `em.