larq/zookeeper

The `plot()` cli command doesn't work

AdamHillier opened this issue · 1 comments

To reproduce: clone the research template using cookie cutter, install the requirements, and then run name_of_project plot cifar10.

This results in the following error:

zookeeper.registry.PreprocessNotFoundError: No preprocessing functions registered for dataset cifar10.

which is definitely an error because there are pre-processing functions defined for cifar10. This happens for every dataset you try. I think the codepaths which register the pre-processing functions aren't being run for some reason.

This is clearly a little-used cli command so I'm not sure how long it hasn't worked. It is odd to me that the plot cli command is defined in zookeeper/cli whereas other commands such as train and netron are defined in name_of_project/train.py -- I suspect that it is this that causes the pre-processing functions not to be registered, but I don't know for sure.

I was able to get the command to work by adding the following to train.py in the project:

@cli.command()
@click.argument("dataset", type=str)
@click.option(
    "--preprocess-fn", default="default", help="Function used to preprocess dataset."
)
@click.option("--data-dir", type=str, help="Directory with training data.")
@click.option(
    "--output-prefix",
    default=os.path.join(os.path.expanduser("~/zookeeper-logs"), "plots"),
    help="Directory prefix used to save plots",
)
@click.option(
    "--format", default="pdf", type=click.Choice(["png", "pdf", "ps", "eps", "svg"])
)
def plot(dataset, preprocess_fn, data_dir, output_prefix, format):
    """Plot data examples."""
    from pathlib import Path
    from zookeeper import registry, data_vis

    utils.prepare_registry()

    output_dir = Path(output_prefix).joinpath(dataset, preprocess_fn)
    output_dir.mkdir(parents=True, exist_ok=True)
    print(output_dir)

    set = registry.get_dataset(dataset, preprocess_fn, data_dir=data_dir)
    figs = data_vis.plot_all_examples(set)
    for fig, filename in zip(figs, ("raw", "train", "eval")):
        fig.savefig(f"{output_dir.joinpath(filename).absolute()}.{format}")

The is exactly the same as the plot function that is currently defined in zookeeper/cli.py but with the addition of the line utils.prepare_registry() before loading the dataset -- a reference to this function in research_project_template.

I propose removing the plot() cli command from zookeeper/cli.py and instead moving it into the research_project_template as described above.