BrianPulfer/PapersReimplementations

positional encoding

alalbiol opened this issue · 3 comments

Hi Brian,
I am learning about ddpms and I find your code really helpful

I am analyzing each bit of code and there is something strange about your positional encoding.
According to https://kazemnejad.com/blog/transformer_architecture_positional_encoding/

sinusoidal positional encodings are great to encode relative positions
so if you plot embeddings @ embeddings.T you get an image like this (figure3):

image

I tried that with your code and the image is completely different. I have made my own positional encodings using this code:

def sinusoidal_embedding2(n, d):
# Returns the standard positional embedding

wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
wk = wk.reshape((1, d))
t = torch.arange(n).reshape((n, 1))
embedding[:,::2] = torch.sin(t * wk[:,::2])
embedding[:,1::2] = torch.cos(t * wk[:,::2])

return embedding

and with it I can get the relative positionals encodings as expected. I will train two models with each version of the positional encoding to check how important is this issue

Dear @alalbiol,

You are absolutely right, there has been a mistake in the way positional embeddings were generated. Thank you very much for your contribution!

Multiple mistakes were present in the old implementation actually, such as the fact that the argument to the $sin$ and $cos$ functions should have been the same in positions $(i, j)$ and $(i, j+1)$, but that was not the case. Also, the cosine function was previously been applied to the already transformed argument (by the $sin$ function).

Here's the code which fixes the issue:

def sinusoidal_embedding(n, d):
    # Returns the standard positional embedding
    embedding = torch.tensor([[i / (10_000 ** (2*(j//2) / d)) for j in range(d)] for i in range(n)])
    sin_mask = torch.arange(0, n, 2)

    embedding[:, sin_mask] = torch.sin(embedding[:, sin_mask])
    embedding[:, sin_mask+1] = torch.cos(embedding[:, sin_mask+1])

    return embedding

Notice that the 2 * (j//2) part makes the argument for both $sin$ and $cos$ function to be the same for adjacent columns, although omitting this is probably minor.

The new embedding looks like this:
emb

Whereas embedding @ embedding.T looks like this:
emb_embT

If you'd like, I welcome you to open a pull request to solve this issue 😄. Also, I am looking forward to the results with the correct positional embedding if you'd like to test that 🤓.

I will let you know how the training goes with this change,

Solved with PR #10 .