flash attention: simplified
Opened this issue · 2 comments
Deleted user commented
I tried to imitate your educational coding style hehe
Here's a pure Pytorch implementation of Flash Attention, hope you like it @karpathy
def flash_attention(Q, K, V, is_causal=True, BLOCK_SIZE:int=64):
NEG_INFINITY = -1e10
EPS = 1e-10
B, nh, T, H = Q.shape
scale = H ** -0.5
assert Q.shape == K.shape and Q.shape == V.shape, "Some of Q,K,V are misshapen!"
# TODO: Allow small sequences
assert T >= BLOCK_SIZE, "For small sequences, use standard attention!"
# initialize buffers
outputs = torch.zeros_like(Q)
maximums = torch.full((B, nh, T, 1), fill_value=NEG_INFINITY)
denominators = torch.full((B, nh, T, 1), fill_value=EPS)
# chop up matrices
Q_blocks, K_blocks, V_blocks = map(
lambda x: torch.split(x, BLOCK_SIZE, dim=2),
(Q, K, V)
)
O_blocks, M_blocks, D_blocks = map(
lambda x: list(torch.split(x, BLOCK_SIZE, dim=2)),
(outputs, maximums, denominators)
)
# helper variables for causal mask
positions = torch.arange(0, T)
K_index_blocks = torch.split(positions[None, :], BLOCK_SIZE, dim=1)
Q_index_blocks = torch.split(positions[:, None], BLOCK_SIZE, dim=0)
for k_index in range(len(K_blocks)):
k_block = K_blocks[k_index]
v_block = V_blocks[k_index]
for q_index in range(len(Q_blocks)):
# create causal mask
causal_mask = K_index_blocks[k_index] <= Q_index_blocks[q_index]
# calculate masked attention scores
q_block = Q_blocks[q_index]
attn = q_block @ k_block.permute(0, 1, 3, 2) * scale
attn = torch.where(causal_mask, attn, NEG_INFINITY)
# calculate new maximum attention score per query vector
old_maximum = M_blocks[q_index]
local_maximum, _ = torch.max(attn, dim=-1, keepdim=True)
new_maximum = torch.maximum(old_maximum, local_maximum)
# Now that maximum is known, we can safely exponentiate attn scores
attn = torch.exp(attn-new_maximum)
# Adjust and update the softmax denominator.
denominator_scaler = torch.exp(old_maximum-new_maximum)
denominator_update = torch.sum(attn, dim=-1, keepdim=True)
old_denominator = D_blocks[q_index]*denominator_scaler
new_denominator = old_denominator + denominator_update
# Adjust and update the output of attention.
output_scaler = old_denominator / new_denominator
output_update = attn @ v_block / new_denominator
old_output = O_blocks[q_index]*output_scaler
new_output = old_output + output_update
# Store new maximums, new denominators and new attention output.
M_blocks[q_index] = new_maximum
D_blocks[q_index] = new_denominator
O_blocks[q_index] = new_output
# Patch together attention output into a single (B, nh, T, H) vector.
return torch.cat(O_blocks, dim=2)
Deleted user commented
Inspired by Shreyansh's implementation.
Playerrrrr commented
- Remind me later!