KellerJordan/modded-nanogpt

Attempted to get this running on AMD mi300x...

Closed this issue ยท 5 comments

Hi there!

I saw your tweet and for ๐Ÿ˜† 's, I tried to run your stuff on one of our 8xMI300x systems, just to see if it works. Would be also nice to eventually compare performance against the H100s.

I pulled the latest container. docker pull rocm/pytorch:latest and ran within that.

docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G -v /mnt/drive1:/mnt/drive1 rocm/pytorch:latest

python3 -m venv env --system-site-packages

source env/bin/activate
pip3 install --upgrade pip
pip3 install -r requirements.txt

 python data/cached_fineweb10B.py
./run.sh

Everything seemed to go well out of the box, I could see the GPUs being used in rocm-smi, but then it seems to have ๐Ÿ’ฉ out after sitting at the val loss output for a bit.

I'm no AI pytorch expert (yet), but if you have any clear indication of what's going wrong here, let me know what I can try any changes you suggest. Either that, or if you are willing, I can give you access to the machine itself (via ssh) and you can try playing with it.

cheers!

jon@hotaisle.xyz

 ./run.sh 
W1006 18:02:14.946367 126467783853248 torch/distributed/run.py:757] 
W1006 18:02:14.946367 126467783853248 torch/distributed/run.py:757] *****************************************
W1006 18:02:14.946367 126467783853248 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1006 18:02:14.946367 126467783853248 torch/distributed/run.py:757] *****************************************
Running pytorch 2.3.0a0+git1b935e2
using device: cuda:3
using device: cuda:1
using device: cuda:7
using device: cuda:2
using device: cuda:4
using device: cuda:6
using device: cuda:5
using device: cuda:0
Training DataLoader: total number of tokens: 10255324043 across 103 files
Validation DataLoader: total number of tokens: 100000000 across 1 files
compiling the model...
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/utils.py:1764: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  return node.target(*args, **kwargs)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py:124: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
val loss 11.003429412841797
[rank5]: Traceback (most recent call last):
[rank5]:   File "/mnt/drive1/modded-nanogpt/train_gpt2.py", line 479, in <module>
[rank5]:     loss.backward()
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank5]:     torch.autograd.backward(
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank5]:     _engine_run_backward(
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank5]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
[rank5]:     return user_fn(self, *args)
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 882, in backward
[rank5]:     out = call_compiled_backward()
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 831, in call_compiled_backward
[rank5]:     out = call_func_at_runtime_with_args(
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
[rank5]:     out = normalize_as_list(f(args))
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank5]:     return fn(*args, **kwargs)
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank5]:     return fn(*args, **kwargs)
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 906, in __call__
[rank5]:     return self.get_current_callable()(inputs)
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 784, in run
[rank5]:     return model(new_inputs)
[rank5]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank5]:     return compiled_graph.compiled_artifact(inputs)
[rank5]:   File "/tmp/torchinductor_root/vt/cvte4ixg2lf4m3me4ex3ts76azjvf7r67ywg5pfh4okaetfvxjto.py", line 518, in call
[rank5]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))
[rank5]: AssertionError: wrong number of dimensions
[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/drive1/modded-nanogpt/train_gpt2.py", line 479, in <module>
[rank1]:     loss.backward()
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 882, in backward
[rank1]:     out = call_compiled_backward()
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 831, in call_compiled_backward
[rank1]:     out = call_func_at_runtime_with_args(
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(args))
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 906, in __call__
[rank1]:     return self.get_current_callable()(inputs)
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 784, in run
[rank1]:     return model(new_inputs)
[rank1]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank1]:     return compiled_graph.compiled_artifact(inputs)
[rank1]:   File "/tmp/torchinductor_root/ir/cirkve7dpc5r7ta4hi2wbzncxl6nsu2adq4n7yo5vdsuf63bbxt4.py", line 518, in call
[rank1]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))
[rank1]: AssertionError: wrong number of dimensions
[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/drive1/modded-nanogpt/train_gpt2.py", line 479, in <module>
[rank0]:     loss.backward()
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 882, in backward
[rank0]:     out = call_compiled_backward()
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 831, in call_compiled_backward
[rank0]:     out = call_func_at_runtime_with_args(
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
[rank0]:     out = normalize_as_list(f(args))
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 906, in __call__
[rank0]:     return self.get_current_callable()(inputs)
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 784, in run
[rank0]:     return model(new_inputs)
[rank0]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank0]:     return compiled_graph.compiled_artifact(inputs)
[rank0]:   File "/tmp/torchinductor_root/kg/ckgqmxhmzvz7dlsac4l5uks2rme5x7eeghl3ckvmqlui6z2o4ylf.py", line 518, in call
[rank0]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))
[rank0]: AssertionError: wrong number of dimensions
[rank3]: Traceback (most recent call last):
[rank3]:   File "/mnt/drive1/modded-nanogpt/train_gpt2.py", line 479, in <module>
[rank3]:     loss.backward()
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank3]:     torch.autograd.backward(
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
[rank3]:     return user_fn(self, *args)
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 882, in backward
[rank3]:     out = call_compiled_backward()
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 831, in call_compiled_backward
[rank3]:     out = call_func_at_runtime_with_args(
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
[rank3]:     out = normalize_as_list(f(args))
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank3]:     return fn(*args, **kwargs)
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 906, in __call__
[rank3]:     return self.get_current_callable()(inputs)
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 784, in run
[rank3]:     return model(new_inputs)
[rank3]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank3]:     return compiled_graph.compiled_artifact(inputs)
[rank3]:   File "/tmp/torchinductor_root/ey/ceya432rp26lzw4thezihkebcocsehnlhftqjioubkkd35ngwvhh.py", line 518, in call
[rank3]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))
[rank3]: AssertionError: wrong number of dimensions
[rank4]: Traceback (most recent call last):
[rank4]:   File "/mnt/drive1/modded-nanogpt/train_gpt2.py", line 479, in <module>
[rank4]:     loss.backward()
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank4]:     torch.autograd.backward(
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank4]:     _engine_run_backward(
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank4]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
[rank4]:     return user_fn(self, *args)
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 882, in backward
[rank4]:     out = call_compiled_backward()
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 831, in call_compiled_backward
[rank4]:     out = call_func_at_runtime_with_args(
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
[rank4]:     out = normalize_as_list(f(args))
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank4]:     return fn(*args, **kwargs)
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 906, in __call__
[rank4]:     return self.get_current_callable()(inputs)
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 784, in run
[rank4]:     return model(new_inputs)
[rank4]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank4]:     return compiled_graph.compiled_artifact(inputs)
[rank4]:   File "/tmp/torchinductor_root/6l/c6llgpjbyq5wo3mxhjrletrm32l2nzzo2loeax4zalumfjr7vsln.py", line 518, in call
[rank4]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))
[rank4]: AssertionError: wrong number of dimensions
[rank6]: Traceback (most recent call last):
[rank6]:   File "/mnt/drive1/modded-nanogpt/train_gpt2.py", line 479, in <module>
[rank6]:     loss.backward()
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank6]:     torch.autograd.backward(
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank6]:     _engine_run_backward(
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank6]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 301, in apply
[rank6]:     return user_fn(self, *args)
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 882, in backward
[rank6]:     out = call_compiled_backward()
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 831, in call_compiled_backward
[rank6]:     out = call_func_at_runtime_with_args(
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 113, in call_func_at_runtime_with_args
[rank6]:     out = normalize_as_list(f(args))
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank6]:     return fn(*args, **kwargs)
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank6]:     return fn(*args, **kwargs)
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 906, in __call__
[rank6]:     return self.get_current_callable()(inputs)
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 784, in run
[rank6]:     return model(new_inputs)
[rank6]:   File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 934, in _run_from_cache
[rank6]:     return compiled_graph.compiled_artifact(inputs)
[rank6]:   File "/tmp/torchinductor_root/gl/cgl63o2axbbprnu2acsqrzfqu27xiyxbxnj5esgyjkzvz4d4ohzx.py", line 518, in call
[rank6]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))
[rank6]: AssertionError: wrong number of dimensions
W1006 18:04:10.185627 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 114 closing signal SIGTERM
W1006 18:04:10.185979 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 115 closing signal SIGTERM
W1006 18:04:10.186084 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 116 closing signal SIGTERM
W1006 18:04:10.187988 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 117 closing signal SIGTERM
W1006 18:04:10.189330 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 118 closing signal SIGTERM
W1006 18:04:10.189493 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 120 closing signal SIGTERM
W1006 18:04:10.189615 126467783853248 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 121 closing signal SIGTERM
E1006 18:04:10.454885 126467783853248 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 5 (pid: 119) of binary: /opt/conda/envs/py_3.9/bin/python
Traceback (most recent call last):
  File "/opt/conda/envs/py_3.9/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.3.0a0+git1b935e2', 'console_scripts', 'torchrun')())
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/run.py", line 879, in main
    run(args)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train_gpt2.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-10-06_18:04:10
  host      : 1d5fd75bfd9d
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 119)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

