AssemblyAI-Community/MinImagen

Question about model save strategy

svjack opened this issue · 0 comments

The model save strategy in project defined by

            for i, l in enumerate(avg_loss):
                print(f"Unet {i} avg validation loss: ", l)
                if l < best_loss[i]:
                    best_loss[i] = l
                    with training_dir("state_dicts"):
                        model_path = f"unet_{i}_state_{timestamp}.pth"
                        torch.save(imagen.unets[i].state_dict(), model_path)

This may make an inconsistency when unet_0's loss decrease and unet_1 not
i.e. the final saved unet model come from two different model in different training stage.
Does this not have bad effect ?