PattanaikL/GeoMol

Runtime Error when enumerating train_loader during training

qcxia20 opened this issue · 2 comments

Hi! I really appreciate your fantastic work and code. And I've reproduced your work through the guidance in README.md
However, I've received this error when executing the training process with train.py.

Describe the error

Starting training...
  0%|                                                                                                                                                       | 0/625 [00:00<?, ?it/s][11:18:30] Explicit valence for atom # 0 N, 4, is greater than permitted
  0%|                                                                                                                                                       | 0/625 [22:56<?, ?it/s]
Traceback (most recent call last):
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/__main__.py", line 45, in <module>
    cli.main()
  File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
    run()
  File "/pubhome/qcxia02/.vscode-server/extensions/ms-python.python-2021.11.1422169775/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
    runpy.run_path(target_as_str, run_name=compat.force_str("__main__"))
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 263, in run_path
    pkg_name=pkg_name, script_name=fname)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 96, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/pubhome/qcxia02/git-repo/AI-CONF/GeoMol/train.py", line 74, in <module>
    train_loss = train(model, train_loader, optimizer, device, scheduler, logger if args.verbose else None, epoch, writer)
  File "/pubhome/qcxia02/git-repo/AI-CONF/GeoMol/model/training.py", line 18, in train
    for i, data in tqdm(enumerate(loader), total=len(loader)):
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 39, in __call__
    return self.collate(batch)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/loader/dataloader.py", line 20, in collate
    self.exclude_keys)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/batch.py", line 75, in from_data_list
    exclude_keys=exclude_keys,
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 86, in collate
    increment)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 142, in _collate
    data_list, stores, increment)
  File "/pubhome/qcxia02/miniconda3/envs/GeoMol/lib/python3.7/site-packages/torch_geometric/data/collate.py", line 162, in _collate
    value = torch.cat(values, dim=cat_dim or 0)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 19 but got size 21 for tensor number 1 in the list.

To Reproduce

`python train.py --data_dir data/QM9/qm9/ --split_path data/QM9/splits/split0.npy --log_dir ./test_run --n_epochs 250 --dataset qm9`

Expected behavior

Training completed smoothly without error

Environments:

The environments are based on the given environment.yml file, the version of torch are listed below:
- OS: CentOS Linux release 8.4.2105
- Package Version:

  • python=3.7.10
  • pytorch=1.10.0=py3.7_cpu_0
  • torchaudio=0.10.0=py37_cpu
  • torchvision=0.11.1=py37_cpu
  • pytorch-cluster=1.5.9=py37_torch_1.10.0_cpu
  • pytorch-mutex=1.0=cpu
  • pytorch-scatter=2.0.9=py37_torch_1.10.0_cpu
  • pytorch-sparse=0.6.12=py37_torch_1.10.0_cpu
  • pytorch-spline-conv=1.2.1=py37_torch_1.10.0_cpu
  • torch-geometric=2.0.2

Additional context:

This error was raised while dataloader enumeration was called during training, i.e. for i, data in tqdm(enumerate(loader), total=len(loader)):. The Expected size 19 but got size 21 error during torch.cat comes from that it tried to cat tensor B (2nd molecule) with shape 10x21x3 to tensor A (1st molecule) with shape 10x19x3 at dimension 0 (10), which needs that the other dimension (19/21) should be the same. I'm not sure if this occurrence is normal to you and not sure where to make the modifications (if needed).

Looking forward to your reply :)

Hmm I'm not sure immediately what the issue is, but I have a few suggestions. First, could you try downgrading torch-geometric to 1.6.3? I think that's the primary difference between my local versions and the versions you have listed.

Hmm I'm not sure immediately what the issue is, but I have a few suggestions. First, could you try downgrading torch-geometric to 1.6.3? I think that's the primary difference between my local versions and the versions you have listed.

It solves my issue. Thanks for your reply.
To be noted, pytorch 1.10.0 has some problems with torch-geometric 1.6.3. So I also downgraded pytorch to 1.7.0.