Source code for decent_bench.algorithms.federated._fed_opt

from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import decent_bench.utils.interoperability as iop
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import FedNetwork
from decent_bench.schemes import ClientSelectionScheme, UniformSelection
from decent_bench.utils.types import InitialStates

from ._fed_algorithm import FedAlgorithm

if TYPE_CHECKING:
    from decent_bench.agents import Agent
    from decent_bench.utils.array import Array


[docs] @dataclass(eq=False) class FedOpt(FedAlgorithm, ABC): r""" Shared FedOpt template with client local SGD and server adaptive optimization :footcite:p:`Alg_FedOpt`. Each selected client starts from the broadcast global model :math:`\mathbf{x}_t` and performs ``num_local_steps`` local SGD steps with client step size ``step_size``: .. math:: \mathbf{x}_{i, t}^{(k+1)} = \mathbf{x}_{i, t}^{(k)} - \eta_l \nabla f_i(\mathbf{x}_{i, t}^{(k)}). After :math:`K` local steps, client :math:`i` forms the model delta .. math:: \delta_i^t = \mathbf{x}_{i, t}^{(K)} - \mathbf{x}_t and uploads :math:`\delta_i^t` to the server. The server averages these client deltas uniformly: .. math:: \Delta_t = \frac{1}{|S_t|} \sum_{i \in S_t} \delta_i^t. The server then applies the shared FedOpt first-moment and model updates .. math:: \mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1 - \beta_1) \Delta_t .. math:: \mathbf{v}_t = \Phi(\mathbf{v}_{t-1}, \Delta_t) .. math:: \mathbf{x}_{t+1} = \mathbf{x}_t + \eta \frac{\mathbf{m}_t}{\sqrt{\mathbf{v}_t} + \epsilon}. Here :math:`\eta_l` is the client learning rate (the corresponding argument is ``step_size``), :math:`K` is the number of local SGD steps (the corresponding argument is ``num_local_steps``), :math:`\eta` is the server learning rate (the corresponding argument is ``server_step_size``), :math:`\beta_1` is the first-moment coefficient (the corresponding argument is ``beta_1``), :math:`\epsilon` is the numerical stability term (the corresponding argument is ``epsilon``), and :math:`S_t` is the set of clients whose uploads are actually received in round :math:`t`. The second-moment update :math:`\Phi` is variant-specific and is defined by subclasses. Aggregation is always uniform across the received clients. Costs that preserve the :class:`~decent_bench.costs.EmpiricalRiskCost` abstraction use mini-batch local updates; generic costs use their usual full-gradient updates. .. footbibliography:: """ iterations: int = 100 step_size: float = 0.001 num_local_steps: int = 1 server_step_size: float = 0.001 beta_1: float = 0.9 epsilon: float = 1e-6 selection_scheme: ClientSelectionScheme | None = field( default_factory=lambda: UniformSelection(fraction_selected_clients=1.0) ) x0: InitialStates = None name: str = "FedOpt" def __post_init__(self) -> None: """ Validate hyperparameters. Raises: ValueError: if hyperparameters are invalid. """ if self.step_size <= 0: raise ValueError("`step_size` must be positive") if self.num_local_steps <= 0: raise ValueError("`num_local_steps` must be positive") if self.server_step_size <= 0: raise ValueError("`server_step_size` must be positive") if not (0 <= self.beta_1 < 1): raise ValueError("`beta_1` must satisfy 0 <= beta_1 < 1") if self.epsilon <= 0: raise ValueError("`epsilon` must be positive") def initialize(self, network: FedNetwork) -> None: self.x0 = initial_states(self.x0, network) server = network.server() server_x0 = self.x0[server] server.initialize( x=server_x0, aux_vars={"m": iop.zeros_like(server_x0), "v": iop.zeros_like(server_x0)}, ) for client in network.clients(): client.initialize(x=self.x0[client]) def step(self, network: FedNetwork, iteration: int) -> None: selected_clients = self.select_clients(network, iteration) if not selected_clients: return self.server_broadcast(network, selected_clients) participating_clients = self._clients_with_server_broadcast(network, selected_clients) if not participating_clients: return self._run_local_updates(network, participating_clients) self.aggregate(network, participating_clients) def _run_local_updates(self, network: FedNetwork, participating_clients: Sequence["Agent"]) -> None: server = network.server() for client in participating_clients: reference_x = self._get_server_broadcast(client, server) client.x = self._compute_local_update(client, server) network.send(sender=client, receiver=server, msg=client.x - reference_x) def _compute_local_update(self, client: "Agent", server: "Agent") -> "Array": """ Run local SGD steps using the batching semantics of ``client.cost.gradient``. Costs that preserve the empirical-risk abstraction default ``gradient`` to ``indices="batch"``, so FedOpt variants perform mini-batch local updates automatically. Generic costs keep their usual full-gradient behavior. """ local_x = self._get_server_broadcast(client, server) for _ in range(self.num_local_steps): grad = client.cost.gradient(local_x) local_x -= self.step_size * grad return local_x
[docs] def aggregate( self, network: FedNetwork, participating_clients: Sequence["Agent"], ) -> None: """ Aggregate client model deltas uniformly, then apply the server adaptive optimizer. This method assumes clients upload final local model deltas. """ server = network.server() received_clients = [client for client in participating_clients if client in server.messages()] if not received_clients: return server_x = iop.copy(server.x) model_deltas = [server.message(client) for client in received_clients] weights = [1.0] * len(received_clients) total_weight = float(len(received_clients)) average_delta = self._weighted_average(model_deltas, weights, total_weight) server.aux_vars["m"] = self.beta_1 * server.aux_vars["m"] + (1 - self.beta_1) * average_delta server.aux_vars["v"] = self._update_second_moment(server.aux_vars["v"], average_delta) server.x = server_x + ( self.server_step_size * server.aux_vars["m"] / (iop.sqrt(server.aux_vars["v"]) + self.epsilon) )
@abstractmethod def _update_second_moment(self, second_moment: "Array", average_delta: "Array") -> "Array": """Return the updated server second-moment state for the current round."""