ganslate-team/ganslate

Simplified Inference Architecture

Closed this issue · 1 comments

Usecases:

  1. The Inferer can be called as Validator during training

  2. A separate inference config might not be necessary as most of the parameters are defined in train_config.yaml. Needed parameters for inference should be made mandatory as CLI arguments or fed through config. However, a mandatory separate config should not be needed. A design decision on what parameters need to provided need to be made.

    if cli.config:
        inference_conf = OmegaConf.load(cli.pop("config"))
        inference_conf = OmegaConf.merge(inference_conf, cli)
    else:
        inference_conf = cli

    # Fetch the config that was used during training of this specific run
    train_conf = Path(inference_conf.logging.checkpoint_dir) / "training_config.yaml"

Maybe conf here points to train_conf itself and separate inference_conf and train_conf are not needed

Proposed new structure:

def build_inference_conf():
    # Load the inference configuration
    cli = OmegaConf.from_cli()
    conf = OmegaConf.load(cli.pop("config"))
    
    # Inference-time defaults
    inference_defaults = get_inference_defaults(conf)

    # Copy the run-specific options that are important for inference
    train_to_inference_options = ["project_dir", "gan", "generator", 
                                  "use_cuda", "mixed_precision", "opt_level"]

    conf = OmegaConf.masked_copy(conf, train_to_inference_options)

    # Merge conf with inference_defaults and then with cli before init
    conf = OmegaConf.merge(conf, inference_defaults, cli)
    return init_config(conf, InferenceConfig)



def get_inference_defaults(conf):
    inference_defaults = f"""
    batch_size: 1
    dataset: 
        shuffle: False

    gan:
        is_train: False
    
    logging:
        checkpoint_dir: {conf.logging.checkpoint_dir}

    """

    return OmegaConf.create(inference_defaults)

This will be called using

python tools/infer.py config=checkpoints/cbct_ex3/training_config.yaml logging.inference_dir=results load_checkpoint.iter=40000 dataset.name="CBCTtoCTInferenceDataset" dataset.root=/home/rt/workspace_suraj/s.pai/test_resampled sliding_window.window_size="[16,128,128]"

Major changes:

  • config now points to training_config. This avoids any user level confusion as there should only be one single config.

  • a lot more parameters need to be defined via the cli. This needs to be improved as it will lead to better organization