MAGICS-LAB/DNABERT_2

Getting embedding of a sequence

CorvusVaine opened this issue · 2 comments

Hello, I am trying to get the output embeddings of my dataset from DNABERT2, and then use it with another model, using the following code :

import os
import transformers
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = transformers.AutoModel.from_pretrained(
"zhihan1996/DNABERT-2-117M",
trust_remote_code=True
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.eval()

tokenizer = transformers.AutoTokenizer.from_pretrained(
"zhihan1996/DNABERT-2-117M",
model_max_length=42,
padding_side="right",
use_fast=True,
trust_remote_code=True,
truncation=True,
padding='max_length',
max_length=40
)

seqs = ["ATCTAGCTAGACGTTACGCTACGCATGTACGTACGCTCAGTAGCATGCTAGCT","CGTAGGTCGTCTAGCTGATCAGTACGCATGCATAGCTAGCTGCATCGTAGCATCGATGATCGATCGATGATGC"]
model.to(device)
inputs = tokenizer(a, padding = 'max_length', truncation=True, max_length = 40, return_tensors='pt')
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)

When I run, this code, I get the following error :

AssertionError Traceback (most recent call last)
Cell In[11], line 9
7 print(inputs)
8 with torch.no_grad():
----> 9 outputs = model(**inputs)
10 print(outputs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:609, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs)
606 first_col_mask[:, 0] = True
607 subset_mask = masked_tokens_mask | first_col_mask
--> 609 encoder_outputs = self.encoder(
610 embedding_output,
611 attention_mask,
612 output_all_encoded_layers=output_all_encoded_layers,
613 subset_mask=subset_mask)
615 if masked_tokens_mask is None:
616 sequence_output = encoder_outputs[-1]

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:447, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask)
445 if subset_mask is None:
446 for layer_module in self.layer:
--> 447 hidden_states = layer_module(hidden_states,
448 cu_seqlens,
449 seqlen,
450 None,
451 indices,
452 attn_mask=attention_mask,
453 bias=alibi_attn_mask)
454 if output_all_encoded_layers:
455 all_encoder_layers.append(hidden_states)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:328, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias)
306 def forward(
307 self,
308 hidden_states: torch.Tensor,
(...)
314 bias: Optional[torch.Tensor] = None,
315 ) -> torch.Tensor:
316 """Forward pass for a BERT layer, including both attention and MLP.
317
318 Args:
(...)
326 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
327 """
--> 328 attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
329 subset_idx, indices, attn_mask, bias)
330 layer_output = self.mlp(attention_output)
331 return layer_output

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:241, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias)
219 def forward(
220 self,
221 input_tensor: torch.Tensor,
(...)
227 bias: Optional[torch.Tensor] = None,
228 ) -> torch.Tensor:
229 """Forward pass for scaled self-attention without padding.
230
231 Arguments:
(...)
239 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
240 """
--> 241 self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
242 attn_mask, bias)
243 if subset_idx is not None:
244 return self.output(index_first_axis(self_output, subset_idx),
245 index_first_axis(input_tensor, subset_idx))

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py:182, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias)
180 bias_dtype = bias.dtype
181 bias = bias.to(torch.float16)
--> 182 attention = flash_attn_qkvpacked_func(qkv, bias)
183 attention = attention.to(orig_dtype)
184 bias = bias.to(bias_dtype)

File ~/.local/lib/python3.10/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
1019 if qkv.stride(-1) != 1:
1020 qkv = qkv.contiguous()
-> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward(
1022 qkv[:, :, 0],
1023 qkv[:, :, 1],
1024 qkv[:, :, 2],
1025 bias=bias,
1026 causal=causal,
1027 softmax_scale=softmax_scale)
1028 ctx.save_for_backward(qkv, o, lse, bias)
1029 ctx.causal = causal

File ~/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py:781, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
778 assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
779 assert q.dtype in [torch.float16,
780 torch.bfloat16], 'Only support fp16 and bf16'
--> 781 assert q.is_cuda and k.is_cuda and v.is_cuda
782 softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
784 has_bias = bias is not None

AssertionError:

I guess this error is due to the structure of the network, that may require the data to be fed differently.
Could you tell me how I could simply get the output embeddings from DNABERT2 please ?

Please try this:

import os
import numpy as np
import transformers
import torch
import torch.utils.data as util_data
import torch.nn as nn
import tqdm
import argparse
from sklearn.preprocessing import normalize



