janvainer/speedyspeech

Positional encoding

Closed this issue · 3 comments

Hi Jan,

I think the function for position encoding is missing a bracket around (10.0**4). To be sure, I checked locally and the result varies without the bracket. Here is the current code snippet:

def positional_encoding(channels, length, w=1):
"""The positional encoding from `Attention is all you need` paper
:param channels: How many channels to use
:param length:
:param w: Scaling factor
:return:
"""
enc = torch.FloatTensor(length, channels)
rows = torch.arange(length, out=torch.FloatTensor())[:, None]
cols = 2 * torch.arange(channels//2, out=torch.FloatTensor())
enc[:, 0::2] = torch.sin(w * rows / (10.0**4 ** (cols / channels)))
enc[:, 1::2] = torch.cos(w * rows / (10.0**4 ** (cols / channels)))
return enc

Can you please verify it and let me know? I shall fix it and open a PR. Thank you

Position encoding matrix without bracket
image

Position encoding matrix with bracket
image

Hi good catch, please use the develop branch where this issue is fixed, as discussed here: #23. Thanks! :)

Thanks Jan. I will use the develop branch. Let me close this issue here.

Hi Jan,
I have two question regarding this please:

  1. The positional encoding is weighted and I see that keys are weighted by a hyperparameter w equaling 6.42 whereas the queries are weight 1. Can you please explain how did you choose this value?
    if hp.positional_encoding:
    keys += positional_encoding(keys.shape[-1], keys.shape[1], w=hp.w).to(self.device)
    queries += positional_encoding(queries.shape[-1], queries.shape[1], w=1).to(self.device)

    class HPDurationExtractor:
    positional_encoding = True
    w = 6.42
  2. Given that I am using the new positional encoding function from develop branch, any suggestion on what weight can I choose please?

My best guess with my current understanding is w = num_spectrogram_frames / num_phonemes (i.e.) average number of frames generated by a phoneme. But I am not sure.
Thank you very much