lucidrains/x-transformers

Support for NormSoftmax

Closed this issue ยท 16 comments

Based on this paper: https://openreview.net/pdf?id=4g7nCbpjNwd

Would require editing this line:

https://github.com/lucidrains/x-transformers/blob/aabee05d6bca6d74646156009159c55f8d27d884/x_transformers/attend.py#L278C70-L278C75

And replacing the * scale with:

    if self.norm_softmax:
        dots = dots / torch.clamp(dots.std(dim=-1, keepdim=True), min=1e-6)
    else:
        dots *= scale

And then something similar in the other flash attention path

@catid oh interesting, reminds me a bit of https://arxiv.org/abs/2005.09561

there will also be a temperature involved

have you tried this? maybe i can run a quick experiment tonight

it won't be compatible with flash attention

NormSoftmax CIFAR-10 benchmark results at epoch=60 using ViT-tiny:
baseline : 77.69%
sqrtd: 76.39%
inf: 77.53%

NormSoftmax CIFAR-10 benchmark results at epoch=300 using ViT-tiny:
baseline: 85.19%
inf: 85.07%

Manages to get about the same result without the extra parameters

@catid well yea, so they claim. cifar-10 is a tiny benchmark too

another engineering obstacle would be handling a masked standard dev

yea, let me run it tonight on enwik8, but if i don't see anything notable on the first or second try, probably will just drop this

@catid i'm thinking for autoregressive text generation (gpt), the triangular causal mask. you are masking out the diagonal?

Yeah I'm just copying your vit_for_small_dataset.py

@catid ohh ok, do you see anything? have you ran the experiments yourself? never trust anything a paper says unless you see the curves in front of you ๐Ÿ˜†

The results I shared above are from my setup

@catid wow! ok, i actually put a lot of weight from results from internet randos

ok, let me try it tonight!

@catid wait, your results show norm softmax to be worse than baseline? is that accuracy?

@catid can you share a wandb report with training curves?

I dunno I mean the numbers are pretty close and I only ran N=1 trial so not sure if one method produces better accuracy than the other. Also I don't have wandb integrated into my scripts yet (haven't learned how to use that yet).

ah, looks to be a negative result.