from __future__ import annotations
import random
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import decent_bench.utils.interoperability as iop
from decent_bench.costs import EmpiricalRiskCost
from decent_bench.utils.agent_utils import infer_client_data_size
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks
if TYPE_CHECKING:
from decent_bench.agents import Agent
[docs]
class AgentActivationScheme(ABC):
"""
Scheme defining how agents go active/inactive over the course of the algorithm execution.
Activation schemes are attached to agents by networks and are queried during algorithm execution.
"""
[docs]
@abstractmethod
def is_active(self, iteration: int) -> bool:
"""
Whether or not the agent is active.
Args:
iteration: current iteration of algorithm execution
"""
[docs]
class AlwaysActive(AgentActivationScheme):
"""Scheme that makes the agent always active."""
[docs]
def is_active(self, iteration: int) -> bool: # noqa: D102, ARG002
return True
[docs]
class MarkovChainActivation(AgentActivationScheme):
"""
Scheme modeling activation with a 2-state Markov chain.
The scheme models activation with a 2-state (active and inactive) Markov chain. The agent transitions
between the two states with the given probabilities.
Args:
inactive_to_active: transition probability from inactive to active
active_to_inactive: transition probability from active to inactive
Raises:
ValueError: if `inactive_to_active` or `active_to_inactive` are not in :math:`[0, 1]`
"""
def __init__(self, inactive_to_active: float = 0.5, active_to_inactive: float = 0.5):
if (inactive_to_active < 0 or inactive_to_active > 1) or (active_to_inactive < 0 or active_to_inactive > 1):
raise ValueError("Transition probabilities must be in [0, 1]")
self.inactive_to_active = inactive_to_active
self.active_to_inactive = active_to_inactive
self._states = np.array([0, 1]) # inactive = 0, active = 1
self._P = np.array(
[
[1 - inactive_to_active, inactive_to_active],
[active_to_inactive, 1 - active_to_inactive],
]
) # transition matrix
self._current_state = iop.rng_numpy().choice(self._states, p=[0, 1])
[docs]
def is_active(self, iteration: int) -> bool: # noqa: D102, ARG002
self._current_state = iop.rng_numpy().choice(
self._states,
p=self._P[self._current_state],
) # evolve the Markov chain
return bool(self._current_state)
[docs]
class PoissonActivation(AgentActivationScheme):
"""
Scheme modeling activation at random intervals determined by a Poisson distribution.
The agent activates at random intervals of length sampled from a Poisson distribution of given mean.
Args:
mean_interval: mean interval of inactivity
Raises:
ValueError: if `mean_interval` is negative
"""
def __init__(self, mean_interval: float = 1.0):
if mean_interval < 0:
raise ValueError("`mean_interval` must be non-negative")
self.mean_interval = mean_interval
self._countdown = int(iop.rng_numpy().poisson(self.mean_interval))
[docs]
def is_active(self, iteration: int) -> bool: # noqa: D102, ARG002
if self._countdown == 0:
self._countdown = int(iop.rng_numpy().poisson(self.mean_interval))
return True
self._countdown -= 1
return False
[docs]
class CyclicActivation(AgentActivationScheme):
"""
Scheme where an agent cycles through active and inactive intervals.
The agent is active for ``active_for`` iterations and inactive for ``inactive_for`` iterations in each cycle.
If ``inactive_for`` is not provided, it defaults to ``active_for``. ``offset`` shifts the phase of the cycle,
allowing agents to follow the same cycle with staggered active windows.
Args:
active_for: number of active iterations in each cycle.
inactive_for: number of inactive iterations in each cycle. If ``None``, it defaults to ``active_for``.
offset: phase offset applied to the cycle.
Raises:
ValueError: if ``active_for``, ``inactive_for``, or ``offset`` is negative, both intervals are zero, or
``iteration`` is negative.
"""
def __init__(self, active_for: int, inactive_for: int | None = None, offset: int = 0):
inactive_for = active_for if inactive_for is None else inactive_for
if active_for < 0 or inactive_for < 0:
raise ValueError("active_for and inactive_for must be non-negative")
if offset < 0:
raise ValueError("offset must be non-negative")
if active_for == 0 and inactive_for == 0:
raise ValueError("At least one of active_for or inactive_for must be positive")
self.active_for = active_for
self.inactive_for = inactive_for
self.offset = offset
[docs]
def is_active(self, iteration: int) -> bool: # noqa: D102
if iteration < 0:
raise ValueError("iteration must be non-negative")
period = self.active_for + self.inactive_for
phase = (iteration + self.offset) % period
return phase < self.active_for
[docs]
class ClientSelectionScheme(ABC):
"""
Scheme defining how to select a subset of available clients.
Federated algorithms call :meth:`select` once per round with the currently active clients. Implementations
should return a subset without modifying the input sequence.
"""
@staticmethod
def _validate_selection_size(
num_selected_clients: int | None,
fraction_selected_clients: float | None,
) -> None:
"""
Validate that exactly one selection-size parameter is provided.
Raises:
ValueError: if neither or both size parameters are provided, or if the provided value is outside the
accepted range.
"""
if num_selected_clients is None and fraction_selected_clients is None:
raise ValueError("Provide num_selected_clients or fraction_selected_clients")
if num_selected_clients is not None and fraction_selected_clients is not None:
raise ValueError("Provide only one of num_selected_clients or fraction_selected_clients")
if num_selected_clients is not None and num_selected_clients <= 0:
raise ValueError("num_selected_clients must be positive")
if fraction_selected_clients is not None and not (0 < fraction_selected_clients <= 1):
raise ValueError("fraction_selected_clients must be in (0, 1]")
@staticmethod
def _resolve_num_selected_clients(
clients: Sequence[Agent],
num_selected_clients: int | None,
fraction_selected_clients: float | None,
) -> int:
"""
Resolve the number of selected clients for a given input client pool.
If ``num_selected_clients`` is provided, it is capped at ``len(clients)``. If
``fraction_selected_clients`` is provided, at least one client is selected from a non-empty input.
"""
if num_selected_clients is not None:
return min(num_selected_clients, len(clients))
k = max(1, int(fraction_selected_clients * len(clients))) # type: ignore[operator]
return min(k, len(clients))
@staticmethod
def _client_loss(client: Agent) -> float:
"""
Evaluate a client's current local loss for selection.
Empirical-risk costs are evaluated on all local samples to avoid consuming a stochastic mini-batch during
client selection.
"""
if isinstance(client.cost, EmpiricalRiskCost):
return client.cost.function(client.x, indices="all")
return client.cost.function(client.x)
[docs]
@abstractmethod
def select(
self,
clients: Sequence[Agent],
iteration: int,
) -> list[Agent]:
"""
Select a subset of available clients.
Args:
clients: available clients
iteration: current iteration of algorithm execution
"""
[docs]
class DataSizeSelection(ClientSelectionScheme):
r"""
Data-size weighted client selection :footcite:p:`Scheme_FedSampling`.
The scheme samples clients without replacement with probability proportional to each client's local data size.
The sampling probability for client :math:`i` is
.. math::
p_i = \frac{n_i}{\sum_{j \in \mathcal{C}} n_j},
where :math:`n_i` is the client's inferred local data size and :math:`\mathcal{C}` is the client pool passed to
:meth:`select`.
Args:
num_selected_clients: number of provided clients to sample.
fraction_selected_clients: fraction of provided clients to sample.
Raises:
ValueError: if the selection size is invalid or any client's data size cannot be inferred.
.. footbibliography::
"""
def __init__(
self,
*,
num_selected_clients: int | None = None,
fraction_selected_clients: float | None = None,
) -> None:
self._validate_selection_size(num_selected_clients, fraction_selected_clients)
self.num_selected_clients = num_selected_clients
self.fraction_selected_clients = fraction_selected_clients
[docs]
def select( # noqa: D102
self,
clients: Sequence[Agent],
iteration: int, # noqa: ARG002
) -> list[Agent]:
if not clients:
return []
k = self._resolve_num_selected_clients(clients, self.num_selected_clients, self.fraction_selected_clients)
if k == len(clients):
return list(clients)
clients_list = list(clients)
data_sizes = np.array(
[infer_client_data_size(client) for client in clients_list],
dtype=np.float64,
)
probabilities = data_sizes / data_sizes.sum()
selected_indices = iop.rng_numpy().choice(len(clients_list), size=k, replace=False, p=probabilities)
return [clients_list[int(index)] for index in selected_indices]
[docs]
class FairSelection(ClientSelectionScheme):
r"""
Fair client selection inspired by fairness-aware client selection :footcite:p:`Scheme_FairFedCS`.
The scheme is a simplified count-based fairness rule that prioritizes clients with fewer past selections. It acts
as a participation-balancing exploration rule: clients selected fewer times are prioritized so that the algorithm
keeps exploring under-represented clients instead of repeatedly selecting the same ones.
At round :math:`t`, let :math:`c_i(t)` be the number of previous rounds in which client :math:`i` was selected.
For the client pool :math:`\mathcal{C}_t` passed to :meth:`select`, the selected set is
.. math::
S_t \in \operatorname{arg\,min}_{S \subseteq \mathcal{C}_t,\ |S| = m}
\sum_{i \in S} c_i(t),
where :math:`m` is the resolved number of selected clients. Clients with the same count keep the order in which
they were provided to :meth:`select`. After selecting :math:`S_t`, the counts are updated as
.. math::
c_i(t+1) = c_i(t) + \mathbf{1}\{i \in S_t\}.
Args:
num_selected_clients: number of provided clients to sample.
fraction_selected_clients: fraction of provided clients to sample.
Raises:
ValueError: if the selection size is invalid.
.. footbibliography::
"""
def __init__(
self,
*,
num_selected_clients: int | None = None,
fraction_selected_clients: float | None = None,
) -> None:
self._validate_selection_size(num_selected_clients, fraction_selected_clients)
self.num_selected_clients = num_selected_clients
self.fraction_selected_clients = fraction_selected_clients
self._selection_counts: dict[Agent, int] = {}
[docs]
def select( # noqa: D102
self,
clients: Sequence[Agent],
iteration: int, # noqa: ARG002
) -> list[Agent]:
if not clients:
return []
k = self._resolve_num_selected_clients(clients, self.num_selected_clients, self.fraction_selected_clients)
if k == len(clients):
selected_clients = list(clients)
else:
clients_list = list(clients)
selected_clients = sorted(clients_list, key=lambda client: self._selection_counts.get(client, 0))[:k]
for client in selected_clients:
self._selection_counts[client] = self._selection_counts.get(client, 0) + 1
return selected_clients
[docs]
class HighLossSelection(ClientSelectionScheme):
r"""
High-loss client selection inspired by Power-of-Choice :footcite:p:`Scheme_PowerOfChoice`.
The scheme evaluates each client's local loss at its current local state ``x`` and selects the clients with
highest loss, breaking ties at random. Unlike the Power-of-Choice strategy, this scheme does not trigger extra
communication to evaluate losses at the current server model.
At round :math:`t`, for the client pool :math:`\mathcal{C}_t` passed to :meth:`select`, the selected set is
.. math::
S_t \in \operatorname{arg\,max}_{S \subseteq \mathcal{C}_t,\ |S| = m}
\sum_{i \in S} F_i(x_i),
where :math:`m` is the resolved number of selected clients, :math:`F_i` is client :math:`i`'s local cost, and
:math:`x_i` is its current local state.
Args:
num_selected_clients: number of provided clients to sample.
fraction_selected_clients: fraction of provided clients to sample.
Raises:
ValueError: if the selection size is invalid.
RuntimeError: if any evaluated client's ``x`` has not been initialized.
.. footbibliography::
"""
def __init__(
self,
*,
num_selected_clients: int | None = None,
fraction_selected_clients: float | None = None,
) -> None:
self._validate_selection_size(num_selected_clients, fraction_selected_clients)
self.num_selected_clients = num_selected_clients
self.fraction_selected_clients = fraction_selected_clients
[docs]
def select( # noqa: D102
self,
clients: Sequence[Agent],
iteration: int, # noqa: ARG002
) -> list[Agent]:
if not clients:
return []
n_selected_clients = self._resolve_num_selected_clients(
clients, self.num_selected_clients, self.fraction_selected_clients
)
if n_selected_clients == len(clients):
return list(clients)
clients_list = list(clients)
losses = [self._client_loss(client) for client in clients_list]
tie_breakers = iop.rng_numpy().permutation(len(clients_list))
ranked_indices = sorted(
range(len(clients_list)),
key=lambda index: (-losses[index], int(tie_breakers[index])),
)
return [clients_list[index] for index in ranked_indices[:n_selected_clients]]
[docs]
class CompressionScheme(ABC):
"""Scheme defining how messages are compressed when sent over the network."""
[docs]
@abstractmethod
def compress(self, msg: Array) -> Array:
"""Apply compression and return a new, compressed message."""
[docs]
def compressed_msg_size(self, msg: Array) -> int:
"""Compute the size of the compressed version of *msg*."""
return int(np.prod(iop.shape(msg))) # replace with msg.size once available
[docs]
class NoCompression(CompressionScheme):
"""Scheme that leaves messages uncompressed."""
[docs]
def compress(self, msg: Array) -> Array: # noqa: D102
return msg
[docs]
class Quantization(CompressionScheme):
r"""
Scheme applying uniform quantization to the message.
Given a message :math:`x` and quantization step :math:`\Delta`, the scheme returns
.. math:: q(x) = \Delta \operatorname{round}(x / \Delta)
where :math:`\operatorname{round}(\cdot)` represents rounding to the nearest integer.
Raises:
ValueError: if ``quantization_step`` is not positive.
"""
def __init__(self, quantization_step: float):
if quantization_step <= 0:
raise ValueError("`quantization_step` must be a positive float")
self.quantization_step = quantization_step
[docs]
def compress(self, msg: Array) -> Array: # noqa: D102
msg_np = iop.to_numpy(msg, dtype=np.float64)
return iop.to_array_like(self.quantization_step * np.rint(msg_np / self.quantization_step), msg)
[docs]
class StochasticQuantization(CompressionScheme):
r"""
Stochastic quantization used in QSGD :footcite:p:`Scheme_QSGD`.
The scheme quantizes each coordinate using ``n_levels`` stochastic levels scaled by the message norm. This keeps the
compressed message unbiased in expectation while preserving the original message shape. Given a message
:math:`x` and :math:`s=\texttt{n\_levels}`, the quantizer computes
.. math::
a_i = \frac{s |x_i|}{\lVert x \rVert_2}, \qquad
\ell_i = \lfloor a_i \rfloor, \qquad
p_i = a_i - \ell_i.
The quantization level is sampled as
.. math::
\xi_i =
\begin{cases}
\ell_i + 1, & \text{with probability } p_i, \\
\ell_i, & \text{with probability } 1 - p_i,
\end{cases}
and the compressed coordinate is
.. math::
Q_s(x_i) = \lVert x \rVert_2 \operatorname{sign}(x_i) \frac{\xi_i}{s}.
Args:
n_levels: number of stochastic quantization levels. Larger values give a finer quantization grid and usually
lower quantization error. Smaller values give coarser quantization and stronger compression noise.
Raises:
ValueError: if ``n_levels`` is not positive.
Warning:
This scheme computes the :math:`\ell_2` norm of each message. This can be computationally expensive for large
messages or when messages live on accelerator devices.
.. footbibliography::
"""
def __init__(self, n_levels: int):
if n_levels <= 0:
raise ValueError("`n_levels` must be a positive integer")
self.n_levels = n_levels
[docs]
def compress(self, msg: Array) -> Array: # noqa: D102
msg_norm = float(iop.norm(msg))
if msg_norm == 0:
return iop.zeros_like(msg)
msg_np = iop.to_numpy(msg, dtype=np.float64)
magnitudes = np.abs(msg_np)
signs = np.sign(msg_np)
scaled_magnitudes = self.n_levels * magnitudes / msg_norm
lower_levels = np.floor(scaled_magnitudes)
probabilities = scaled_magnitudes - lower_levels
quantized_levels = lower_levels + (iop.rng_numpy().random(size=magnitudes.shape) < probabilities)
compressed_msg = msg_norm * signs * quantized_levels / self.n_levels
return iop.to_array_like(compressed_msg, msg)
[docs]
class TopK(CompressionScheme):
"""
Top-k compression which transmits only a subset of elements with largest magnitude.
The parameter ``k`` can be either:
- an ``int``: transmit exactly ``k`` elements, or
- a ``float`` in :math:`(0, 1]`: transmit a fraction ``k`` of elements.
Message size is preserved by transmitting zeros in place of non-transmitted elements.
Raises:
ValueError: if ``k`` is a float and not in :math:`(0, 1]`
ValueError: if ``k`` is an int and less than 1
Note:
If ``k * n_elements < 1``, at least one element is still transmitted.
"""
def __init__(self, k: float):
if isinstance(k, int):
if k < 1:
raise ValueError(f"If `k` is an integer, it must be at least 1, got {k}")
elif k <= 0 or k > 1:
raise ValueError(f"If `k` is a float, it must be in (0, 1], got {k}")
self.k = k
self.is_integer_k = isinstance(self.k, int)
[docs]
def compress(self, msg: Array) -> Array: # noqa: D102
msg_np = iop.to_numpy(msg)
n_elements = msg_np.size
k_count = min(int(self.k), n_elements) if self.is_integer_k else max(1, int(np.ceil(self.k * n_elements)))
flat_msg = msg_np.reshape(-1)
idx = np.argpartition(np.abs(flat_msg), -k_count)[-k_count:]
compressed_flat = np.zeros_like(flat_msg)
compressed_flat[idx] = flat_msg[idx]
return iop.to_array_like(compressed_flat.reshape(msg_np.shape), msg)
[docs]
def compressed_msg_size(self, msg: Array) -> int:
"""Compute the size of the compressed version of *msg*."""
return int(self.k if self.is_integer_k else np.ceil(self.k * np.prod(iop.shape(msg)))) # replace with msg.size
[docs]
class RandK(CompressionScheme):
"""
Rand-k compression which transmits only a random subset of elements.
The parameter ``k`` can be either:
- an ``int``: transmit exactly ``k`` elements chosen uniformly at random (without replacement), or
- a ``float`` in :math:`(0, 1]`: transmit a fraction ``k`` of elements.
Message size is preserved by transmitting zeros in place of non-transmitted elements.
Raises:
ValueError: if ``k`` is a float and not in :math:`(0, 1]`
ValueError: if ``k`` is an int and less than 1
Note:
If ``k * n_elements < 1``, at least one element is still transmitted.
"""
def __init__(self, k: float):
if isinstance(k, int):
if k < 1:
raise ValueError(f"`k` must be at least 1 if an integer, got {k}")
elif k <= 0 or k > 1:
raise ValueError(f"`k` must be in (0, 1], got {k}")
self.k = k
self.is_integer_k = isinstance(self.k, int)
[docs]
def compress(self, msg: Array) -> Array: # noqa: D102
msg_np = iop.to_numpy(msg)
n_elements = msg_np.size
k_count = min(int(self.k), n_elements) if self.is_integer_k else max(1, int(np.ceil(self.k * n_elements)))
flat_msg = msg_np.reshape(-1)
idx = iop.rng_numpy().choice(n_elements, size=k_count, replace=False)
compressed_flat = np.zeros_like(flat_msg)
compressed_flat[idx] = flat_msg[idx]
return iop.to_array_like(compressed_flat.reshape(msg_np.shape), msg)
[docs]
def compressed_msg_size(self, msg: Array) -> int:
"""Compute the size of the compressed version of *msg*."""
return int(self.k if self.is_integer_k else np.ceil(self.k * np.prod(iop.shape(msg)))) # replace with msg.size
[docs]
class DropScheme(ABC):
"""Scheme defining how message drops occur over the network."""
[docs]
@abstractmethod
def should_drop(self) -> bool:
"""Whether or not to drop."""
[docs]
class NoDrops(DropScheme):
"""Scheme that never drops messages."""
[docs]
def should_drop(self) -> bool: # noqa: D102
return False
[docs]
class GilbertElliott(DropScheme):
"""
Drop scheme based on the Gilbert-Elliott model :footcite:p:`Scheme_GilbertElliott`.
The Gilbert-Elliott model is characterized by a Markov chain with two states (good and bad), which
can stay the same or transition into each other. In the bad state message drops occur with probability
`drop_rate`, while in the good state no message drops occur.
Args:
drop_rate: message drop rate while in the bad state
bad_to_good: transition probability from bad to good state
good_to_bad: transition probability from good to bad state
Raises:
ValueError: if `drop_rate`, `bad_to_good` or `good_to_bad` are not in :math:`[0, 1]`
.. footbibliography::
"""
def __init__(self, drop_rate: float, bad_to_good: float = 0.5, good_to_bad: float = 0.5):
if drop_rate < 0 or drop_rate > 1:
raise ValueError("Drop rate must be in [0, 1]")
if (bad_to_good < 0 or bad_to_good > 1) or (good_to_bad < 0 or good_to_bad > 1):
raise ValueError("Transition probabilities `bad_to_good` and `good_to_bad` must be in [0, 1]")
self.drop_rate = drop_rate
self.bad_to_good = bad_to_good
self.good_to_bad = good_to_bad
self._states = np.array([0, 1]) # good = 0, bad = 1
self._P = np.array([[1 - good_to_bad, good_to_bad], [bad_to_good, 1 - bad_to_good]]) # transition matrix
self._current_state = iop.rng_numpy().choice(self._states) # initialize uniformly at random
[docs]
def should_drop(self) -> bool: # noqa: D102
self._current_state = iop.rng_numpy().choice(
self._states, p=self._P[self._current_state]
) # evolve the Markov chain
return iop.rng_numpy().random() < self.drop_rate if self._current_state else False
# later remove framework and device when iop refactored
[docs]
class NoiseScheme(ABC):
"""Scheme defining the noise impacting messages."""
[docs]
@abstractmethod
def make_noise(
self, shape: tuple[int, ...], framework: SupportedFrameworks, device: SupportedDevices
) -> Array | None:
"""Generate noise array of given shape (None if no noise)."""
[docs]
class NoNoise(NoiseScheme):
"""Scheme representing transmission without noise."""
[docs]
def make_noise( # noqa: D102
self,
_: tuple[int, ...],
_framework: SupportedFrameworks,
_device: SupportedDevices,
) -> Array | None:
return None
[docs]
class GaussianNoise(NoiseScheme):
"""
Scheme generating normal noise.
The scheme generates independent noise sampled from a normal distribution with mean ``mean`` and standard deviation
``std`` to each message entry.
Args:
mean: mean of the normal noise.
std: standard deviation of the normal noise.
Raises:
ValueError: if ``std`` is negative.
"""
def __init__(self, mean: float, std: float):
if std < 0:
raise ValueError("Standard deviation (std) must be non-negative for Gaussian noise.")
self.mean = mean
self.std = std
[docs]
def make_noise(self, shape: tuple[int, ...], framework: SupportedFrameworks, device: SupportedDevices) -> Array: # noqa: D102
return iop.normal(framework=framework, device=device, shape=shape, mean=self.mean, std=self.std)