axlearn on GPU started failing during init after upgrade
samos123 opened this issue · 7 comments
This is the error message I see when launching like this:
timeout -k 60s 900s python3 -m axlearn.common.launch_trainer_main --module=gke_fuji --config=fuji-7B-b512-fsdp8 --trainer_dir=/tmp/test_trainer --data_dir=gs://axlearn-public/tensorflow_datasets --jax_backend=gpu --num_processes=8 --distributed_coordinator=stoelinga-may13-1-j-0-0.stoelinga-may13-1 --process_id=0 --trace_at_steps=25
Error message:
2024-05-13 16:17:05.732984: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make
sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.ten
sorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 16, in <module>
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 10, in main
launch.setup()
File "/tmp/axlearn/axlearn/common/launch.py", line 92, in setup
setup_spmd(
File "/tmp/axlearn/axlearn/common/utils_spmd.py", line 118, in setup
jax.distributed.initialize(**init_kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/distributed.py", line 196, in initialize
global_state.initialize(coordinator_address, num_processes, process_id,
File "/usr/local/lib/python3.10/dist-packages/jax/_src/distributed.py", line 72, in initialize
default_coordinator_bind_address = '[::]:' + coordinator_address.rsplit(':', 1)[1]
IndexError: list index out of range
It seems there was a code change that requires setting the port. I can probably fix this on my end.
hmm setting the port on my side didn't work either:
+ timeout -k 60s 900s python3 -m axlearn.common.launch_trainer_main --module=gke_fuji --config=fuji-7B-b512-fsdp8 --trainer_dir=/tmp/test_trainer --data_dir=gs://axlearn-public/tensorflow_datasets --jax_backend=gpu --num_processes=8 --distributed_coordinator=stoelinga-may13-2-j-0-0.stoelinga-may13-2:6666 --process_id=0 --trace_at_steps=25
jax version=0.4.28
2024-05-13 16:38:13.047916: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-13 16:38:13.050211: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-13 16:38:13.078257: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-13 16:38:13.078283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-13 16:38:13.078306: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-13 16:38:13.084257: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-13 16:38:13.084434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-13 16:38:13.599922: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
2024-05-13 16:38:14.524477: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
I0513 16:38:14.525568 137754883110016 distributed.py:90] Starting JAX distributed service on stoelinga-may13-2-j-0-0.stoelinga-may13-2:6666
2024-05-13 16:38:14.525810: I external/tsl/tsl/platform/default/grpc_credentials.cc:38] gRPC insecure server credentials are used.
2024-05-13 16:38:14.526924: I external/xla/xla/pjrt/distributed/service.cc:72] Coordination service is enabled.
2024-05-13 16:38:14.527239: I external/xla/xla/pjrt/distributed/service.cc:102] Jax service listening on [::]:6666
2024-05-13 16:38:14.527257: I external/tsl/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.
I0513 16:38:14.527360 137754883110016 distributed.py:101] Connecting to JAX distributed service on stoelinga-may13-2-j-0-0.stoelinga-may13-2:6666
2024-05-13 16:38:15.515558: I external/xla/xla/pjrt/distributed/client.cc:130] Connected to distributed JAX controller
2024-05-13 16:38:15.519482: F external/xla/xla/parse_flags_from_env.cc:225] Unknown flags in XLA_FLAGS: --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_async_all_reduce=true --xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=1073741824 --xla_gpu_reduce_scatter_combine_threshold_bytes=1073741824 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization --xla_dump_hlo_as_text --xla_dump_to=/tmp/hlo_dump --xla_gpu_pgle_profile_file_or_directory_path=/workspace/pgle_profile.pb
Fatal Python error: Aborted
Current thread 0x00007d498edfb480 (most recent call first):
File "/usr/local/lib/python3.10/dist-packages/jaxlib/xla_client.py", line 75 in make_cpu_client
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 261 in make_cpu_client
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 965 in _init_backend
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 874 in backends
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 990 in _get_backend_uncached
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 1011 in get_backend
File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 1077 in devices
File "/tmp/axlearn/axlearn/common/launch.py", line 108 in setup
File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 10 in main
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254 in _run_main
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308 in run
File "/tmp/axlearn/axlearn/common/launch_trainer_main.py", line 16 in <module>
File "/usr/lib/python3.10/runpy.py", line 86 in _run_code
File "/usr/lib/python3.10/runpy.py", line 196 in _run_module_as_main
@samos123 you may need to remove some deprecated xla flags like --xla_gpu_enable_async_all_reduce=true
. https://github.com/openxla/xla/blob/d3e881ad668b7aa44283d47bb553a04b86b71315/xla/xla.proto#L394
That got me passed the issue of Fatal Python Error: aborted! Thanks a lot Mark.
The next issue:
I0513 18:16:23.374316 139735000650880 launch_trainer.py:77] Did not find config 'fuji-7B-b512-fsdp8' or module 'gke_fuji' -- will continue searching.
This is what I do in my Dockerfile:
RUN pip install --no-cache-dir --upgrade \
"jax[cuda12_pip]==${JAX_VERSION}" \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
-c constraints.txt
COPY axlearn /tmp/axlearn
RUN pip install --no-cache-dir -e /tmp/axlearn
# Install Google Cloud CLI
RUN apt-get update && apt-get install -y apt-transport-https ca-certificates gnupg curl && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - && \
echo "deb https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
apt-get update && apt-get install -y google-cloud-sdk && \
apt-get clean && rm -rf /var/lib/apt/lists/*
# RUN pip install --no-cache-dir axlearn==0.0.1.post20240213000026
RUN mkdir -p .axlearn && touch .axlearn/.axlearn.config
RUN mkdir -p /tmp/test_trainer /tmp/first_run
COPY gke_fuji.py .
COPY run-fuji.sh .
COPY convert_pgle.py .
CMD bash run-fuji.sh
inside my entrypoint script I set this:
export PYTHONPATH="${PYTHONPATH}:/workspace"
and gke_fuji.py is copied to /workspace/gke_fuji.sh
I think this is resolved after appending the version suffix to the config name -- @samos123 are we good to close this?
There were 2 issues:
- the deprecated XLA flags
- incorrect fuji 7b config name which now should use
-v1
suffix
Thanks a lot for helping troubleshoot it. Closing.
One thing to note is that we should have better error reporting when there is a code error in your custom experiment. It will show up as a config not found error with no clear error messages on what could be wrong.