Differences between 'edge_vector' & 'edge_attr'
Closed this issue ยท 7 comments
Hi,I'm wondering the differences bewtween edge_vector
& edge_attr
especially in layers.py:
class EdgeEncoding(nn.Module):
def __init__(self, edge_dim: int, max_path_distance: int):
"""
:param edge_dim: edge feature matrix number of dimension
"""
super().__init__()
self.edge_dim = edge_dim
self.max_path_distance = max_path_distance
self.edge_vector = nn.Parameter(torch.randn(self.max_path_distance, self.edge_dim))
def dot_product(self, x1, x2) -> torch.Tensor:
return (x1 * x2).sum(dim=1)
def forward(self, x: torch.Tensor, edge_attr: torch.Tensor, edge_paths) -> torch.Tensor:
"""
:param x: node feature matrix
:param edge_attr: edge feature matrix
:param edge_paths: pairwise node paths in edge indexes
:return: torch.Tensor, Edge Encoding matrix
"""
cij = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device)
for src in edge_paths:
for dst in edge_paths[src]:
path_ij = edge_paths[src][dst][:self.max_path_distance]
weight_inds = [i for i in range(len(path_ij))]
cij[src][dst] = self.dot_product(self.edge_vector[weight_inds], edge_attr[path_ij]).mean()
cij = torch.nan_to_num(cij)
return cij
Does edge_vector
represent the weight embedding mentioned in the paper and edge_attr
represents the edge feature? So we just need to assign a random value to edge_vector
at first and it will be updated automatically,right?
You are absolutely right.
edge_vector is a trainable parameter, edge_attr is a matrix of edge attributes (that come with the graph)
edge_vector is updated automatically as a layer and initialized randomly (just like layer)
Oh I see, thx! This is very helpful!
Hi, I'm confused about the meaning of batch_mask
in layers.py:
def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
edge_attr: torch.Tensor,
b: torch.Tensor,
edge_paths,
ptr) -> torch.Tensor:
"""
:param query: node feature matrix
:param key: node feature matrix
:param value: node feature matrix
:param edge_attr: edge feature matrix
:param b: spatial Encoding matrix
:param edge_paths: pairwise node paths in edge indexes
:param ptr: batch pointer that shows graph indexes in batch of graphs
:return: torch.Tensor, node embeddings after attention operation
"""
batch_mask = torch.zeros((query.shape[0], query.shape[0])).to(next(self.parameters()).device)
# OPTIMIZE: get rid of slices: rewrite to torch
if type(ptr) == type(None):
batch_mask += 1
else:
for i in range(len(ptr) - 1):
batch_mask[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1
query = self.q(query)
key = self.k(key)
value = self.v(value)
c = self.edge_encoding(query, edge_attr, edge_paths)
a = query.mm(key.transpose(0, 1)) / query.size(-1) ** 0.5
a = (a + b + c) * batch_mask
print('a.shape:',a.shape)
print('a:',a)
softmax = torch.softmax(a, dim=-1)
print('softmax.shape:',softmax.shape)
print('softmax:',softmax)
x = softmax.mm(value)
return x
I'm not sure what batch_mask
means here, is it related to masked self-attention?
Also, if I input a batch of graphs into this model, it seems that the model will treat the whole DataBatch
as a large graph (by aggregating every graph from Data
)?
You are right that the DataBatch
class aggregates all the inner graphs into one large graph. Because all the graphs have an adjacency matrix they will not exchange any information.
Because the shape on the matrix inside the model is [num_nodes_in_batch, embedding_dim], but not [batch_size, seq_len, embedding_dim] as in classical transformers, I perform the matrix multiplication (self attention), but we need to exclude any exchange of data between nodes from different graphs. That's when batch_mask comes in.
This code is not optimised as we perform uneeded operations in self attention. In the future I will do my best to optimise it
Thanks for your patient explanation. So all the graphs in DataBatch
are aggregated into a large graph so we need batch_mask
to separate the nodes in different graphs, right?
We need batch_mask
in attention operation so that nodes only do message passing inside each graph, but not with nodes from different graphs.
Overall, yes you are right!
Got it, thanks again!