softmax in GraphormerAttentionHead
JingweiQu opened this issue · 3 comments
JingweiQu commented
softmax = torch.softmax(a, dim=-1)
The softmax function is not exact here since we should compute the attention in each graph. However, such direct computation causes the attention between nodes from different graphs in a batch.
Maybe combining batch_mask
with -10^6 replacing 0 is a solution.
leffff commented
Thanks for the issue!
Yeah, that sounds like a good solution!