Doodleverse/seg2map

Potential Bug: get_weights_list

Opened this issue · 0 comments

update the get_weights_list to only read a single line from the best_model.txt file.

def get_weights_list(self, model_choice: str = "ENSEMBLE") -> List[str]:
        """Returns a list of the model weights files (.h5) within the weights directory.

        Args:
            model_choice (str, optional): The type of model weights to return.
                Valid choices are 'ENSEMBLE' (default) to return all available
                weights files or 'BEST' to return only the best model weights file.

        Returns:
            list: A list of strings representing the file paths to the model weights
            files in the weights directory.

        Raises:
            FileNotFoundError: If the BEST_MODEL.txt file is not found in the weights directory.
        """
        if model_choice == "ENSEMBLE":
            weights_list = glob(os.path.join(self.weights_directory, "*.h5"))
            return weights_list
        elif model_choice == "BEST":
            # read model name (fullmodel.h5) from BEST_MODEL.txt
            with open(os.path.join(self.weights_directory, "BEST_MODEL.txt")) as f:
                model_name = f.readline()
            # remove any leading or trailing whitespace and newline characters
            model_name = model_name.strip()
            weights_list = [os.path.join(self.weights_directory, model_name)]
            return weights_list