SFT may crash if input data exceeds the context length
odelalleau opened this issue · 1 comments
odelalleau commented
Describe the bug
One can run into a crash
"/opt/NeMo/nemo/collections/nlp/modules/common/megatron/language_model.py", line 352, in forward
embeddings = words_embeddings + position_embeddings
RuntimeError: CUDA error: device-side assert triggered
e.g. when running a model with 512 context length.
This may or may not be a bug, but we should ideally:
- have a clearer error message
- mention it in the tutorial and/or update the Dolly15k conversion script to filter out too long samples
ryxli commented
Encountering a possibly related issue:
ValueError: Caught ValueError in DataLoader worker process 0.
File "nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py", line 459, in collate_fn
contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id))
ValueError: expected sequence of length 4096 at dim 1 (got 4870)
This seems to be caused when customizing the prompt_template
and truncation_fields
, if any of the truncation fields are not long enough to truncate, then there is a chance that the context_ids becomes greater than the max_seq_len and it can hard fail in Dataset collate_fn:
contexts = torch.LongTensor(self._collate_item(contexts, max_length=max_length, pad_id=self.tokenizer.eos_id))