Source code for optuna_integration.allennlp._executor

from __future__ import annotations

import json
import os
from typing import Any
import warnings

import optuna
from optuna import TrialPruned
from optuna._deprecated import deprecated_class

from optuna_integration._imports import try_import
from optuna_integration.allennlp._environment import _environment_variables
from optuna_integration.allennlp._variables import _VariableManager
from optuna_integration.allennlp._variables import OPTUNA_ALLENNLP_DISTRIBUTED_FLAG


with try_import() as _imports:
    import allennlp
    import allennlp.commands
    import allennlp.common.cached_transformers
    import allennlp.common.util

# TrainerCallback is conditionally imported because allennlp may be unavailable in
# the environment that builds the documentation.
if _imports.is_successful():
    import _jsonnet
    import psutil
    from torch.multiprocessing.spawn import ProcessRaisedException


def _fetch_pruner_config(trial: optuna.Trial) -> dict[str, Any]:
    pruner = trial.study.pruner
    kwargs: dict[str, Any] = {}

    if isinstance(pruner, optuna.pruners.HyperbandPruner):
        kwargs["min_resource"] = pruner._min_resource
        kwargs["max_resource"] = pruner._max_resource
        kwargs["reduction_factor"] = pruner._reduction_factor

    elif isinstance(pruner, optuna.pruners.MedianPruner):
        kwargs["n_startup_trials"] = pruner._n_startup_trials
        kwargs["n_warmup_steps"] = pruner._n_warmup_steps
        kwargs["interval_steps"] = pruner._interval_steps

    elif isinstance(pruner, optuna.pruners.PercentilePruner):
        kwargs["percentile"] = pruner._percentile
        kwargs["n_startup_trials"] = pruner._n_startup_trials
        kwargs["n_warmup_steps"] = pruner._n_warmup_steps
        kwargs["interval_steps"] = pruner._interval_steps

    elif isinstance(pruner, optuna.pruners.SuccessiveHalvingPruner):
        kwargs["min_resource"] = pruner._min_resource
        kwargs["reduction_factor"] = pruner._reduction_factor
        kwargs["min_early_stopping_rate"] = pruner._min_early_stopping_rate

    elif isinstance(pruner, optuna.pruners.ThresholdPruner):
        kwargs["lower"] = pruner._lower
        kwargs["upper"] = pruner._upper
        kwargs["n_warmup_steps"] = pruner._n_warmup_steps
        kwargs["interval_steps"] = pruner._interval_steps
    elif isinstance(pruner, optuna.pruners.NopPruner):
        pass
    else:
        raise ValueError("Unsupported pruner is specified: {}".format(type(pruner)))

    return kwargs