hmm. could you try it without torch.compile? (just comment out the line that compiles the model)

Running!

./run.sh 
W1006 18:45:31.200540 125345176495296 torch/distributed/run.py:757] 
W1006 18:45:31.200540 125345176495296 torch/distributed/run.py:757] *****************************************
W1006 18:45:31.200540 125345176495296 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1006 18:45:31.200540 125345176495296 torch/distributed/run.py:757] *****************************************
Running pytorch 2.3.0a0+git1b935e2
using device: cuda:4
using device: cuda:7
using device: cuda:5
using device: cuda:1
using device: cuda:6
using device: cuda:3
using device: cuda:2
using device: cuda:0
Training DataLoader: total number of tokens: 10255324043 across 103 files
Validation DataLoader: total number of tokens: 100000000 across 1 files
compiling the model...
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
/mnt/drive1/modded-nanogpt/train_gpt2.py:159: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at /var/lib/jenkins/pytorch/aten/src/ATen/native/transformers/hip/sdp_utils.cpp:517.)
  y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
val loss 11.021111488342285
step    1/6676 | train loss 11.0205 | lr_scale 1.00e+00 | (64862.29 ms | 8083 tok/s)
step    2/6676 | train loss 9.2245 | lr_scale 1.00e+00 | (880.07 ms | 595736 tok/s)
step    3/6676 | train loss 7.8305 | lr_scale 1.00e+00 | (872.35 ms | 601009 tok/s)
step    4/6676 | train loss 7.6170 | lr_scale 1.00e+00 | (885.78 ms | 591895 tok/s)
step    5/6676 | train loss 7.6862 | lr_scale 1.00e+00 | (886.25 ms | 591581 tok/s)
step    6/6676 | train loss 7.8194 | lr_scale 1.00e+00 | (883.03 ms | 593740 tok/s)
step    7/6676 | train loss 8.5758 | lr_scale 1.00e+00 | (891.98 ms | 587782 tok/s)
step    8/6676 | train loss 7.9219 | lr_scale 1.00e+00 | (880.58 ms | 595388 tok/s)
step    9/6676 | train loss 7.6205 | lr_scale 1.00e+00 | (879.77 ms | 595935 tok/s)
step   10/6676 | train loss 7.3670 | lr_scale 1.00e+00 | (889.60 ms | 589351 tok/s)
step   11/6676 | train loss 7.2259 | lr_scale 1.00e+00 | (873.73 ms | 600054 tok/s)
step   12/6676 | train loss 7.0139 | lr_scale 1.00e+00 | (873.87 ms | 599960 tok/s)
step   13/6676 | train loss 6.9057 | lr_scale 1.00e+00 | (873.36 ms | 600312 tok/s)
step   14/6676 | train loss 6.7420 | lr_scale 1.00e+00 | (875.86 ms | 598597 tok/s)
step   15/6676 | train loss 6.7187 | lr_scale 1.00e+00 | (876.47 ms | 598179 tok/s)
step   16/6676 | train loss 6.6241 | lr_scale 1.00e+00 | (875.32 ms | 598964 tok/s)
step   17/6676 | train loss 6.6011 | lr_scale 1.00e+00 | (876.41 ms | 598220 tok/s)
step   18/6676 | train loss 6.6240 | lr_scale 1.00e+00 | (877.33 ms | 597596 tok/s)
step   19/6676 | train loss 6.6198 | lr_scale 1.00e+00 | (877.07 ms | 597771 tok/s)
step   20/6676 | train loss 6.4442 | lr_scale 1.00e+00 | (877.82 ms | 597261 tok/s)
step   21/6676 | train loss 6.4658 | lr_scale 1.00e+00 | (877.99 ms | 597146 tok/s)
step   22/6676 | train loss 6.5030 | lr_scale 1.00e+00 | (877.85 ms | 597238 tok/s)
step   23/6676 | train loss 6.3411 | lr_scale 1.00e+00 | (876.89 ms | 597894 tok/s)
step   24/6676 | train loss 6.3826 | lr_scale 1.00e+00 | (876.63 ms | 598076 tok/s)
step   25/6676 | train loss 6.2716 | lr_scale 1.00e+00 | (878.37 ms | 596885 tok/s)
step   26/6676 | train loss 6.2952 | lr_scale 1.00e+00 | (877.99 ms | 597143 tok/s)
step   27/6676 | train loss 6.2863 | lr_scale 1.00e+00 | (880.89 ms | 595178 tok/s)
step   28/6676 | train loss 6.1839 | lr_scale 1.00e+00 | (880.93 ms | 595152 tok/s)
step   29/6676 | train loss 6.2608 | lr_scale 1.00e+00 | (879.70 ms | 595983 tok/s)
step   30/6676 | train loss 6.1791 | lr_scale 1.00e+00 | (881.42 ms | 594823 tok/s)
step   31/6676 | train loss 6.2140 | lr_scale 1.00e+00 | (880.76 ms | 595266 tok/s)
step   32/6676 | train loss 6.1020 | lr_scale 1.00e+00 | (882.17 ms | 594316 tok/s)
step   33/6676 | train loss 6.1077 | lr_scale 1.00e+00 | (882.89 ms | 593829 tok/s)
step   34/6676 | train loss 6.0905 | lr_scale 1.00e+00 | (883.37 ms | 593508 tok/s)
step   35/6676 | train loss 6.1618 | lr_scale 1.00e+00 | (882.22 ms | 594281 tok/s)
step   36/6676 | train loss 6.0684 | lr_scale 1.00e+00 | (883.57 ms | 593372 tok/s)
step   37/6676 | train loss 6.0830 | lr_scale 1.00e+00 | (885.20 ms | 592280 tok/s)
step   38/6676 | train loss 6.0379 | lr_scale 1.00e+00 | (885.72 ms | 591934 tok/s)
step   39/6676 | train loss 6.0624 | lr_scale 1.00e+00 | (902.16 ms | 581148 tok/s)
step   40/6676 | train loss 5.9839 | lr_scale 1.00e+00 | (885.54 ms | 592053 tok/s)
step   41/6676 | train loss 6.0095 | lr_scale 1.00e+00 | (886.44 ms | 591453 tok/s)
step   42/6676 | train loss 5.9636 | lr_scale 1.00e+00 | (888.73 ms | 589930 tok/s)
step   43/6676 | train loss 6.0477 | lr_scale 1.00e+00 | (889.83 ms | 589198 tok/s)
step   44/6676 | train loss 5.9231 | lr_scale 1.00e+00 | (888.89 ms | 589823 tok/s)
step   45/6676 | train loss 5.9101 | lr_scale 1.00e+00 | (890.92 ms | 588478 tok/s)

