Source code for decent_bench.algorithms.federated._fed_lt

from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import decent_bench.utils.interoperability as iop
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import FedNetwork
from decent_bench.schemes import ClientSelectionScheme, UniformSelection
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates

from ._fed_algorithm import FedAlgorithm

if TYPE_CHECKING:
    from decent_bench.agents import Agent
    from decent_bench.utils.array import Array

    SolverArgs = dict[str, float]
else:
    SolverArgs = dict


[docs] @tags("federated") @dataclass(eq=False) class FedLT(FedAlgorithm): r""" Federated Local Training (Fed-LT) with cost-driven local gradients :footcite:p:`Alg_FedPLT,Alg_FedLT`. Fed-LT maintains one auxiliary variable :math:`z_i` per client. This :math:`z_i` state is maintained during execution and is not a constructor argument. At the start of round :math:`k`, the server computes the broadcast variable .. math:: y_{k+1} = \operatorname{prox}_{\rho h / N}\left(\frac{1}{N}\sum_{i=1}^N z_{i,k}\right) where :math:`N` is the number of clients, and :math:`\rho` is the regularization strength (the corresponding argument is ``penalty``). The server sends :math:`y_{k+1}` to the selected clients. The initial auxiliary states ``z0`` follow the same :obj:`~decent_bench.utils.types.InitialStates` convention as ``x0``. If ``z0`` is ``None``, Fed-LT initializes :math:`z_{i,0}=x_{i,0}` for every client. In this implementation, each client cost :math:`f_i` is treated as the full local objective already available on that client. The global regularizer :math:`h` is represented by the server cost's proximal operator; with the default :class:`~decent_bench.costs.ZeroCost` server, this is the documented :math:`h=0` case and the server step is plain averaging. **Local solvers.** A selected client :math:`i` sets :math:`w^0_{i,k}=x_{i,k}` and :math:`v_{i,k}=2y_{k+1}-z_{i,k}`, then uses ``local_solver`` to approximately minimize the regularized local objective .. math:: f_i(w) + \frac{1}{2\rho}\|w - v_{i,k}\|^2. The local gradient of this subproblem is .. math:: \nabla f_i(w^\ell_{i,k}) + \frac{1}{\rho}(w^\ell_{i,k} - v_{i,k}). Costs preserving the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use its default mini-batch sampling, so gradient-based local solvers use mini-batches. Generic :class:`~decent_bench.costs.Cost` objects use their normal full-gradient behavior. Solver-specific hyperparameters are passed through ``solver_args``. **Gradient descent.** The default ``local_solver="gd"`` uses ``step_size`` as the local step size and expects empty ``solver_args``: .. math:: w^{\ell+1}_{i,k} = w^\ell_{i,k} - \gamma\left(\nabla f_i(w^\ell_{i,k}) + \frac{1}{\rho}(w^\ell_{i,k} - v_{i,k})\right). Here :math:`\gamma` is the local step size (the corresponding argument is ``step_size``). **Nesterov.** The ``local_solver="nesterov"`` option applies a Nesterov-style update to the same local gradient. It initializes :math:`u_i^0=w^0_{i,k}` and uses ``step_size`` as the local step size. Its ``solver_args`` may contain ``"momentum"``; the default is ``0.9``: .. math:: u_i^{\ell+1} = w^\ell_{i,k} - \gamma\left(\nabla f_i(w^\ell_{i,k}) + \frac{1}{\rho}(w^\ell_{i,k} - v_{i,k})\right) .. math:: w^{\ell+1}_{i,k} = u_i^{\ell+1} + \beta\left(u_i^{\ell+1} - u_i^\ell\right). Here :math:`\beta` is the Nesterov momentum coefficient (the corresponding argument is ``solver_args["momentum"]``). One possible centralized strongly-convex choice is :math:`\beta=(\sqrt{L_i + 1/\rho} - \sqrt{\mu_i + 1/\rho}) / (\sqrt{L_i + 1/\rho} + \sqrt{\mu_i + 1/\rho})`, with local step size :math:`1/(L_i + 1/\rho)`, where ``m_smooth`` supplies :math:`L_i` and ``m_cvx`` supplies :math:`\mu_i`. **Adam.** The ``local_solver="adam"`` option applies Adam to the same local gradient. Adam moments are reset at the start of every local solve because Fed-LT locally trains the current round's subproblem rather than maintaining a persistent optimizer state across rounds. Its ``solver_args`` may contain ``"beta1"``, ``"beta2"``, and ``"epsilon"``; the defaults are ``0.9``, ``0.999``, and ``1e-8``: .. math:: g^\ell_{i,k} = \nabla f_i(w^{\ell-1}_{i,k}) + \frac{1}{\rho}(w^{\ell-1}_{i,k} - v_{i,k}) .. math:: m^\ell_{i,k} = \beta_1 m^{\ell-1}_{i,k} + (1-\beta_1)g^\ell_{i,k}, \qquad s^\ell_{i,k} = \beta_2 s^{\ell-1}_{i,k} + (1-\beta_2)(g^\ell_{i,k})^2 .. math:: w^\ell_{i,k} = w^{\ell-1}_{i,k} - \gamma\frac{\hat m^\ell_{i,k}}{\sqrt{\hat s^\ell_{i,k}} + \epsilon}, for :math:`\ell=1,\ldots,N_e`, with :math:`m^0_{i,k}=s^0_{i,k}=0`. Here :math:`N_e` is the local step count (the corresponding argument is ``num_local_steps``), :math:`\beta_1` and :math:`\beta_2` are the Adam moment coefficients (the corresponding arguments are ``solver_args["beta1"]`` and ``solver_args["beta2"]``), and :math:`\epsilon` is the Adam numerical stability term (the corresponding argument is ``solver_args["epsilon"]``). The terms :math:`\hat m^\ell_{i,k}` and :math:`\hat s^\ell_{i,k}` are the Adam bias-corrected moments. After local training, the client sets .. math:: x_{i,k+1}=w^{N_e}_{i,k}, \qquad z_{i,k+1}=z_{i,k}+2(x_{i,k+1}-y_{k+1}), and uploads :math:`z_{i,k+1}` to the server. Inactive clients keep :math:`x_{i,k+1}=x_{i,k}` and :math:`z_{i,k+1}=z_{i,k}`. For later server averages, the server stores a received fresh :math:`z_{i,k+1}` when the upload arrives and otherwise keeps its previous stored :math:`z_i`. Fed-PLT is the privacy-noise version of Fed-LT :footcite:p:`Alg_FedPLT`. .. footbibliography:: """ iterations: int = 100 step_size: float = 0.001 num_local_steps: int = 1 penalty: float = 1.0 local_solver: str = "gd" solver_args: SolverArgs = field(default_factory=dict) selection_scheme: ClientSelectionScheme | None = field( default_factory=lambda: UniformSelection(fraction_selected_clients=1.0) ) x0: InitialStates = None z0: InitialStates = None name: str = "FedLT" def __post_init__(self) -> None: """ Validate hyperparameters. Raises: ValueError: if hyperparameters are invalid. """ if self.step_size <= 0: raise ValueError("`step_size` must be positive") if self.num_local_steps <= 0: raise ValueError("`num_local_steps` must be positive") if self.penalty <= 0: raise ValueError("`penalty` must be positive") if self.local_solver not in {"gd", "nesterov", "adam"}: raise ValueError("`local_solver` must be one of 'gd', 'nesterov', or 'adam'") self._validate_solver_args() def _validate_solver_args(self) -> None: user_args = self.solver_args if self.local_solver == "adam": default_args = {"beta1": 0.9, "beta2": 0.999, "epsilon": 1e-8} self.solver_args = { "beta1": user_args.get("beta1", default_args["beta1"]), "beta2": user_args.get("beta2", default_args["beta2"]), "epsilon": user_args.get("epsilon", default_args["epsilon"]), } unknown_args = set(user_args) - set(default_args) if unknown_args: names = ", ".join(sorted(unknown_args)) raise ValueError(f"Unsupported solver_args for local_solver='{self.local_solver}': {names}") if not (0 <= self.solver_args["beta1"] < 1): raise ValueError("`solver_args['beta1']` must satisfy 0 <= beta1 < 1") if not (0 <= self.solver_args["beta2"] < 1): raise ValueError("`solver_args['beta2']` must satisfy 0 <= beta2 < 1") if self.solver_args["epsilon"] <= 0: raise ValueError("`solver_args['epsilon']` must be positive") elif self.local_solver == "nesterov": default_args = {"momentum": 0.9} self.solver_args = {"momentum": user_args.get("momentum", default_args["momentum"])} unknown_args = set(user_args) - set(default_args) if unknown_args: names = ", ".join(sorted(unknown_args)) raise ValueError(f"Unsupported solver_args for local_solver='{self.local_solver}': {names}") if not (0 <= self.solver_args["momentum"] < 1): raise ValueError("`solver_args['momentum']` must satisfy 0 <= momentum < 1") else: self.solver_args = {} unknown_args = set(user_args) if unknown_args: names = ", ".join(sorted(unknown_args)) raise ValueError(f"Unsupported solver_args for local_solver='{self.local_solver}': {names}") def initialize(self, network: FedNetwork) -> None: self.x0 = initial_states(self.x0, network) self.z0 = self.x0 if self.z0 is None else initial_states(self.z0, network) server = network.server() z_by_client = {client: self.z0[client] for client in network.clients()} server.initialize(x=self.x0[server], aux_vars={"z_by_client": z_by_client}) for client in network.clients(): client.initialize(x=self.x0[client], aux_vars={"z": self.z0[client]}) def step(self, network: FedNetwork, iteration: int) -> None: y = self._compute_server_y(network) network.server().x = y selected_clients = self.select_clients(network, iteration) if not selected_clients: return self.server_broadcast(network, selected_clients) participating_clients = self._clients_with_server_broadcast(network, selected_clients) if not participating_clients: return self._run_local_updates(network, participating_clients) self.aggregate(network, participating_clients) def _compute_server_y(self, network: FedNetwork) -> "Array": z_values = list(network.server().aux_vars["z_by_client"].values()) z_sum = iop.zeros_like(z_values[0]) for z_value in z_values: z_sum += z_value average_z = z_sum / len(z_values) return network.server().cost.proximal(average_z, self.penalty / len(network.clients())) def _run_local_updates(self, network: FedNetwork, participating_clients: Sequence["Agent"]) -> None: for client in participating_clients: client.x, client.aux_vars["z"] = self._compute_local_update(client, network.server()) network.send(sender=client, receiver=network.server(), msg=client.aux_vars["z"]) def _compute_local_update(self, client: "Agent", server: "Agent") -> tuple["Array", "Array"]: """ Run Fed-LT local training and return the updated local model and auxiliary variable. The gradient call intentionally delegates batching to ``client.cost.gradient``. For :class:`~decent_bench.costs.EmpiricalRiskCost`, that default call samples mini-batches; for generic costs it is a full-gradient call. """ y = self._get_server_broadcast(client, server) z = client.aux_vars["z"] v = (2 * y) - z local_x = iop.copy(client.x) if self.local_solver == "nesterov": local_x = self._compute_nesterov_local_update(client, local_x, v) elif self.local_solver == "adam": local_x = self._compute_adam_local_update(client, local_x, v) else: for _ in range(self.num_local_steps): grad = client.cost.gradient(local_x) + ((local_x - v) / self.penalty) local_x -= self.step_size * grad z_next = z + (2 * (local_x - y)) return local_x, z_next def _compute_nesterov_local_update(self, client: "Agent", local_x: "Array", v: "Array") -> "Array": momentum = self.solver_args["momentum"] u_previous = iop.copy(local_x) for _ in range(self.num_local_steps): grad = client.cost.gradient(local_x) + ((local_x - v) / self.penalty) u_next = local_x - (self.step_size * grad) local_x = u_next + (momentum * (u_next - u_previous)) u_previous = u_next return local_x def _compute_adam_local_update(self, client: "Agent", local_x: "Array", v: "Array") -> "Array": beta1 = self.solver_args["beta1"] beta2 = self.solver_args["beta2"] epsilon = self.solver_args["epsilon"] m = iop.zeros_like(local_x) s = iop.zeros_like(local_x) for step in range(1, self.num_local_steps + 1): grad = client.cost.gradient(local_x) + ((local_x - v) / self.penalty) m = (beta1 * m) + ((1 - beta1) * grad) s = (beta2 * s) + ((1 - beta2) * (grad * grad)) m_hat = m / (1 - (beta1**step)) s_hat = s / (1 - (beta2**step)) local_x -= self.step_size * m_hat / (iop.sqrt(s_hat) + epsilon) return local_x
[docs] def aggregate( self, network: FedNetwork, participating_clients: Sequence["Agent"], ) -> None: """ Store received Fed-LT ``z`` uploads for future server averages. Clients whose uploads are not received keep their previous server-side ``z`` value, matching the stale-value aggregation in partial participation and lossy communication settings. """ z_by_client = network.server().aux_vars["z_by_client"] for client in participating_clients: if client in network.server().messages(): z_by_client[client] = network.server().message(client)