comp_gcn function
Closed this issue · 0 comments
eujhwang commented
Hello,
In comp_gcn
function (graph_encoder.py):
def comp_gcn(...):
...
# First part
o = concept_hidden.gather(1, head.unsqueeze(2).expand(bsz, mem_t, hidden_size))
o = o.masked_fill(triple_label.unsqueeze(2) == -1, 0)
scatter_add(o, tail, dim=1, out=update_node)
scatter_add( - relation_hidden.masked_fill(triple_label.unsqueeze(2) == -1, 0), tail, dim=1, out=update_node)
scatter_add(count, tail, dim=1, out=count_out)
# => o, update_node, and count_node variables are not used anywhere???
# Second part
o = concept_hidden.gather(1, tail.unsqueeze(2).expand(bsz, mem_t, hidden_size))
o = o.masked_fill(triple_label.unsqueeze(2) == -1, 0)
scatter_add(o, head, dim=1, out=update_node)
scatter_add( - relation_hidden.masked_fill(triple_label.unsqueeze(2) == -1, 0), head, dim=1, out=update_node)
scatter_add(count, head, dim=1, out=count_out)
# => o, update_node, and count_node variables are used in the below part
act = nn.ReLU()
# calculating final update_node representation
update_node = self.W_s[layer_idx](concept_hidden) + self.W_n[layer_idx](update_node) / count_out.clamp(min=1).unsqueeze(2)
update_node = act(update_node)
Thank you!