/FluxMPI.jl

Distributed Data Parallel Training of Deep Neural Networks

Primary LanguageJuliaMIT LicenseMIT

FluxMPI.jl

Caution

This package should be considered deprecated and won't receive any updates. Distributed Training will become a native feature for Lux, so it makes little sense for me to maintain an additional package that does the same thing. Track LuxDL/Lux.jl#494 for furthur updates.

Stable Latest

CI codecov Package Downloads

ColPrac: Contributor's Guide on Collaborative Practices for Community Packages SciML Code Style

Distributed Data Parallel Training of Neural Networks

Installation

Stable release:

] add FluxMPI

Latest development version:

] add FluxMPI#main

Quick Start

using CUDA, FluxMPI, Lux, Optimisers, Random, Zygote

FluxMPI.Init()
CUDA.allowscalar(false)

model = Chain(Dense(1 => 256, tanh), Dense(256 => 512, tanh), Dense(512 => 256, tanh),
              Dense(256 => 1))
rng = Random.default_rng()
Random.seed!(rng, local_rank())
ps, st = Lux.setup(rng, model) .|> gpu

ps = FluxMPI.synchronize!(ps; root_rank = 0)
st = FluxMPI.synchronize!(st; root_rank = 0)

x = rand(rng, 1, 16) |> gpu
y = x .^ 2

opt = DistributedOptimizer(Adam(0.001f0))
st_opt = Optimisers.setup(opt, ps)

loss(p) = sum(abs2, model(x, p, st)[1] .- y)

st_opt = FluxMPI.synchronize!(st_opt; root_rank = 0)

gs_ = gradient(loss, ps)[1]
Optimisers.update(st_opt, ps, gs_)

t1 = time()

for epoch in 1:100
  global ps, st_opt
  l, back = Zygote.pullback(loss, ps)
  FluxMPI.fluxmpi_println("Epoch $epoch: Loss $l")
  gs = back(one(l))[1]
  st_opt, ps = Optimisers.update(st_opt, ps, gs)
end

FluxMPI.fluxmpi_println(time() - t1)

Run the code using mpiexecjl -n 3 julia --project=. <filename>.jl.

Examples

Style Guide

We follow the Lux Style Guide. All contributions must adhere to this style guide.

Changelog

v0.7

  • Dropped support for MPI v0.19.
  • FLUXMPI_DISABLE_CUDAMPI_SUPPORT is no longer used. Instead use FluxMPI.disable_cudampi_support() to setup a LocalPreferences.toml file.
  • clean_(print/println) functions are now fluxmpi_(print/println).

v0.6

  • Dropped support for LearnBase, aka DataLoaders.jl. DistributedDataContainer is now the only compatible with MLUtils.jl.
  • DistributedOptimiser name changed to DistributedOptimizer.

v0.5

v0.5.3

  • Introduces a new API for gradient synchronization
    • Don't wrap in DistributedOptimiser
    • Instead just add a line allreduce_gradients(gs::NamedTuple)

v0.5.1

  • Internal MPIExtensions functions renamed
    • Allreduce! --> allreduce!
    • Bcast! --> bcast!
    • Reduce! --> reduce!
  • CUDA-unaware MPI bug resolved LuxDL/Lux.jl#18
  • Disable CUDA-aware MPI support from FluxMPI using FLUXMPI_DISABLE_CUDAMPI_SUPPORT=true
  • Temporarily re-added dependencies on MLDataUtils and LearnBase to ensure DataLoaders.jl still works -- This will be dropped in a future release

v0.5.0

  • DistributedOptimiser no longer averages the gradients. Instead, the values are summed across the processes. To ensure averaging divide the loss by total_workers()
  • rrules and frules defined for local_rank() and total_workers -- they can now be safely used inside loss functions.

v0.4

  • fluxmpi_print and fluxmpi_println print the current time even if FluxMPI has not been initialized.
  • Calling local_rank or total_workers before FluxMPI.Init doesn't lead to a segfault. Rather we throw an error.
  • MLDataUtils and LearnBase dependencies have been dropped (See #17)
  • Zygote and Flux dependencies have been removed
    • No dispatch for FluxMPI.synchronize! is now available for Zygote.Params. Instead users should be manually broadcasting the function over Zygote.Params

v0.3

  • broadcast_parameters has been renamed to FluxMPI.synchronize! since it synchronizes a lot more than trainable parameters now.
  • DistributedOptimiser is no longer tied with Flux. We can essentially deal with any training as long as it is compatible with Optimisers.jl