gradient_checkpointing=True issue in TrainerArgument
lolshuo opened this issue · 1 comments
lolshuo commented
I'm using the Retnet base config with the following TrainingArguments:
args = TrainingArguments(
output_dir="/content/retnet-xsum",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
evaluation_strategy="steps",
eval_steps=370,
logging_steps=370,
num_train_epochs=10,
weight_decay=0.01,
warmup_steps=10,
lr_scheduler_type="cosine",
learning_rate=6e-4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
dataloader_pin_memory=True,
dataloader_num_workers=4,
# torch_compile=True,
# checkpointing:
save_steps=370,
optim="adafactor",
# optim="adamw_torch",
fp16=True,
# push_to_hub=True,
)
But I'm getting this error when running trainer.train();
syncdoth commented
Recent commit should solve this issue