AI-Hypercomputer/maxtext

Issues running test_llama2_7b.sh on TPU VM v3-8

korney3 opened this issue · 1 comments

Hi!
I was trying to run test_llama2_7b.sh following default instructions on TPU-VM with tpus v3-8.
I was able to succesfully run script till fine-tuning part with command
python3 MaxText/train.py MaxText/configs/base.yml load_parameters_path=gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items run_name=runner_finetuning_2024-04-01-01-24 base_output_directory=gs://MY_BUCKET_NAME dataset_path=gs://MY_BUCKET_NAME async_checkpointing=false per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 steps=10 max_target_length=1024 per_device_batch_size=1 checkpoint_period=5

I got the following traceback

2024-04-01 02:06:24.369879: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Updating keys from env and command line: ['run_name', 'model_name', 'load_parameters_path', 'async_checkpointing', 'checkpoint_period', 'base_output_directory', 'ici_tensor_parallelism', 'dataset_path', 'per_device_batch_size', 'steps', 'max_target_length']
Running Model: llama2-7b
Updating following parameters in config

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 32
base_mlp_dim: 11008
base_num_decoder_layers: 32
head_dim: 128
mlp_activations: ['silu', 'linear']
vocab_size: 32000
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1e-05
decoder_block: llama2
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block']
2024-04-01 02:06:31.698078: I external/xla/xla/pjrt/pjrt_c_api_client.cc:137] PjRtCApiClient created.
System Information: Jax Version: 0.4.25
System Information: Jaxlib Version: 0.4.25
System Information: Jax Backend: PJRT C API
TFRT TPU v3
Built on Feb 24 2024 03:12:26 (1708773146) cl/609954703
Config param adam_b1: 0.9
Config param adam_b2: 0.95
Config param adam_eps: 1e-08
Config param adam_eps_root: 0.0
Config param adam_weight_decay: 0.1
Config param async_checkpointing: False
Config param attention: autoselected
Config param autoregressive_decode_assert:
Config param base_emb_dim: 4096
Config param base_mlp_dim: 11008
Config param base_num_decoder_layers: 32
Config param base_num_kv_heads: 32
Config param base_num_query_heads: 32
Config param base_output_directory: gs://MY_BUCKET_NAME
Config param checkpoint_dir: gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/checkpoints/
Config param checkpoint_period: 5
Config param collect_stack_trace: False
Config param compile_topology:
Config param compile_topology_num_slices: -1
Config param compiled_trainstep_file:
Config param cosine_learning_rate_final_fraction: 0.1
Config param data_sharding: (('data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)
Config param data_shuffle_seed: 0
Config param dataset_name: c4/en:3.0.1
Config param dataset_path: gs://MY_BUCKET_NAME
Config param dataset_type: c4
Config param dcn_autoregressive_parallelism: 1
Config param dcn_data_parallelism: -1
Config param dcn_fsdp_parallelism: 1
Config param dcn_fsdp_transpose_parallelism: 1
Config param dcn_sequence_parallelism: 1
Config param dcn_tensor_parallelism: 1
Config param decode_sampling_nucleus_p: -1
Config param decode_sampling_strategy: greedy
Config param decode_sampling_temperature: 1.0
Config param decode_sampling_top_k: 0
Config param decoder_block: llama2
Config param dropout_rate: 0
Config param dtype: bfloat16
Config param emb_dim: 4096
Config param enable_checkpointing: True
Config param enable_data_shuffling: True
Config param enable_dropout: False
Config param enable_profiler: False
Config param enable_single_replica_ckpt_restoring: False
Config param eval_dataset_name: c4/en:3.0.1
Config param eval_interval: -1
Config param eval_per_device_batch_size: 0
Config param eval_split: validation
Config param force_unroll: False
Config param fused_mlp: False
Config param fused_qkv: False
Config param gcs_metrics: False
Config param global_batch_size_to_load: 8
Config param global_batch_size_to_train_on: 8
Config param global_parameter_scale: 1
Config param gradient_clipping_threshold: 1.0
Config param grain_worker_count: 4
Config param hardware: tpu
Config param head_dim: 128
Config param ici_autoregressive_parallelism: 1
Config param ici_data_parallelism: 1
Config param ici_fsdp_parallelism: -1
Config param ici_fsdp_transpose_parallelism: 1
Config param ici_sequence_parallelism: 1
Config param ici_tensor_parallelism: 4
Config param init_weights_seed: 0
Config param jax_cache_dir: ~/jax_cache
Config param learning_rate: 3e-05
Config param learning_rate_schedule_steps: 10
Config param load_from_prefill_dir: False
Config param load_full_state_path:
Config param load_parameters_path: gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items
Config param log_period: 100
Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('kv', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()))
Config param logits_dot_in_fp32: True
Config param logits_via_embedding: False
Config param max_corpus_chars: 10000000
Config param max_prefill_predict_length: 64
Config param max_target_length: 1024
Config param mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']
Config param metrics_dir: gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/metrics/
Config param metrics_file:
Config param mlp_activations: ['silu', 'linear']
Config param mlp_dim: 11008
Config param model_name: llama2-7b
Config param normalization_layer_epsilon: 1e-05
Config param normalize_embedding_logits: True
Config param num_decoder_layers: 32
Config param num_experts: 1
Config param num_experts_per_tok: 1
Config param num_kv_heads: 32
Config param num_query_heads: 32
Config param num_slices: 1
Config param opt_type: adamw
Config param param_scan_axis: 1
Config param per_device_batch_size: 1.0
Config param prefill_cache_dir:
Config param profiler_steps: 5
Config param prompt: I love to
Config param quantization:
Config param quantization_local_shard_count: 1
Config param quantize_kvcache: False
Config param record_internal_nn_metrics: 0
Config param remat_policy: full
Config param reuse_example_batch: 0
Config param run_name: runner_finetuning_2024-04-01-01-24
Config param save_config_to_gcs: False
Config param scan_layers: True
Config param skip_first_n_steps_for_profiler: 1
Config param stack_trace_interval_seconds: 600
Config param stack_trace_to_cloud: False
Config param steps: 10
Config param target_eval_loss: 0.0
Config param tensorboard_dir: gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/tensorboard/
Config param tokenizer_path: assets/tokenizer.llama2
Config param trainable_position_size: -1
Config param upload_all_profiler_results: False
Config param use_iota_embed: False
Config param use_untrainable_positional_embedding: False
Config param vocab_size: 32000
Config param warmup_steps_fraction: 0.1
Config param weight_dtype: float32
Creating checkpoint manager...
I0401 02:06:32.866911 140289813645312 checkpoint_manager.py:1040] Found 0 checkpoint steps in gs://MY_BUCKET_NAME/runner_finetuning_2024-04-01-01-24/checkpoints
I0401 02:06:32.867159 140289813645312 checkpoint_manager.py:484] jax.process_index=0, primary_host=0. CheckpointManager created: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7f960fb3dff0>
Checkpoint manager created!
I0401 02:06:32.867487 140289813645312 mesh_utils.py:73] Reordering mesh to physical ring order on single-tray TPU v2/v3.
Num_devices: 8, shape (1, 2, 1, 1, 4, 1)
I0401 02:06:33.556910 140289813645312 dataset_info.py:610] Load dataset info from gs://MY_BUCKET_NAME/c4/en/3.0.1
I0401 02:06:34.487275 140289813645312 dataset_info.py:702] For 'c4/en/3.0.1': fields info.[splits] differ on disk and in the code. Keeping the one from code.
I0401 02:06:34.662644 140289813645312 reader.py:261] Creating a tf.data.Dataset reading 1024 files located in folders: gs://MY_BUCKET_NAME/c4/en/3.0.1.
I0401 02:06:34.819458 140289813645312 logging_logger.py:49] Constructing tf.data.Dataset c4 for split train, from gs://MY_BUCKET_NAME/c4/en/3.0.1
I0401 02:06:35.186302 140289813645312 dataset_info.py:610] Load dataset info from gs://MY_BUCKET_NAME/c4/en/3.0.1
I0401 02:06:36.488089 140289813645312 dataset_info.py:702] For 'c4/en/3.0.1': fields info.[splits] differ on disk and in the code. Keeping the one from code.
I0401 02:06:36.609976 140289813645312 reader.py:261] Creating a tf.data.Dataset reading 8 files located in folders: gs://MY_BUCKET_NAME/c4/en/3.0.1.
I0401 02:06:36.685018 140289813645312 logging_logger.py:49] Constructing tf.data.Dataset c4 for split validation, from gs://MY_BUCKET_NAME/c4/en/3.0.1
Tokenizer path: assets/tokenizer.llama2
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
checkpoint manager exists so trying to load this run's existing checkpoint
restoring params from load_parameters_from_path='gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items'
I0401 02:06:40.339250 140289813645312 checkpointer.py:166] Restoring item from gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1711937201.278496   28101 gcs_resource.cc:109] Using default AdmissionQueue with limit 32
I0000 00:00:1711937201.282484   29547 google_auth_provider.cc:180] Running on GCE, using service account 1035920046472-compute@developer.gserviceaccount.com
W0401 02:07:25.916970 140289813645312 transform_utils.py:229] The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
I0401 02:07:25.947791 140289813645312 checkpointer.py:169] Finished restoring checkpoint from gs://MY_BUCKET_NAME/test/2024-04-01-01-24/decode-ckpt-maxtext/0/items.
number parameters: 6.738 billion
Per train step:
 Total TFLOPs: 42.23
 split as 98.05% learnable weight flops and 1.95% attention flops
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/usr/maxtext/MaxText/train.py", line 497, in <module>
    app.run(main)
  File "/home/usr/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/usr/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/usr/maxtext/MaxText/train.py", line 493, in main
    train_loop(config)
  File "/home/usr/maxtext/MaxText/train.py", line 433, in train_loop
    state, metrics = p_train_step(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.

The MLIR operation involved:
  %1481 = "tpu.matmul"(%1478, %1479, %1480) {transpose_lhs = false, transpose_rhs = true} : (vector<512x128xbf16>, vector<128x128xbf16>, vector<512x128xf32>) -> vector<512x128xf32>
... additional diagnostics were skipped.

I've also printed an example batch from code that caused the error

{'inputs': Array([[    1,  2023, 14606, ...,     0,     0,     0],
       [    1,  2567,   393, ...,     0,     0,     0],
       [    1,   390,  2965, ...,     0,     0,     0],
       ...,
       [    1,   887, 30010, ...,   367, 16010,   746],
       [    1, 12547,   393, ..., 29915, 29885, 10932],
       [    1,  1383,   279, ...,     0,     0,     0]], dtype=int32), 'inputs_position': Array([[   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       ...,
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ...,    0,    0,    0]], dtype=int32), 'inputs_segmentation': Array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32), 'targets': Array([[ 2023, 14606,   437, ...,     0,     0,     0],
       [ 2567,   393,   366, ...,     0,     0,     0],
       [  390,  2965, 15444, ...,     0,     0,     0],
       ...,
       [  887, 30010,   345, ..., 16010,   746,   372],
       [12547,   393,   385, ..., 29885, 10932,   393],
       [ 1383,   279,  9010, ...,     0,     0,     0]], dtype=int32), 'targets_position': Array([[   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       [   0,    1,    2, ...,    0,    0,    0],
       ...,
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ..., 1021, 1022, 1023],
       [   0,    1,    2, ...,    0,    0,    0]], dtype=int32), 'targets_segmentation': Array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)}

I was using default settings (except bucket names, of course)
Can you please help me, what direction to look to fix the error?

Please try attention=dot_product, I believe the flash kernel is unhappy on v3-8