Runtime Error when enumerating train_loader during training
qcxia20 opened this issue · 2 comments
Hi! I really appreciate your fantastic work and code. And I've reproduced your work through the guidance in README.md
However, I've received this error when executing the training process with train.py
.
Describe the error
Starting training...
0%| | 0/625 [00:00<?, ?it/s][11:18:30] Explicit valence for atom # 0 N, 4, is greater than permitted
0%| | 0/625 [22:56<?, ?it/s]
Traceback (most recent call last):
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
cli.main()
File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
run()
File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/pubhome/qcxia02/git-repo/AI-CONF/GeoMol/train.py", line 74, in <module>
train_loss = train(model, train_loader, optimizer, device, scheduler, logger if args.verbose else None, epoch, writer)
File "/pubhome/qcxia02/git-repo/AI-CONF/GeoMol/model/training.py", line 18, in train
for i, data in tqdm(enumerate(loader), total=len(loader)):
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/tqdm/std.py", line 1178, in __iter__
for obj in iterable:
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
data.reraise()
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 39, in __call__
return self.collate(batch)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 20, in collate
self.exclude_keys)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/batch.py", line 75, in from_data_list
exclude_keys=exclude_keys,
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 86, in collate
increment)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 142, in _collate
data_list, stores, increment)
File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 162, in _collate
value = torch.cat(values, dim=cat_dim or 0)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 19 but got size 21 for tensor number 1 in the list.
To Reproduce
`python train.py --data_dir data/QM9/qm9/ --split_path data/QM9/splits/split0.npy --log_dir ./test_run --n_epochs 250 --dataset qm9`
Expected behavior
Training completed smoothly without error
Environments:
The environments are based on the given environment.yml
file, the version of torch are listed below:
- OS: CentOS Linux release 8.4.2105
- Package Version:
- python=3.7.10
- pytorch=1.10.0=py3.7_cpu_0
- torchaudio=0.10.0=py37_cpu
- torchvision=0.11.1=py37_cpu
- pytorch-cluster=1.5.9=py37_torch_1.10.0_cpu
- pytorch-mutex=1.0=cpu
- pytorch-scatter=2.0.9=py37_torch_1.10.0_cpu
- pytorch-sparse=0.6.12=py37_torch_1.10.0_cpu
- pytorch-spline-conv=1.2.1=py37_torch_1.10.0_cpu
- torch-geometric=2.0.2
Additional context:
This error was raised while dataloader enumeration was called during training, i.e. for i, data in tqdm(enumerate(loader), total=len(loader)):
. The Expected size 19 but got size 21
error during torch.cat comes from that it tried to cat tensor B (2nd molecule) with shape 10x21x3
to tensor A (1st molecule) with shape 10x19x3
at dimension 0 (10), which needs that the other dimension (19/21) should be the same. I'm not sure if this occurrence is normal to you and not sure where to make the modifications (if needed).
Looking forward to your reply :)
Hmm I'm not sure immediately what the issue is, but I have a few suggestions. First, could you try downgrading torch-geometric to 1.6.3? I think that's the primary difference between my local versions and the versions you have listed.
Hmm I'm not sure immediately what the issue is, but I have a few suggestions. First, could you try downgrading torch-geometric to 1.6.3? I think that's the primary difference between my local versions and the versions you have listed.
It solves my issue. Thanks for your reply.
To be noted, pytorch 1.10.0 has some problems with torch-geometric 1.6.3. So I also downgraded pytorch to 1.7.0.