lucidrains/x-transformers

RuntimeError: No available kernel. Aborting execution.

kyegomez opened this issue ยท 12 comments

training:   0% 0/100000 [00:00<?, ?it/s]/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/attend.py:168: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:545.)
  out = F.scaled_dot_product_attention(
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/attend.py:168: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
  out = F.scaled_dot_product_attention(
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/attend.py:168: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
  out = F.scaled_dot_product_attention(
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/attend.py:168: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:191.)
  out = F.scaled_dot_product_attention(
training:   0% 0/100000 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/trainandromeda.py", line 109, in <module>
    loss = model(next(train_loader))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/autoregressive_wrapper.py", line 141, in forward
    logits = self.net(inp, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/x_transformers.py", line 1365, in forward
    x = self.attn_layers(x, mask = mask, mems = mems, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/x_transformers.py", line 1112, in forward
    out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/x_transformers.py", line 541, in forward
    return self.fn(x, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/x_transformers.py", line 823, in forward
    out, intermediates = self.attend(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/attend.py", line 198, in forward
    return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/Optimus-Prime/x_transformers/attend.py", line 168, in flash_attn
    out = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution.

Flash doesn't work on a100

CODE:


from torch.serialization import load
import torch 
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

#training
import random
import tqdm
import gzip
import numpy as np
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import os
# from torch.utils.tensorboard import SummaryWriter
# from torchmetrics import MetricCollection, Accuracy


# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 1024
SAVE_EVERY=500


# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

model = TransformerWrapper(
    num_tokens=20000,
    max_seq_len=5000,
    use_abs_pos_emb = False,
    attn_layers = Decoder(
        dim=512,
        depth=6,
        heads=8,
        alibi_pos_bias=True,
        alibi_num_heads=4,
        rotary_xpos=True,
        attn_flash = True,
        deepnorm=True,
        # dynamic_pos_bias=True,
        # dynamic_pos_bias_log_distance=False,
        shift_tokens=1,
        # rel_pos_bias=True
    )
)


model = AutoregressiveWrapper(model)
model.cuda()

with gzip.open('./enwik8.gz') as file:
  data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
  train_x, valid_x = np.split(data, [int(90e6)])
  data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

# #init tensorboard 
# writer = SummaryWriter(log_dir="./log")

# #define metrics
# metrics = MetricCollection({'accuracy': Accuracy(num_classes=num_classes, task='classification')})

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()


    if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                loss = model(next(val_loader))
                print(f'validation loss: {loss.item()}')

                # # Calculate validation metrics
                # val_metrics = MetricCollection({'val_accuracy': Accuracy()})
                # val_metrics(loss, model(next(val_loader)).argmax(dim=-1))

                # # Add validation metrics to the SummaryWriter
                # writer.add_scalar('Validation/Accuracy', val_metrics['val_accuracy'].compute(), global_step=i)

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

    # Save the model every save_every iterations
    if i % SAVE_EVERY == 0:
        # Specify the directory and filename to save the model
        save_dir = './saved_models/'
        save_filename = 'model_checkpoint.pt'

        # Create the save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)

        # Save the model checkpoint
        torch.save(model.state_dict(), os.path.join(save_dir, save_filename))
        print(f"Model saved at iteration {i}")

#     # Add training metrics to the SummaryWriter
#     writer.add_scalar('Training/Accuracy', metrics['accuracy'].compute(), global_step=i)

#     # Close the SummaryWriter
# writer.close()

And, NVIDIA SMI:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   30C    P0    44W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

was working before without a100 but now no ๐Ÿ˜”

works with v100

!git clone https://github.com/kyegomez/Optimus-Prime.git
%cd Optimus-Prime
!pip install --upgrade torch
# !pip install -r requirements.txt
!pip install einops
# !pip install --upgrade torch

# %cd Optimus-Prime
# # %cd examples
# # !ls
# !python3 trainandromeda.py 
# #%cd enwik8_simple
# # !python trainandromeda.py


from torch.serialization import load
import torch 
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

#training
import random
import tqdm
import gzip
import numpy as np
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import os
# from torch.utils.tensorboard import SummaryWriter
# from torchmetrics import MetricCollection, Accuracy


# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 1024
SAVE_EVERY=500


# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

model = TransformerWrapper(
    num_tokens=20000,
    max_seq_len=5000,
    use_abs_pos_emb = False,
    attn_layers = Decoder(
        dim=512,
        depth=6,
        heads=8,
        alibi_pos_bias=True,
        alibi_num_heads=4,
        rotary_xpos=True,
        attn_flash = True,
        deepnorm=True,
        # dynamic_pos_bias=True,
        # dynamic_pos_bias_log_distance=False,
        shift_tokens=1,
        # rel_pos_bias=True
    )
)


model = AutoregressiveWrapper(model)
model.cuda()

with gzip.open('./enwik8.gz') as file:
  data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
  train_x, valid_x = np.split(data, [int(90e6)])
  data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) #.cuda()??

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

# #init tensorboard 
# writer = SummaryWriter(log_dir="./log")

# #define metrics
# metrics = MetricCollection({'accuracy': Accuracy(num_classes=num_classes, task='classification')})
device="cuda"
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))#.to(device)
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()#.to(device)#.cuda()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()


    if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                loss = model(next(val_loader))
                print(f'validation loss: {loss.item()}')

                # # Calculate validation metrics
                # val_metrics = MetricCollection({'val_accuracy': Accuracy()})
                # val_metrics(loss, model(next(val_loader)).argmax(dim=-1))

                # # Add validation metrics to the SummaryWriter
                # writer.add_scalar('Validation/Accuracy', val_metrics['val_accuracy'].compute(), global_step=i)

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

    # Save the model every save_every iterations
    if i % SAVE_EVERY == 0:
        # Specify the directory and filename to save the model
        save_dir = './saved_models/'
        save_filename = 'model_checkpoint.pt'

        # Create the save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)

        # Save the model checkpoint
        torch.save(model.state_dict(), os.path.join(save_dir, save_filename))
        print(f"Model saved at iteration {i}")

#     # Add training metrics to the SummaryWriter
#     writer.add_scalar('Training/Accuracy', metrics['accuracy'].compute(), global_step=i)

#     # Close the SummaryWriter
# writer.close()

that is strange, a100 is the last machine I would have thought the kernels did not support

yeah maybe the bug is in the gradient accumulate logic

Any fix on this I'm trying to train

I'm facing the same issue:

File "/home/users/user/falcontune/venv_falcontune/lib/python3.8/site-packages/falcontune-0.1.0-py3.8.egg/falcontune/model/falcon/model.py", line 527, in forward
    attn_output = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution.

I'm running it on 3 Tesla V100-SXM2-32GB with below config:
OS: Ubuntu 18.04.5 LTS
Libs:

bitsandbytes==0.39.0
transformers==4.29.2
triton==2.0.0
sentencepiece==0.1.99
datasets==2.12.0
peft==0.3.0
torch==2.0.1+cu118
accelerate==0.19.0
safetensors==0.3.1
einops==0.6.1
wandb==0.15.3
bitsandbytes==0.39.0
scipy==1.10.1

@chintan-donda download the latest PyTorch nightly or preview release this worked for me

@chintan-donda download the latest PyTorch nightly or preview release this worked for me

Didn't help. Any other fix?

It worked for me

@kyegomez what cuda version do you use?

PyTorch's blog says to use the built-in flash attention, attn_mask should be None. I see you set alibi_pos_bias=True. Then, x-transformers generate an float attn_mask internally, causing the issue.

I think flash attention is not compatible with alibi. See this issue Dao-AILab/flash-attention#214.

that is strange, a100 is the last machine I would have thought the kernels did not support

Hi @lucidrains
I was training your naturalspeech2, and I had the same problem, and it was lucky to meet you here.
use_flash_attn = True in Model class of naturalspeech2,change to use_flash_attn = False. it worked for me.
attn_flash = True in this TransformerWrapper class here, can change to attn_flash = False, have a try