Problem with the "Preparing your own data" tutorial
Closed this issue · 5 comments
I am getting the following error at the end of the second part of the first tutorial (https://github.com/atomistic-machine-learning/schnetpack/blob/master/examples/tutorials/tutorial_01_preparing_data.ipynb):
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[13], line 20
4 custom_data = spk.data.AtomsDataModule(
5 './new_dataset.db',
6 batch_size=10,
(...)
17 pin_memory=True, # set to false, when not using a GPU
18 )
19 custom_data.prepare_data()
---> 20 custom_data.setup()
File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/datamodule.py:198, in AtomsDataModule.setup(self, stage)
196 self._val_dataset = self.dataset.subset(self.val_idx)
197 self._test_dataset = self.dataset.subset(self.test_idx)
--> 198 self._setup_transforms()
File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/datamodule.py:338, in AtomsDataModule._setup_transforms(self)
336 def _setup_transforms(self):
337 for t in self.train_transforms:
--> 338 t.datamodule(self)
339 for t in self.val_transforms:
340 t.datamodule(self)
File ~/anaconda3/lib/python3.9/site-packages/schnetpack/transform/atomistic.py:126, in RemoveOffsets.datamodule(self, _datamodule)
123 self.atomref = atrefs[self._property].detach()
125 if self.remove_mean and not self._mean_initialized:
--> 126 stats = _datamodule.get_stats(
127 self._property, self.is_extensive, self.remove_atomrefs
128 )
129 self.mean = stats[0].detach()
File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/datamodule.py:354, in AtomsDataModule.get_stats(self, property, divide_by_atoms, remove_atomref)
351 if key in self._stats:
352 return self._stats[key]
--> 354 stats = calculate_stats(
355 self.train_dataloader(),
356 divide_by_atoms={property: divide_by_atoms},
357 atomref=self.train_dataset.atomrefs if remove_atomref else None,
358 )[property]
359 self._stats[key] = stats
360 return stats
File ~/anaconda3/lib/python3.9/site-packages/schnetpack/data/stats.py:44, in calculate_stats(dataloader, divide_by_atoms, atomref)
41 mean = torch.zeros_like(norm_mask)
42 M2 = torch.zeros_like(norm_mask)
---> 44 for props in tqdm(dataloader):
45 sample_values = []
46 for p in property_names:
File ~/anaconda3/lib/python3.9/site-packages/tqdm/std.py:1182, in tqdm.__iter__(self)
1179 time = self._time
1181 try:
-> 1182 for obj in iterable:
1183 yield obj
1184 # Update and possibly print the progressbar.
1185 # Note: does not call self.update(1) for speed optimisation.
File ~/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File ~/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1345, in _MultiProcessingDataLoaderIter._next_data(self)
1343 else:
1344 del self._task_info[idx]
-> 1345 return self._process_data(data)
File ~/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1371, in _MultiProcessingDataLoaderIter._process_data(self, data)
1369 self._try_put_index()
1370 if isinstance(data, ExceptionWrapper):
-> 1371 data.reraise()
1372 return data
File ~/anaconda3/lib/python3.9/site-packages/torch/_utils.py:694, in ExceptionWrapper.reraise(self)
690 except TypeError:
691 # If the exception takes multiple arguments, don't try to
692 # instantiate since we don't know how to
693 raise RuntimeError(msg) from None
--> 694 raise exception
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/investigator/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/investigator/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/investigator/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/investigator/anaconda3/lib/python3.9/site-packages/schnetpack/data/atoms.py", line 269, in __getitem__
props = self._get_properties(
File "/home/investigator/anaconda3/lib/python3.9/site-packages/schnetpack/data/atoms.py", line 339, in _get_properties
row = conn.get(idx + 1)
File "/home/investigator/anaconda3/lib/python3.9/site-packages/ase/db/core.py", line 432, in get
raise KeyError('no match')
KeyError: 'no match'
The notebook has not been altered, the version of SchNetPack used corresponds to the latest one in the repository and the ASE version is 3.22.1.
Dear @mariofp77 ,
I cannot reproduce your error. Could you please try again running the notebook after deleting all the remaining files associated with the custom dataset in the tutorials directory?
md17_uracil.npz
new_dataset.db
split.npz
splitting.lock
Best,
Jonas
Dear @jnsLs,
Many thanks for your response.
I have tried removing the files you mention and then rerunning again the notebook and I get the same error again.
Best,
Mario
I think I found the cause of the error.
custom_data = spk.data.AtomsDataModule(...)` loads the split file which was defined for the qm9 dataset in one of the previous cells.
The split file of qm9 contains 110k training indices (the uracil data does not contain as many). Hence for one of the larger indices we get an error.
Could you please try to verify this by running the notebook again and deleting the split file before running the very last cell?
Yes, doing that it works.
Many thanks!
Thank you, Mario.
We will adapt the tutorial soon. To avoid this error.
Best, Jonas