Experiments in multi-architecture parallelism for deep learning with JAX.
What if we could create a new kind of multi-architecture parallelism library for deep learning compilers, supporting expressive frontends like JAX? This would optimize a mix of pipeline and operator parallelism on accelerated devices. Use both CPU, GPU, and/or TPU in the same program, and automatically interleave between them.
Experiments are given in this repository, dated and annotated with brief descriptions.
All code and notebooks in this repository are distributed under the terms of the MIT license.