jax-ml/jax

Clarify: Larger-than-memory arrays

Closed this issue · 2 comments

Hi,

One thing I'm interested in is larger-than-memory DRL, to store weight matrices and things on disk for larger scale projects in genomics and structural bioinformatics. Can JAX XLA compile memmap arrays like dask.array.dot to compute dot products of absurdly huge arrays? Numba does something similar but it would be fun to do this sort of thing with Jax. Apologies if already supported.

Example, in meta-learning, we might want a function which produces a new weight matrix, and with a naive example of 1-layer dense feedback/feedforward meta-control this can be on the order of
((d_state, d_action, d_reward), ((d_state, d_action, d_reward), d_action)), which can be doable on smaller RL projects but not necessarily the case for larger MIMO RL envs with images included in state and action vectors. GPU memory is expensive so it's nice to be able to memmap these things.

(Yes I know about sampling, convolution, and Huffman coding -- just wanted to play with huge dense meta stuffs on XLA)

Thanks,

Bion @ bitpharma.com

Cool ideas!

We don't have anything like that in JAX or XLA, but would love to see it built. JAX solves a specific set of problems (e.g. JIT compilation of Python+NumPy code with end-to-end array-level optimization, accelerator execution, autodiff, autobatching), and hopefully it can be used as a tool by other libraries that might want to pair the things that JAX does with other capabilities (like supporting absurdly large arrays). But JAX isn't a monolith that solves everything, and disk-backed arrays are out of scope.

NumPy itself is in an analogous position: it has a specific scope, yet pretty much every numerical computing project (Dask, TensorFlow, etc) can use it to build things outside of that scope.

If you have ideas for how to interface JAX with Dask, or how to build a library that supports super large arrays while using JAX for autodiff and/or JIT compilation to accelerators, we'll do what we can on the JAX side to support that use case!

(Please re-open the issue if I failed to answer your question :) )

Thank you