lucidrains/performer-pytorch

Find FastAttention is slower, and also with more GPU memory usage

Opened this issue · 0 comments

from performer_pytorch import FastAttention
import torch.nn.functional as F

attn_fn = FastAttention(
    dim_heads = 64,
    causal=False,
    no_projection=False)

attn_fn_2 = FastAttention(
    dim_heads = 64,
    causal=False,
    no_projection=True)

query = torch.randn(2,10,4096,64).to('cuda')
key = torch.randn(2,10,4096,64).to('cuda')
value = torch.randn(2,10,4096,64).to('cuda')

###### scaled_dot_product_attention
import time
a = time.time()
out1 = F.scaled_dot_product_attention(query, key, value)
print(f' scaled_dot_product_attention time is {time.time()- a}')

###### with projection FA time
a= time.time()
out2 = attn_fn(query, key, value)
print('project time is:',time.time()-a)
loss_fn = torch.nn.MSELoss()
loss = loss_fn(out1.float(), out2.float())
print('FA project loss is',loss) 

###### no projection FA time
a= time.time()
out2 = attn_fn_2(query, key, value)
print('efficient time is:',time.time()-a)
loss_fn = torch.nn.MSELoss()
loss = loss_fn(out1.float(), out2.float())
print('FA efficient loss is',loss) 

Hi @lucidrains , I just found that in my V100 GPU, the FastAttention is not faster than traditional F.scaled_dot_product_attention, the time I measured using two different FA; I wonder if I apply the FA module in the wrong way, looking forward your reply, many thanks!

scaled_dot_product_attention time is 0.0007033348083496094
FA project time is: 0.05042743682861328
FA efficient time is: 0.0007433891296386719