Potential Bug: get_weights_list
Opened this issue · 0 comments
2320sharon commented
update the get_weights_list to only read a single line from the best_model.txt file.
def get_weights_list(self, model_choice: str = "ENSEMBLE") -> List[str]:
"""Returns a list of the model weights files (.h5) within the weights directory.
Args:
model_choice (str, optional): The type of model weights to return.
Valid choices are 'ENSEMBLE' (default) to return all available
weights files or 'BEST' to return only the best model weights file.
Returns:
list: A list of strings representing the file paths to the model weights
files in the weights directory.
Raises:
FileNotFoundError: If the BEST_MODEL.txt file is not found in the weights directory.
"""
if model_choice == "ENSEMBLE":
weights_list = glob(os.path.join(self.weights_directory, "*.h5"))
return weights_list
elif model_choice == "BEST":
# read model name (fullmodel.h5) from BEST_MODEL.txt
with open(os.path.join(self.weights_directory, "BEST_MODEL.txt")) as f:
model_name = f.readline()
# remove any leading or trailing whitespace and newline characters
model_name = model_name.strip()
weights_list = [os.path.join(self.weights_directory, model_name)]
return weights_list