R0oup1iao/Traffic-Transformer

LearnedPositionalEncoding

Opened this issue · 1 comments

关于模型里
class LearnedPositionalEncoding(nn.Embedding):
def init(self,d_model, dropout = 0.1,max_len = 500):
super().init(max_len, d_model)
self.dropout = nn.Dropout(p = dropout)

def forward(self, x):
    weight = self.weight.data.unsqueeze(1)
    x = x + weight[:x.size(0),:]
    return self.dropout(x)

这部分起到什么作用呢?论文里好像没有相关的讲解

这是用nn.Embedding类实现了一个可学习的位置编码。

当时主要是想做到一种相对位置编码的感觉。现在关于Related Postional Encoding,以及各种魔改PE的文章很多很强,应该可以把这部分升级替换掉