Source code for decent_bench.algorithms.federated._fed_algorithm

from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING

import decent_bench.utils.interoperability as iop
from decent_bench.algorithms._algorithm import Algorithm
from decent_bench.networks import FedNetwork
from decent_bench.schemes import ClientSelectionScheme
from decent_bench.utils.types import LocalSteps

if TYPE_CHECKING:
    from decent_bench.agents import Agent
    from decent_bench.utils.array import Array


[docs] class FedAlgorithm(Algorithm[FedNetwork]): """Federated algorithm - clients collaborate via a central server.""" selection_scheme: ClientSelectionScheme | None = None num_local_steps: LocalSteps = 1
[docs] def cleanup_agents(self, network: FedNetwork) -> Iterable["Agent"]: return [network.server(), *network.clients()]
def _validate_num_local_steps(self) -> None: """ Validate homogeneous or per-client local step counts. Raises: TypeError: if ``num_local_steps`` is not an integer or client mapping. ValueError: if ``num_local_steps`` contains non-positive step counts. """ if isinstance(self.num_local_steps, int): if self.num_local_steps <= 0: raise ValueError("`num_local_steps` must be positive") return if isinstance(self.num_local_steps, dict): for step in self.num_local_steps.values(): if not isinstance(step, int): raise TypeError("`num_local_steps` mapping values must be integers") if step <= 0: raise ValueError("`num_local_steps` must have positive values") return raise TypeError("`num_local_steps` must be an int or a mapping from Agent to integer values") def _settle_num_local_steps(self, network: FedNetwork) -> dict["Agent", int]: """ Resolve homogeneous or per-client local step counts for the network clients. Raises: ValueError: if a per-client mapping is missing a network client. """ clients = network.clients() if isinstance(self.num_local_steps, int): return dict.fromkeys(clients, self.num_local_steps) missing_clients = [client for client in clients if client not in self.num_local_steps] if missing_clients: raise ValueError( "`num_local_steps` mapping must provide a value for every network client; " f"missing clients: {missing_clients}" ) return {client: self.num_local_steps[client] for client in clients}
[docs] def select_clients(self, network: FedNetwork, iteration: int) -> list["Agent"]: """ Select clients for the current federated round. The method selects the subset of active clients that will receive the server broadcast and perform local training. The clients are selected using ``self.selection_scheme``. If ``self.selection_scheme`` is ``None``, all active clients are selected. """ active_clients = network.active_clients() if not active_clients: return [] if self.selection_scheme is None: selected_clients = list(active_clients) else: selected_clients = self.selection_scheme.select(active_clients, iteration) self._record_selected_clients(selected_clients) return selected_clients
@staticmethod def _record_selected_clients(selected_clients: Sequence["Agent"]) -> None: """Record clients selected by a federated round.""" for client in selected_clients: client._n_times_selected += 1 # noqa: SLF001
[docs] def server_broadcast( self, network: FedNetwork, selected_clients: Sequence["Agent"], channel: str = "default", ) -> None: """Send the current server model to the selected clients under ``channel``.""" network.send(sender=network.server(), receiver=selected_clients, msg=network.server().x, channel=channel)
def _clients_with_server_broadcast( self, network: FedNetwork, selected_clients: Sequence["Agent"], channel: str = "default", ) -> list["Agent"]: """Return selected clients that received :meth:`server_broadcast` under ``channel``.""" return [client for client in selected_clients if network.server() in client.messages(channel)] @staticmethod def _get_server_broadcast(client: "Agent", server: "Agent", channel: str = "default") -> "Array": """ Return the current server broadcast received by the client under ``channel``. Raises: ValueError: if the client did not receive the current server broadcast. """ if server not in client.messages(channel): raise ValueError("Client did not receive the current server broadcast") return iop.copy(client.message(server, channel)) @staticmethod def _weighted_average(values: Sequence["Array"], weights: Sequence[float], total_weight: float) -> "Array": """Compute a weighted average of same-shaped arrays.""" weighted_sum = iop.zeros_like(values[0]) for value, weight in zip(values, weights, strict=True): weighted_sum += value * weight return weighted_sum / total_weight
[docs] def aggregate( self, network: FedNetwork, participating_clients: Sequence["Agent"], ) -> None: """ Aggregate client model uploads at the server using uniform averaging. This default federated aggregation assumes clients upload final local model states. """ received_clients = [client for client in participating_clients if client in network.server().messages()] if not received_clients: return updates = [network.server().message(client) for client in received_clients] weights = [1.0] * len(received_clients) total_weight = float(len(received_clients)) network.server().x = self._weighted_average(updates, weights, total_weight)