Pruning Not Working in Pytorch
Serendipity31 opened this issue · 1 comments
I am trying to use Optuna to tune some hyperparameters for a deepAR model. I'm using the DeepAREstimator published by gluonTS that uses pytorch (info here). Additionally, I have built the objective function according to the gluonTS tutorial (here).
I've done various tests, including some with 10 or 20 trials, and in none of my tests have any trials been pruned. Optuna appears to work in that trials are completed and hyperparameter values are returned. It's just that none of the trials are ever pruned.
I have tried two ways of getting pruning to work. In the first I use the pruner argument within create_study{}. I call the line of code to define 'pruner' before calling optuna.create_study().
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction = "minimize", pruner = pruner)
In the second, I use the optuna.integration.PyTorchLightningPruningCallback() in the objective function. It gets defined before the deepAREstimator and then the callback is used in the trainer_kwards{} argument, which passes onto the py.lighning trainer. I have also tried with both these approaches simultaneously.
class DeepARTuningObjective:
#... for brevity sake I'm excluding here everything in this class definition above def __call__, but you can find it in the link from above
def __call__(self, trial):
params = self.get_params(trial)
pruning_callback = optuna.integration.PyTorchLightningPruningCallback(trial, "val_loss")
estimator = DeepAREstimator(
#Basic Parameters
freq = self.freq
prediction_length = self.prediction_length,
context_length = self.prediction_length,
#Hyperparameters to be tuned
num_layers = params["num_layers"],
hidden_size = params["hidden_size"],
lr = params["learning_rate"],
weight_decay = params["weight_decay"],
dropout_rate = params["dropout_rate"],
batch_size = params["batch_size"],
lags_seq = self.lag_options[params["lag_options_index"]],
#Exogenous Variables
num_feat_dynamic_real = self.valid_pickle.num_feat_dynamic_real,
num_feat_static_cat = self.valid_pickle.num_feat_static_cat,
cardinality = self.valid_pickle.static_cardinalities.tolist(),
#Additional call to pl.Trainer with hyperparameter to be tuned
trainer_kwargs ={
"max_epochs": params["max_epochs"]+1,
"deterministic": True,
"callbacks": [pruning_callback]
},
)
predictor = estimator.train(self.train, cache_data=True)
forecast_it = predictor.predict(self.validation_input)
forecasts = list(forecast_it)
evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
agg_metrics, item_metrics = evaluator(
self.validation_label, forecasts, num_series=len(self.valid_pickle)
)
return agg_metrics[self.metric_type] #else agg_metrics is a dict and code errors
Expected behavior
In calls to optimize() that have numerous trials, I would expect at least some trials to be pruned. Because I'm not seeing this I'm left wondering how to get pruning to work in this case.
If someone could troubleshoot this and/or tell me where I am going wrong, I would really appreciate it! I would very much like to be able to take advantage of the pruning functionality within optuna
Environment
- Optuna version: 3.4.0
- Python version: 3.11.5
- OS: Windows 10 Enterprise
- gluonTS version: 0.13.7
- lighning version: 2.1.0
- pytorch-lightning version: 2.1.0
- torch version: 2.1.0
Steps to reproduce
- Create the objective function as per the gluonTS tutorial. I replaced 'dataset' with 'valid_pickle' because I have my dataset saved as a pickle, and my prediction_length is 1, but otherwise it's basically the same.
Approach 1
2. In an interactive environment, create pruner() as shown above. Then create the study as shown above. Then call study.optimize() with n_trials >3
Approach 2
2. Modify the class from step 1. Create call_back as per the example code above. Pass a call to 'call_back' in trainer_kwards{} within the estimator
3. Create the study and call study.otpimize with n_trials >3
Since optuna/optuna#5068 is self-contained, let me close this one.