syncdoth/RetNet

gradient_checkpointing=True issue in TrainerArgument

lolshuo opened this issue · 1 comments

I'm using the Retnet base config with the following TrainingArguments:

args = TrainingArguments(
output_dir="/content/retnet-xsum",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
evaluation_strategy="steps",
eval_steps=370,
logging_steps=370,
num_train_epochs=10,
weight_decay=0.01,
warmup_steps=10,
lr_scheduler_type="cosine",
learning_rate=6e-4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
dataloader_pin_memory=True,
dataloader_num_workers=4,
# torch_compile=True,
# checkpointing:
save_steps=370,
optim="adafactor",
# optim="adamw_torch",
fp16=True,
# push_to_hub=True,
)

But I'm getting this error when running trainer.train();

image

Recent commit should solve this issue