from __future__ import annotations
from typing import Any
import optuna
from optuna_integration._imports import try_import
with try_import() as _imports:
from skorch.callbacks import Callback
from skorch.net import NeuralNet
if not _imports.is_successful():
Callback = object # NOQA
[docs]
class SkorchPruningCallback(Callback):
"""Skorch callback to prune unpromising trials.
.. versionadded:: 2.1.0
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
monitor:
An evaluation metric for pruning, e.g. ``val_loss`` or
``val_acc``. The metrics are obtained from the returned dictionaries,
i.e., ``net.histroy``. The names thus depend on how this dictionary
is formatted.
"""
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
_imports.check()
super().__init__()
self._trial = trial
self._monitor = monitor
def on_epoch_end(self, net: "NeuralNet", **kwargs: Any) -> None:
history = net.history
if not history:
return
epoch = len(history) - 1
current_score = history[-1, self._monitor]
self._trial.report(current_score, epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.TrialPruned(message)