DM2-ND/MoKGE

comp_gcn function

Closed this issue · 0 comments

Hello,

In comp_gcn function (graph_encoder.py):

def comp_gcn(...):
    ...
    # First part
    o = concept_hidden.gather(1, head.unsqueeze(2).expand(bsz, mem_t, hidden_size))
    o = o.masked_fill(triple_label.unsqueeze(2) == -1, 0)

    scatter_add(o, tail, dim=1, out=update_node)
    scatter_add( - relation_hidden.masked_fill(triple_label.unsqueeze(2) == -1, 0), tail, dim=1, out=update_node)
    scatter_add(count, tail, dim=1, out=count_out)
    # => o, update_node, and count_node variables are not used anywhere???

    # Second part
    o = concept_hidden.gather(1, tail.unsqueeze(2).expand(bsz, mem_t, hidden_size))
    o = o.masked_fill(triple_label.unsqueeze(2) == -1, 0)
    scatter_add(o, head, dim=1, out=update_node)
    scatter_add( - relation_hidden.masked_fill(triple_label.unsqueeze(2) == -1, 0), head, dim=1, out=update_node)
    scatter_add(count, head, dim=1, out=count_out)
    # => o, update_node, and count_node variables are used in the below part

    act = nn.ReLU()
    # calculating final update_node representation
    update_node = self.W_s[layer_idx](concept_hidden) + self.W_n[layer_idx](update_node) / count_out.clamp(min=1).unsqueeze(2)
    update_node = act(update_node)

Thank you!