def calculate_llm_embedding(dna_sequences, model_name_or_path, model_max_length=400, batch_size=25):
    # reorder the sequences by length
    # process sequences with similar lengths in the same batch can greatly speed up the computation
    lengths = [len(seq) for seq in dna_sequences]
    idx = np.argsort(lengths)
    dna_sequences = [dna_sequences[i] for i in idx]
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_name_or_path,
            cache_dir=None,
            model_max_length=model_max_length,
            padding_side="left",
            use_fast=True,
            trust_remote_code=True,
        )


    model = transformers.AutoModel.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16, 
            attn_implementation="flash_attention_2",
        )
    

    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        model = nn.DataParallel(model)
        
    model.to("cuda")


    train_loader = util_data.DataLoader(dna_sequences, batch_size=batch_size*n_gpu, shuffle=False, num_workers=2*n_gpu)
    for j, batch in enumerate(tqdm.tqdm(train_loader)):
        with torch.no_grad():
            token_feat = tokenizer.batch_encode_plus(
                    batch, 
                    max_length=model_max_length, 
                    return_tensors='pt', 
                    padding='longest', 
                    truncation=True
                )
            input_ids = token_feat['input_ids'].cuda()
            attention_mask = token_feat['attention_mask'].cuda()
            model_output = model.forward(input_ids=input_ids, attention_mask=attention_mask)[0].detach().cpu()
                
            attention_mask = attention_mask.unsqueeze(-1).detach().cpu()
            embedding = torch.sum(model_output*attention_mask, dim=1) / torch.sum(attention_mask, dim=1)
            
            if j==0:
                embeddings = embedding
            else:
                embeddings = torch.cat((embeddings, embedding), dim=0)

    embeddings = np.array(embeddings.detach().float().cpu())
    
    # reorder the embeddings
    embeddings = embeddings[np.argsort(idx)]

    return embeddings

It appears to me that AutoModel.from_pretrained() does not accept attn_implementation as a parameter, although the config does, so I got an error with your code. I tried changing it a bit by loading the model as follows :

config = transformers.AutoConfig.from_pretrained(model_name_or_path,trust_remote_code=True)
config.attn_implementation="flash_attention_2"
model = transformers.AutoModel.from_config(
config=config,
trust_remote_code=True,
torch_dtype=torch.bfloat16
)

and it seems to load fine.

Although, I still have two errors in the forward pass :

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
0%| | 0/1 [00:00<?, ?it/s]

RuntimeError Traceback (most recent call last)
Cell In[47], line 1
----> 1 a= calculate_llm_embedding(["ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT","ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT","ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT","ATCGACTGTGCATAGCCTAGATAGCTACGTACGCTCAGCTGACTGATGCTACAGCGT"],"zhihan1996/DNABERT-2-117M")

Cell In[46], line 45, in calculate_llm_embedding(dna_sequences, model_name_or_path, model_max_length, batch_size)
43 input_ids = token_feat['input_ids'].cuda()
44 attention_mask = token_feat['attention_mask'].cuda()
---> 45 model_output = model.forward(input_ids=input_ids, attention_mask=attention_mask)[0].detach().cpu()
47 attention_mask = attention_mask.unsqueeze(-1).detach().cpu()
48 embedding = torch.sum(model_output*attention_mask, dim=1) / torch.sum(attention_mask, dim=1)

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:171, in DataParallel.forward(self, *inputs, **kwargs)
169 return self.module(*inputs[0], **kwargs[0])
170 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 171 outputs = self.parallel_apply(replicas, inputs, kwargs)
172 return self.gather(outputs, self.output_device)

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:181, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
180 def parallel_apply(self, replicas, inputs, kwargs):
--> 181 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File ~/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:89, in parallel_apply(modules, inputs, kwargs_tup, devices)
87 output = results[i]
88 if isinstance(output, ExceptionWrapper):
---> 89 output.reraise()
90 outputs.append(output)
91 return outputs

File ~/.local/lib/python3.10/site-packages/torch/_utils.py:643, in ExceptionWrapper.reraise(self)
639 exception = self.exc_type(msg)
640 except TypeError:
641 # If the exception takes multiple arguments, don't try to
642 # instantiate since we don't know how to
--> 643 raise RuntimeError(msg) from None
644 raise exception

RuntimeError: Caught CompilationError in replica 0 on device 0.
Original Traceback (most recent call last):
File "", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-83ca8b715a9dc5f32dc1110973485f64-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
generator.visit(fn.parse())
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/opt/conda/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
self.visit_compound_statement(node.body)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 319, in visit_AugAssign
self.visit(assign)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 301, in visit_Assign
values = self.visit(node.value)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 339, in visit_BinOp
rhs = self.visit(node.right)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/opt/conda/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 797, in visit_Call
return fn(*args, _builder=self.builder, **kws)
File "/home/groy/.local/lib/python3.10/site-packages/triton/impl/base.py", line 22, in wrapper
return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
output = module(*input, **kwargs)
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 609, in forward
encoder_outputs = self.encoder(
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 447, in forward
hidden_states = layer_module(hidden_states,
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 328, in forward
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 241, in forward
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
File "/home/groy/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/bert_layers.py", line 186, in forward
attention = flash_attn_qkvpacked_func(qkv, bias)
File "/home/groy/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py", line 1021, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/home/groy/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/d064dece8a8b41d9fb8729fbe3435278786931f1/flash_attn_triton.py", line 826, in _flash_attn_forward
_fwd_kernel[grid]( # type: ignore
File "/home/groy/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 90, in run
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File "/home/groy/.local/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 199, in run
return self.fn.run(*args, **kwargs)
File "", line 41, in _fwd_kernel
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 1621, in compile
next_module = compile(module)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 1550, in
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 962, in ast_to_ttir
mod, _ = build_triton_ir(fn, signature, specialization, constants)
File "/home/groy/.local/lib/python3.10/site-packages/triton/compiler.py", line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 114:24:
def _fwd_kernel(
Q,
K,
V,
Bias,
Out,
Lse,
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# triton-lang/triton#741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
offs_m[:, None] * stride_bm + offs_n[None, :])
else:
raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs,
mask=(offs_m[:, None] < seqlen_q) &
(offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
(start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) &
(offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
^

Thank you for your help