
Got "scalar prefetch Triton" error with gencast_demo inference in local GPU device

Closed this issue · 6 comments

I tried to adjust the "gencast_demo_cloud_vm.ipynb" on my local H100 device, but got error in the rollout.chunked_prediction_generator_multiple_runs in the inference step.

Is it because the inference&training is only available on TPU? If so, it there any plan for a GPU or CPU version? or just because of the triton/pallas not compatible with jax/jaxlib yet?

I noticed that the GPU memory is already occupied when doing the inference.

jax and triton info in conda

jax 0.4.35 pypi_0 pypi
jax-cuda12-pjrt 0.4.35 pypi_0 pypi
jax-cuda12-plugin 0.4.35 pypi_0 pypi
jax-triton 0.2.0 pypi_0 pypi
jaxlib 0.4.34 pypi_0 pypi

jax-triton 0.2.0 pypi_0 pypi
triton 3.1.0 pypi_0 pypi

Error info

JaxStackTraceBeforeTransformation: NotImplementedError: scalar prefetch not implemented in the Triton backend

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[12], line 16
     12 rngs = np.stack(
     13     [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)
     15 chunks = []
---> 16 for chunk in rollout.chunked_prediction_generator_multiple_runs(
     17     xarray_jax.pmap(run_forward_jitted, dim="sample"),
     18     rngs=rngs,
     19     inputs=eval_inputs,
     20     targets_template=eval_targets * np.nan,
     21     forcings=eval_forcings,
     22     num_steps_per_chunk = 1,
     23     num_samples = num_ensemble_members,
     24     pmap_devices=jax.local_devices()
     25     ):
     26     chunks.append(chunk)
     27 predictions = xarray.combine_by_coords(chunks)

File /data3/john/gencast/src/graphcast-main/graphcast/, in chunked_prediction_generator_multiple_runs(predictor_fn, rngs, inputs, targets_template, forcings, num_samples, pmap_devices, **chunked_prediction_kwargs)
    160 else:
    161   sample_forcings = None
--> 163 for prediction_chunk in chunked_prediction_generator(
    164     predictor_fn=predictor_fn_pmap_named_args,
    165     rng=sample_group_rngs,
    166     inputs=sample_inputs,
    167     targets_template=targets_template,
    168     forcings=sample_forcings,
    169     pmap_devices=pmap_devices,
    170     **chunked_prediction_kwargs,
    171 ):
    172   prediction_chunk.coords["sample"] = np.arange(
    173       sample_idx.start, sample_idx.stop, sample_idx.step
    174   )
    175   yield prediction_chunk

File /data3/john/gencast/src/graphcast-main/graphcast/, in chunked_prediction_generator(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose, pmap_devices)
    343 # Make predictions for the chunk.
    344 rng, this_rng = split_rng_fn(rng)
--> 345 predictions = predictor_fn(
    346     rng=this_rng,
    347     inputs=current_inputs,
    348     targets_template=current_targets_template,
    349     forcings=current_forcings)
    351 # In the pmapped case, profiling reveals that the predictions, forcings and
    352 # inputs are all copied onto a single TPU, causing OOM. To avoid this
    353 # we pull all of the input/output data off the devices. This will have
    354 # some performance impact, but maximise the memory efficiency.
    355 # TODO(aelkadi): Pmap `_get_next_inputs` when running under pmap, and
    356 # remove the device_get.
    357 if pmap_devices is not None:

File /data3/john/gencast/src/graphcast-main/graphcast/, in chunked_prediction_generator_multiple_runs.<locals>.predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings)
    114 def predictor_fn_pmap_named_args(rng, inputs, targets_template, forcings):
    115   targets_template = _replicate_dataset(
    116       targets_template,
    117       replica_dim="sample",
    118       replicate_to_device=True,
    119       devices=pmap_devices,
    120   )
--> 121   return predictor_fn(rng, inputs, targets_template, forcings)

File /data3/john/gencast/src/graphcast-main/graphcast/, in pmap.<locals>.result_fn(*args)
    600 nonlocal input_treedef
    601 flat_args, input_treedef = jax.tree_util.tree_flatten(args)
--> 602 flat_result = pmapped_fn(*flat_args)
    603 assert output_treedef is not None
    604 # After the pmap an extra leading axis will be present, we need to add an
    605 # xarray dimension for this when unflattening the result:

    [... skipping hidden 36 frame]

File ~/anaconda3/envs/gencast/lib/python3.10/site-packages/jax/_src/pallas/, in _pallas_call_lowering(ctx, interpret, backend, *in_nodes, **params)
   1503     raise _unsupported_lowering_error("gpu")
   1505   return pallas_call_registration.pallas_call_lowering(
   1506       ctx, *in_nodes, **params
   1507   )
