enable_checkpointing = False results in MisconfigurationException
mraapshockwavemedical opened this issue · 0 comments
Description
I would like to train a TemporalFusionTransformerEstimator
and set trainer_kwargs["enable_checkpointing"] = False
. However, on line 196 in torch\model\estimator.py a ModelCheckpoint is created nevertheless and added to the list of callbacks on line 204. This results in an error: MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: Trainer was configured with enable_checkpointing=False
but found ModelCheckpoint
in callbacks list.
I need to disable checkpoints, because I would like to run this on Snowflake which does not allow writing to a filesystem. Workaround ideas for now would be highly appreciated.
Edit
Snowflake would allow me to write to /tmp/checkpoints
, but it seems to be impossible to set the dir_path
of the checkpoint created in estimator.py on lines 195-198:
monitor = "train_loss" if validation_data is None else "val_loss"
checkpoint = pl.callbacks.ModelCheckpoint(
monitor=monitor, mode="min", verbose=True
)
To Reproduce
import pandas as pd
from gluonts.torch import TemporalFusionTransformerEstimator
from gluonts.dataset.pandas import PandasDataset
data = {
"item_id": [1, 1, 1],
"ts": ['2024-01-01', '2024-02-01', '2024-03-01'],
"target": [1, 2, 3]
}
ds = PandasDataset.from_long_dataframe(pd.DataFrame(data), target="target", item_id='item_id', timestamp='ts')
trainer_kwargs = {'enable_checkpointing': False}
estimator = TemporalFusionTransformerEstimator(freq='M', prediction_length=1, trainer_kwargs=trainer_kwargs)
predictor = estimator.train(training_data=ds)
Error message or code output
Traceback (most recent call last):
File "C:\projects\test\venv\test.py", line 15, in <module>
predictor = estimator.train(training_data=ds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\projects\test\venv\Lib\site-packages\gluonts\torch\model\estimator.py", line 246, in train
return self.train_model(
^^^^^^^^^^^^^^^^^
File "C:\projects\test\venv\Lib\site-packages\gluonts\torch\model\estimator.py", line 201, in train_model
trainer = pl.Trainer(
^^^^^^^^^^^
File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\utilities\argparse.py", line 70, in insert_env_defaults
return fn(self, **kwargs)
^^^^^^^^^^^^^^^^^^
File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\trainer.py", line 431, in __init__
self._callback_connector.on_trainer_init(
File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\callback_connector.py", line 66, in on_trainer_init
self._configure_checkpoint_callbacks(enable_checkpointing)
File "C:\projects\test\venv\Lib\site-packages\lightning\pytorch\trainer\connectors\callback_connector.py", line 88, in _configure_checkpoint_callbacks
raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: Trainer was configured with `enable_checkpointing=False` but found `ModelCheckpoint` in callbacks list.
Environment
- Operating system: Windows
- Python version: 3.12
- GluonTS version: 0.15.1