JinheonBaek/RGCN

The memory overflow

Yongquan-He opened this issue · 5 comments

I use your code as a part in our experiment.
But there is a problem about the test_graph.
Because the code uses all train triplets to build graph during valid and test, so when I test the model, 64G memory is not enough.
So I wish you to give me some suggestions.
Thank you very much!

Hi, I have also encountered the problem of out of memory. It seems that when there are too many triplets, massive CPU memory is needed:
image

In my experiments, information of the graph is as follows:
image

It seems that the only way to avoid OOM is to reduce the number of triplets, then there will be less edge_type indexed in w:
image

I use your code as a part in our experiment. But there is a problem about the test_graph. Because the code uses all train triplets to build graph during valid and test, so when I test the model, 64G memory is not enough. So I wish you to give me some suggestions. Thank you very much!

Hi He, Could you please tell me the way you fix this overflow issue? Many thanks in advance.

One solution seems to be to modify valid() to include batching for evaluation phase (& modify calc_mrr accordingly to return hits also):

def valid(valid_triplets, model, test_graph, all_triplets, batch_size=1024):
    with torch.no_grad():
        model.eval()
        mrr = 0
        hits = {1: 0, 3: 0, 10: 0}
        for i in range(0, len(valid_triplets), batch_size):
            batch_valid_triplets = valid_triplets[i:i+batch_size]
            entity_embedding = model(test_graph.entity, test_graph.edge_index, test_graph.edge_type, test_graph.edge_norm)
            mrr_b, hits_bdict = calc_mrr(entity_embedding, model.relation_embedding, batch_valid_triplets, all_triplets, hits=[1, 3, 10])
            mrr+=mrr_b
            hits[1]+=hits_bdict[1]
            hits[3]+=hits_bdict[3]
            hits[10]+=hits_bdict[10]
        mrr /= (len(valid_triplets) // batch_size)
        hits[1] /= (len(valid_triplets) // batch_size)
        hits[3] /= (len(valid_triplets) // batch_size)
        hits[10] /= (len(valid_triplets) // batch_size)
        print(f'MRR: {mrr}, Hits@1: {hits[1]}, Hits@3: {hits[3]}, Hits@10: {hits[10]}')
    return mrr

The above however does not seem to work for FB15k-237. Could the source of the issue be this line: https://github.com/JinheonBaek/RGCN/blob/818bf70b00d5cd178a7496a748e4f18da3bcde82/main.py#L25C41-L25C47

In case it helps, here is the memory profiling for the message function during training & during validation.

During Training:

Line #    Mem usage    Increment  Occurrences   Line Contents                                                                                                                                  =============================================================
   188   1904.4 MiB   1904.4 MiB           1       @profile
   189                                             def message(self, x_j, edge_index_j, edge_type, edge_norm):
   190                                                 """
   191                                                 """
   192
   193                                                 # Call the function that might be causing the memory overflow
   194   1904.4 MiB      0.0 MiB           1           w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
   195
   196                                                 # If no node features are given, we implement a simple embedding
   197                                                 # loopkup based on the target node index and its edge type.                                                                                198   1904.4 MiB      0.0 MiB           1           if x_j is None:
   199                                                     w = w.view(-1, self.out_channels)
   200                                                     index = edge_type * self.in_channels + edge_index_j
   201                                                     out = torch.index_select(w, 0, index)
   202                                                 else:
   203   1904.4 MiB      0.0 MiB           1               w = w.view(self.num_rel, self.in_chan, self.out_chan)
   204   3047.9 MiB   1143.5 MiB           1               w = torch.index_select(w, 0, edge_type)
   205   3047.9 MiB      0.0 MiB           1               out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
   206
   207   3047.9 MiB      0.0 MiB           1           if edge_norm is not None:
   208   3047.9 MiB      0.0 MiB           1               out = out * edge_norm.view(-1, 1)                                                                                                      209
   210   3047.9 MiB      0.0 MiB           1           return out

During Validation:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   188    844.2 MiB    844.2 MiB           1       @profile
   189                                             def message(self, x_j, edge_index_j, edge_type, edge_norm):
   190                                                 """
   191                                                 """
   192
   193                                                 # Call the function that might be causing the memory overflow
   194    844.2 MiB      0.0 MiB           1           w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))
   195
   196                                                 # If no node features are given, we implement a simple embedding
   197                                                 # loopkup based on the target node index and its edge type.                                                                                198    844.2 MiB      0.0 MiB           1           if x_j is None:
   199                                                     w = w.view(-1, self.out_channels)
   200                                                     index = edge_type * self.in_channels + edge_index_j
   201                                                     out = torch.index_select(w, 0, index)                                                                                                  202                                                 else:
   203    844.2 MiB      0.0 MiB           1               w = w.view(self.num_rel, self.in_chan, self.out_chan)
   204  11635.1 MiB  10790.8 MiB           1               w = torch.index_select(w, 0, edge_type)                                                                                                205  11743.1 MiB    108.0 MiB           1               out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)
   206
   207  11743.1 MiB      0.0 MiB           1           if edge_norm is not None:
   208  11743.2 MiB      0.2 MiB           1               out = out * edge_norm.view(-1, 1)                                                                                                      209                                                                                                                                                                                            
   210  11743.2 MiB      0.0 MiB           1           return out

It appears that the memory overflow happens specifically during validation because the size of edge_type is large during validation compared to training.
During Training:

Size of edge_type 30000

During Validation:

Size of edge_type 282884