-> 1509 return mlir.lower_per_platform(ctx, "pallas_call",
   1510                                dict(cpu=cpu_lowering,
   1511                                     tpu=tpu_lowering,
   1512                                     cuda=gpu_lowering,
   1513                                     rocm=gpu_lowering),
   1514                                None,  # default_rule
   1515                                effects.no_effects,
   1516                                *in_nodes,
   1517                                interpret=interpret,
   1518                                **params)

    [... skipping hidden 1 frame]

File ~/anaconda3/envs/gencast/lib/python3.10/site-packages/jax/_src/pallas/, in _pallas_call_lowering.<locals>.gpu_lowering(ctx, *in_nodes, **params)
   1502 except ImportError as e:
   1503   raise _unsupported_lowering_error("gpu")
-> 1505 return pallas_call_registration.pallas_call_lowering(
   1506     ctx, *in_nodes, **params
   1507 )

File ~/anaconda3/envs/gencast/lib/python3.10/site-packages/jax/_src/pallas/triton/, in pallas_call_lowering(***failed resolving arguments***)
     56   raise NotImplementedError(
     57       "dynamic grid bounds not supported in the Triton backend"
     58   )
     59 if grid_mapping.num_index_operands:
---> 60   raise NotImplementedError(
     61       "scalar prefetch not implemented in the Triton backend"
     62   )
     63 triton_params = compiler_params.get("triton", compiler_params)
     64 num_warps = triton_params.pop("num_warps", 4)

NotImplementedError: scalar prefetch not implemented in the Triton backend

Hi, thanks for your message.

Could you confirm if you made the config changes required for GPU as described here?



Also trying to run on H100, I am getting the exact same error. I did make the required config changes, although until the point of this error, any change to SparseTransformerConfig does not seem to have an effect.

Thanks for you help!

Same issue here when running on a L40. Similarly, the required config changes does not seem to have an effect.


Sorry I probably could have been a bit clearer with the documentation here.

In the case that you are running a random model (i.e. you are in the gencat_mini_demo notebook and have chosen Random source), I believe this should work as intended. Please let us know if it doesn't.

In the case that you are running a checkpoint model (in either demo notebook), could you confirm you are indeed changing the attention mechanism of the loaded config (as opposed to e.g. creating a different config that is then not actually passed to the model initialisation)?

To be more concrete, you'll want to do something like

  with ... as f:
     ckpt = checkpoint.load(f, gencast.CheckPoint)
  params = ckpt.params
  state = {}

  task_config = ckpt.task_config
  sampler_config = ckpt.sampler_config
  noise_config = ckpt.noise_config
  noise_encoder_config = ckpt.noise_encoder_config

  # Replace attention mechanism.
  splash_spt_cfg = ckpt.denoiser_architecture_config.sparse_transformer_config
  tbd_spt_cfg = dataclasses.replace(splash_spt_cfg, attention_type="triblockdiag_mha", mask_type="full")
  denoiser_architecture_config = dataclasses.replace(ckpt.denoiser_architecture_config, sparse_transformer_config=tbd_spt_cfg)

This is something that could probably be simplified by unfreezing the config class on our end.

Let me know if this still doesn't work.



It works for me after replacing the config after ckpt loading :)

Thanks a lot!


Sorry I probably could have been a bit clearer with the documentation here.

In the case that you are running a random model (i.e. you are in the gencat_mini_demo notebook and have chosen Random source), I believe this should work as intended. Please let us know if it doesn't.

In the case that you are running a checkpoint model (in either demo notebook), could you confirm you are indeed changing the attention mechanism of the loaded config (as opposed to e.g. creating a different config that is then not actually passed to the model initialisation)?

To be more concrete, you'll want to do something like

  with ... as f:
     ckpt = checkpoint.load(f, gencast.CheckPoint)
  params = ckpt.params
  state = {}

  task_config = ckpt.task_config
  sampler_config = ckpt.sampler_config
  noise_config = ckpt.noise_config
  noise_encoder_config = ckpt.noise_encoder_config

  # Replace attention mechanism.
  splash_spt_cfg = ckpt.denoiser_architecture_config.sparse_transformer_config
  tbd_spt_cfg = dataclasses.replace(splash_spt_cfg, attention_type="triblockdiag_mha", mask_type="full")
  denoiser_architecture_config = dataclasses.replace(ckpt.denoiser_architecture_config, sparse_transformer_config=tbd_spt_cfg)

This is something that could probably be simplified by unfreezing the config class on our end.

Let me know if this still doesn't work.


Brilliant! Thanks for confirming.