Source code for decent_bench.algorithms.federated._fed_yogi
from dataclasses import dataclass
from typing import TYPE_CHECKING
import decent_bench.utils.interoperability as iop
from decent_bench.utils._tags import tags
from ._fed_opt import FedOpt
if TYPE_CHECKING:
from decent_bench.utils.array import Array
[docs]
@tags("federated")
@dataclass(eq=False)
class FedYogi(FedOpt):
r"""
FedYogi uses local SGD on clients and a Yogi-style adaptive server update :footcite:p:`Alg_FedOpt`.
Each selected client starts from the broadcast global model :math:`\mathbf{x}_t` and performs
``num_local_steps`` local SGD steps with client step size ``step_size``.
.. math::
\mathbf{x}_{i, t}^{(k+1)} = \mathbf{x}_{i, t}^{(k)} - \eta_l
\nabla f_i(\mathbf{x}_{i, t}^{(k)}).
The final client model defines the uploaded delta
.. math::
\delta_i^t = \mathbf{x}_{i, t}^{(K)} - \mathbf{x}_t.
The server aggregates client model deltas uniformly over the participating clients:
.. math::
\Delta_t = \frac{1}{|S_t|} \sum_{i \in S_t} \delta_i^t.
FedYogi then updates its moment buffers and global model as
.. math::
\mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1 - \beta_1) \Delta_t
.. math::
\mathbf{v}_t = \mathbf{v}_{t-1} - (1 - \beta_2) \Delta_t^2
\operatorname{sign}(\mathbf{v}_{t-1} - \Delta_t^2)
.. math::
\mathbf{x}_{t+1} = \mathbf{x}_t + \eta
\frac{\mathbf{m}_t}{\sqrt{\mathbf{v}_t} + \epsilon}.
Here :math:`\eta_l` is the client learning rate (the corresponding argument is ``step_size``), :math:`K` is the
number of local SGD steps (the corresponding argument is ``num_local_steps``), :math:`\eta` is the server
learning rate (the corresponding argument is ``server_step_size``), :math:`\beta_1` and :math:`\beta_2` are the
first- and second-moment coefficients (the corresponding arguments are ``beta_1`` and ``beta_2``),
:math:`\epsilon` is the numerical stability term (the corresponding argument is ``epsilon``), and :math:`S_t` is
the set of clients whose uploads are actually received in round :math:`t`. Aggregation is always uniform across
the received clients. Costs that preserve the
:class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use mini-batch local updates; generic costs use their
usual full-gradient updates.
.. footbibliography::
"""
beta_2: float = 0.99
name: str = "FedYogi"
def __post_init__(self) -> None:
"""
Validate the Yogi-specific hyperparameters.
Raises:
ValueError: if ``beta_2`` is outside ``[0, 1)``.
"""
super().__post_init__()
if not (0 <= self.beta_2 < 1):
raise ValueError("`beta_2` must satisfy 0 <= beta_2 < 1")
def _update_second_moment(self, second_moment: "Array", average_delta: "Array") -> "Array":
delta_squared = average_delta * average_delta
return second_moment - ((1 - self.beta_2) * delta_squared * iop.sign(second_moment - delta_squared))