ELEKTRONN/elektronn3

Missing assignment to mode in train_unet_neurodata.py

Closed this issue · 1 comments

Around line 99 of train_unet_neurodata.py, this code appears:

if args.jit == 'onsave':
    # Make sure that tracing works
    tracedmodel = torch.jit.trace(model, example_input.to(device))
elif args.jit == 'train':
    if getattr(model, 'checkpointing', False):
        raise NotImplementedError(
            'Traced models with checkpointing currently don\'t '
            'work, so either run with --disable-trace or disable '
            'checkpointing.')
    tracedmodel = torch.jit.trace(model, example_input.to(device))
    model = tracedmodel

Is a model = tracedmodel assignment missing from the if args.jit == 'onsave' TRUE branch? That is, should the code read:

if args.jit == 'onsave':
    # Make sure that tracing works
    tracedmodel = torch.jit.trace(model, example_input.to(device))
    model = tracedmodel
elif args.jit == 'train':
    if getattr(model, 'checkpointing', False):
        raise NotImplementedError(
            'Traced models with checkpointing currently don\'t '
            'work, so either run with --disable-trace or disable '
            'checkpointing.')
    tracedmodel = torch.jit.trace(model, example_input.to(device))
    model = tracedmodel
mdraw commented

Here we only run the tracing to check as early as possible if the tracing is actually possible with the chosen model.
With --jit=onsave, we "Use regular Python model for training, but trace it on-demand for saving training state;" (see https://github.com/ELEKTRONN/elektronn3/blob/35dbc1bc/examples/train_unet_neurodata.py#L46). The rationale for this mode is that JIT-traced models sometimes have issues when they are used for training (mainly due to control flow issues and slightly different behavior due to graph optimizations etc.), so in many cases it's a better idea to train the Python model and only JIT-trace it on demand when saving it to disk.
Replacing the model with the tracedmodel would have the same effect as the --jit=train option.

I will change this line to _ = torch.jit.trace(model, example_input.to(device)) to make it more clear that we don't want to actually use the resulting trace.