graphdeeplearning/benchmarking-gnns

Inconsistent performance by setting dgl_builtin=True in GCNLayer

ycjing opened this issue ยท 15 comments

Hi,

Thank you for the great work! This work is really wonderful. When I try to use GCN model for node classification by running:
python main_SBMs_node_classification.py --dataset SBM_PATTERN --gpu_id 0 --seed 41 --config 'configs/SBMs_node_clustering_GCN_PATTERN_100k.json'
I found that when I set dgl_builtin to false, the test acc is 63.77, which is consistent with the results reported in the paper; however, when I set dgl_builtin to true, the test acc became 85.56.

I do not think this behavior is normal. But I did not figure out why the performances are so different after struggling for some time. I would appreciate it if you could help me. Thank you! Have a nice day!

Best,
Yongcheng

Hi @ycjing, wow, that seems unexpected. In my experience, the performance was unchanged when enabling/disabling dgl_builtin flags, but the inference time of native DGL layers was obviously faster.

Could you provide more details:

  • What DGL version are you using + any other differences from our recommended package versions? Also, what hardware are you using?
  • Did you notice this big bump in performance for any other dataset besides PATTERN?

Hi @chaitjo

Thank you for the response! I appreciate it. I follow the instructions at https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/docs/01_benchmark_installation.md for installation. The pytorch version is 1.3.1. The DGL version is 0.4.2. I use one Tesla V100 GPU.

To confirm and reproduce the results, I just deleted my previous repo and git clone a new one. After preparing the data, changing the code in

self.dgl_builtin = dgl_builtin
to self.dgl_builtin = True, and running python main_SBMs_node_classification.py --dataset SBM_PATTERN --gpu_id 0 --seed 41 --config 'configs/SBMs_node_clustering_GCN_PATTERN_100k.json', the full log is as follows:

