Training with subgraph augmentation ?
Closed this issue · 1 comments
Hi ! Thank you for your work! However on TUDatasets, the code seems to failed when using subgraphs as augmentation :
Traceback (most recent call last):
File "rgcl.py", line 172, in <module>
for data in dataloader:
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/opt/conda/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
raise exception
AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/user/.local/lib/python3.7/site-packages/torch_geometric/data/dataset.py", line 239, in __getitem__
data = self.get(self.indices()[idx])
File "/workspace/external_src/RGCL/unsupervised_TU/aug.py", line 251, in get
assert False
AssertionError
In the get method in aug.py only 'drop_ra' seems to be implemented. drop_ra is equivalent to node dropping ? and i'm not sure to understant the role of n = np.random.randint(2) m = np.random.randint(2)
in the get methods. Thanks in advance for your answer !
Thanks for your comments!
(1) Subgraph function appends node into candidate set (nodes to be preserved) one-by-one. In TUdatasets, some graphs contain thousands of nodes, making this function dramatically inefficient. Thus we suggest not to implement it in these datasets.
(2) n = np.random.randint(2) m = np.random.randint(2). Thanks for pointing out. These lines, which are for debug when I'm doign the coding work, are unnecessary. I've already cleaned this aug.py file now.