Question: torch.max term used in `softmax_kernel`
Closed this issue · 4 comments
Hi! Thanks for the implementation of performer-pytorch. It really helps!
I feel confused about the torch.max
terms in softmax_kernel
as below. I reckon they are for some kind of normalization, but I did not find the corresponding equations in the original paper. Could you please help explain this term? And why need to discriminate between query and key?
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
@ClawangTU I don't really know either, I transcribed this faithfully from their repository https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py#L100-L107
@lucidrains Thanks for your response!
Sorry to bother you, but I have another question about the implementation in reversible.py
:
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
Is param retain_graph=True
necessary here?
@ClawangTU yes it is, it is necessary for the scenario where both the encoder and decoder are reversible
@lucidrains I get it! Thanks a lot for your answer!