[Feature] Speed up the resume process of IterBased loop
Opened this issue · 2 comments
YinAoXiong commented
What is the feature?
mmengine/mmengine/runner/loops.py
Line 281 in 2c4516c
现有的恢复方式会对dataloader 迭代 n 个step,当n较大时,速度会很慢,因为执行了实际的数据加载和处理逻辑。 是否有比较好的方式只迭代index,不执行实际的数据加载流程。
- 一种可能的方式是和用户约定一个返回虚拟数据的数据集接口,在恢复时返回虚拟数据,
class Dataset:
def __getitem__(self, index):
if self._skip_flag:
return # Fake data
# 处理数据
return Real data
def skip(self):
self._skip_flag = True
def resume(self):
self._skip_flag = False
# loop中的处理逻辑
if (
hasattr(self.dataloader.dataset, "skip")
and callable(self.dataloader.dataset.skip)
and hasattr(self.dataloader.dataset, "resume")
and callable(self.dataloader.dataset.resume)
):
self.dataloader.dataset.skip()
for _ in range(self._iter):
next(self.dataloader_iterator)
self.dataloader.dataset.resume()
else:
for _ in range(self._iter):
next(self.dataloader_iterator)
- 方式一还是需要用户进行配合,是否可以对dataloader进行操作从而无感知的快速跳过?
iter_batch_sampler = iter(self.dataloader.batch_sampler)
for _ in range(self._iter):
next(iter_batch_sampler)
尝试直接迭代batch_sampler 在worker=0的时候是正常的,在多worker的时候恢复数据顺序出现错误。 像知道有没有什么比较好的解决方案
Any other context?
https://discuss.pytorch.org/t/is-there-any-way-to-skip-steps-in-a-dataloader/123201
https://pytorch.org/data/main/dataloader2.html
Snapshot the state of data-preprocessing pipeline (WIP)
zhouzaida commented
一个最小改动的方案是在迭代前 mock dataset 的__getitem__
方法:
def run(self) -> None:
"""Launch training."""
self.runner.call_hook('before_train')
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
if self._iter > 0:
print_log(
f'Advance dataloader {self._iter} steps to skip data '
'that has already been trained',
logger='current',
level=logging.WARNING)
# mock
old_getitem = self.dataloader_iterator.dataset.__getitem__
self.dataloader_iterator.dataset.__getitem__ = a_new_getitem_method
for _ in range(self._iter):
next(self.dataloader_iterator)
self.dataloader_iterator.dataset.__getitem__ = old_getitem