/alpa

Auto parallelization for large-scale neural networks

Primary LanguagePythonApache License 2.0Apache-2.0

Alpa

Documentation | Slack

CI Build Jaxlib

Alpa is a system for training and serving large-scale neural networks.

Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques. Alpa aims to automate large-scale distributed training and serving with just a few lines of code.

The key features of Alpa include:

💻 Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.

🚀 Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.

✨ Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray.

👉 Try Alpa-served OPT-175B!

Alpa provides a free, unlimited OPT-175B text generation service. Try our service at https://opt.alpa.ai/ and share your prompting results!

Join Alpa slack and let us know any new features you want!

Quick Start

Use Alpa's decorator @parallelize to scale your single-device training code to distributed clusters.

import alpa

# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
    def loss_func(params):
        out = model_state.forward(params, batch["x"])
        return jnp.mean((out - batch["y"]) ** 2)

    grads = grad(loss_func)(model_state.params)
    new_model_state = model_state.apply_gradient(grads)
    return new_model_state

# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
    model_state = train_step(model_state, batch)

Check out the Alpa Documentation site for installation instructions, tutorials, examples, and more.

Learning more

Getting Involved

License

Alpa is licensed under the Apache-2.0 license.