Wrong model file names for tfgloro
Closed this issue · 3 comments
For tfgloro, the model files are named model.gloronet
. However, Juneberry will claim that it saves a model named model.h5
.
2021-10-28 20:18:00,612 INFO # ====================================================================================================
2021-10-28 20:18:00,612 INFO # ============================================ Finalizing ============================================
2021-10-28 20:18:00,612 INFO # ====================================================================================================
2021-10-28 20:18:00,613 INFO Writing output file: models/gloro/train/output.json
2021-10-28 20:18:00,617 INFO Saving model to 'models/gloro/model.h5'
Additionally, ModelManager.get_model_path()
can output model.pt
because it does not check for the tfgloro
platform (https://github.com/cmu-sei/juneberry/blob/shadow-model-experiment-creation/juneberry/filesystem.py#L510-L515)
def get_model_path(self):
""" :return: The path to a pytorch-compatible model file. """
if self.model_platform in ['tensorflow']:
return self.get_tensorflow_model_path()
else:
return self.get_pytorch_model_path()
Should model file names be more generic? For example, something specified by the config.
Good catch! Yeah, that informational message is wrong. Gloro chooses its own extension and saves it that way.
We are actually in the process of removing the explicit platform flag from the model config, to allow pluggable models like tfgloro. We definitely need to move the model name / extension in to the trainer.
With the extensible trainer, is this still an issue?
Nope, thanks for following up!