Source code for optuna_integration.pytorch_distributed.pytorch_distributed

from __future__ import annotations

from collections.abc import Callable
from collections.abc import Sequence
from datetime import datetime
import functools
from typing import Any
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar

import optuna
from optuna._deprecated import deprecated_func
from optuna._experimental import experimental_class
from optuna._imports import try_import
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType


with try_import() as _imports:
    import torch
    import torch.distributed as dist
    from torch.distributed import ProcessGroup


if TYPE_CHECKING:
    from typing_extensions import ParamSpec

    _T = TypeVar("_T")
    _P = ParamSpec("_P")


_suggest_deprecated_msg = "Use suggest_float{args} instead."

_g_pg: "ProcessGroup" | None = None


def broadcast_properties(f: "Callable[_P, _T]") -> "Callable[_P, _T]":
    """Method decorator to fetch updated trial properties from rank 0 after ``f`` is run.

    This decorator ensures trial properties (params, distributions, etc.) on all distributed
    processes are up-to-date with the wrapped trial stored on rank 0.
    It should be applied to all :class:`~optuna_integration.TorchDistributedTrial`
    methods that update property values.
    """

    @functools.wraps(f)
    def wrapped(*args: "_P.args", **kwargs: "_P.kwargs") -> "_T":
        # TODO(nlgranger): Remove type ignore after mypy includes
        # https://github.com/python/mypy/pull/12668
        self: TorchDistributedTrial = args[0]  # type: ignore[assignment]

        def fetch_properties() -> Sequence:
            assert self._delegate is not None
            return (
                self._delegate.number,
                self._delegate.params,
                self._delegate.distributions,
                self._delegate.user_attrs,
                self._delegate.system_attrs,
                self._delegate.datetime_start,
            )

        try:
            return f(*args, **kwargs)
        finally:
            (
                self._number,
                self._params,
                self._distributions,
                self._user_attrs,
                self._system_attrs,
                self._datetime_start,
            ) = self._call_and_communicate_obj(fetch_properties)

    return wrapped


