The `plot()` cli command doesn't work
AdamHillier opened this issue · 1 comments
To reproduce: clone the research template using cookie cutter, install the requirements, and then run name_of_project plot cifar10
.
This results in the following error:
zookeeper.registry.PreprocessNotFoundError: No preprocessing functions registered for dataset cifar10.
which is definitely an error because there are pre-processing functions defined for cifar10
. This happens for every dataset you try. I think the codepaths which register the pre-processing functions aren't being run for some reason.
This is clearly a little-used cli command so I'm not sure how long it hasn't worked. It is odd to me that the plot
cli command is defined in zookeeper/cli
whereas other commands such as train
and netron
are defined in name_of_project/train.py
-- I suspect that it is this that causes the pre-processing functions not to be registered, but I don't know for sure.
I was able to get the command to work by adding the following to train.py
in the project:
@cli.command()
@click.argument("dataset", type=str)
@click.option(
"--preprocess-fn", default="default", help="Function used to preprocess dataset."
)
@click.option("--data-dir", type=str, help="Directory with training data.")
@click.option(
"--output-prefix",
default=os.path.join(os.path.expanduser("~/zookeeper-logs"), "plots"),
help="Directory prefix used to save plots",
)
@click.option(
"--format", default="pdf", type=click.Choice(["png", "pdf", "ps", "eps", "svg"])
)
def plot(dataset, preprocess_fn, data_dir, output_prefix, format):
"""Plot data examples."""
from pathlib import Path
from zookeeper import registry, data_vis
utils.prepare_registry()
output_dir = Path(output_prefix).joinpath(dataset, preprocess_fn)
output_dir.mkdir(parents=True, exist_ok=True)
print(output_dir)
set = registry.get_dataset(dataset, preprocess_fn, data_dir=data_dir)
figs = data_vis.plot_all_examples(set)
for fig, filename in zip(figs, ("raw", "train", "eval")):
fig.savefig(f"{output_dir.joinpath(filename).absolute()}.{format}")
The is exactly the same as the plot function that is currently defined in zookeeper/cli.py
but with the addition of the line utils.prepare_registry()
before loading the dataset -- a reference to this function in research_project_template
.
I propose removing the plot()
cli command from zookeeper/cli.py
and instead moving it into the research_project_template
as described above.