positional encoding
alalbiol opened this issue · 3 comments
Hi Brian,
I am learning about ddpms and I find your code really helpful
I am analyzing each bit of code and there is something strange about your positional encoding.
According to https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
sinusoidal positional encodings are great to encode relative positions
so if you plot embeddings @ embeddings.T you get an image like this (figure3):
I tried that with your code and the image is completely different. I have made my own positional encodings using this code:
def sinusoidal_embedding2(n, d):
# Returns the standard positional embedding
wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
wk = wk.reshape((1, d))
t = torch.arange(n).reshape((n, 1))
embedding[:,::2] = torch.sin(t * wk[:,::2])
embedding[:,1::2] = torch.cos(t * wk[:,::2])
return embedding
and with it I can get the relative positionals encodings as expected. I will train two models with each version of the positional encoding to check how important is this issue
Dear @alalbiol,
You are absolutely right, there has been a mistake in the way positional embeddings were generated. Thank you very much for your contribution!
Multiple mistakes were present in the old implementation actually, such as the fact that the argument to the
Here's the code which fixes the issue:
def sinusoidal_embedding(n, d):
# Returns the standard positional embedding
embedding = torch.tensor([[i / (10_000 ** (2*(j//2) / d)) for j in range(d)] for i in range(n)])
sin_mask = torch.arange(0, n, 2)
embedding[:, sin_mask] = torch.sin(embedding[:, sin_mask])
embedding[:, sin_mask+1] = torch.cos(embedding[:, sin_mask+1])
return embedding
Notice that the 2 * (j//2)
part makes the argument for both
The new embedding
looks like this:
Whereas embedding @ embedding.T
looks like this:
If you'd like, I welcome you to open a pull request to solve this issue 😄. Also, I am looking forward to the results with the correct positional embedding if you'd like to test that 🤓.
I will let you know how the training goes with this change,
Solved with PR #10 .