[docs] @experimental_class("2.6.0") class TorchDistributedTrial(optuna.trial.BaseTrial): """A wrapper of :class:`~optuna.trial.Trial` to incorporate Optuna with PyTorch distributed. .. seealso:: :class:`~optuna_integration.TorchDistributedTrial` provides the same interface as :class:`~optuna.trial.Trial`. Please refer to :class:`optuna.trial.Trial` for further details. See `the example <https://github.com/optuna/optuna-examples/blob/main/ pytorch/pytorch_distributed_simple.py>`__ if you want to optimize an objective function that trains neural network written with PyTorch distributed data parallel. Args: trial: A :class:`~optuna.trial.Trial` object or :obj:`None`. Please set trial object in rank-0 node and set :obj:`None` in the other rank node. group: A `torch.distributed.ProcessGroup` to communicate with the other nodes. TorchDistributedTrial use CPU tensors to communicate, make sure the group supports CPU tensors communications. Use `gloo` backend when group is None. Create a global `gloo` backend when group is None and WORLD is nccl. .. note:: The methods of :class:`~optuna_integration.TorchDistributedTrial` are expected to be called by all workers at once. They invoke synchronous data transmission to share processing results and synchronize timing. """ def __init__( self, trial: optuna.trial.BaseTrial | None, group: "ProcessGroup" | None = None, ) -> None: _imports.check() global _g_pg if group is not None: self._group: "ProcessGroup" = group else: if _g_pg is None: if dist.group.WORLD is None: raise RuntimeError("torch distributed is not initialized.") default_pg: "ProcessGroup" = dist.group.WORLD if dist.get_backend(default_pg) == "nccl": new_group: "ProcessGroup" = dist.new_group(backend="gloo") _g_pg = new_group else: _g_pg = default_pg self._group = _g_pg if dist.get_rank(self._group) == 0: if not isinstance(trial, optuna.trial.BaseTrial): raise ValueError( "Rank 0 node expects an optuna.trial.Trial instance as the trial argument." ) else: if trial is not None: raise ValueError( "Non-rank 0 node is supposed to receive None as the trial argument." ) assert trial is None, "error message" self._delegate = trial self._number = self._broadcast(getattr(self._delegate, "number", None)) self._params = self._broadcast(getattr(self._delegate, "params", None)) self._distributions = self._broadcast(getattr(self._delegate, "distributions", None)) self._user_attrs = self._broadcast(getattr(self._delegate, "user_attrs", None)) self._system_attrs = self._broadcast(getattr(self._delegate, "system_attrs", None)) self._datetime_start = self._broadcast(getattr(self._delegate, "datetime_start", None)) @broadcast_properties def suggest_float( self, name: str, low: float, high: float, *, step: float | None = None, log: bool = False, ) -> float: def func() -> float: assert self._delegate is not None return self._delegate.suggest_float(name, low, high, step=step, log=log) return self._call_and_communicate(func, torch.float)
[docs] @deprecated_func("3.0.0", "6.0.0", text=_suggest_deprecated_msg.format(args="")) def suggest_uniform(self, name: str, low: float, high: float) -> float: return self.suggest_float(name, low, high)
[docs] @deprecated_func("3.0.0", "6.0.0", text=_suggest_deprecated_msg.format(args="(..., log=True)")) def suggest_loguniform(self, name: str, low: float, high: float) -> float: return self.suggest_float(name, low, high, log=True)
[docs] @deprecated_func("3.0.0", "6.0.0", text=_suggest_deprecated_msg.format(args="(..., step=...)")) def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: return self.suggest_float(name, low, high, step=q)
@broadcast_properties def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: def func() -> float: assert self._delegate is not None return self._delegate.suggest_int(name, low, high, step=step, log=log) return self._call_and_communicate(func, torch.int) @overload def suggest_categorical(self, name: str, choices: Sequence[None]) -> None: ... # noqa: E704 @overload def suggest_categorical(self, name: str, choices: Sequence[bool]) -> bool: ... # noqa: E704 @overload def suggest_categorical(self, name: str, choices: Sequence[int]) -> int: ... # noqa: E704 @overload def suggest_categorical(self, name: str, choices: Sequence[float]) -> float: ... # noqa: E704 @overload def suggest_categorical(self, name: str, choices: Sequence[str]) -> str: ... # noqa: E704 @overload def suggest_categorical( # noqa: E704 self, name: str, choices: Sequence[CategoricalChoiceType] ) -> CategoricalChoiceType: ... @broadcast_properties def suggest_categorical( self, name: str, choices: Sequence[CategoricalChoiceType] ) -> CategoricalChoiceType: def func() -> CategoricalChoiceType: assert self._delegate is not None return self._delegate.suggest_categorical(name, choices) return self._call_and_communicate_obj(func) @broadcast_properties def report(self, value: float, step: int) -> None: err = None if dist.get_rank(self._group) == 0: try: assert self._delegate is not None self._delegate.report(value, step) except Exception as e: err = e err = self._broadcast(err) else: err = self._broadcast(err) if err is not None: raise err @broadcast_properties def should_prune(self) -> bool: def func() -> bool: assert self._delegate is not None # Some pruners return numpy.bool_, which is incompatible with bool. return bool(self._delegate.should_prune()) # torch.bool seems to be the correct type, but the communication fails # due to the RuntimeError. return self._call_and_communicate(func, torch.uint8) @broadcast_properties def set_user_attr(self, key: str, value: Any) -> None: err = None if dist.get_rank(self._group) == 0: try: assert self._delegate is not None self._delegate.set_user_attr(key, value) except Exception as e: err = e err = self._broadcast(err) else: err = self._broadcast(err) if err is not None: raise err
[docs] @broadcast_properties @deprecated_func("3.1.0", "5.0.0") def set_system_attr(self, key: str, value: Any) -> None: err = None if dist.get_rank(self._group) == 0: try: assert self._delegate is not None self._delegate.storage.set_trial_system_attr( # type: ignore[attr-defined] self._delegate._trial_id, key, value # type: ignore[attr-defined] ) except Exception as e: err = e err = self._broadcast(err) else: err = self._broadcast(err) if err is not None: raise err
@property def number(self) -> int: return self._number @property def params(self) -> dict[str, Any]: return self._params @property def distributions(self) -> dict[str, BaseDistribution]: return self._distributions @property def user_attrs(self) -> dict[str, Any]: return self._user_attrs @property @deprecated_func("3.1.0", "5.0.0") def system_attrs(self) -> dict[str, Any]: return self._system_attrs @property def datetime_start(self) -> datetime | None: return self._datetime_start def _call_and_communicate(self, func: Callable, dtype: "torch.dtype") -> Any: buffer = torch.empty(1, dtype=dtype) rank = dist.get_rank(self._group) if rank == 0: result = func() buffer[0] = result dist.broadcast(buffer, src=0, group=self._group) return buffer.item() def _call_and_communicate_obj(self, func: Callable) -> Any: rank = dist.get_rank(self._group) result = func() if rank == 0 else None return self._broadcast(result) def _broadcast(self, value: Any | None) -> Any: obj_list = [value] dist.broadcast_object_list(obj_list, src=0, group=self._group) return obj_list[0]