pure_callback using a threadpool to accelerate vmap
Opened this issue · 3 comments
I would like to be able to write a callback that accepts accepts unbatched arguments, which would normally be executed under pure_callball(... vmap_method='sequential')
, but then perform vmap
using a threadpool with a given number of workers. Perhaps a common threadpool could be reused, rather than creating one per pure_callback
Use case
I interface JAX with scientific code, which is written in some low-level language, with python bindings which release the GIL. These ops need to be vmapped and I'd like to be able to take advantage of threading. Since the scientific code wasn't written with ufunc or broadcasting in mind the only way to achieve vectorised behaviour is sequential or using threading. Threading is the way to go because they release GIL.
- I find this pattern coming up time after time, and it's tedious to reconstruct this behaviour for each new callback. Plus I think it could be useful to reuse some common threadpool.
- It makes writing the callback logic much simpler as the implementer only needs to write the code for the thread target. Currently in the callback if
is used one needs to handle leading dimensions. This leads to more lines of code and trickier shape check logic.
Okay, so this can be done already using what is available. This uses expand_dims
(default) or broadcast_all
, and then uses a thread pool to execute the core of the callback. By default it reuses a global singleton threadpool, which can be configured with export JAX_PURE_CALLBACK_NUM_THREADS=...
. Otherwise it creates a new thread pool for each callback.
import atexit
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Any
import jax
import numpy as np
__all__ = [
class _ThreadPoolSingleton:
"""Singleton wrapper for ThreadPoolExecutor."""
_instance: ThreadPoolExecutor | None = None
# Take the same default as threading library, but set it here for clarity
_num_threads = os.environ.get('JAX_PURE_CALLBACK_NUM_THREADS', min(32, (os.cpu_count() or 1) + 4))
def get_instance(cls):
if cls._instance is None:
cls._instance = ThreadPoolExecutor(max_workers=cls._num_threads, thread_name_prefix="jax_pure_callback_")
return cls._instance
def shutdown(cls):
if cls._instance:
# wait True ensures a displaced running threadpool finishes before shutdown.
cls._instance = None
def _build_callback_from_kernel(cb_kernel: Callable, batch_shape_determiner: Callable, num_threads: int | None):
def callback(*args):
# Determine leading dims.
batch_shape = batch_shape_determiner(*args)
batch_size = int(np.prod(batch_shape))
def sliced_kernel(index):
multi_idx = np.unravel_index(index, batch_shape)
def _slice(x):
if x is None:
return x
_multi_idx = []
for idx, s in (zip(multi_idx, np.shape(x))):
if s == 1:
return x[tuple(_multi_idx)]
args_slice = jax.tree.map(_slice, args)
return cb_kernel(*args_slice)
if num_threads is not None:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
result_map = executor.map(sliced_kernel, range(batch_size))
executor = _ThreadPoolSingleton.get_instance()
result_map = executor.map(sliced_kernel, range(batch_size))
results_list = list(result_map)
# pytree stack
results = jax.tree.map(lambda *r: np.stack(r, axis=0), *results_list)
# unflatten
results = jax.tree.map(lambda x: jax.lax.reshape(x, batch_shape + np.shape(x)[1:]), results)
return results
return callback
def _build_batch_shape_determiner(*args_shape_size):
def batch_shape_determiner(*args):
if len(args) != len(args_shape_size):
raise ValueError(f'Expected {len(args_shape_size)} arguments, got {len(args)}.')
def _determine(x, shape_size):
if not isinstance(shape_size, int):
raise ValueError(f'shape_size must be an integer, got {type(shape_size)}.')
if x is None:
return None
if shape_size == 0:
return np.shape(x)
return np.shape(x)[:-shape_size]
batch_shapes = jax.tree.map(_determine, args, args_shape_size, is_leaf=lambda x: x is None)
def is_leaf(s):
# if tuple of int then it is a leaf
return isinstance(s, tuple) and all(isinstance(i, int) for i in s)
leaves = jax.tree.leaves(batch_shapes, is_leaf=is_leaf)
shapes = set(leaves)
# remove None
# broadcast
batch_shape = np.broadcast_shapes(*list(shapes))
except ValueError as e:
if "shape mismatch" in str(e):
raise ValueError(f'Inconsistent batch shapes: {shapes}')
raise e
return batch_shape
return batch_shape_determiner
def construct_threaded_pure_callback(cb_kernel: Callable, result_shape_dtypes: Any, *args_shape_size,
num_threads: int | None = None, vmap_method='expand_dims'):
Construct a pure callback with vmap using threading.
cb_kernel: a callable that takes a consistently shaped set of arguments and returns a consistently shaped
pytree of results.
result_shape_dtypes: a pytree of ShapeDtypeStruct objects representing the expected shape and dtype of the
result of cb_kernel.
*args_shape_size: the number of (unbatched) dimensions for each argument to cb_kernel.
num_threads: the number of threads to use. If None, reuses a shared global threadpool, by default using all
available CPUs, and which can be configured with environment variable `JAX_PURE_CALLBACK_NUM_THREADS`.
vmap_method: the vmap method to use. Must be one of 'expand_dims' or 'broadcast_all'. Default is 'expand_dims'.
See jax.pure_callback for more information.
A pure callback that works with vmap, using threading to parallelize the computation.
def wrapped_cb_kernel(*args):
def _check_shape(x, shape_size):
if x is None:
if len(np.shape(x)) != shape_size:
raise ValueError(
f'Expected shape of size {shape_size} but got {np.shape(x)}, sized ({len(np.shape(x))}).')
jax.tree.map(_check_shape, args, args_shape_size, is_leaf=lambda x: x is None)
return cb_kernel(*args)
batch_shape_determiner = _build_batch_shape_determiner(*args_shape_size)
cb = _build_callback_from_kernel(wrapped_cb_kernel, batch_shape_determiner, num_threads=num_threads)
def callback(*args):
return jax.pure_callback(cb, result_shape_dtypes, *args, vmap_method=vmap_method)
return callback
def add_kernel(x, y, z):
assert x.shape == ()
assert y.shape == ()
assert z.shape == ()
return x + y + z
if __name__ == '__main__':
x = jnp.ones((4,), dtype=jnp.float32)
y = jnp.ones((5,), dtype=jnp.float32)
z = jnp.ones((), dtype=jnp.float32)
cb = construct_threaded_pure_callback(
jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32),
0, 0, 0
cb_vmap = jax.vmap(jax.vmap(cb, in_axes=(None, 0, None)), in_axes=(0, None, None))
assert cb_vmap(x, y, z).shape == (4, 5)
An improvement would be to not tile, and rather use broadcast_dims
I updated the code snippet above to use expand_dims
by default so that the args do not need to be tiled.