Question about model save strategy
svjack opened this issue · 0 comments
svjack commented
The model save strategy in project defined by
for i, l in enumerate(avg_loss):
print(f"Unet {i} avg validation loss: ", l)
if l < best_loss[i]:
best_loss[i] = l
with training_dir("state_dicts"):
model_path = f"unet_{i}_state_{timestamp}.pth"
torch.save(imagen.unets[i].state_dict(), model_path)
This may make an inconsistency when unet_0's loss decrease and unet_1 not
i.e. the final saved unet model come from two different model in different training stage.
Does this not have bad effect ?