optuna_integration.PyTorchLightningPruningCallback
- class optuna_integration.PyTorchLightningPruningCallback(trial, monitor)[source]
PyTorch Lightning callback to prune unpromising trials.
See the example if you want to add a pruning callback which observes accuracy.
- Parameters:
trial (optuna.trial.Trial) – A
Trial
corresponding to the current evaluation of the objective function.monitor (str) – An evaluation metric for pruning, e.g.,
val_loss
orval_acc
. The metrics are obtained from the returned dictionaries from e.g.lightning.pytorch.LightningModule.training_step
orlightning.pytorch.LightningModule.validation_epoch_end
and the names thus depend on how this dictionary is formatted.
Note
For the distributed data parallel training, the version of PyTorchLightning needs to be higher than or equal to v1.6.0. In addition,
Study
should be instantiated with RDB storage.Note
If you would like to use PyTorchLightningPruningCallback in a distributed training environment, you need to evoke PyTorchLightningPruningCallback.check_pruned() manually so that
TrialPruned
is properly handled.Methods
Raise
optuna.TrialPruned
manually if pruned.on_fit_start
(trainer, pl_module)on_validation_end
(trainer, pl_module)- check_pruned()[source]
Raise
optuna.TrialPruned
manually if pruned.Currently,
intermediate_values
are not properly propagated between processes due to storage cache. Therefore, necessary information is kept in trial_system_attrs when the trial runs in a distributed situation. Please call this method right after callinglightning.pytorch.Trainer.fit()
. If a callback doesn’t have any backend storage for DDP, this method does nothing.- Return type:
None