(benchmark_gnn) [crassus@gpu-11-71-1-153 benchmarking-gnns]$ python main_SBMs_node_classification.py --dataset SBM_PATTERN --gpu_id 0 --seed 41 --config 'configs/SBMs_node_clustering_GCN_PATTERN_100k.json'
cuda available with GPU: Tesla V100-PCIE-16GB
[I] Loading dataset SBM_PATTERN...
train, test, val sizes : 10000 2000 2000
[I] Finished loading.
[I] Data load time: 23.2470s
MODEL DETAILS:
MODEL/Total parameters: GCN 100923
Training Graphs: 10000
Validation Graphs: 2000
Test Graphs: 2000
Number of Classes: 2
Epoch 8: 1%| | 8/1000 [03:28<6:25:26, 23.31s/it, lr=0.001, test_acc=84.5, time=22.5, train_acc=85.1, train_loss=0.338, val_acc=84.4, val_loss=0.351]Epoch 8: reducing learning rate of group 0 to 5.0000e-04.
Epoch 29: 3%| | 29/1000 [11:48<6:23:19, 23.69s/it, lr=0.0005, test_acc=85.3, time=24, train_acc=85.5, train_loss=0.331, val_acc=85.3, val_loss=0.336]Epoch 29: reducing learning rate of group 0 to 2.5000e-04.
Epoch 40: 4%| | 40/1000 [16:10<6:19:59, 23.75s/it, lr=0.00025, test_acc=85.3, time=24.3, train_acc=85.5, train_loss=0.329, val_acc=85.1, val_loss=0.33Epoch 40: reducing learning rate of group 0 to 1.2500e-04.
Epoch 58: 6%| | 58/1000 [23:21<6:09:48, 23.56s/it, lr=0.000125, test_acc=85.4, time=24.4, train_acc=85.6, train_loss=0.327, val_acc=85.3, val_loss=0.3Epoch 58: reducing learning rate of group 0 to 6.2500e-05.
Epoch 68: 7%| | 68/1000 [27:20<6:13:28, 24.04s/it, lr=6.25e-5, test_acc=85.5, time=23.5, train_acc=85.7, train_loss=0.326, val_acc=85.4, val_loss=0.33Epoch 68: reducing learning rate of group 0 to 3.1250e-05.
Epoch 74: 7%| | 74/1000 [29:43<6:06:02, 23.72s/it, lr=3.13e-5, test_acc=85.5, time=24.2, train_acc=85.6, train_loss=0.326, val_acc=85.4, val_loss=0.33Epoch 74: reducing learning rate of group 0 to 1.5625e-05.
Epoch 80: 8%| | 80/1000 [32:05<6:04:28, 23.77s/it, lr=1.56e-5, test_acc=85.5, time=24, train_acc=85.7, train_loss=0.325, val_acc=85.4, val_loss=0.333]Epoch 80: reducing learning rate of group 0 to 7.8125e-06.
!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.
Epoch 80: 8%| | 80/1000 [32:05<6:09:08, 24.07s/it, lr=1.56e-5, test_acc=85.5, time=24, train_acc=85.7, train_loss=0.325, val_acc=85.4, val_loss=0.333]
Test Accuracy: 85.5290
Train Accuracy: 85.7210
Convergence Time (Epochs): 80.0000
TOTAL TIME TAKEN: 1946.3838s
AVG TIME PER EPOCH: 23.7732s

However, when I set dgl_builtin=False, the results are consistent with those reported in the paper. This is a very weird thing. I have not tried other datasets yet. I will try other datasets and see the results these days.

Thank you again!

Best,
Yongcheng

Thanks @ycjing for bringing this up and I'll investigate into this issue.
Would you mind also posting the training log with dgl_builtin=False?

Hi @yzh119

Thank you for the response! Sure, the training log with dgl_builtin=False is as follows:

(benchmark_gnn) [crassus@gpu-11-71-1-153 benchmarking-gnns]$ python main_SBMs_node_classification.py --dataset SBM_PATTERN --gpu_id 0 --seed 41 --config 'configs/SBMs_node_clustering_GCN_PATTERN_100k.json'
cuda available with GPU: Tesla V100-PCIE-16GB
[I] Loading dataset SBM_PATTERN...
train, test, val sizes : 10000 2000 2000
[I] Finished loading.
[I] Data load time: 23.0067s
MODEL DETAILS:
MODEL/Total parameters: GCN 100923
Training Graphs: 10000
Validation Graphs: 2000
Test Graphs: 2000
Number of Classes: 2
Epoch 32: 3%| | 32/1000 [12:31<6:21:05, 23.62s/it, lr=0.001, test_acc=61.3, time=23.5, train_acc=62.8, train_loss=0.642, val_acc=61.5, val_loss=0.651]python main_SBMs_node_classification.py --dataset SBM_PATTERN --gpu_id 0 --seed 41 --config 'configs/SBMs_node_clustering_GCN_PATTERN_100k.json'^[[C^[OSEpoch 53: 5%|โ– | 53/1000 [21:14<6:19:38, 24.05s/it, lr=0.001, test_acc=62.4, time=23, train_acc=63.7, train_loss=0.635, val_acc=62.6, val_loss=0.643]Epoch 53: reducing learning rate of group 0 to 5.0000e-04.
Epoch 61: 6%| | 61/1000 [24:23<6:10:31, 23.68s/it, lr=0.0005, test_acc=63.4, time=22.5, train_acc=64.3, train_loss=0.629, val_acc=63.6, val_loss=0.637Epoch 61: reducing learning rate of group 0 to 2.5000e-04.
Epoch 76: 8%| | 76/1000 [30:29<6:19:19, 24.63s/it, lr=0.00025, test_acc=63.5, time=24.8, train_acc=64.8, train_loss=0.625, val_acc=63.8, val_loss=0.63Epoch 76: reducing learning rate of group 0 to 1.2500e-04.
Epoch 84: 8%| | 84/1000 [33:45<6:10:56, 24.30s/it, lr=0.000125, test_acc=63.9, time=24.9, train_acc=65, train_loss=0.624, val_acc=64.1, val_loss=0.633Epoch 84: reducing learning rate of group 0 to 6.2500e-05.
Epoch 92: 9%| | 92/1000 [37:00<6:15:06, 24.79s/it, lr=6.25e-5, test_acc=63.3, time=23.5, train_acc=65, train_loss=0.623, val_acc=63.4, val_loss=0.638]Epoch 92: reducing learning rate of group 0 to 3.1250e-05.
Epoch 107: 11%| | 107/1000 [43:06<6:04:17, 24.48s/it, lr=3.13e-5, test_acc=64.1, time=24.3, train_acc=65.2, train_loss=0.622, val_acc=64.2, val_loss=0.Epoch 107: reducing learning rate of group 0 to 1.5625e-05.
Epoch 114: 11%| | 114/1000 [45:58<6:03:19, 24.60s/it, lr=1.56e-5, test_acc=63.8, time=24, train_acc=65.2, train_loss=0.622, val_acc=63.7, val_loss=0.63Epoch 114: reducing learning rate of group 0 to 7.8125e-06.
!! LR SMALLER OR EQUAL TO MIN LR THRESHOLD.
Epoch 114: 11%| | 114/1000 [45:58<5:57:17, 24.20s/it, lr=1.56e-5, test_acc=63.8, time=24, train_acc=65.2, train_loss=0.622, val_acc=63.7, val_loss=0.63
Test Accuracy: 63.7740
Train Accuracy: 64.8701
Convergence Time (Epochs): 114.0000
TOTAL TIME TAKEN: 2777.8065s
AVG TIME PER EPOCH: 23.9812s

I really appreciate your help! Thank you!

Best,
Yongcheng

Hi everybody,

we reimplemented some parts of the pipeline in a separate project using pytorch geometric and are experiencing very similar behavior.

The GCN model implemented in pytorch geometric reaches more than 80% accuracy after only very short training. Most components of our pipeline are the same/similar (also the weighted loss / accuracy computation) besides the framework used.

Generally, this seems to be indicative of some issues when running with dgl_builtin=False, at least on this particular dataset.

Cheers,
Max

@ExpectationMax dgl_builtin=True is the recommended way to implement GNN models.
For GCN, the cuda kernels triggered by dgl with builtin=True should be the similar PyG's, the neighbor's messages are aggregated one by one.
With builtin=False, dgl parallels message-passing by degree bucketing which calls PyTorch's sum reduction function (uses tree reduction).
It's weird to see the two execution order has so much difference.

@ExpectationMax dgl_builtin=True is the recommended way to implement GNN models.
For GCN, the cuda kernels triggered by dgl with builtin=True should be the similar PyG's, the neighbor's messages are aggregated one by one.
With builtin=False, dgl parallels message-passing by degree bucketing which calls PyTorch's sum reduction function (uses tree reduction).
It's weird to see the two execution order has so much difference.

Hi @ycjing , after digging into the codebase, I found the implementation of GCNLayer is indeed different from the DGL's provided GraphConv layer.

if self.dgl_builtin == False:
g.ndata['h'] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
h = g.ndata['h'] # result of graph convolution
else:
h = self.conv(g, feature)

When dgl_builtin=False, it calls update_all with message and reduce functions defined here:

# Sends a message of node feature h
# Equivalent to => return {'m': edges.src['h']}
msg = fn.copy_src(src='h', out='m')
reduce = fn.mean('m', 'h')

This implements a message passing module that averages the received messages. By contrast, DGL's provided GraphConv layer by default uses Kipf's original formulation (i.e., normalize both side by the inverse of root square of in- and out- degree). Please checkout the API doc for details. I think this discrepancy may cause the performance gap.

@chaitjo Do you think it is reasonable to remove the dgl_builtin option and always use DGL's provided GraphConv layer?

Hi @jermainewang

Thank you for the response! I appreciate it. If my understanding is correct, the achieved 85.52% accuracy on PATTERN with dgl_builtin=True is not caused by bugs, but rather the real performance of the GCN model, right?

If this is the case, the provided results in the benchmark paper might need updating, since the performances are so different.

Hi @ExpectationMax
I'm also trying to reproduce the results on PATTERN and CLUSTER using GCN on pytorch-geometric, do you mind sharing the code you used, or just describe the tricks you have used (pre-transforms, transforms, etc)?

In fact, I'm trying to run the experiment on CLUSTER, this dataset has a binary feature of size 6, do you use an nn.Embedding to embed the feature? do you use 4 layers of GCN?

Thank you in advance!

Hi everyone, first of all, thank you for this discussion and apologies for the late response. Indeed, @jermainewang's explanation is correct regarding the performance difference between the built-in GCN layer which normalizes by sqrt of the src and dst degrees vs. our initial implementation performing mean pooling.

After internal discussion, we plan to move to the DGL built-in GCN layer, and plan to update the benchmark leaderboard along with the next release of our paper/repository.

Hi @chaitjo

Thank you for the response! I truly appreciate it. Now I understand the whole thing. Also, thanks again for this great work!

Best,
Yongcheng

Hi @chaitjo
Is there an expected date when the updated paper with the updated leaderboard will be released?

Hi @ycjing
Do you set the activation function when using dgl_bultiin=True? I find that the default activation function is None, which is inconsistent with the original GCN paper.

Hi @jermainewang

Thank you for the response! I appreciate it. If my understanding is correct, the achieved 85.52% accuracy on PATTERN with dgl_builtin=True is not caused by bugs, but rather the real performance of the GCN model, right?

If this is the case, the provided results in the benchmark paper might need updating, since the performances are so different.

Issue fixed in the recent update. Thanks everyone!