Gencast - num_steps_per_chunk > 1 breaks rollout.chunked_prediction_generator_multiple_runs
Opened this issue · 7 comments
Running examples with num_steps_per_chunk = 2 results in the following error in example notebooks with the 1p0 model:
ValueError: 'grid2mesh_gnn/~_networks_builder/encoder_nodes_grid_nodes_mlp/~/linear_0/w' with retrieved shape (267, 512) does not match shape=[355, 512] dtype=dtype('float32')
Is the step chunking working?
Thanks!
Hi! Could you clarify what is it that you are trying to achieve setting that parameter to 2, so we can advice appropriately?
You should be able to get it to run by adding the wrapper in autoregressive.py wrapper in the construct_wrapped_gencast() function, but it may go out of device memory if you do that.
I am trying to run 50-member ensembles for the 0.25 resolution, and encountered the following:
- Currently, running on TPU v5p-8 as advised in the docs. It takes ~30 minutes to get 8-member 30 step forecast, also consistent with the docs. But, when I try to run for another 30 steps with 8-members, it takes still 30 minutes not 8 minutes (as in paper and in docs).
- When I try running 30-step 0.25 resolution with 50 members instead of 8 members it takes forever, and after 1h+ the process seems to be killed for some reason.
To resolve these and speed up inference, my understanding was this parameter could help. Would really appreciate if you have any other suggestions!
Thanks for explaining.
I don't think that argument will help you with those.
With respect to 1. could you confirm that you are working past this commit.
With respect to 2. I suspect what is happening here is that you are running out of host memory, when you generate a large number of ensembles you probably want to write the chunks to disk as they get generated rather than appending go the list (of course there will be associated time cost with writing to disk, so you may want to set it up to write it asynchronously, or write a subset of the variables only).
Could you confirm what number you get when you print len(jax.local_devices)
?
Thanks!
(1) I think that should be the case, I was running this as in notebooks:
%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip
(2) I see, thanks for the heads up I'll try to manage the memory more efficiently.
(3) len(jax.local_devices)
outputs 4
. Shouldn't that be 8 for v5p-8?
Thanks again!
Hey!
Sure, but which version of the notebook are you using? Could you confirm it was the one past this commit? Note the change in that commit to separate the line that pmap
s the run_forward
method in the notebooks.
Regarding the number of devices, that's indeed bizarre. Did you mention you were following these instructions? If so, can you confirm how you requested the TPU VM? But indeed, running 8 samples when you have 4 devices is going to double the inference speed because they will be produced sequentially in two batches of 4.
In the meantime, you should be able to reproduce the inference speed by generating just 4 samples (maximising parallelism in the number of devices).
Andrew
Hey Andrew, you were right I had the pmap
as in previous version, fixed it, and will test again, thanks a lot!
Also for running the 1deg version with different ERA5 conditions, do you use a known regridder or is it a custom one that goes from 0.25 degree to 1 degree –if so would it be possible for you to share the script that generated the 1deg ERA5 datasets in dm_graphcast/gencast/dataset/
? 🙏
The 1 deg data is simply the 0.25 deg data subsampling it 1 every 4 points along each of the spatial axes. We do it like this so the distribution of the data does not change and we can more easily compare models across resolutions. We follow this approach because we usually use the 1 dev models just as a baseline for the 0.25 deg models, but for other use cases of 1 deg models it may better to train on data subsampled in a different way.