Lightning-AI/lit-llama

Running into StopIteration with single node multi GPU pretraining against the redpajama sample

cabal-daniel opened this issue · 6 comments

Running the pre-training script against the red pajama sample with 4 A100 80gbs on a single node. Per the advice given in this issue: #301, I reduced max_iters to 1 and this error is still here. Any ideas?

(root) root@C.7378502:~/lit-llama$ python pretrain/redpajama.py --devices 4 --train_data_dir data/lit-redpajama-sample
/root/lit-llama/pretrain/redpajama.py:320: JsonargparseDeprecationWarning: 
    Only use the public API as described in https://jsonargparse.readthedocs.io/en/stable/#api-reference.
    Importing from jsonargparse.cli is kept only to avoid breaking code that does not correctly use the public
    API. It will no longer be available from v5.0.0.

  from jsonargparse.cli import CLI
/root/lib/python3.10/site-packages/lightning/fabric/strategies/fsdp.py:699: `FSDPStrategy(activation_checkpointing=<class 'lit_llama.model.Block'>)` is deprecated, use `FSDPStrategy(activation_checkpointing_policy={<class 'lit_llama.model.Block'>})` instead.
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
/root/lit-llama/pretrain/redpajama.py:320: JsonargparseDeprecationWarning: 
    Only use the public API as described in https://jsonargparse.readthedocs.io/en/stable/#api-reference.
    Importing from jsonargparse.cli is kept only to avoid breaking code that does not correctly use the public
    API. It will no longer be available from v5.0.0.

  from jsonargparse.cli import CLI
/root/lit-llama/pretrain/redpajama.py:320: JsonargparseDeprecationWarning: 
    Only use the public API as described in https://jsonargparse.readthedocs.io/en/stable/#api-reference.
    Importing from jsonargparse.cli is kept only to avoid breaking code that does not correctly use the public
    API. It will no longer be available from v5.0.0.

  from jsonargparse.cli import CLI
/root/lit-llama/pretrain/redpajama.py:320: JsonargparseDeprecationWarning: 
    Only use the public API as described in https://jsonargparse.readthedocs.io/en/stable/#api-reference.
    Importing from jsonargparse.cli is kept only to avoid breaking code that does not correctly use the public
    API. It will no longer be available from v5.0.0.

  from jsonargparse.cli import CLI
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

[rank: 0] Seed set to 1337
[rank: 3] Seed set to 1337
[rank: 1] Seed set to 1337
[rank: 2] Seed set to 1337
Traceback (most recent call last):
  File "/root/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 274, in __iter__
    for item in self._dataloader:
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
    return self._get_iterator()
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 383, in _get_iterator
    return _SingleProcessDataLoaderIter(self)
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
    self._dataset_fetcher = _DatasetKind.create_fetcher(
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/root/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 249, in __iter__
    return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in __init__
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in <listcomp>
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 58, in __iter__
    return PackedDatasetIterator(
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 164, in __init__
    self._load_n_chunks()
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 188, in _load_n_chunks
    raise StopIteration
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/lit-llama/pretrain/redpajama.py", line 322, in <module>
    CLI(main)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/pretrain/redpajama.py", line 122, in main
    train(fabric, model, optimizer, train_dataloader, val_dataloader, gradient_accumulation_iters, devices)
  File "/root/lit-llama/pretrain/redpajama.py", line 146, in train
    for iter_num, train_data in enumerate(train_dataloader):
RuntimeError: generator raised StopIteration
Traceback (most recent call last):
  File "/root/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 274, in __iter__
    for item in self._dataloader:
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
    return self._get_iterator()
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 383, in _get_iterator
    return _SingleProcessDataLoaderIter(self)
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
    self._dataset_fetcher = _DatasetKind.create_fetcher(
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/root/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 249, in __iter__
    return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in __init__
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in <listcomp>
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 58, in __iter__
    return PackedDatasetIterator(
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 164, in __init__
    self._load_n_chunks()
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 188, in _load_n_chunks
    raise StopIteration
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/lit-llama/pretrain/redpajama.py", line 322, in <module>
    CLI(main)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/pretrain/redpajama.py", line 122, in main
    train(fabric, model, optimizer, train_dataloader, val_dataloader, gradient_accumulation_iters, devices)
  File "/root/lit-llama/pretrain/redpajama.py", line 146, in train
    for iter_num, train_data in enumerate(train_dataloader):
RuntimeError: generator raised StopIteration
Traceback (most recent call last):
  File "/root/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 274, in __iter__
    for item in self._dataloader:
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
    return self._get_iterator()
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 383, in _get_iterator
    return _SingleProcessDataLoaderIter(self)
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
    self._dataset_fetcher = _DatasetKind.create_fetcher(
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/root/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 249, in __iter__
    return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in __init__
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in <listcomp>
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 58, in __iter__
    return PackedDatasetIterator(
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 164, in __init__
    self._load_n_chunks()
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 188, in _load_n_chunks
    raise StopIteration
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/lit-llama/pretrain/redpajama.py", line 322, in <module>
    CLI(main)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/pretrain/redpajama.py", line 122, in main
    train(fabric, model, optimizer, train_dataloader, val_dataloader, gradient_accumulation_iters, devices)
  File "/root/lit-llama/pretrain/redpajama.py", line 146, in train
    for iter_num, train_data in enumerate(train_dataloader):
RuntimeError: generator raised StopIteration
Traceback (most recent call last):
  File "/root/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 274, in __iter__
    for item in self._dataloader:
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
    return self._get_iterator()
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 383, in _get_iterator
    return _SingleProcessDataLoaderIter(self)
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
    self._dataset_fetcher = _DatasetKind.create_fetcher(
  File "/root/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/root/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 249, in __iter__
    return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in __init__
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 254, in <listcomp>
    self._datasets = [iter(el) for el in datasets]
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 58, in __iter__
    return PackedDatasetIterator(
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 164, in __init__
    self._load_n_chunks()
  File "/root/lit-llama/lit_llama/packed_dataset.py", line 188, in _load_n_chunks
    raise StopIteration
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/lit-llama/pretrain/redpajama.py", line 322, in <module>
    CLI(main)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI
    return _run_component(components, cfg_init)
  File "/root/lib/python3.10/site-packages/jsonargparse/_cli.py", line 181, in _run_component
    return component(**cfg)
  File "/root/lit-llama/pretrain/redpajama.py", line 122, in main
    train(fabric, model, optimizer, train_dataloader, val_dataloader, gradient_accumulation_iters, devices)
  File "/root/lit-llama/pretrain/redpajama.py", line 146, in train
    for iter_num, train_data in enumerate(train_dataloader):
RuntimeError: generator raised StopIteration

also tried setting micro_batch_size to 1

Does it work on a single GPU? In my experience, when I saw the RuntimeError: generator raised StopIteration error, that was usually because I passed it the wrong data folder.

Yeah actually I found the issue running against the sample was to only use the common crawl data set. Was passing in the right folder. Closing the issue...

How did we end up resolving this? @cabal-daniel @rasbt

Hi, I ran into the same problem with RedPajama-sample datasets. Could you please tell me how did you solve the problem? @cabal-daniel

Hi, I ran into the same problem with RedPajama-sample datasets. Could you please tell me how did you solve the problem? @cabal-daniel

Hi, if you look into the code in lit_llama/packed_dataset.py, you will notice that the sample datasets only have 12 bin files. If you set device_num = 4 (by default), then each device only has 3 bin files. There is an error "if self._n_chunks > len(self._filenames[self._file_idx:]):" , which is 4 > 3 in the default runtime, so there would be an error. If you set number devices = 2, there would be no problem.