open-mmlab/mmengine

[Feature] Speed up the resume process of IterBased loop

Opened this issue · 2 comments

What is the feature?

next(self.dataloader_iterator)

现有的恢复方式会对dataloader 迭代 n 个step,当n较大时,速度会很慢,因为执行了实际的数据加载和处理逻辑。 是否有比较好的方式只迭代index,不执行实际的数据加载流程。

  1. 一种可能的方式是和用户约定一个返回虚拟数据的数据集接口,在恢复时返回虚拟数据,
class Dataset:

    def __getitem__(self, index):
        if self._skip_flag:
            return # Fake data
        # 处理数据
        return Real data

    def skip(self):
        self._skip_flag = True

    def resume(self):
        self._skip_flag = False



# loop中的处理逻辑
            if (
                hasattr(self.dataloader.dataset, "skip")
                and callable(self.dataloader.dataset.skip)
                and hasattr(self.dataloader.dataset, "resume")
                and callable(self.dataloader.dataset.resume)
            ):
                self.dataloader.dataset.skip()
                for _ in range(self._iter):
                    next(self.dataloader_iterator)
                self.dataloader.dataset.resume()
            else:
                for _ in range(self._iter):
                    next(self.dataloader_iterator)
  1. 方式一还是需要用户进行配合,是否可以对dataloader进行操作从而无感知的快速跳过?
                iter_batch_sampler = iter(self.dataloader.batch_sampler)
                for _ in range(self._iter):
                    next(iter_batch_sampler)

尝试直接迭代batch_sampler 在worker=0的时候是正常的,在多worker的时候恢复数据顺序出现错误。 像知道有没有什么比较好的解决方案

Any other context?

https://discuss.pytorch.org/t/is-there-any-way-to-skip-steps-in-a-dataloader/123201
https://pytorch.org/data/main/dataloader2.html

Snapshot the state of data-preprocessing pipeline (WIP)

一个最小改动的方案是在迭代前 mock dataset 的__getitem__方法:

    def run(self) -> None:
        """Launch training."""
        self.runner.call_hook('before_train')
        # In iteration-based training loop, we treat the whole training process
        # as a big epoch and execute the corresponding hook.
        self.runner.call_hook('before_train_epoch')
        if self._iter > 0:
            print_log(
                f'Advance dataloader {self._iter} steps to skip data '
                'that has already been trained',
                logger='current',
                level=logging.WARNING)
            # mock
            old_getitem = self.dataloader_iterator.dataset.__getitem__
            self.dataloader_iterator.dataset.__getitem__ = a_new_getitem_method
            for _ in range(self._iter):
                next(self.dataloader_iterator)
            self.dataloader_iterator.dataset.__getitem__ = old_getitem
chtzs commented

I believe this PR is the cause of the issue: #1471.
While it fixed the resume iteration problem, it also led to slow resume speed. A suitable solution would be to call the _next_index() method of the DataLoader's built-in iterator to skip a batch without reading the data.