ThilinaRajapakse/simpletransformers

save_steps not working , checkpoint getting generated on every epoch morethan once

guruprasaad123 opened this issue · 2 comments

Describe the bug
we are trying to save the model's checkpoint not too frequent cause we are running low on storage. So, our idea is that to save the model's checkpoint every 4/5 epochs. Kindly check the below code used to accomplish this mentioned task.

import math
import os
SAVE_EVERY_N_EPOCHS = 4
train_batch_size = 8
steps_per_epoch = math.floor(len(train_df) / train_batch_size)
save_steps = steps_per_epoch * SAVE_EVERY_N_EPOCHS

print(f'save_steps : {save_steps}')

args = {
    "output_dir": os.path.join(output_dir,"outputs"),
    "cache_dir": os.path.join(output_dir, "cache_dir"),
    "fp16": True,
    "train_batch_size": train_batch_size,
    "num_train_epochs": 16,

    "save_steps": save_steps,
    "save_model_every_epoch": False,
    "overwrite_output_dir": True,
    "reprocess_input_data": False,
    "evaluate_during_training": True,
    "evaluate_during_training_verbose":True,

    "process_count": cpu_count() - 2 if cpu_count() > 2 else 1,
    "n_gpu": 1,
}

# Optional model configuration
model_args = ClassificationArgs(**args)

# Create a ClassificationModel
model = ClassificationModel(
    'bert',
    'bert-base-uncased',
    num_labels=num_labels,
    args=model_args
) 

# Train the model
model.train_model(train_df,eval_df=test_df)

To Reproduce
Steps to reproduce the behavior:

  1. copy the code from above to a file - train.py
  2. use this command python3 train.py to run the code

Expected behavior
we were expecting that the process would generate checkpoint only after the 4th epoch, 8th epoch etc. But its generating checkpoints on every epoch and that too every 2000 steps.

Screenshots
the process generates these many checkpoint like below :

outputs
|----checkpoint-4000
|----checkpoint-10000
|----best_model
|----checkpoint-5063-epoch-1
|----checkpoint-6000
|----checkpoint-10126-epoch-2
|----checkpoint-8000
|----checkpoint-2000
|----eval_results.txt
|----training_progress_scores.csv

Desktop (please complete the following information):

  • linux

Additional context
None

@guruprasaad123 I am wondering whether this is correct ?

SAVE_EVERY_N_EPOCHS = 4
train_batch_size = 8
steps_per_epoch = math.floor(len(train_df) / train_batch_size)
save_steps = steps_per_epoch * SAVE_EVERY_N_EPOCHS

Basically save_steps moreless equal to len(train_df)/2 right ?

So if you do the calculation with your len(train_df) and if you notice the logic in the library,
image
Seems the behavior could be correct right ?

Above screenshot is from classification_model.py

@DamithDR i believe not , because i have taken the code directly from the official documentation itself and produced the results with this code below :

import math
SAVE_EVERY_N_EPOCHS = 4
train_batch_size = 8
steps_per_epoch = math.floor(len(train_df) / SAVE_EVERY_N_EPOCHS) # train_df -> 40500
save_steps = steps_per_epoch * SAVE_EVERY_N_EPOCHS # save_steps -> 40400
print(f'save_steps : {save_steps}')

args = {
    "output_dir": os.path.join(output_dir,"outputs"),
    "cache_dir": os.path.join(output_dir, "cache_dir"),
    "fp16": True,
    "train_batch_size": train_batch_size,
    "num_train_epochs": 16,

    "save_steps": save_steps,
    "save_model_every_epoch": False,
    "overwrite_output_dir": True,
    "reprocess_input_data": False,
    "evaluate_during_training": True,
    "evaluate_during_training_verbose":True,

    "process_count": cpu_count() - 2 if cpu_count() > 2 else 1,
    "n_gpu": 1,
}

# Optional model configuration
model_args = ClassificationArgs(**args)

# Create a ClassificationModel
model = ClassificationModel(
    'bert',
    'bert-base-uncased',
    num_labels=num_labels,
    args=model_args
) 

# Train the model
model.train_model(train_df,eval_df=test_df)

Source

source

simpletransformers | Tips and Tricks

Results :

checkpoints

training

Conclusion

As i have ran the training script for 4 epochs , its expected to create checkpoint after the 4th epoch only which isnt the case here. Please let me know if i am wrong