Callback Error for example pytorch_lightning_simple.py
leoguen opened this issue · 5 comments
When trying to run the example optuna-examples/pytorch_lightning_simple.py I get the runtime error: RuntimeError: The on_init_start
callback hook was deprecated in v1.6 and is no longer supported as of v1.8.
Environment
- Optuna version: 3.1.0
- Python version: 3.8.10
- OS: Ubuntu 22.04
- (Optional) Other libraries and their versions:
Error messages, stack traces, or logs
RuntimeError: The on_init_start
callback hook was deprecated in v1.6 and is no longer supported as of v1.8.
Additional context (optional)
If I comment out the callback part the code runs without problems. But this eliminates the pruning function which is quite important for the example. I could sadly not find a working example, so cannot really suggest a fix for lightning.
trainer = pl.Trainer(
logger=True,
limit_val_batches=PERCENT_VALID_EXAMPLES,
enable_checkpointing=False,
max_epochs=EPOCHS,
gpus=1 if torch.cuda.is_available() else None,
#callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")],
)
Hi,
I understand that for the example the requirements are set as lightning being between 1.5 and 1.6. As someone who is trying to build from this example, is there another common way to implement the pruning with the recent lightning version?
Cheers,
Leo
For the distributed training (: multi-processes), no. For single-process training, the old ver. callback: https://github.com/optuna/optuna/blob/release-v2.10.1/optuna/integration/pytorch_lightning.py might work fine, but I've not tested it.
The shared code had a minor issue due to PyTorch-lighting default sanity check value. The following callback would be okay.
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
class PyTorchLightningPruningCallback(Callback):
"""PyTorch Lightning callback to prune unpromising trials.
See `the example <https://github.com/optuna/optuna-examples/blob/
main/pytorch/pytorch_lightning_simple.py>`__
if you want to add a pruning callback which observes accuracy.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` or
``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
``pytorch_lightning.LightningModule.training_step`` or
``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
how this dictionary is formatted.
"""
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
super().__init__()
self._trial = trial
self.monitor = monitor
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
# When the trainer calls `on_validation_end` for sanity check,
# do not call `trial.report` to avoid calling `trial.report` multiple times
# at epoch 0. The related page is
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
if trainer.sanity_checking:
return
epoch = pl_module.current_epoch
current_score = trainer.callback_metrics.get(self.monitor)
if current_score is None:
message = (
"The metric '{}' is not in the evaluation logs for pruning. "
"Please make sure you set the correct metric name.".format(self.monitor)
)
warnings.warn(message)
return
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)
This worked for me, great thank you!