labmlai/annotated_deep_learning_paper_implementations

Question about gatv2 code

XiaokangORCA opened this issue · 1 comments

Hello, I am a beginner in GAT , and I've been studying your GATv2 code lately. I have a question while going through the code in

labml_nn/graphs/gatv2/init.py

When calculating g_sum

g_sum = g_l_repeat + g_r_repeat_interleave

You mentioned in the comments: Now we add the two tensors to get

$$ \lbrace\overrightarrow{g_{l1}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l1}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{l1}} + \overrightarrow{g_{rN}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{l2}} + \overrightarrow{g_{rN}}, \dots\rbrace $$

But in the previous code, g_l_repeat gets

$$ \lbrace\overrightarrow{g_{l1}}, \overrightarrow{g_{l2}}, \dots, \overrightarrow{g_{lN}}, \overrightarrow{g_{l1}}, \overrightarrow{g_{l2}}, \dots, \overrightarrow{g_{lN}}, \dots\rbrace $$

and g_r_repeat_interleave gets

$$ \lbrace\overrightarrow{g_{r1}}, \overrightarrow{g_{r1}}, \dots, \overrightarrow{g_{r1}}, \overrightarrow{g_{r2}}, \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{r2}}, \dots\rbrace $$

So I think the result of adding the two tensors should be

$$ \lbrace\overrightarrow{g_{l1}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{lN}} + \overrightarrow{g_{r1}}, \overrightarrow{g_{l1}} + \overrightarrow{g_{r2}}, \overrightarrow{g_{l2}} + \overrightarrow{g_{r2}}, \dots, \overrightarrow{g_{lN}} + \overrightarrow{g_{r2}}, \dots\rbrace $$

I'm not sure whether I may have overlooked some crucial information or if there's a mismatch between your comments and the code. I would greatly appreciate it if you could help clarify my confusion. Thank you.

Hello! I am also new to GAT, I found your issue.

So, to your question, the implementation in the website is correct (partially), I think this is because

g_l_repeat

$${\overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, \overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, ...}$$

is:

>>> n_nodes = 3 # or N
>>> torch.tensor([[1], [2], [3]])
>>> tensor.repeat(n_nodes , 1)
tensor([[1],
        [2],
        [3],
        [1],
        [2],
        [3],
        [1],
        [2],
        [3]])

and g_r_repeat_interleave

$${\overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...}$$

is instead:

>>> tensor.repeat_interleave(n_nodes, dim=0)
tensor([[1],
        [1],
        [1],
        [2],
        [2],
        [2],
        [3],
        [3],
        [3]])

So, the operation g_l_repeat + g_r_repeat_interleave

$${\overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, \overrightarrow{{g_l}_1}, \overrightarrow{{g_l}_2}, \dots, \overrightarrow{{g_l}_N}, ...}$$

$$\ + $$

$${\overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_r}_1}, \overrightarrow{{g_r}_2}, \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_r}_2}, ...}$$

is

>>> tensor.repeat(n_nodes , 1) + tensor.repeat_interleave(n_nodes, dim=0)
tensor([[1] + [1],
        [2] + [1],
        [3] + [1],
        [1] + [2],
        [2] + [2],
        [3] + [2],
        [1] + [3],
        [2] + [3],
        [3] + [3]])

So, this is correct:

$${\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_N}, ...}$$

But, if you want to match the notation (to avoid confusion), should (I think) be this. However, the current implementation is correct:

$${\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1}, \dots, \overrightarrow{{g_l}_N} +\overrightarrow{{g_r}_1}, \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2}, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2}, \dots, \overrightarrow{{g_l}_N} + \overrightarrow{{g_r}_2}, ...}$$