Got "scalar prefetch Triton" error with gencast_demo inference in local GPU device
Closed this issue · 6 comments
I tried to adjust the "gencast_demo_cloud_vm.ipynb" on my local H100 device, but got error in the rollout.chunked_prediction_generator_multiple_runs in the inference step.
Is it because the inference&training is only available on TPU? If so, it there any plan for a GPU or CPU version? or just because of the triton/pallas not compatible with jax/jaxlib yet?
I noticed that the GPU memory is already occupied when doing the inference.
jax and triton info in conda
jax 0.4.35 pypi_0 pypi
jax-cuda12-pjrt 0.4.35 pypi_0 pypi
jax-cuda12-plugin 0.4.35 pypi_0 pypi
jax-triton 0.2.0 pypi_0 pypi
jaxlib 0.4.34 pypi_0 pypi
jax-triton 0.2.0 pypi_0 pypi
triton 3.1.0 pypi_0 pypi
Error info
JaxStackTraceBeforeTransformation: NotImplementedError: scalar prefetch not implemented in the Triton backend
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
Cell In[12], line 16
12 rngs = np.stack(
13 [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)
15 chunks = []
---> 16 for chunk in rollout.chunked_prediction_generator_multiple_runs(
17 xarray_jax.pmap(run_forward_jitted, dim="sample"),
18 rngs=rngs,
19 inputs=eval_inputs,
20 targets_template=eval_targets * np.nan,
21 forcings=eval_forcings,
22 num_steps_per_chunk = 1,
23 num_samples = num_ensemble_members,
24 pmap_devices=jax.local_devices()
25 ):
26 chunks.append(chunk)
27 predictions = xarray.combine_by_coords(chunks)
File /data3/john/gencast/src/graphcast-main/graphcast/rollout.py:163, in chunked_prediction_generator_multiple_runs(predictor_fn, rngs, inputs, targets_template, forcings, num_samples, pmap_devices, **chunked_prediction_kwargs)
160 else:
161 sample_forcings = None
--> 163 for prediction_chunk in chunked_prediction_generator(
164 predictor_fn=predictor_fn_pmap_named_args,
165 rng=sample_group_rngs,
166 inputs=sample_inputs,
167 targets_template=targets_template,
168 forcings=sample_forcings,
169 pmap_devices=pmap_devices,
170 **chunked_prediction_kwargs,
171 ):
172 prediction_chunk.coords["sample"] = np.arange(
173 sample_idx.start, sample_idx.stop, sample_idx.step
174 )
175 yield prediction_chunk
File /data3/john/gencast/src/graphcast-main/graphcast/rollout.py:345, in chunked_prediction_generator(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose, pmap_devices)
343 # Make predictions for the chunk.
344 rng, this_rng = split_rng_fn(rng)
--> 345 predictions = predictor_fn(
346 rng=this_rng,
347 inputs=current_inputs,
348 targets_template=current_targets_template,
349 forcings=current_forcings)
351 # In the pmapped case, profiling reveals that the predictions, forcings and
352 # inputs are all copied onto a single TPU, causing OOM. To avoid this
353 # we pull all of the input/output data off the devices. This will have
354 # some performance impact, but maximise the memory efficiency.
355 # TODO(aelkadi): Pmap `_get_next_inputs` when running under pmap, and
356 # remove the device_get.
357 if pmap_devices is not None:
File /data3/john/gencast/src/graphcast-main/graphcast/rollout.py:121, in chunked_prediction_generator_multiple_runs.<locals>.predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings)
114 def predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings):
115 targets_template = _replicate_dataset(
116 targets_template,
117 replica_dim="sample",
118 replicate_to_device=True,
119 devices=pmap_devices,
120 )
--> 121 return predictor_fn(rng, inputs, targets_template, forcings)
File /data3/john/gencast/src/graphcast-main/graphcast/xarray_jax.py:602, in pmap.<locals>.result_fn(*args)
600 nonlocal input_treedef
601 flat_args, input_treedef = jax.tree_util.tree_flatten(args)
--> 602 flat_result = pmapped_fn(*flat_args)
603 assert output_treedef is not None
604 # After the pmap an extra leading axis will be present, we need to add an
605 # xarray dimension for this when unflattening the result:
[... skipping hidden 36 frame]
File ~/anaconda3/envs/gencast/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:1509, in _pallas_call_lowering(ctx, interpret, backend, *in_nodes, **params)
1503 raise _unsupported_lowering_error("gpu")
1505 return pallas_call_registration.pallas_call_lowering(
1506 ctx, *in_nodes, **params
1507 )
-> 1509 return mlir.lower_per_platform(ctx, "pallas_call",
1510 dict(cpu=cpu_lowering,
1511 tpu=tpu_lowering,
1512 cuda=gpu_lowering,
1513 rocm=gpu_lowering),
1514 None, # default_rule
1515 effects.no_effects,
1516 *in_nodes,
1517 interpret=interpret,
1518 **params)
[... skipping hidden 1 frame]
File ~/anaconda3/envs/gencast/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:1505, in _pallas_call_lowering.<locals>.gpu_lowering(ctx, *in_nodes, **params)
1502 except ImportError as e:
1503 raise _unsupported_lowering_error("gpu")
-> 1505 return pallas_call_registration.pallas_call_lowering(
1506 ctx, *in_nodes, **params
1507 )
File ~/anaconda3/envs/gencast/lib/python3.10/site-packages/jax/_src/pallas/triton/pallas_call_registration.py:60, in pallas_call_lowering(***failed resolving arguments***)
56 raise NotImplementedError(
57 "dynamic grid bounds not supported in the Triton backend"
58 )
59 if grid_mapping.num_index_operands:
---> 60 raise NotImplementedError(
61 "scalar prefetch not implemented in the Triton backend"
62 )
63 triton_params = compiler_params.get("triton", compiler_params)
64 num_warps = triton_params.pop("num_warps", 4)
NotImplementedError: scalar prefetch not implemented in the Triton backend
Hi, thanks for your message.
Could you confirm if you made the config changes required for GPU as described here?
Thanks!
Hi!
Also trying to run on H100, I am getting the exact same error. I did make the required config changes, although until the point of this error, any change to SparseTransformerConfig does not seem to have an effect.
Thanks for you help!
Tobias
Hi!
Same issue here when running on a L40. Similarly, the required config changes does not seem to have an effect.
Hey!
Sorry I probably could have been a bit clearer with the documentation here.
In the case that you are running a random model (i.e. you are in the gencat_mini_demo notebook and have chosen Random source), I believe this should work as intended. Please let us know if it doesn't.
In the case that you are running a checkpoint model (in either demo notebook), could you confirm you are indeed changing the attention mechanism of the loaded config (as opposed to e.g. creating a different config that is then not actually passed to the model initialisation)?
To be more concrete, you'll want to do something like
with ... as f:
ckpt = checkpoint.load(f, gencast.CheckPoint)
params = ckpt.params
state = {}
task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
# Replace attention mechanism.
splash_spt_cfg = ckpt.denoiser_architecture_config.sparse_transformer_config
tbd_spt_cfg = dataclasses.replace(splash_spt_cfg, attention_type="triblockdiag_mha", mask_type="full")
denoiser_architecture_config = dataclasses.replace(ckpt.denoiser_architecture_config, sparse_transformer_config=tbd_spt_cfg)
This is something that could probably be simplified by unfreezing the config class on our end.
Let me know if this still doesn't work.
Thanks!
Hey!
It works for me after replacing the config after ckpt loading :)
Thanks a lot!
Hey!
Sorry I probably could have been a bit clearer with the documentation here.
In the case that you are running a random model (i.e. you are in the gencat_mini_demo notebook and have chosen Random source), I believe this should work as intended. Please let us know if it doesn't.
In the case that you are running a checkpoint model (in either demo notebook), could you confirm you are indeed changing the attention mechanism of the loaded config (as opposed to e.g. creating a different config that is then not actually passed to the model initialisation)?
To be more concrete, you'll want to do something like
with ... as f: ckpt = checkpoint.load(f, gencast.CheckPoint) params = ckpt.params state = {} task_config = ckpt.task_config sampler_config = ckpt.sampler_config noise_config = ckpt.noise_config noise_encoder_config = ckpt.noise_encoder_config # Replace attention mechanism. splash_spt_cfg = ckpt.denoiser_architecture_config.sparse_transformer_config tbd_spt_cfg = dataclasses.replace(splash_spt_cfg, attention_type="triblockdiag_mha", mask_type="full") denoiser_architecture_config = dataclasses.replace(ckpt.denoiser_architecture_config, sparse_transformer_config=tbd_spt_cfg)
This is something that could probably be simplified by unfreezing the config class on our end.
Let me know if this still doesn't work.
Thanks!
Brilliant! Thanks for confirming.