google-deepmind/graphcast

GenCast Ensemble GPU Parallelization

jacob-t-radford opened this issue · 3 comments

I have a few different cloud GPU partitions I'd like to use to run a GenCast ensemble. I was wondering if running with num_ensemble_members set to 4 on one GPU and running with num_ensemble members set to 1 on four different GPUs (and then combining at the end) is equivalent? If not, do you have any suggestions on how I could run a small GenCast ensemble with a controller and multiple GPU partitions? Sorry if this question is a bit unclear, I'm just trying to figure out if we can feasibly run a GenCast ensemble in real-time using our existing resources.

Yeah sounds interesting. Running a GenCast ensemble across multiple GPUs by splitting num_ensemble_members should work theoretically, but combining outputs accurately might depend on model consistency. To enable efficient parallelization, consider using a controller to manage GPU assignments dynamically. Clarifying synchronization methods for ensemble outputs could also help. Would be great to hear thoughts on potential bottlenecks!

This should definitely be possible with some care around rngs!

As per the demo notebook

rng = jax.random.PRNGKey(0)
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs, regardless of
# total ensemble size.
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

I.e. using this code as it stands: 4 runs x 1 ensemble member != 1 run x 4 ensemble members because in the former they will all have the same rng.

Otherwise, if you ensure each GPU partition gets the correct rng(s) then this is indeed equivalent.

Hope this helps!

Andrew

Awesome, thank you for the quick response and clarification!