Source code for decent_bench.algorithms.federated._fed_nova

from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import decent_bench.utils.interoperability as iop
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import FedNetwork
from decent_bench.schemes import ClientSelectionScheme, UniformSelection
from decent_bench.utils._tags import tags
from decent_bench.utils.agent_utils import infer_client_data_size
from decent_bench.utils.types import InitialStates, LocalSteps

from ._fed_algorithm import FedAlgorithm

if TYPE_CHECKING:
    from decent_bench.agents import Agent
    from decent_bench.utils.array import Array


_NORMALIZER_CHANNEL = "normalizer"
_CUMULATIVE_GRADIENT_CHANNEL = "cumulative_gradient"


[docs] @tags("federated") @dataclass(eq=False) class FedNova(FedAlgorithm): r""" FedNova with optional local momentum, proximal correction, and server momentum :footcite:p:`Alg_FedNova`. Each selected client starts from the broadcast global model :math:`\mathbf{x}_t = \mathbf{x}^{(t,0)}` and performs :math:`\tau_i` local steps with client step size ``step_size``. At local step :math:`k`, client :math:`i` computes the gradient .. math:: \mathbf{g}_{i, t}^{(k)} = \nabla F_i(\mathbf{x}_{i, t}^{(k)}) + \mu \left(\mathbf{x}_{i, t}^{(k)} - \mathbf{x}_t\right), where the proximal term is present only when ``use_prox=True``. If local momentum is enabled, the momentum buffer and local direction update as .. math:: \mathbf{v}_{i, t}^{(k+1)} = \beta \mathbf{v}_{i, t}^{(k)} + \mathbf{g}_{i, t}^{(k)}, \qquad \mathbf{d}_{i, t}^{(k)} = \mathbf{v}_{i, t}^{(k+1)}, otherwise :math:`\mathbf{d}_{i, t}^{(k)} = \mathbf{g}_{i, t}^{(k)}`. The local model update is .. math:: \mathbf{x}_{i, t}^{(k+1)} = \mathbf{x}_{i, t}^{(k)} - \eta_l \mathbf{d}_{i, t}^{(k)}. Client :math:`i` accumulates the local update .. math:: \mathbf{c}_i^t = \sum_{k=0}^{\tau_i - 1} \eta_l \mathbf{d}_{i, t}^{(k)}, and maintains the FedNova scalar recurrences .. math:: s_i^{(k+1)} = \beta s_i^{(k)} + 1 when ``use_momentum=True``, and :math:`s_i^{(k+1)} = 1` otherwise, together with .. math:: a_i^{(k+1)} = (1 - \eta_l \mu) a_i^{(k)} + s_i^{(k+1)} when ``use_prox=True`` and :math:`a_i^{(k+1)} = a_i^{(k)} + s_i^{(k+1)}` otherwise. During ``initialize``, the server resolves and stores each client's sample count :math:`n_i`. After :math:`\tau_i` local steps, client :math:`i` first uploads the FedNova coefficient :math:`a_i` and then uploads the cumulative local update :math:`\mathbf{c}_i^t` in a second transmission. For the clients in the current round whose two uploads are both actually received, the server forms the data-proportional client weight .. math:: p_i = \frac{n_i}{\sum_{j \in S_t} n_j}. The server forms the weighted effective local-step coefficient .. math:: \tau_{\mathrm{eff}, t} = \bar{a}_t = \sum_{i \in S_t} p_i a_i, and the normalized FedNova aggregate .. math:: \mathbf{G}_t = \sum_{i \in S_t} p_i \frac{\tau_{\mathrm{eff}, t}}{a_i} \mathbf{c}_i^t. Without server momentum, the server update is .. math:: \mathbf{x}_{t+1} = \mathbf{x}_t - \mathbf{G}_t. With ``use_server_momentum=True``, the server momentum buffer and model update become .. math:: \mathbf{m}_{t+1} = \gamma \mathbf{m}_t + \mathbf{G}_t, \qquad \mathbf{x}_{t+1} = \mathbf{x}_t - \mathbf{m}_{t+1}. When ``use_momentum=False``, ``use_prox=False``, and ``use_server_momentum=False``, this reduces to the plain local-SGD FedNova variant. In that plain setting, FedNova reduces to FedAvg if and only if all participating clients use the same number of local steps (:math:`\tau_i = \tau_j` for all :math:`i, j \in S_t`) and FedNova and FedAvg both use data-proportional aggregation weights. Here :math:`\tau_i` is the number of local SGD steps used by client :math:`i` (the corresponding argument is ``num_local_steps``), :math:`\eta_l` is the local step size (the corresponding argument is ``step_size``), :math:`\mu` is the proximal coefficient (the corresponding argument is ``penalty``), :math:`\beta` is the local momentum coefficient (the corresponding argument is ``momentum``), and :math:`\gamma` is the server momentum coefficient (the corresponding argument is ``server_momentum``). In this implementation, :math:`n_i` is inferred once during ``initialize`` from each client's local cost via :func:`~decent_bench.utils.agent_utils.infer_client_data_size`, then stored on the server for later rounds. If no first-phase ``a_i`` uploads are received in a round under network impairments, the server skips that round without error. .. footbibliography:: """ iterations: int = 100 step_size: float = 0.001 num_local_steps: LocalSteps = 1 use_momentum: bool = False momentum: float = 0.9 use_prox: bool = False penalty: float = 0.01 use_server_momentum: bool = False server_momentum: float = 0.9 selection_scheme: ClientSelectionScheme | None = field( default_factory=lambda: UniformSelection(fraction_selected_clients=1.0) ) x0: InitialStates = None name: str = "FedNova" _num_local_steps_by_client: dict["Agent", int] = field(init=False, repr=False, default_factory=dict) def __post_init__(self) -> None: """ Validate hyperparameters. Raises: ValueError: if hyperparameters are invalid. """ if self.step_size <= 0: raise ValueError("`step_size` must be positive") if not (0 <= self.momentum < 1): raise ValueError("`momentum` must satisfy 0 <= momentum < 1") if self.penalty < 0: raise ValueError("`penalty` must be non-negative") if not (0 <= self.server_momentum < 1): raise ValueError("`server_momentum` must satisfy 0 <= server_momentum < 1") self._validate_num_local_steps() def _resolve_client_sample_counts(self, network: FedNetwork) -> dict["Agent", float]: client_sample_counts: dict[Agent, float] = {} for client in network.clients(): client_sample_counts[client] = infer_client_data_size(client) return client_sample_counts def initialize(self, network: FedNetwork) -> None: self.x0 = initial_states(self.x0, network) server = network.server() server_x0 = self.x0[server] aux_vars: dict[str, Any] = { "client_sample_counts": self._resolve_client_sample_counts(network), "received_a_i": {}, } if self.use_server_momentum: aux_vars["m"] = iop.zeros_like(server_x0) server.initialize(x=server_x0, aux_vars=aux_vars) for client in network.clients(): client.initialize(x=self.x0[client]) self._num_local_steps_by_client = self._settle_num_local_steps(network) self.num_local_steps = self._num_local_steps_by_client def step(self, network: FedNetwork, iteration: int) -> None: selected_clients = self.select_clients(network, iteration) if not selected_clients: return self.server_broadcast(network, selected_clients) participating_clients = self._clients_with_server_broadcast(network, selected_clients) if not participating_clients: return self._run_local_updates(network, participating_clients) self._collect_received_normalizers(network, participating_clients) if not network.server().aux_vars["received_a_i"]: for client in participating_clients: client.aux_vars.pop("_fednova_cumulative_gradient", None) return self._communicate_cumulative_gradients(network, participating_clients) self.aggregate(network, participating_clients) def _run_local_updates(self, network: FedNetwork, participating_clients: Sequence["Agent"]) -> None: server = network.server() for client in participating_clients: local_x, cumulative_gradient, a_i = self._compute_local_update(client, server) client.x = local_x client.aux_vars["_fednova_cumulative_gradient"] = cumulative_gradient normalizer_upload = iop.reshape(iop.to_array_like(a_i, cumulative_gradient), (1,)) network.send(sender=client, receiver=server, msg=normalizer_upload, channel=_NORMALIZER_CHANNEL) def _compute_local_update(self, client: "Agent", server: "Agent") -> tuple["Array", "Array", float]: """ Run local SGD steps and return the cumulative local SGD update and FedNova coefficient ``a_i``. Costs that preserve the empirical-risk abstraction default ``gradient`` to ``indices="batch"``, so FedNova performs mini-batch local updates automatically. Generic costs keep their usual full-gradient behavior. This method assumes ``initialize`` has already normalized ``num_local_steps`` to a per-client mapping. """ reference_x = self._get_server_broadcast(client, server) local_x = iop.copy(reference_x) cumulative_gradient = iop.zeros_like(reference_x) local_momentum = iop.zeros_like(reference_x) tau_i = self._num_local_steps_by_client[client] a_i = 0.0 momentum_scalar = 0.0 for _ in range(tau_i): grad = client.cost.gradient(local_x) if self.use_prox: grad += self.penalty * (local_x - reference_x) if self.use_momentum: local_momentum = (self.momentum * local_momentum) + grad direction = local_momentum else: direction = grad local_step_update = self.step_size * direction local_x -= local_step_update cumulative_gradient += local_step_update momentum_scalar = (self.momentum * momentum_scalar) + 1.0 if self.use_momentum else 1.0 if self.use_prox: a_i = ((1 - (self.step_size * self.penalty)) * a_i) + momentum_scalar else: a_i += momentum_scalar return local_x, cumulative_gradient, a_i def _collect_received_normalizers(self, network: FedNetwork, participating_clients: Sequence["Agent"]) -> None: server = network.server() received_normalizers = { client: iop.astype(server.message(client, _NORMALIZER_CHANNEL), float) for client in participating_clients if client in server.messages(_NORMALIZER_CHANNEL) } server.aux_vars["received_a_i"] = received_normalizers def _communicate_cumulative_gradients(self, network: FedNetwork, participating_clients: Sequence["Agent"]) -> None: server = network.server() for client in participating_clients: cumulative_gradient = client.aux_vars.pop("_fednova_cumulative_gradient") network.send( sender=client, receiver=server, msg=cumulative_gradient, channel=_CUMULATIVE_GRADIENT_CHANNEL, )
[docs] def aggregate( self, network: FedNetwork, participating_clients: Sequence["Agent"], ) -> None: r""" Aggregate FedNova client uploads following the Local-SGD FedNova pseudocode. This method assumes the current round has already cached the received FedNova coefficients ``a_i`` in ``server.aux_vars["received_a_i"]`` and that cumulative local updates :math:`\mathbf{c}_i^t` are read from the server inbox under the cumulative-gradient channel. Client sample counts are looked up from the mapping stored on the server during ``initialize``. Only clients with both uploads available in the current round are aggregated; if none are available, this method returns without updating the server model. Raises: ValueError: if any received FedNova coefficient ``a_i`` is non-positive. """ server = network.server() received_normalizers = server.aux_vars["received_a_i"] received_gradient_clients = [ client for client in participating_clients if client in server.messages(_CUMULATIVE_GRADIENT_CHANNEL) ] received_clients = [client for client in received_gradient_clients if client in received_normalizers] if not received_clients: return server_sample_counts = server.aux_vars["client_sample_counts"] server_x = iop.copy(server.x) cumulative_gradients = [server.message(client, _CUMULATIVE_GRADIENT_CHANNEL) for client in received_clients] a_values = [received_normalizers[client] for client in received_clients] if any(a_i <= 0 for a_i in a_values): raise ValueError("FedNova coefficients `a_i` must be positive") sample_counts = [server_sample_counts[client] for client in received_clients] total_samples = float(sum(sample_counts)) client_weights = [n_i / total_samples for n_i in sample_counts] tau_eff = sum(client_weight * a_i for client_weight, a_i in zip(client_weights, a_values, strict=True)) weighted_terms = [ client_weight * (tau_eff / a_i) * cumulative_gradient for cumulative_gradient, a_i, client_weight in zip( cumulative_gradients, a_values, client_weights, strict=True ) ] global_update = iop.zeros_like(server_x) for weighted_term in weighted_terms: global_update += weighted_term if self.use_server_momentum: server.aux_vars["m"] = (self.server_momentum * server.aux_vars["m"]) + global_update server.x = server_x - server.aux_vars["m"] else: server.x = server_x - global_update