lucidrains/PaLM-rlhf-pytorch

i use other params with palm, but got error

wac81 opened this issue · 4 comments

wac81 commented

model = PaLM(
num_tokens=256, #512 1024
dim=2048, #dim_head*heads
depth=24,
dim_head = 256, #always 256
heads = 8,
flash_attn=True
).to(device)

wac81 commented

training: 0%| | 0/2000000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train_qa_webtext2.py", line 164, in
accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)
File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/accelerator.py", line 1683, in backward
loss.backward(**kwargs)
File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/init.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

model = PaLM( num_tokens=256, #512 1024 dim=2048, #dim_head*heads depth=24, dim_head = 256, #always 256 heads = 8, flash_attn=True ).to(device)

Which type of GPU are you using? Are you using PyTorch 2.0? Flash Attention requires an A100. Also, I do not believe Flash Attention supports dim_head larger than 128.

FlashAttention currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., A100).
2. fp16 and bf16
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100 or H100.
wac81 commented

use a6000,can‘t open FA?
how to setup dim_head larger than 256?

wac81 commented

runs well with 256 dim_head while i comments FA