nttcslab-sp/mamba-diarization

pytorch_lightning.Trainer.fit() error

Closed this issue · 3 comments

When I was running the 1_training_from_scratch script, there was no problem with trainer.validate(mamba_diar). However, when the trainer.fit(mamba_diar) statement was running, I encountered the following error message, which showed that there was no key of "metadata" in the task for *[self.prepared_data["metadata"][key] for key in balance]. Have you ever encountered a similar problem? Do you have any suggestions for solving it?

Details

{
"name": "KeyError",
"message": "Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in train__iter

*[self.prepared_data["metadata"][key] for key in balance]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in
*[self.prepared_data["metadata"][key] for key in balance]
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'metadata'
",
"stack": "---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[17], line 1
----> 1 trainer.fit(mamba_diar)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
536 self.state.status = TrainerStatus.RUNNING
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
567 assert self.state.fn is not None
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
577 self.training = False

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:981, in Trainer._run(self, model, ckpt_path)
976 self._signal_connector.register_signal_handlers()
978 # ----------------------------
979 # RUN THE TRAINER
980 # ----------------------------
--> 981 results = self._run_stage()
983 # ----------------------------
984 # POST-Training CLEAN UP
985 # ----------------------------
986 log.debug(f"{self.class.name}: trainer tearing down")

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1025, in Trainer._run_stage(self)
1023 self._run_sanity_check()
1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1025 self.fit_loop.run()
1026 return None
1027 raise RuntimeError(f"Unexpected state {self.state}")

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
203 try:
204 self.on_advance_start()
--> 205 self.advance()
206 self.on_advance_end()
207 self._restarting = False

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
361 with self.trainer.profiler.profile("run_training_epoch"):
362 assert self._data_fetcher is not None
--> 363 self.epoch_loop.run(self._data_fetcher)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher)
138 while not self.done:
139 try:
--> 140 self.advance(data_fetcher)
141 self.on_advance_end(data_fetcher)
142 self._restarting = False

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:212, in _TrainingEpochLoop.advance(self, data_fetcher)
210 else:
211 dataloader_iter = None
--> 212 batch, _, __ = next(data_fetcher)
213 # TODO: we should instead use the batch_idx returned by the fetcher, however, that will require saving the
214 # fetcher state so that the batch_idx is correct after restarting
215 batch_idx = self.batch_idx + 1

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:133, in _PrefetchDataFetcher.next(self)
130 self.done = not self.batches
131 elif not self.done:
132 # this will run only when no pre-fetching was done.
--> 133 batch = super().next()
134 else:
135 # the iterator is empty
136 raise StopIteration

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:60, in _DataFetcher.next(self)
58 self._start_profiler()
59 try:
---> 60 batch = next(self.iterator)
61 except StopIteration:
62 self.done = True

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:341, in CombinedLoader.next(self)
339 def next(self) -> _ITERATOR_RETURN:
340 assert self._iterator is not None
--> 341 out = next(self._iterator)
342 if isinstance(self._iterator, _Sequential):
343 return out

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:78, in _MaxSizeCycle.next(self)
76 for i in range(n):
77 try:
---> 78 out[i] = next(self.iterators[i])
79 except StopIteration:
80 self._consumed[i] = True

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.next(self)
627 if self._sampler_iter is None:
628 # TODO(pytorch/pytorch#76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1344, in _MultiProcessingDataLoaderIter._next_data(self)
1342 else:
1343 del self._task_info[idx]
-> 1344 return self._process_data(data)

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1370, in _MultiProcessingDataLoaderIter._process_data(self, data)
1368 self._try_put_index()
1369 if isinstance(data, ExceptionWrapper):
-> 1370 data.reraise()
1371 return data

File /home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/_utils.py:706, in ExceptionWrapper.reraise(self)
702 except TypeError:
703 # If the exception takes multiple arguments, don't try to
704 # instantiate since we don't know how to
705 raise RuntimeError(msg) from None
--> 706 raise exception

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/torch/utils/data/utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in train__iter

*[self.prepared_data["metadata"][key] for key in balance]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/data/xxx/anaconda3/envs/mamba/lib/python3.11/site-packages/pyannote/audio/tasks/segmentation/mixins.py", line 187, in
*[self.prepared_data["metadata"][key] for key in balance]
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'metadata'
"
}

Thanks for the issue and detailed logs ! I pushed a PR that should fix this issue if you can install pyannote from git.

You can also ignore this issue by not using balance (do not pass balance to the task/pass None), maybe this is OK for your case. The balance parameter is there to uniformly sample according to a criterion (in the case of this paper, uniformly sample from each dataset).

Actually I realized there has been some breaking changes in latest pyannote versions while, so I'm not 100% confident this PR will work right away, especially with this repository (and I dont have the time to really test things for now). But if you apply the commit from this PR to pyannote 3.1 (git cherry-pick?) it should work. Or again, if you are OK not using it, it's probably the easiest option.

Sorry for the inconvenience !

Thank you for taking the time to reply amidst your busy schedule.

I have tried your two suggestions. Firstly, removing balance=['database'] from the task worked. After removing it, the training was successfully carried out!
Your pull request (PR) was also effective. I don't really know how to quickly apply your PR. The steps I tried were to copy down the corresponding modifications of your task.py and mixins.py in the site-packages/pyannote/audio of anaconda3/envs in the conda environment. After running it again, the normal training could also be achieved!

Thank you very much for your reply! Best regards!