/swarm-jax

Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes

Primary LanguagePython

Pipelined Swarm Training

Swarm training "framework" using Haiku + Jax + Ray.

Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually)

Look in swarm_run.py for an example of running a character transformer on enwik8.

TODOs

  • Forward passes
  • Backward passes with activation reconstruction
  • Run optimizer
  • Logging
  • Checkpointing
  • Actually do pipelining
  • fp16 with static loss scaling
  • Integer quantization for activations and gradients between layers
  • Get rid of pipeline stalls from running optimizer
  • Data parallelism with multiple nodes per layer and gradient/weight aggregation
  • Heterogeneous nodes with potentially multiple layers per node
  • Handle unbalanced and unreliable nodes (layerdrop)
  • Dynamic node addition
  • 1T or bust?