Disallow Sampling From Correct
Opened this issue · 1 comments
zanussbaum commented
Hey thanks for the great package!
Just wanted to double check on something. In the paper, they say
First, if the
generator happens to generate the correct token, that token is considered “real” instead of “fake”;
we found this formulation to moderately improve results on downstream tasks
In this code, if I'm understanding correctly, it doesn't seem like we account for this and only sample the the masked tokens, whereas in the original code they disallow from sampling for that token.
Shouldn't we only replace the tokens where the generator is incorrect?
zanussbaum commented
Nevermind! Seems that I misunderstood the original Electra code
# force the generator to sample incorrect
# tokens (so 15% of tokens are always
# fake)
But then how is the sampling enforcing that if the generator is correct, that we don't replace that token? Isn't there a case where if the generator is correct, adding noise could cause a correct token to be replaced?
simple example
>>> inputs
tensor([[0, 1, 0],
[0, 0, 1]])
>>> torch.argmax(inputs, dim=-1)
tensor([1, 2])
>>> logits
tensor([[-1.0000, 1.1200, 1.0000],
[ 0.0000, 0.0000, 1.0000]])
>>> noise
tensor([[ 0.5214, -0.6400, -0.3058],
[-1.2392, 1.8223, -1.1086]])
>>> (logits + noise).argmax(dim=-1)
tensor([2, 1])