/torchft

Prototype repo for PyTorch fault tolerance

Primary LanguagePythonOtherNOASSERTION

⚠️ NOTE: this has been moved to https://github.com/pytorch-labs/torchft ⚠️

torch-ft

Prototype repo for PyTorch fault tolerance

This implements a lighthouse server that coordinates across the different replica groups and then a per replica group manager and fault tolerance library that can be used in a standard PyTorch training loop.

This allows for membership changes at the training step granularity which can greatly improve efficiency by avoiding stop the world training on errors.

Installation

$ pip install .

This uses pyo3+maturin to build the package, you'll need maturin installed.

To install in editable mode w/ the Rust extensions you can use the normal pip install command:

$ pip install -e .

Lighthouse

The lighthouse is used for fault tolerance across replicated workers (DDP/FSDP) when using synchronous training.

You can start a lighthouse server by running:

$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000

Example Training Loop (DDP)

See train_ddp.py for the full example.

Invoke with:

$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py

train.py:

from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo

manager = Manager(
    pg=ProcessGroupGloo(), 
    load_state_dict=...,
    state_dict=...,
)

m = nn.Linear(2, 3)
m = DistributedDataParallel(manager, m)
optimizer = Optimizer(manager, optim.AdamW(m.parameters()))

for i in range(1000):
    batch = torch.rand(2, 2, device=device)

    optimizer.zero_grad()

    out = m(batch)
    loss = out.sum()

    loss.backward()

    optimizer.step()

Example Parameter Server

torchft has a fault tolerant parameter server implementation built on it's reconfigurable ProcessGroups. This does not require/use a Lighthouse server.

See parameter_server_test.py for an example.

Running Tests / Lint

$ cargo fmt
% cargo test

License

BSD 3-Clause -- see LICENSE for more details.

Copyright (c) Tristan Rice 2024