Source code for decent_bench.algorithms.federated._fed_adagrad
from dataclasses import dataclass
from typing import TYPE_CHECKING
from decent_bench.utils._tags import tags
from ._fed_opt import FedOpt
if TYPE_CHECKING:
from decent_bench.utils.array import Array
[docs]
@tags("federated")
@dataclass(eq=False)
class FedAdagrad(FedOpt):
r"""
FedAdagrad uses local SGD on clients and an Adagrad-style adaptive server update :footcite:p:`Alg_FedOpt`.
Each selected client starts from the broadcast global model :math:`\mathbf{x}_t` and performs
``num_local_steps`` local SGD steps with client step size ``step_size``.
.. math::
\mathbf{x}_{i, t}^{(k+1)} = \mathbf{x}_{i, t}^{(k)} - \eta_l
\nabla f_i(\mathbf{x}_{i, t}^{(k)}).
The final client model defines the uploaded delta
.. math::
\delta_i^t = \mathbf{x}_{i, t}^{(K)} - \mathbf{x}_t.
The server aggregates client model deltas uniformly over the participating clients:
.. math::
\Delta_t = \frac{1}{|S_t|} \sum_{i \in S_t} \delta_i^t.
FedAdagrad then updates its moment buffers and global model as
.. math::
\mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1 - \beta_1) \Delta_t
.. math::
\mathbf{v}_t = \mathbf{v}_{t-1} + \Delta_t^2
.. math::
\mathbf{x}_{t+1} = \mathbf{x}_t + \eta
\frac{\mathbf{m}_t}{\sqrt{\mathbf{v}_t} + \epsilon}.
Here :math:`\eta_l` is the client learning rate (the corresponding argument is ``step_size``), :math:`K` is the
number of local SGD steps (the corresponding argument is ``num_local_steps``), :math:`\eta` is the server
learning rate (the corresponding argument is ``server_step_size``), :math:`\beta_1` is the first-moment
coefficient (the corresponding argument is ``beta_1``), :math:`\epsilon` is the numerical stability term (the
corresponding argument is ``epsilon``), and :math:`S_t` is the set of clients whose uploads are actually received
in round :math:`t`. Aggregation is always uniform across the received clients.
Costs that preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use mini-batch local updates;
generic costs use their usual full-gradient updates.
.. footbibliography::
"""
name: str = "FedAdagrad"
def _update_second_moment(self, second_moment: "Array", average_delta: "Array") -> "Array":
return second_moment + (average_delta * average_delta)