Source code for decent_bench.agents

from __future__ import annotations

import bisect
import contextlib
from collections.abc import Iterator, Mapping, Sequence
from types import MappingProxyType
from typing import Any, Self, cast
from uuid import UUID, uuid4

import decent_bench.utils.interoperability as iop
from decent_bench.costs import Cost, EmpiricalRiskCost
from decent_bench.schemes import AgentActivationScheme, AlwaysActive
from decent_bench.utils.array import Array


[docs] class Agent: """ Agent with local cost function, activation scheme, state snapshot period, and optional data. At initialization, the agent is assigned a unique id (accessible via ``Agent.id``) which serves as its hash. The agent can also be assigned an index (``Agent.index``). This assignment is performed when initializing a network, and is useful to index arrays by Agent or for user-friendly representation of Agents. Args: cost: local cost function; once assigned, it should not be modified activation: activation scheme to model synchrony/asynchrony; defaults to synchrony (activate at all iterations) state_snapshot_period: how often to record the agent's state when executing an algorithm data: dictionary for arbitrary agent data Raises: ValueError: if ``state_snapshot_period`` is not a positive int """ _id: UUID _index: int def __new__(cls, *_: object, _id: UUID | None = None, **__: object) -> Self: """Ensure agent id and index are defined early (including during unpickling/deepcopy).""" obj = super().__new__(cls) obj._id = uuid4() if _id is None else _id obj._index = -1 return obj def __init__( self, cost: Cost, activation: AgentActivationScheme | None = None, state_snapshot_period: int = 1, data: dict[str, Any] | None = None, ): if state_snapshot_period <= 0: raise ValueError("state_snapshot_period must be a positive integer") self._index = -1 self._cost = cost self._activation = AlwaysActive() if activation is None else activation self._state_snapshot_period = state_snapshot_period self.data = {} if data is None else data self._current_x: Array | None = None self._x_history: AgentHistory = AgentHistory() self._auxiliary_variables: dict[str, Any] = {} self._received_messages = ReceivedMessages() self._n_x_updates = 0 self._n_sent_messages: float = 0 self._n_received_messages: float = 0 self._n_sent_messages_dropped: float = 0 self._n_times_selected = 0 self._is_server = False self._n_function_calls: float = 0 self._n_gradient_calls: float = 0 self._n_hessian_calls: float = 0 self._n_proximal_calls: float = 0 self._no_count_depth: int = 0 # Nesting counter; counting disabled when > 0 cost.function = self._call_counting_function # type: ignore[method-assign] cost.gradient = self._call_counting_gradient # type: ignore[method-assign] cost.hessian = self._call_counting_hessian # type: ignore[method-assign] cost.proximal = self._call_counting_proximal # type: ignore[method-assign] @property def id(self) -> UUID: """Unique id for the agent.""" return self._id @property def index(self) -> int: """Agent index within a network, -1 if unassigned.""" return self._index @index.setter def index(self, value: int) -> None: self._index = value @property def cost(self) -> Cost: """ Local cost function. Alias: :class:`f`, :class:`loss` """ return self._cost # Aliases for cost f = cost loss = cost @property def x(self) -> Array: """ Local optimization variable x. Raises: RuntimeError: if x is retrieved before being set or initialized """ if self._current_x is None: raise RuntimeError("x must be initialized before being accessed") return self._current_x @x.setter def x(self, x: Array) -> None: self._n_x_updates += 1 self._current_x = x @property def state_snapshot_period(self) -> int: """Number of iterations between snapshots of the agent's state.""" return self._state_snapshot_period
[docs] def messages(self, channel: str = "default") -> Mapping[Agent, Array]: """Received messages with ``channel``, keyed by sender.""" return self._received_messages.by_channel(channel)
[docs] def message(self, sender: Agent, channel: str = "default") -> Array: """Received message from ``sender`` with ``channel``.""" return self._received_messages.get(sender, channel)
@property def aux_vars(self) -> dict[str, Any]: """Auxiliary optimization variables used by algorithms that require more variables than x.""" return self._auxiliary_variables
[docs] def initialize( self, *, x: Array | None = None, aux_vars: dict[str, Any] | None = None, ) -> None: """ Initialize local variables and messages before running an algorithm. Args: x: initial x aux_vars: initial auxiliary variables Raises: ValueError: if initialized x has incorrect shape """ self._x_history = AgentHistory() self._auxiliary_variables = {} self._received_messages = ReceivedMessages() self._n_x_updates = 0 self._n_sent_messages = 0 self._n_received_messages = 0 self._n_sent_messages_dropped = 0 self._n_times_selected = 0 self._n_function_calls = 0 self._n_gradient_calls = 0 self._n_hessian_calls = 0 self._n_proximal_calls = 0 if x is not None: if iop.shape(x) != self.cost.shape: raise ValueError(f"Initialized x has shape {iop.shape(x)}, expected {self.cost.shape}") self._x_history[0] = iop.copy(x) self._current_x = iop.copy(x) if aux_vars is not None: self._auxiliary_variables = {k: self._copy_aux_var(v) for k, v in aux_vars.items()}
def _copy_aux_var(self, value: object) -> object: if isinstance(value, Mapping): return {k: self._copy_aux_var(v) for k, v in value.items()} return iop.copy(cast("Array", value)) def _snapshot(self, iteration: int, force: bool = False) -> None: """ Snapshot the agent's state. This saves the current optimization variable x every :attr:`state_snapshot_period` iterations. Warning: This method is called automatically by algorithms during execution. It should not be called manually during algorithm execution, as this may lead to unexpected behaviour of the agent's state history and metrics. Args: iteration: Algorithm iteration force: If true, skip :attr:`state_snapshot_period` and forcefully snapshot the agent state. Useful when saving the agents final state. """ if (force or iteration % self.state_snapshot_period == 0) and self._current_x is not None: self._x_history[iteration] = iop.copy(self._current_x) def _call_counting_function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401 # Call the function first so "batch_used" is populated for EmpiricalRiskCost before counting function calls res = self._cost.__class__.function(self.cost, x, *args, **kwargs) if self._no_count_depth > 0: return res if isinstance(self._cost, EmpiricalRiskCost): self._n_function_calls += len(self._cost.batch_used) else: self._n_function_calls += 1 return res def _call_counting_gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 res = self._cost.__class__.gradient(self.cost, x, *args, **kwargs) if self._no_count_depth > 0: return res if isinstance(self._cost, EmpiricalRiskCost): self._n_gradient_calls += len(self._cost.batch_used) else: self._n_gradient_calls += 1 return res def _call_counting_hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 res = self._cost.__class__.hessian(self.cost, x, *args, **kwargs) if self._no_count_depth > 0: return res if isinstance(self._cost, EmpiricalRiskCost): self._n_hessian_calls += len(self._cost.batch_used) else: self._n_hessian_calls += 1 return res def _call_counting_proximal(self, x: Array, penalty: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 res = self._cost.__class__.proximal(self.cost, x, penalty, *args, **kwargs) if self._no_count_depth > 0: return res if isinstance(self._cost, EmpiricalRiskCost): self._n_proximal_calls += len(self._cost.batch_used) else: self._n_proximal_calls += 1 return res def __index__(self) -> int: """Enable using agent as index, for example ``W[a1, a2]`` instead of ``W[a1.index, a2.index]``.""" return self._index def __repr__(self) -> str: """Human readable representation of the agent.""" return f"Agent {hash(self._id) if self._index == -1 else self._index} (instance_id={id(self)})" def __getnewargs_ex__(self) -> tuple[tuple[()], dict[str, UUID]]: """Preserve Agent._id in pickle/deepcopy by passing it to :meth:`__new__`.""" return (), {"_id": self._id} def __hash__(self) -> int: """Hash of the agent, which coincides with the unique identifier.""" return hash(self._id) def __eq__(self, other: object) -> bool: """Agent instances are equal if they have the same id.""" if not isinstance(other, Agent): return NotImplemented return self._id == other._id
[docs] @staticmethod @contextlib.contextmanager def no_count(agents: Sequence[Agent]) -> Iterator[None]: """ Context manager that disables call counting for a sequence of agents. Use this when computing metrics or other operations that should not be counted as algorithm function/gradient calls. Args: agents: sequence of agents to disable call counting for Example:: with Agent.no_count(agents): value = metric.compute(problem, agents, iteration) """ for agent in agents: agent._no_count_depth += 1 # noqa: SLF001 try: yield finally: for agent in agents: agent._no_count_depth -= 1 # noqa: SLF001
[docs] class ReceivedMessages: """Container for received messages keyed by channel and sender.""" def __init__(self) -> None: self._messages: dict[str, dict[Agent, Array]] = {}
[docs] def put(self, sender: Agent, msg: Array, channel: str = "default") -> None: """Store or overwrite a message from ``sender`` under ``channel``.""" if channel not in self._messages: self._messages[channel] = {} self._messages[channel][sender] = msg
[docs] def get(self, sender: Agent, channel: str = "default") -> Array: """Return the message from ``sender`` under ``channel``.""" return self._messages[channel][sender]
[docs] def has(self, sender: Agent, channel: str = "default") -> bool: """Return ``True`` if a message from ``sender`` exists under ``channel``.""" return channel in self._messages and sender in self._messages[channel]
[docs] def by_channel(self, channel: str = "default") -> Mapping[Agent, Array]: """Return a read-only sender->message mapping for ``channel``.""" return MappingProxyType(self._messages.get(channel, {}))
[docs] def clear(self, sender: Agent | Sequence[Agent] | None = None, channel: str | None = None) -> None: """Clear messages with optional sender/channel scoping.""" if sender is None: if channel is None: # clear all messages self._messages.clear() else: # clear by channel (all senders) self._messages.pop(channel, None) return sender_list = [sender] if isinstance(sender, Agent) else list(sender) if channel is None: # clear by sender(s) (all channels) empty_channels: list[str] = [] for msg_channel, msg_bucket in self._messages.items(): for s in sender_list: msg_bucket.pop(s, None) if len(msg_bucket) == 0: empty_channels.append(msg_channel) for empty_channel in empty_channels: self._messages.pop(empty_channel, None) return # clear (channel, sender(s)) pairs specifically channel_bucket = self._messages.get(channel) if channel_bucket is None: return for s in sender_list: channel_bucket.pop(s, None) if not channel_bucket: self._messages.pop(channel, None)
[docs] class AgentHistory: """ Ordered history of an agent's optimization variable x, indexed by algorithm iteration. Snapshots are stored sparsely — only iterations at which ``Agent._snapshot`` was called are recorded. Lookups for iterations between snapshots fall back to the nearest preceding snapshot. Internally, snapshots are kept in a dict for O(1) exact lookup and a parallel sorted list for O(log n) predecessor search via :mod:`bisect`. """ def __init__(self) -> None: self._x_history: dict[int, Array] = {} self._sorted_keys: list[int] = []
[docs] def max(self) -> int: """ Return the latest iteration for which a snapshot exists. Raises: ValueError: if no snapshots have been recorded yet. """ if len(self._sorted_keys) < 1: raise ValueError("No history available") return self._sorted_keys[-1]
[docs] def min(self) -> int: """ Return the earliest iteration for which a snapshot exists. Raises: ValueError: if no snapshots have been recorded yet. """ if len(self._sorted_keys) < 1: raise ValueError("No history available") return self._sorted_keys[0]
[docs] def items(self) -> Iterator[tuple[int, Array]]: """Yield ``(iteration, x)`` pairs for every snapshot, in ascending iteration order.""" return ((iteration, self._x_history[iteration]) for iteration in self._sorted_keys)
[docs] def values(self) -> Iterator[Array]: """Yield the x snapshot for every recorded iteration, in ascending iteration order.""" return (self._x_history[iteration] for iteration in self._sorted_keys)
[docs] def keys(self) -> list[int]: """Return a sorted list of all iterations for which a snapshot has been recorded.""" return self._sorted_keys.copy()
[docs] def set_x(self, iteration: int, x: Array) -> None: """ Record ``x`` at ``iteration``, replacing any existing snapshot at that iteration. Also available as ``history[iteration] = x``. Raises: ValueError: if ``iteration`` is negative. """ if iteration < 0: raise ValueError(f"Iteration must be non-negative, got {iteration}") if iteration not in self._x_history: bisect.insort(self._sorted_keys, iteration) self._x_history[iteration] = x
[docs] def get_x(self, iteration: int) -> Array: """ Return x at ``iteration``, falling back to the nearest preceding snapshot if needed. Snapshots are not necessarily recorded at every iteration (controlled by :attr:`Agent.state_snapshot_period`). When the exact iteration is not found, the closest snapshot with an iteration number ``<= iteration`` is returned instead. For example, if snapshots exist at iterations 0, 10, 20 and iteration 23 is requested, the snapshot from iteration 20 is returned. Also available as ``value = history[iteration]``. Args: iteration: The algorithm iteration to retrieve x for. Raises: ValueError: if ``iteration`` is before the first recorded snapshot or if ``iteration < -1``. """ if iteration < -1: raise ValueError(f"Iteration must be positive or -1 for the latest snapshot, got {iteration}") if iteration == -1: return self._x_history[self.max()] if iteration not in self._x_history: # Binary search for the closest previous snapshot idx = bisect.bisect_right(self._sorted_keys, iteration) - 1 if idx < 0: raise ValueError(f"No snapshot available for iteration {iteration}") iteration = self._sorted_keys[idx] return self._x_history[iteration]
def __setitem__(self, iteration: int, x: Array) -> None: """Record ``x`` at ``iteration``, replacing any existing snapshot at that iteration.""" self.set_x(iteration, x) def __getitem__(self, iteration: int) -> Array: """ Return x at ``iteration``. Falls back to the nearest preceding snapshot when no exact match exists. See :meth:`get_x` for full semantics. """ return self.get_x(iteration) def __contains__(self, iteration: int) -> bool: """Return ``True`` if an exact snapshot was recorded at ``iteration``.""" return iteration in self._x_history def __iter__(self) -> Iterator[int]: """Iterate over recorded iteration numbers in ascending order.""" return iter(self._sorted_keys) def __len__(self) -> int: """Return the number of snapshots recorded.""" return len(self._x_history) def __repr__(self) -> str: """Human readable representation of the history.""" return self._x_history.__repr__()