RuntimeError: The `on_init_start` callback hook was deprecated in v1.6 and is no longer supported as of v1.8.
RiccardoRiglietti opened this issue · 1 comments
RiccardoRiglietti commented
Expected behavior
The example below, taken from this github repository, runs correctly:
"""
Optuna example that optimizes multi-layer perceptrons using PyTorch Lightning.
In this example, we optimize the validation accuracy of fashion product recognition using
PyTorch Lightning, and FashionMNIST. We optimize the neural network architecture. As it is too time
consuming to use the whole FashionMNIST dataset, we here use a small subset of it.
You can run this example as follows, pruning can be turned on and off with the `--pruning`
argument.
$ python pytorch_lightning_simple.py [--pruning]
"""
import argparse
import os
from typing import List
from typing import Optional
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from packaging import version
import pytorch_lightning as pl
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import datasets
from torchvision import transforms
if version.parse(pl.__version__) < version.parse("1.0.2"):
raise RuntimeError("PyTorch Lightning>=1.0.2 is required for this example.")
PERCENT_VALID_EXAMPLES = 0.1
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 10
DIR = os.getcwd()
class Net(nn.Module):
def __init__(self, dropout: float, output_dims: List[int]):
super().__init__()
layers: List[nn.Module] = []
input_dim: int = 28 * 28
for output_dim in output_dims:
layers.append(nn.Linear(input_dim, output_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout))
input_dim = output_dim
layers.append(nn.Linear(input_dim, CLASSES))
self.layers: nn.Module = nn.Sequential(*layers)
def forward(self, data: torch.Tensor) -> torch.Tensor:
logits = self.layers(data)
return F.log_softmax(logits, dim=1)
class LightningNet(pl.LightningModule):
def __init__(self, dropout: float, output_dims: List[int]):
super().__init__()
self.model = Net(dropout, output_dims)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.model(data.view(-1, 28 * 28))
def training_step(self, batch, batch_idx: int) -> torch.Tensor:
data, target = batch
output = self(data)
return F.nll_loss(output, target)
def validation_step(self, batch, batch_idx: int) -> None:
data, target = batch
output = self(data)
pred = output.argmax(dim=1, keepdim=True)
accuracy = pred.eq(target.view_as(pred)).float().mean()
self.log("val_acc", accuracy)
self.log("hp_metric", accuracy, on_step=False, on_epoch=True)
def configure_optimizers(self) -> optim.Optimizer:
return optim.Adam(self.model.parameters())
class FashionMNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, batch_size: int):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: Optional[str] = None) -> None:
self.mnist_test = datasets.FashionMNIST(
self.data_dir, train=False, download=True, transform=transforms.ToTensor()
)
mnist_full = datasets.FashionMNIST(
self.data_dir, train=True, download=True, transform=transforms.ToTensor()
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.mnist_train, batch_size=self.batch_size, shuffle=True, pin_memory=True
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.mnist_val, batch_size=self.batch_size, shuffle=False, pin_memory=True
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.mnist_test, batch_size=self.batch_size, shuffle=False, pin_memory=True
)
def objective(trial: optuna.trial.Trial) -> float:
# We optimize the number of layers, hidden units in each layer and dropouts.
n_layers = trial.suggest_int("n_layers", 1, 3)
dropout = trial.suggest_float("dropout", 0.2, 0.5)
output_dims = [
trial.suggest_int("n_units_l{}".format(i), 4, 128, log=True) for i in range(n_layers)
]
model = LightningNet(dropout, output_dims)
datamodule = FashionMNISTDataModule(data_dir=DIR, batch_size=BATCHSIZE)
trainer = pl.Trainer(
logger=True,
limit_val_batches=PERCENT_VALID_EXAMPLES,
enable_checkpointing=False,
max_epochs=EPOCHS,
gpus=1 if torch.cuda.is_available() else None,
callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")],
)
hyperparameters = dict(n_layers=n_layers, dropout=dropout, output_dims=output_dims)
trainer.logger.log_hyperparams(hyperparameters)
trainer.fit(model, datamodule=datamodule)
return trainer.callback_metrics["val_acc"].item()
pruning = True
pruner: optuna.pruners.BasePruner = (
optuna.pruners.MedianPruner() if pruning else optuna.pruners.NopPruner()
)
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
Environment
requirements.txt
absl-py==1.3.0
aeppl==0.0.33
aesara==2.7.9
aiohttp==3.8.3
aiosignal==1.3.1
alabaster==0.7.12
albumentations==1.2.1
alembic==1.9.1
altair==4.2.0
appdirs==1.4.4
arviz==0.12.1
astor==0.8.1
astropy==4.3.1
astunparse==1.6.3
async-timeout==4.0.2
atari-py==0.2.9
atomicwrites==1.4.1
attrs==22.2.0
audioread==3.0.0
autograd==1.5
autopage==0.5.1
Babel==2.11.0
backcall==0.2.0
beautifulsoup4==4.6.3
bleach==5.0.1
blis==0.7.9
bokeh==2.3.3
branca==0.6.0
bs4==0.0.1
CacheControl==0.12.11
cachetools==5.2.0
catalogue==2.0.8
certifi==2022.12.7
cffi==1.15.1
cftime==1.6.2
chardet==4.0.0
charset-normalizer==2.1.1
click==7.1.2
cliff==4.1.0
clikit==0.6.2
cloudpickle==1.5.0
cmaes==0.9.0
cmake==3.22.6
cmd2==2.4.2
cmdstanpy==1.0.8
colorcet==3.0.1
colorlog==6.7.0
colorlover==0.3.0
community==1.0.0b1
confection==0.0.3
cons==0.4.5
contextlib2==0.5.5
convertdate==2.4.0
crashtest==0.3.1
crcmod==1.7
cufflinks==0.17.3
cvxopt==1.3.0
cvxpy==1.2.2
cycler==0.11.0
cymem==2.0.7
Cython==0.29.32
daft==0.0.4
dask==2022.2.1
datascience==0.17.5
db-dtypes==1.0.5
debugpy==1.0.0
decorator==4.4.2
defusedxml==0.7.1
descartes==1.1.0
dill==0.3.6
distributed==2022.2.1
dlib==19.24.0
dm-tree==0.1.8
dnspython==2.2.1
docutils==0.17.1
dopamine-rl==1.0.5
earthengine-api==0.1.335
easydict==1.10
ecos==2.0.11
editdistance==0.5.3
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
entrypoints==0.4
ephem==4.1.4
et-xmlfile==1.1.0
etils==0.9.0
etuples==0.3.8
fa2==0.3.5
fastai==2.7.10
fastcore==1.5.27
fastdownload==0.0.7
fastdtw==0.3.4
fastjsonschema==2.16.2
fastprogress==1.0.3
fastrlock==0.8.1
feather-format==0.4.1
filelock==3.8.2
firebase-admin==5.3.0
fix-yahoo-finance==0.0.22
Flask==1.1.4
flatbuffers==1.12
folium==0.12.1.post1
frozenlist==1.3.3
fsspec==2022.11.0
future==0.16.0
gast==0.4.0
GDAL==2.2.3
gdown==4.4.0
gensim==3.6.0
geographiclib==1.52
geopy==1.17.0
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-api-core==2.11.0
google-api-python-client==2.70.0
google-auth==2.15.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-cloud-bigquery==3.4.1
google-cloud-bigquery-storage==2.17.0
google-cloud-core==2.3.2
google-cloud-datastore==2.11.0
google-cloud-firestore==2.7.3
google-cloud-language==2.6.1
google-cloud-storage==2.7.0
google-cloud-translate==3.8.4
google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz
google-crc32c==1.5.0
google-pasta==0.2.0
google-resumable-media==2.4.0
googleapis-common-protos==1.57.0
googledrivedownloader==0.4
graphviz==0.10.1
greenlet==2.0.1
grpcio==1.51.1
grpcio-status==1.48.2
gspread==3.4.2
gspread-dataframe==3.0.8
gym==0.25.2
gym-notices==0.0.8
h5py==3.1.0
HeapDict==1.0.1
hijri-converter==2.2.4
holidays==0.17.2
holoviews==1.14.9
html5lib==1.0.1
httpimport==0.5.18
httplib2==0.17.4
httpstan==4.6.1
humanize==0.5.1
hyperopt==0.1.2
idna==2.10
imageio==2.9.0
imagesize==1.4.1
imbalanced-learn==0.8.1
imblearn==0.0
imgaug==0.4.0
importlib-metadata==4.13.0
importlib-resources==5.10.1
imutils==0.5.4
inflect==2.1.0
intel-openmp==2023.0.0
intervaltree==2.1.0
ipykernel==5.3.4
ipython==7.9.0
ipython-genutils==0.2.0
ipython-sql==0.3.9
ipywidgets==7.7.1
itsdangerous==1.1.0
jax==0.3.25
jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.25+cuda11.cudnn805-cp38-cp38-manylinux2014_x86_64.whl
jieba==0.42.1
Jinja2==2.11.3
joblib==1.2.0
jpeg4py==0.1.4
jsonschema==4.3.3
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter_core==5.1.1
jupyterlab-widgets==3.0.5
kaggle==1.5.12
kapre==0.3.7
keras==2.9.0
Keras-Preprocessing==1.1.2
keras-vis==0.4.1
kiwisolver==1.4.4
korean-lunar-calendar==0.3.1
langcodes==3.3.0
libclang==14.0.6
librosa==0.8.1
lightgbm==2.2.3
lightning-utilities==0.5.0
llvmlite==0.39.1
lmdb==0.99
locket==1.0.0
logical-unification==0.4.5
LunarCalendar==0.0.9
lxml==4.9.2
Mako==1.2.4
Markdown==3.4.1
MarkupSafe==2.0.1
marshmallow==3.19.0
matplotlib==3.2.2
matplotlib-venn==0.11.7
miniKanren==1.0.3
missingno==0.5.1
mistune==0.8.4
mizani==0.7.3
mkl==2019.0
mlxtend==0.14.0
more-itertools==9.0.0
moviepy==0.2.3.5
mpmath==1.2.1
msgpack==1.0.4
multidict==6.0.3
multipledispatch==0.6.0
multitasking==0.0.11
murmurhash==1.0.9
music21==5.5.0
natsort==5.5.0
nbconvert==5.6.1
nbformat==5.7.1
netCDF4==1.6.2
networkx==2.8.8
nibabel==3.0.2
nltk==3.7
notebook==5.7.16
numba==0.56.4
numexpr==2.8.4
numpy==1.21.6
oauth2client==4.1.3
oauthlib==3.2.2
okgrade==0.4.3
opencv-contrib-python==4.6.0.66
opencv-python==4.6.0.66
opencv-python-headless==4.6.0.66
openpyxl==3.0.10
opt-einsum==3.3.0
optuna==3.0.5
osqp==0.6.2.post0
packaging==21.3
palettable==3.3.0
pandas==1.3.5
pandas-datareader==0.9.0
pandas-gbq==0.17.9
pandas-profiling==1.4.1
pandocfilters==1.5.0
panel==0.12.1
param==1.12.3
parso==0.8.3
partd==1.3.0
pastel==0.2.1
pathlib==1.0.1
pathy==0.10.1
patsy==0.5.3
pbr==5.11.0
pep517==0.13.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==7.1.2
pip-tools==6.6.2
platformdirs==2.6.0
plotly==5.5.0
plotnine==0.8.0
pluggy==0.7.1
pooch==1.6.0
portpicker==1.3.9
prefetch-generator==1.0.3
preshed==3.0.8
prettytable==3.5.0
progressbar2==3.38.0
prometheus-client==0.15.0
promise==2.3
prompt-toolkit==2.0.10
prophet==1.1.1
proto-plus==1.22.1
protobuf==3.19.6
psutil==5.4.8
psycopg2==2.9.5
ptyprocess==0.7.0
py==1.11.0
pyarrow==9.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycocotools==2.0.6
pycparser==2.21
pyct==0.4.8
pydantic==1.10.2
pydata-google-auth==1.4.0
pydot==1.3.0
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
pyemd==0.5.1
pyerfa==2.0.0.1
Pygments==2.6.1
PyGObject==3.26.1
pylev==1.4.0
pymc==4.1.4
PyMeeus==0.5.12
pymongo==4.3.3
pymystem3==0.2.0
PyOpenGL==3.1.6
pyparsing==3.0.9
pyperclip==1.8.2
pyrsistent==0.19.2
pysimdjson==3.2.0
pysndfile==1.3.8
PySocks==1.7.1
pystan==3.3.0
pytest==3.6.4
python-apt==0.0.0
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==7.0.0
python-utils==3.4.5
pytorch-lightning==1.8.6
pytz==2022.7
pyviz-comms==2.2.1
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==23.2.1
qdldl==0.1.5.post2
qudida==0.0.4
regex==2022.6.2
requests==2.25.1
requests-oauthlib==1.3.1
resampy==0.4.2
rpy2==3.5.5
rsa==4.9
scikit-image==0.18.3
scikit-learn==1.0.2
scipy==1.7.3
screen-resolution-extra==0.0.0
scs==3.2.2
seaborn==0.11.2
Send2Trash==1.8.0
setuptools-git==1.2
shapely==2.0.0
six==1.15.0
sklearn-pandas==1.8.0
smart-open==6.3.0
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.11.0
spacy==3.4.4
spacy-legacy==3.0.10
spacy-loggers==1.0.4
Sphinx==1.8.6
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib-websupport==1.2.4
SQLAlchemy==1.4.45
sqlparse==0.4.3
srsly==2.4.5
statsmodels==0.12.2
stevedore==4.1.1
sympy==1.7.1
tables==3.7.0
tabulate==0.8.10
tblib==1.7.0
tenacity==8.1.0
tensorboard==2.9.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
tensorflow==2.9.2
tensorflow-datasets==4.6.0
tensorflow-estimator==2.9.0
tensorflow-gcs-config==2.9.1
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.29.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.17.0
termcolor==2.1.1
terminado==0.13.3
testpath==0.6.0
text-unidecode==1.3
textblob==0.15.3
thinc==8.1.6
threadpoolctl==3.1.0
tifffile==2022.10.10
toml==0.10.2
tomli==2.0.1
toolz==0.12.0
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.0%2Bcu116-cp38-cp38-linux_x86_64.whl
torchmetrics==0.11.0
torchsummary==1.5.1
torchtext==0.14.0
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.0%2Bcu116-cp38-cp38-linux_x86_64.whl
tornado==6.0.4
tqdm==4.64.1
traitlets==5.7.1
tweepy==3.10.0
typeguard==2.7.1
typer==0.7.0
typing_extensions==4.4.0
tzlocal==1.5.1
uritemplate==4.1.1
urllib3==1.24.3
vega-datasets==0.9.0
wasabi==0.10.1
wcwidth==0.2.5
webargs==8.2.0
webencodings==0.5.1
Werkzeug==1.0.1
widgetsnbextension==3.6.1
wordcloud==1.8.2.2
wrapt==1.14.1
xarray==2022.12.0
xarray-einstats==0.4.0
xgboost==0.90
xkit==0.0.0
xlrd==1.2.0
xlwt==1.3.0
yarl==1.8.2
yellowbrick==1.5
zict==2.2.0
zipp==3.11.0
Error messages, stack traces, or logs
[W 2023-01-03 18:30:22,015] Trial 0 failed because of the following error: RuntimeError('The `on_init_start` callback hook was deprecated in v1.6 and is no longer supported as of v1.8.')
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/optuna/study/_optimize.py", line 196, in _run_trial
value_or_values = func(trial)
File "<ipython-input-3-b120a10c1659>", line 138, in objective
trainer.fit(model, datamodule=datamodule)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1024, in _run
verify_loop_configurations(self)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py", line 53, in verify_loop_configurations
_check_deprecated_callback_hooks(trainer)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py", line 221, in _check_deprecated_callback_hooks
raise RuntimeError(
RuntimeError: The `on_init_start` callback hook was deprecated in v1.6 and is no longer supported as of v1.8.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-3-b120a10c1659>](https://localhost:8080/#) in <module>
146
147 study = optuna.create_study(direction="maximize", pruner=pruner)
--> 148 study.optimize(objective, n_trials=100, timeout=600)
149
150 print("Number of finished trials: {}".format(len(study.trials)))
11 frames
[/usr/local/lib/python3.8/dist-packages/optuna/study/study.py](https://localhost:8080/#) in optimize(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
417 """
418
--> 419 _optimize(
420 study=self,
421 func=func,
[/usr/local/lib/python3.8/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _optimize(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
64 try:
65 if n_jobs == 1:
---> 66 _optimize_sequential(
67 study,
68 func,
[/usr/local/lib/python3.8/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _optimize_sequential(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)
158
159 try:
--> 160 frozen_trial = _run_trial(study, func, catch)
161 finally:
162 # The following line mitigates memory problems that can be occurred in some
[/usr/local/lib/python3.8/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _run_trial(study, func, catch)
232 and not isinstance(func_err, catch)
233 ):
--> 234 raise func_err
235 return frozen_trial
236
[/usr/local/lib/python3.8/dist-packages/optuna/study/_optimize.py](https://localhost:8080/#) in _run_trial(study, func, catch)
194 with get_heartbeat_thread(trial._trial_id, study._storage):
195 try:
--> 196 value_or_values = func(trial)
197 except exceptions.TrialPruned as e:
198 # TODO(mamu): Handle multi-objective cases.
[<ipython-input-3-b120a10c1659>](https://localhost:8080/#) in objective(trial)
136 hyperparameters = dict(n_layers=n_layers, dropout=dropout, output_dims=output_dims)
137 trainer.logger.log_hyperparams(hyperparameters)
--> 138 trainer.fit(model, datamodule=datamodule)
139
140 return trainer.callback_metrics["val_acc"].item()
[/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
601 raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
602 self.strategy._lightning_module = model
--> 603 call._call_and_handle_interrupt(
604 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
605 )
[/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py](https://localhost:8080/#) in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
37 else:
---> 38 return trainer_fn(*args, **kwargs)
39
40 except _TunerExitException:
[/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
643 model_connected=self.lightning_module is not None,
644 )
--> 645 self._run(model, ckpt_path=self.ckpt_path)
646
647 assert self.state.stopped
[/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py](https://localhost:8080/#) in _run(self, model, ckpt_path)
1022 self._callback_connector._attach_model_logging_functions()
1023
-> 1024 verify_loop_configurations(self)
1025
1026 # hook
[/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py](https://localhost:8080/#) in verify_loop_configurations(trainer)
51 __verify_batch_transfer_support(trainer)
52 # TODO: Delete this check in v2.0
---> 53 _check_deprecated_callback_hooks(trainer)
54 # TODO: Delete this check in v2.0
55 _check_on_epoch_start_end(model)
[/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py](https://localhost:8080/#) in _check_deprecated_callback_hooks(trainer)
219 for callback in trainer.callbacks:
220 if callable(getattr(callback, "on_init_start", None)):
--> 221 raise RuntimeError(
222 "The `on_init_start` callback hook was deprecated in v1.6 and is no longer supported as of v1.8."
223 )
RuntimeError: The `on_init_start` callback hook was deprecated in v1.6 and is no longer supported as of v1.8.```
## Steps to reproduce
1. Run the example code above
## Reproducible examples (optional)
See above.
## Additional context (optional)