After some iteration in pretraining a LLM, IndexError is raised related to dataset chunking
MusulmonLolayev opened this issue · 0 comments
MusulmonLolayev commented
After some iteration, pretraining script suddenly raised IndexError when resume the pretraining from checkpoints. Here some logs:
Epoch 1 | iter 82002 step 41001 | loss train: 1.772, val: n/a | iter time: 243.87 ms (step) remaining time: 395762 days, 18:00:45
Traceback (most recent call last):
File "/home/user/miniconda3/envs/llama2/bin/litgpt", line 8, in <module>
sys.exit(main())
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/__main__.py", line 143, in main
fn(**kwargs)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 121, in setup
main(
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 213, in main
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 265, in fit
for train_data in train_iterator:
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/utils.py", line 382, in __next__
return next(self._iterator)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/lightning/fabric/wrappers.py", line 315, in __iter__
for item in self._dataloader:
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataloader.py", line 598, in __iter__
for batch in super().__iter__():
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
return self._process_data(data)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
data.reraise()
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/_utils.py", line 722, in reraise
raise exception
IndexError: Caught IndexError in DataLoader worker process 3.
Original Traceback (most recent call last):
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
self.dataset_iter = iter(dataset)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataset.py", line 187, in __iter__
self._resume(chunks_replica, intervals_replica)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataset.py", line 246, in _resume
interval = self.worker_intervals[self.chunk_index]
IndexError: list index out of range
Printing print(self.worker_intervals, self.chunk_index)
shows [[251476, 269466], [179572, 197595], [35858, 53769], [53769, 71771]] 4
that there are only 4 items in self.worker_intervals
while self.chunk_index
is 4. It seems there is todo saying Implement elastic sampling where the number of workers, ranks can change.
, so it raising that error.