Source code for optuna_integration.tfkeras.tfkeras

from __future__ import annotations

from typing import Any
import warnings

import optuna

from optuna_integration._imports import try_import

with try_import() as _imports:
    from tensorflow.keras.callbacks import Callback

if not _imports.is_successful():
    Callback = object  # NOQA

[docs] class TFKerasPruningCallback(Callback): """tf.keras callback to prune unpromising trials. This callback is intend to be compatible for TensorFlow v1 and v2, but only tested with TensorFlow v2. See `the example < tfkeras/>`__ if you want to add a pruning callback which observes the validation accuracy. Args: trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. monitor: An evaluation metric for pruning, e.g., ``val_loss`` or ``val_acc``. """ def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None: super().__init__() _imports.check() self._trial = trial self._monitor = monitor def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None: logs = logs or {} current_score = logs.get(self._monitor) if current_score is None: message = ( "The metric '{}' is not in the evaluation logs for pruning. " "Please make sure you set the correct metric name.".format(self._monitor) ) warnings.warn(message) return # Report current score and epoch to Optuna's trial., step=epoch) # Prune trial if needed if self._trial.should_prune(): message = "Trial was pruned at epoch {}.".format(epoch) raise optuna.TrialPruned(message)