First run with pytorch.compile disabled:

step 6676/6676 | train loss 3.2916 | lr_scale 5.00e-04 | (883.35 ms | 593522 tok/s)
val loss 3.2770633697509766
peak memory consumption: 62878 MiB

Trying with nightly instead of latest and seems to have the same issue:

./run.sh 
W1006 20:51:18.022267 70 site-packages/torch/distributed/run.py:793] 
W1006 20:51:18.022267 70 site-packages/torch/distributed/run.py:793] *****************************************
W1006 20:51:18.022267 70 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1006 20:51:18.022267 70 site-packages/torch/distributed/run.py:793] *****************************************
Running pytorch 2.6.0a0+git77fba0c
using device: cuda:6using device: cuda:1

using device: cuda:7
using device: cuda:3
using device: cuda:5
using device: cuda:0
using device: cuda:4
using device: cuda:2
Training DataLoader: total number of tokens: 10255324043 across 103 files
Validation DataLoader: total number of tokens: 100000000 across 1 files
compiling the model...
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
libibverbs: Warning: couldn't open config directory '/etc/libibverbs.d'.
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:171: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
val loss 10.966203689575195
[rank1]: Traceback (most recent call last):
[rank1]:   File "train_gpt2.py", line 479, in <module>
[rank1]:     loss.backward()
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_tensor.py", line 581, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/autograd/function.py", line 307, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2048, in backward
[rank1]:     out = call_compiled_backward()
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1980, in call_compiled_backward
[rank1]:     out = call_func_at_runtime_with_args(
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args
[rank1]:     out = normalize_as_list(f(args))
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/codecache.py", line 1562, in __call__
[rank1]:     return self.current_callable(inputs)
[rank1]:   File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/utils.py", line 1991, in run
[rank1]:     return model(new_inputs)
[rank1]:   File "/tmp/torchinductor_root/ex/cexygkqmw2w7mvgo6leniyrthpyyjjpivrfi46sxowqcjraiw5e7.py", line 600, in call
[rank1]:     assert_size_stride(getitem_4, (64, 12, 1024), (12288, 1024, 1))

Ok, found a version of pytorch that works! Seems like much higher tok/s too.

https://hub.docker.com/layers/rocm/pytorch/latest-internal/images/sha256-b9a1c61f20ef013264c73b24c258cf6b06620d00399c0fb85f4f5db6f7992f04?context=explore

Running pytorch 2.4.0a0+git8531805

step    1/6676 | train loss 10.9977 | lr_scale 1.00e+00 | (176327.68 ms | 2973 tok/s)
step    2/6676 | train loss 9.2039 | lr_scale 1.00e+00 | (853.15 ms | 614531 tok/s)
step    3/6676 | train loss 7.8569 | lr_scale 1.00e+00 | (909.38 ms | 576536 tok/s)
step    4/6676 | train loss 7.4966 | lr_scale 1.00e+00 | (791.65 ms | 662272 tok/s)
step    5/6676 | train loss 7.6010 | lr_scale 1.00e+00 | (779.18 ms | 672868 tok/s)
step    6/6676 | train loss 8.0912 | lr_scale 1.00e+00 | (772.11 ms | 679028 tok/s)
step    7/6676 | train loss 8.1390 | lr_scale 1.00e+00 | (790.10 ms | 663570 tok/s)
step    8/6676 | train loss 7.7398 | lr_scale 1.00e+00 | (781.98 ms | 670463 tok/s)
step    9/6676 | train loss 7.6330 | lr_scale 1.00e+00 | (773.85 ms | 677505 tok/s)
step   10/6676 | train loss 7.2317 | lr_scale 1.00e+00 | (778.70 ms | 673286 tok/s)
step   11/6676 | train loss 7.0780 | lr_scale 1.00e+00 | (769.38 ms | 681438 tok/s)
step   12/6676 | train loss 6.9914 | lr_scale 1.00e+00 | (774.42 ms | 677009 tok/s)
step   13/6676 | train loss 6.8845 | lr_scale 1.00e+00 | (775.21 ms | 676314 tok/s)
step   14/6676 | train loss 6.7133 | lr_scale 1.00e+00 | (773.74 ms | 677601 tok/s)
step   15/6676 | train loss 6.7093 | lr_scale 1.00e+00 | (777.38 ms | 674428 tok/s)
step   16/6676 | train loss 6.6078 | lr_scale 1.00e+00 | (767.97 ms | 682698 tok/s)
step   17/6676 | train loss 6.5961 | lr_scale 1.00e+00 | (787.01 ms | 666174 tok/s)
step   18/6676 | train loss 6.6134 | lr_scale 1.00e+00 | (775.62 ms | 675957 tok/s)
step   19/6676 | train loss 6.5873 | lr_scale 1.00e+00 | (777.11 ms | 674661 tok/s)
step   20/6676 | train loss 6.4243 | lr_scale 1.00e+00 | (772.23 ms | 678926 tok/s)
step   21/6676 | train loss 6.4449 | lr_scale 1.00e+00 | (782.67 ms | 669872 tok/s)
step   22/6676 | train loss 6.4932 | lr_scale 1.00e+00 | (781.58 ms | 670803 tok/s)
step   23/6676 | train loss 6.3389 | lr_scale 1.00e+00 | (782.05 ms | 670402 tok/s)
step   24/6676 | train loss 6.3687 | lr_scale 1.00e+00 | (767.46 ms | 683149 tok/s)
step   25/6676 | train loss 6.2571 | lr_scale 1.00e+00 | (770.30 ms | 680632 tok/s)
step   26/6676 | train loss 6.2845 | lr_scale 1.00e+00 | (784.56 ms | 668254 tok/s)
step   27/6676 | train loss 6.2624 | lr_scale 1.00e+00 | (777.18 ms | 674607 tok/s)
step   28/6676 | train loss 6.1619 | lr_scale 1.00e+00 | (777.76 ms | 674099 tok/s)
step   29/6676 | train loss 6.2400 | lr_scale 1.00e+00 | (768.77 ms | 681982 tok/s)
step   30/6676 | train loss 6.1702 | lr_scale 1.00e+00 | (780.84 ms | 671438 tok/s)

I will close this as it is working now.