nshepperd/flash_attn_jax

High difference between values from vanilla attention and flash_mha

Closed this issue · 8 comments

Hi, I observed that my language model wasn't converging (it did converge with vanilla_att) so ran the below

B = 128; h = 6; T = 256; dim = 288//6; shape=(B, h, T, dim)
q = jrand.uniform(jrand.PRNGKey(2323), shape=shape, dtype=jnp.float16) # (B, h, T, dim)
k = jrand.uniform(jrand.PRNGKey(323232), shape=shape, dtype=jnp.float16)
v = jrand.uniform(jrand.PRNGKey(2323221111), shape=shape, dtype=jnp.float16)

def vanilla_att(q, k, v, is_causal=True):
    att_wei = (q @ jnp.matrix_transpose(k))/(dim**0.5) # (B, h, T, T) <= (B, h, T, dim) @ (B, h, T, dim).transpose(2, 3)
    # causal mask
    if is_causal:
        att_wei + jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)
    att_wei = jax.nn.softmax(att_wei, axis=-1) # (B, h, T, T)
    # apply attention weights to v
    att_out = att_wei @ v # (B, h, T, T) @ (B, h, T, dv) => (B, h, T, dv)
    return att_out

diff = abs(flash_mha(q, k, v, is_causal=True, softmax_scale=dim**-0.5)-vanilla_att(q, k, v))
>>> jnp.max(diff), jnp.min(diff), jnp.mean(diff)
(Array(0.5703, dtype=float16),
 Array(0., dtype=float16),
 Array(0.146, dtype=float16))

Oh! I see the problem. This confused me for a bit. Flash attention accepts inputs in NTHD order (aka [n, l, h, d] as I wrote in the readme). You're comparing it against a NHTD vanilla attention. You need to put some transpose (or remove some) in your model.

After Transposing...

diff = abs(flash_mha(
    q.transpose((0,2,1,3)), # (B, T, h, dim)
    k.transpose((0,2,1,3)), # (B, T, h, dim)
    v.transpose((0,2,1,3)), # (B, T, h, dim)
    is_causal=True, softmax_scale=dim**-0.5
    ).transpose((0,2,1,3) # (B, h, T, dim) <= (B, T, h, dim)
                )-vanilla_att(q, k, v)) # (B, h, T, dim)

>>> jnp.max(diff), jnp.min(diff), jnp.mean(diff)
(Array(0.551, dtype=float16),
 Array(0., dtype=float16),
 Array(0.02173, dtype=float16))

Oh, you also have a bug here:

 att_wei + jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)

this should be

att_wei += jnp.triu(jnp.full(shape=(1, 1, T, T), fill_value=-jnp.inf), k=1)[:, :, :T, :T] # (B, h, T, T)

your test passes on my desktop with this

whoops! ya sry, thanks.

And why is flash_mha not supported for float32?

The authors of the paper didn't implement it for float32, probably mainly because it would need 2x more sm memory and is also slower (memory is in pretty short supply afaik). You could probably use it in a float32 model by casting to bf16 and back?

Thanks for your interest in my repo btw!

And should I learn C++ Cuda or can I implement Flash Attention from scratch using the Python version of Cuda and what's the best learning resource to learn it (python or C++ version)? Thanks!

(BTW are you a student?)

Join the cuda mode discord and watch the recorded lectures! I'm not sure that i can say what you should learn but i think c++ cuda is quite interesting and a good way to understand the hardware.

I'm not a student. I'm, errr, a researcher (neet, lol).