lucidrains/performer-pytorch

FastAttention doesn't give results in agreement with standard attention?

simonaxelrod opened this issue · 7 comments

Hi there,

I ran this code to compare the results of standard attention with fast attention. Surprisingly, I'm getting very large errors (about 80%). Any idea as to where this comes from?

import torch
import numpy as np
from performer_pytorch import FastAttention

num_nodes = 24
feat_dim = 128
nb_features = 8 * feat_dim

num_its = 5
errs = []

for _ in range(num_its):
    Q = torch.randn(1, 1, num_nodes, feat_dim)
    K = torch.randn(1, 1, num_nodes, feat_dim)
    V = torch.randn(1, 1, num_nodes, feat_dim)

    # fast attention
    
    attn = FastAttention(dim_heads=feat_dim,
                         nb_features=nb_features,
                         causal=False)
        
    fast = attn(q=Q,
                k=K,
                v=V)


    Q = Q.reshape(-1, feat_dim)
    K = K.reshape(-1, feat_dim)
    V = V.reshape(-1, feat_dim)

    
    # standard attention
    
    A = torch.exp(torch.matmul(Q, K.transpose(0, 1)) / feat_dim ** 0.5)
    ones = torch.ones(num_nodes)
    D_inv = torch.diag(1 / torch.matmul(A, ones))
    slow = torch.matmul(D_inv, torch.matmul(A, V))

    err = abs(slow - fast).mean() / abs(slow).mean() * 100
    
    errs.append(err)

mean_err = np.mean(errs)
std_err = np.std(errs)

print("Error is (%.2f +/- %.2f)%%" % (mean_err, std_err)) # prints Error is (73.28 +/- 1.99)%

@lucidrains Bumping this up again. Any thoughts on this?

@simonaxelrod
Does this kind of large error occur when experimenting with the original code written in Jax from Google?

I haven't tried it was Jax yet but I'll give that a shot

All right, I'm curious, too:)

Hi @simonaxelrod,
I tried the above code using SLiM performer code which is written by the original 'Performer' authors. And also it is written in Pytorch, so I could try it easily.

import torch
import numpy as np
from slim_performer_model import MultiHeadAttention

batch = 1
num_nodes = 24 # seq_len
feat_dim = 64
n_heads = 1

num_its = 5
errs = []

for _ in range(num_its):


    # fast attention
    x = torch.randn((batch, num_nodes, feat_dim))  # x: [B, seq_len, feat_dim]
    attn = MultiHeadAttention(feature_type='favor+', n_heads=n_heads, hidden_dim=feat_dim, compute_type='iter')
    rfs = attn.sample_rfs(x.device)  # [n_heads, feat_dim, feat_dim]
    fast = attn.full_forward(x ,rfs)  # x: [B, seq_len, feat_dim] -> fast: [B, seq_len, feat_dim]



    # '_get_original_qkv' method is temporarily made by me to get the Q,K,V used in 'fast '(not in the original 'MultiHeadAttention')
    Q, K ,V = attn._get_original_qkv(x)  # -> Q, K ,V: [B, seq_len, feat_dim]  Note that this is just original Q and V,  not Q' and K'. 

    # standard attention
    A = torch.einsum('bid, bjd -> bij', Q, K) / feat_dim ** 0.5 # [B, seq_len, seq_len]
    A = torch.nn.Softmax(dim=-1)(A)
    slow = torch.einsum('bij, bjd-> bid',  A, V)  # [B, seq_len, feat_dim]


    err = (abs(slow - fast).mean() / abs(slow).mean() * 100).item()
    
    errs.append(err)

mean_err = np.mean(errs)
std_err = np.std(errs)

print("Error is (%.2f +/- %.2f)%%" % (mean_err, std_err)) # Error is (130.53 +/- 2.27)%

But the error is (130.53 +/- 2.27)%.
I don't know why we're getting very large errors...
@lucidrains, @simonaxelrod, is this normal?

@simonaxelrod I observed something similar.
One more observation is that in the code that you linked the error decreases by a lot (to ~1.5%) if in calculating the standard attention, we scale down the logits, for example by dividing the query matrix by 100. This probably makes the attention distribution much flatter so I guess when Q, K, V are not learned, like in this case, performer tends to produce a much flatter attention distribution compared to regular attention. I am not sure if this would hold true if this were a part of a trained neural network because then the weights might be adjusted so that this is no longer an issue.
Also the authors have mentioned that "Backwards compatibility with pretrained models is available as a benefit from softmax approximation, via small finetuning...". So even though the two are compatible, it takes some finetuning to transfer the weights of standard attention to performer.

Consider equation (7) of the paper. In this Lemma the SM+ is defined as an expectation of some exponential terms. The Expectation is of course an integral over R^d. Now what is done in the code and in the paper is that to approximate this integral, we take m=numb_features samples and orthogonalized them. Now this induced a huge error. The integral cannot be approximated sufficient good enough by taking only m points from R^d.

This explains could explain the errors by @simonaxelrod or am I missing here something?