AI-Hypercomputer/maxtext

PGLE doesn't work for Tensor Parallelism

Opened this issue · 3 comments

We observed good overlap with FSDP + PGLE:
Bq7PCuqyJbygSuL. Turning on and off PGLE makes a big difference here.

However, with TP + PGLE:
7nGeZQwG5Un84P3

There is no performance improvements. Computation and communications are completely exposed.

Here is the command:
switch to lance-405b-clean branch

python3
MaxText/train.py MaxText/configs/models/gpu/llama3.1_405b.yml hardware=gpu
run_name=maxtext-llama3.1-405b steps=10 max_target_length=4096 model_name=llama3.1-405b
enable_checkpointing=false attention=cudnn_flash_te dataset_type=synthetic
async_checkpointing=false base_output_directory=gs://lancewang-dev-supercomputer-testing/maxtext_gpu
logits_dot_in_fp32=false use_iota_embed=true ici_tensor_parallelism=8 dcn_fsdp_parallelism=32
dcn_pipeline_parallelism=1 per_device_batch_size=1 num_layers_per_pipeline_stage=16 weight_dtype=bfloat16
remat_policy=save_qkv_proj profiler=xplane skip_first_n_steps_for_profiler=5
base_num_decoder_layers=126

Here are the xla flags:
--xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
--xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=536870912
--xla_gpu_all_gather_combine_threshold_bytes=536870912 --xla_gpu_reduce_scatter_combine_threshold_bytes=536870912
--xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true
--xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true
--xla_disable_hlo_passes=rematerialization --xla_gpu_enable_pgle_accuracy_checker=false
--xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false
--xla_gpu_enable_reduce_scatter_combine_by_dim=false

Here are the env variable:
NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto
NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS=600000
JAX_ENABLE_PGLE=true
JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY=true
JAX_DEBUG_LOG_MODULES=compiler

The image we built on Oct 22nd.

@Tixxx do you know what the issue is? I'm trying to reproduce this issue myself still.

I cannot access the screenshot above, it says page not found. Just a preliminary guess, the combiner threshold might introduce more data dependencies, so we usually tune it down if the collective is a combined one with a lot of data dependencies.

I have tried reproing using your command on maxtext main, but the yaml file doesnt exist for me. Would you be able to share a smaller model that can be easily repro'd on a single node? Thanks