karpathy/nn-zero-to-hero

flash attention: simplified

Opened this issue · 2 comments

I tried to imitate your educational coding style hehe

Here's a pure Pytorch implementation of Flash Attention, hope you like it @karpathy

def flash_attention(Q, K, V, is_causal=True, BLOCK_SIZE:int=64):
    NEG_INFINITY = -1e10
    EPS = 1e-10

    B, nh, T, H = Q.shape
    scale = H ** -0.5
    assert Q.shape == K.shape and Q.shape == V.shape, "Some of Q,K,V are misshapen!"

    # TODO: Allow small sequences
    assert T >= BLOCK_SIZE, "For small sequences, use standard attention!"

    # initialize buffers
    outputs = torch.zeros_like(Q)
    maximums = torch.full((B, nh, T, 1), fill_value=NEG_INFINITY)
    denominators = torch.full((B, nh, T, 1), fill_value=EPS)

    # chop up matrices
    Q_blocks, K_blocks, V_blocks = map(
        lambda x: torch.split(x, BLOCK_SIZE, dim=2), 
        (Q, K, V)
    )
    O_blocks, M_blocks, D_blocks = map(
        lambda x: list(torch.split(x, BLOCK_SIZE, dim=2)), 
        (outputs, maximums, denominators)
    )

    # helper variables for causal mask
    positions = torch.arange(0, T)
    K_index_blocks = torch.split(positions[None, :], BLOCK_SIZE, dim=1)
    Q_index_blocks = torch.split(positions[:, None], BLOCK_SIZE, dim=0)

    for k_index in range(len(K_blocks)):
        k_block = K_blocks[k_index]
        v_block = V_blocks[k_index]
        for q_index in range(len(Q_blocks)):
            # create causal mask
            causal_mask = K_index_blocks[k_index] <= Q_index_blocks[q_index]

            # calculate masked attention scores
            q_block = Q_blocks[q_index]
            attn = q_block @ k_block.permute(0, 1, 3, 2) * scale
            attn = torch.where(causal_mask, attn, NEG_INFINITY)

            # calculate new maximum attention score per query vector
            old_maximum = M_blocks[q_index]
            local_maximum, _ = torch.max(attn, dim=-1, keepdim=True)
            new_maximum = torch.maximum(old_maximum, local_maximum)

            # Now that maximum is known, we can safely exponentiate attn scores
            attn = torch.exp(attn-new_maximum)

            # Adjust and update the softmax denominator.
            denominator_scaler = torch.exp(old_maximum-new_maximum)
            denominator_update = torch.sum(attn, dim=-1, keepdim=True)
            old_denominator = D_blocks[q_index]*denominator_scaler
            new_denominator = old_denominator + denominator_update

            # Adjust and update the output of attention.
            output_scaler = old_denominator / new_denominator
            output_update = attn @ v_block / new_denominator
            old_output = O_blocks[q_index]*output_scaler
            new_output = old_output + output_update

            # Store new maximums, new denominators and new attention output.
            M_blocks[q_index] = new_maximum
            D_blocks[q_index] = new_denominator
            O_blocks[q_index] = new_output

    # Patch together attention output into a single (B, nh, T, H) vector.
    return torch.cat(O_blocks, dim=2)
  • Remind me later!