lucidrains/electra-pytorch

Disallow Sampling From Correct

Opened this issue · 1 comments

Hey thanks for the great package!

Just wanted to double check on something. In the paper, they say

First, if the
generator happens to generate the correct token, that token is considered “real” instead of “fake”;
we found this formulation to moderately improve results on downstream tasks

In this code, if I'm understanding correctly, it doesn't seem like we account for this and only sample the the masked tokens, whereas in the original code they disallow from sampling for that token.

Shouldn't we only replace the tokens where the generator is incorrect?

Nevermind! Seems that I misunderstood the original Electra code

# force the generator to sample incorrect
                                   # tokens (so 15% of tokens are always
                                   # fake)

But then how is the sampling enforcing that if the generator is correct, that we don't replace that token? Isn't there a case where if the generator is correct, adding noise could cause a correct token to be replaced?

simple example

>>> inputs
tensor([[0, 1, 0],
        [0, 0, 1]])
>>> torch.argmax(inputs, dim=-1)
tensor([1, 2])
>>> logits
tensor([[-1.0000,  1.1200,  1.0000],
        [ 0.0000,  0.0000,  1.0000]])
>>> noise
tensor([[ 0.5214, -0.6400, -0.3058],
        [-1.2392,  1.8223, -1.1086]])
>>> (logits + noise).argmax(dim=-1)
tensor([2, 1])