LUMIA-Group/rasat

Indexing error if max_train_samples is not 7000

tonyzhao6 opened this issue · 0 comments

If max_train_samples is not 7000, then this line will throw an IndexError.

It seems that we should always go to the else branch since it will always loop over exactly the length of the dataset.

train_input_ids = [dataset[i]['input_ids'] for i in range(7000)]