lucidrains/performer-pytorch

torch.max(data_dash) bug

martinpflaum opened this issue · 2 comments

Hello, i really like your implementation, but i think there is a mistake in line 109 of performer pytorch. There torch.max returns only one value meaning it s also calculated across batches and attention heads. Where as in https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py in line 108 this is not the case
last_dims_t+attention_dims_t is a tuple! this is true since both last_dims_t and attention_dims_t are tuples

data_dash = ratio * (
jnp.exp(data_dash - diag_data - jnp.max(
data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
eps)

i didn t run the code aboth but i think it is much more likely that they didn t calculated the maximum across batches and also didn t calculate the maximum across multiple attention heads.

@martinpflaum Hi Martin, thank you for catching this error (though it should be harmless, as it is only used for numerical stability. Do you want to check version 1.1.1 and see if this fixes the problem?

Hi, yes looks good 👍