lucidrains/performer-pytorch

Question: torch.max term used in `softmax_kernel`

Closed this issue · 4 comments

Hi! Thanks for the implementation of performer-pytorch. It really helps!

I feel confused about the torch.max terms in softmax_kernel as below. I reckon they are for some kind of normalization, but I did not find the corresponding equations in the original paper. Could you please help explain this term? And why need to discriminate between query and key?

    if is_query:
        data_dash = ratio * (
            torch.exp(data_dash - diag_data -
                    torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
    else:
        data_dash = ratio * (
            torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)

@lucidrains Thanks for your response!

Sorry to bother you, but I have another question about the implementation in reversible.py:

        with torch.enable_grad():
            x2.requires_grad = True
            fx2 = self.f(x2, set_rng=True, **f_args)
            torch.autograd.backward(fx2, dx1, retain_graph=True)

Is param retain_graph=True necessary here?

@ClawangTU yes it is, it is necessary for the scenario where both the encoder and decoder are reversible

@lucidrains I get it! Thanks a lot for your answer!