[docs] @deprecated_class("3.5.0", "5.0.0") class AllenNLPExecutor: """AllenNLP extension to use optuna with Jsonnet config file. See the examples of `objective function <https://github.com/optuna/optuna-examples/tree/ main/allennlp/allennlp_jsonnet.py>`_. You can also see the tutorial of our AllenNLP integration on `AllenNLP Guide <https://guide.allennlp.org/hyperparameter-optimization>`_. .. note:: From Optuna v2.1.0, users have to cast their parameters by using methods in Jsonnet. Call ``std.parseInt`` for integer, or ``std.parseJson`` for floating point. Please see the `example configuration <https://github.com/optuna/optuna-examples/tree/main/ allennlp/classifier.jsonnet>`_. .. note:: In :class:`~optuna_integration.AllenNLPExecutor`, you can pass parameters to AllenNLP by either defining a search space using Optuna suggest methods or setting environment variables just like AllenNLP CLI. If a value is set in both a search space in Optuna and the environment variables, the executor will use the value specified in the search space in Optuna. Args: trial: A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the objective function. config_file: Config file for AllenNLP. Hyperparameters should be masked with ``std.extVar``. Please refer to `the config example <https://github.com/allenai/allentune/blob/ master/examples/classifier.jsonnet>`_. serialization_dir: A path which model weights and logs are saved. metrics: An evaluation metric. `GradientDescrentTrainer.train() <https://docs.allennlp.org/ main/api/training/gradient_descent_trainer/#train>`_ of AllenNLP returns a dictionary containing metrics after training. :class:`~optuna_integration.AllenNLPExecutor` accesses the dictionary by the key ``metrics`` you specify and use it as a objective value. force: If :obj:`True`, an executor overwrites the output directory if it exists. file_friendly_logging: If :obj:`True`, tqdm status is printed on separate lines and slows tqdm refresh rate. include_package: Additional packages to include. For more information, please see `AllenNLP documentation <https://docs.allennlp.org/master/api/commands/train/>`_. """ def __init__( self, trial: optuna.Trial, config_file: str, serialization_dir: str, metrics: str = "best_validation_accuracy", *, include_package: str | list[str] | None = None, force: bool = False, file_friendly_logging: bool = False, ): _imports.check() self._params = trial.params self._config_file = config_file self._serialization_dir = serialization_dir self._metrics = metrics self._force = force self._file_friendly_logging = file_friendly_logging if include_package is None: include_package = [] if isinstance(include_package, str): include_package = [include_package] self._include_package = include_package + ["optuna_integration.allennlp"] storage = trial.study._storage if isinstance(storage, optuna.storages.RDBStorage): url = storage.url elif isinstance(storage, optuna.storages._CachedStorage): assert isinstance(storage._backend, optuna.storages.RDBStorage) url = storage._backend.url else: url = "" target_pid = psutil.Process().ppid() variable_manager = _VariableManager(target_pid) pruner_kwargs = _fetch_pruner_config(trial) variable_manager.set_value("study_name", trial.study.study_name) variable_manager.set_value("trial_id", trial._trial_id) variable_manager.set_value("storage_name", url) variable_manager.set_value("monitor", metrics) if trial.study.pruner is not None: variable_manager.set_value("pruner_class", type(trial.study.pruner).__name__) variable_manager.set_value("pruner_kwargs", pruner_kwargs) def _build_params(self) -> dict[str, Any]: """Create a dict of params for AllenNLP. _build_params is based on allentune's ``train_func``. For more detail, please refer to https://github.com/allenai/allentune/blob/master/allentune/modules/allennlp_runner.py#L34-L65 """ params = _environment_variables() params.update({key: str(value) for key, value in self._params.items()}) return json.loads(_jsonnet.evaluate_file(self._config_file, ext_vars=params)) def _set_environment_variables(self) -> None: for key, value in _environment_variables().items(): if key is None: continue os.environ[key] = value
[docs] def run(self) -> float: """Train a model using AllenNLP.""" for package_name in self._include_package: allennlp.common.util.import_module_and_submodules(package_name) # Without the following lines, the transformer model construction only takes place in the # first trial (which would consume some random numbers), and the cached model will be used # in trials afterwards (which would not consume random numbers), leading to inconsistent # results between single trial and multiple trials. To make results reproducible in # multiple trials, we clear the cache before each trial. # TODO(MagiaSN) When AllenNLP has introduced a better API to do this, one should remove # these lines and use the new API instead. For example, use the `_clear_caches()` method # which will be in the next AllenNLP release after 2.4.0. allennlp.common.cached_transformers._model_cache.clear() allennlp.common.cached_transformers._tokenizer_cache.clear() self._set_environment_variables() params = allennlp.common.params.Params(self._build_params()) if "distributed" in params: if OPTUNA_ALLENNLP_DISTRIBUTED_FLAG in os.environ: warnings.warn( "Other process may already exists." " If you have trouble, please unset the environment" " variable `OPTUNA_ALLENNLP_USE_DISTRIBUTED`" " and try it again." ) os.environ[OPTUNA_ALLENNLP_DISTRIBUTED_FLAG] = "1" try: allennlp.commands.train.train_model( params=params, serialization_dir=self._serialization_dir, file_friendly_logging=self._file_friendly_logging, force=self._force, include_package=self._include_package, ) except ProcessRaisedException as e: if "raise TrialPruned()" in str(e): raise TrialPruned() metrics = json.load(open(os.path.join(self._serialization_dir, "metrics.json"))) return metrics[self._metrics]