THUDM/HGB

convolution optimization

felipemello1 opened this issue · 1 comments

Hi, I was checking the convolution, and apparently there are expensive layers there that can be completely eliminated:

The code is:

e_feat = self.edge_emb(e_feat)
e_feat = self.fc_e(e_feat).view(-1, self._num_heads, self._edge_feats)
ee = (e_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
graph.edata.update({'ee': ee})
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e')+graph.edata.pop('ee'))

Problem 1:

self.edge_emb = nn.Embedding(num_etypes, edge_feats)
e_feat = self.edge_emb(e_feat)
e_feat = self.fc_e(e_feat).view(-1, self._num_heads, self._edge_feats)

Is it necessary to run a fully connected layer over embeddings? As far as I understand, the embeddings can naturally learn the same projection emb = self.fc_c(emb). This becomes even more expensive when we think that the conv might have only 20 types of edges, but it is running this fully connected layer hundreds of thousands of times for the same repeated 20 types.

Problem 2:

self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_e = nn.Parameter(th.FloatTensor(size=(1, num_heads, edge_feats)))

ee = (e_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)

e_feat, feat_src and feat_dst are the products of an MLP. Is it necessary to multiply it by a constant (called attention here)? I guess the MLP can naturally achieve the same value. We can just say that:

y = (a * x+ b)  
y = d(a*x + b)
y = d*a*x + b*d
y = new_a*x + new_b

If you remove these two parts, then attention can be calculated just as right + left + edge_emb (graph.apply_edges(fn.u_add_v('feat_src', 'feat_dst', 'e_feat'))), without doing all these transformations beforehand.

Thank you very much for your suggestion, but sometimes although there exists a equivalent value, the optimization method might not find it.
We will test the performance in the future, but due to some personal heathy problems we cannot test it now. If you have good result, it will be great to contribute a pull request!