google-deepmind/graphcast

Floating Point Error in 1-month forecast

tvonich opened this issue · 9 comments

The plot below is the same set of inputs (06z / 12z 1 June 2021) run 10 times on the same A100 for a 28.5 day forecast. The curves are the loss at each individual timestep (not averaged over the window). This is for Graphcast Small.

Would it be fair to assume this is all from floating point error related to non-determinism with GPU usage? I am using the same jax seed everywhere to the best of my knowledge. My goal is to make the inference as reproducible as possible for as long as possible, but speed also matters.

UPDATE: I can achieve deterministic runs with CPU, but this is obviously slow. Speed also matters for my purposes. Does anyone have any suggestions on how to achieve deterministic runs on an A100?

Ive tried:

  • Disabling JIT
  • jax_enable_x64 in case the normalization was contributing
  • CPU only (successful but slow)

image (2)

FYI, I think setting XLA_FLAGS=--xla_gpu_deterministic_reductions will help reduce the divergence. Similalry, this flag might help too: XLA_FLAGS='--xla_gpu_deterministic_ops=true'

IIRC, these flags provide more reproducibility, but they will come with a performance penalty.

FYI, I think setting XLA_FLAGS=--xla_gpu_deterministic_reductions will help reduce the divergence. Similalry, this flag might help too: XLA_FLAGS='--xla_gpu_deterministic_ops=true'

IIRC, these flags provide more reproducibility, but they will come with a performance penalty.

Thank you! I'll give this a try.

xla_gpu_deterministic_ops=true' worked... The only problem is my forecast went from taking about 15 seconds to 2.5 hours each. CPU actually ends up being faster! If you find any other tricks, please share. Thank you for responding.
image

Thanks for sharing all this, if you really need deterministic behavior, one option is to switch to TPU which is deterministic by default.

Now, note that even with xla_gpu_deterministic_ops=True or on TPU, the output will still not be completely deterministic, if you e.g. refactor the code such that the computation graph changes a bit and the XLA compilation of the program changes a bit, because floating point operations always lead to small numerical differences if performed in slightly different orders.

However, I believe what you see here is evidence that the behavior of the learned neural network is a bit chaotic (e.g. very sensitive to differences in inputs), which is not surprising, because weather itself is chaotic, e.g. small changes in the initial state, can have a big change in the final state and this increases with time. And of course this can lead to a single trajectory having higher or smaller error at long horizons when computing the error against a single ground truth.

On the other hand though, is you were to repeat your plot, not just based on one initialization, but averaging many initializations, you should see that the error as function of leadtime becomes a lot more stable.

So in a way, rather than just trying to hide away the instability/chaos by trying to get the computation to be deterministic, you may want to actually "own" the non determinism, and average across multiple init times, or even treat the problem as a probabilistic prediction, looking, e.g. at ensemble mean RMSE. If you go down this route, you may be interested in the new GenCast paper!

Thanks for the thorough response! The context involves initial condition sensitivity analysis at long timescales. We have been able to link them out to ~15 days forecasts, but have failed thereafter. As the loss variability grows at longer time scales, we think the gradient sensitivity breaks because there are multiple outputs for the same input. We are preparing to submit a paper to GRL on this and have been trying to see if we can break 15 days.

The ensemble mean GPU approach was going to be my next step, but the TPU framework could be a real unlock since it is deterministic! Does Google offer any academic TPU usage grants on Colab or is purchasing hours the only way to get TPU access?

Hi, I would like to know how you calculate the loss of a single time step? I also want to get a beautiful time step loss visualization like yours. Is there any relevant code implementation that I can refer to? Thank you very much.

ok, thank you very much. @tvonich

Hi @tvonich, I am struggling to implement an autoregressive prediction model using graphcast, how were you able to do it?