bayer-science-for-a-better-life/phc-gnn

assert x_j.size(-1) == edge_attr.size(-1)

wtzhao1631 opened this issue · 3 comments

Hi,

when I run this code via bash run_script_pcba_phm4.sh

I got the following error:

Traceback (most recent call last):
File "train_pcba.py", line 634, in
main()
File "train_pcba.py", line 610, in main
ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run(i, model, args,
File "train_pcba.py", line 327, in do_run
train_metrics = train(epoch=epoch, model=model, device=device, transform=transform,
File "train_pcba.py", line 174, in train
logits = model(data)
File "/root/miniconda3/envs/phc-gnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/ogb/phc-gnn-master/phc/hypercomplex/undirectional/models.py", line 243, in forward
x = self.compute_hidden_layer_embedding(conv=self.convs[i], norm=self.norms[i],
File "/ogb/phc-gnn-master/phc/hypercomplex/undirectional/models.py", line 208, in compute_hidden_layer_embedding
x = conv(x=tmp[0], edge_index=edge_index, edge_attr=edge_attr, size=size)
File "/root/miniconda3/envs/phc-gnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/ogb/phc-gnn-master/phc/hypercomplex/undirectional/messagepassing.py", line 519, in forward
return self.transform(x, edge_index, edge_attr, size)
File "/root/miniconda3/envs/phc-gnn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/phc-gnn-master/phc/hypercomplex/undirectional/messagepassing.py", line 60, in forward
x = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, size=size)
File "/root/miniconda3/envs/phc-gnn/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 236, in propagate
out = self.message(**msg_kwargs)
File "/ogb/phc-gnn-master/phc/hypercomplex/undirectional/messagepassing.py", line 74, in message
assert x_j.size(-1) == edge_attr.size(-1)
AssertionError

Then, I print the shape of both x_j and edge_attr,
x_j : [28302, 512]
edge_attr : [28302, 2048]

what caused this error?

Hi @wtzhao1631 ,

the error was caused by wrong dimensionality setting within the GNN model class.
This should be fixed now in the main branch after PR #10 .

Best regards,
Tuan

It can run, thank you!

Great. I'm going to close the issue. Feel free to reopen, in case you have further questions or remarks.