Source code for optuna_integration._lightgbm_tuner._train

from __future__ import annotations

from import Callable
from typing import Any

from optuna._imports import try_import
from import Study
from optuna.trial import FrozenTrial

from optuna_integration._lightgbm_tuner.optimize import _imports
from optuna_integration._lightgbm_tuner.optimize import LightGBMTuner

with try_import():
    import lightgbm as lgb

[docs] def train( params: dict[str, Any], train_set: "lgb.Dataset", num_boost_round: int = 1000, valid_sets: list["lgb.Dataset"] | tuple["lgb.Dataset", ...] | "lgb.Dataset" | None = None, valid_names: Any | None = None, feval: Callable[..., Any] | None = None, feature_name: str = "auto", categorical_feature: str = "auto", keep_training_booster: bool = False, callbacks: list[Callable[..., Any]] | None = None, time_budget: int | None = None, sample_size: int | None = None, study: Study | None = None, optuna_callbacks: list[Callable[[Study, FrozenTrial], None]] | None = None, model_dir: str | None = None, *, show_progress_bar: bool = True, optuna_seed: int | None = None, ) -> "lgb.Booster": """Wrapper of LightGBM Training API to tune hyperparameters. It optimizes the following hyperparameters in a stepwise manner: ``lambda_l1``, ``lambda_l2``, ``num_leaves``, ``feature_fraction``, ``bagging_fraction``, ``bagging_freq`` and ``min_child_samples``. It is a drop-in replacement for `lightgbm.train()`_. See `a simple example of LightGBM Tuner < lightgbm/>`_ which optimizes the validation log loss of cancer detection. :func:`~optuna_integration.lightgbm.train` is a wrapper function of :class:`~optuna_integration.lightgbm.LightGBMTuner`. To use feature in Optuna such as suspended/resumed optimization and/or parallelization, refer to :class:`~optuna_integration.lightgbm.LightGBMTuner` instead of this function. .. note:: Arguments and keyword arguments for `lightgbm.train()`_ can be passed. For ``params``, please check `the official documentation for LightGBM <>`_. Args: time_budget: A time budget for parameter tuning in seconds. study: A :class:`` instance to store optimization results. The :class:`~optuna.trial.Trial` instances in it has the following user attributes: ``elapsed_secs`` is the elapsed time since the optimization starts. ``average_iteration_time`` is the average time of iteration to train the booster model in the trial. ``lgbm_params`` is a JSON-serialized dictionary of LightGBM parameters used in the trial. optuna_callbacks: List of Optuna callback functions that are invoked at the end of each trial. Each function must accept two parameters with the following types in this order: :class:`` and :class:`~optuna.trial.FrozenTrial`. Please note that this is not a ``callbacks`` argument of `lightgbm.train()`_ . model_dir: A directory to save boosters. By default, it is set to :obj:`None` and no boosters are saved. Please set shared directory (e.g., directories on NFS) if you want to access :meth:`~optuna_integration.lightgbm.LightGBMTuner.get_best_booster` in distributed environments. Otherwise, it may raise :obj:`ValueError`. If the directory does not exist, it will be created. The filenames of the boosters will be ``{model_dir}/{trial_number}.pkl`` (e.g., ``./boosters/0.pkl``). show_progress_bar: Flag to show progress bars or not. To disable progress bar, set this :obj:`False`. .. note:: Progress bars will be fragmented by logging messages of LightGBM and Optuna. Please suppress such messages to show the progress bars properly. optuna_seed: ``seed`` of :class:`~optuna.samplers.TPESampler` for random number generator that affects sampling for ``num_leaves``, ``bagging_fraction``, ``bagging_freq``, ``lambda_l1``, and ``lambda_l2``. .. note:: The `deterministic`_ parameter of LightGBM makes training reproducible. Please enable it when you use this argument. .. _lightgbm.train(): .. _LightGBM's verbosity: .. _deterministic: """ _imports.check() auto_booster = LightGBMTuner( params=params, train_set=train_set, num_boost_round=num_boost_round, valid_sets=valid_sets, valid_names=valid_names, feval=feval, feature_name=feature_name, categorical_feature=categorical_feature, keep_training_booster=keep_training_booster, callbacks=callbacks, time_budget=time_budget, sample_size=sample_size, study=study, optuna_callbacks=optuna_callbacks, model_dir=model_dir, show_progress_bar=show_progress_bar, optuna_seed=optuna_seed, ) return auto_booster.get_best_booster()