optuna/optuna-examples

Callback Error for example pytorch_lightning_simple.py

leoguen opened this issue · 5 comments

When trying to run the example optuna-examples/pytorch_lightning_simple.py I get the runtime error: RuntimeError: The on_init_start callback hook was deprecated in v1.6 and is no longer supported as of v1.8.

Environment

  • Optuna version: 3.1.0
  • Python version: 3.8.10
  • OS: Ubuntu 22.04
  • (Optional) Other libraries and their versions:

Error messages, stack traces, or logs

RuntimeError: The on_init_start callback hook was deprecated in v1.6 and is no longer supported as of v1.8.

Additional context (optional)

If I comment out the callback part the code runs without problems. But this eliminates the pruning function which is quite important for the example. I could sadly not find a working example, so cannot really suggest a fix for lightning.

    trainer = pl.Trainer(
        logger=True,
        limit_val_batches=PERCENT_VALID_EXAMPLES,
        enable_checkpointing=False,
        max_epochs=EPOCHS,
        gpus=1 if torch.cuda.is_available() else None,
        #callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")],
    )

Hi,
I understand that for the example the requirements are set as lightning being between 1.5 and 1.6. As someone who is trying to build from this example, is there another common way to implement the pruning with the recent lightning version?
Cheers,
Leo

For the distributed training (: multi-processes), no. For single-process training, the old ver. callback: https://github.com/optuna/optuna/blob/release-v2.10.1/optuna/integration/pytorch_lightning.py might work fine, but I've not tested it.

The shared code had a minor issue due to PyTorch-lighting default sanity check value. The following callback would be okay.

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback

class PyTorchLightningPruningCallback(Callback):
    """PyTorch Lightning callback to prune unpromising trials.
    See `the example <https://github.com/optuna/optuna-examples/blob/
    main/pytorch/pytorch_lightning_simple.py>`__
    if you want to add a pruning callback which observes accuracy.
    Args:
        trial:
            A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
            objective function.
        monitor:
            An evaluation metric for pruning, e.g., ``val_loss`` or
            ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
            ``pytorch_lightning.LightningModule.training_step`` or
            ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
            how this dictionary is formatted.
    """

    def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
        super().__init__()

        self._trial = trial
        self.monitor = monitor

    def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
        # When the trainer calls `on_validation_end` for sanity check,
        # do not call `trial.report` to avoid calling `trial.report` multiple times
        # at epoch 0. The related page is
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
        if trainer.sanity_checking:
            return

        epoch = pl_module.current_epoch

        current_score = trainer.callback_metrics.get(self.monitor)
        if current_score is None:
            message = (
                "The metric '{}' is not in the evaluation logs for pruning. "
                "Please make sure you set the correct metric name.".format(self.monitor)
            )
            warnings.warn(message)
            return

        self._trial.report(current_score, step=epoch)
        if self._trial.should_prune():
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)

This worked for me, great thank you!