Source code for decent_bench.metrics._metrics_view

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from uuid import UUID

import networkx as nx

from decent_bench.agents import Agent, AgentHistory
from decent_bench.costs import Cost
from decent_bench.metrics import utils
from decent_bench.networks import FedNetwork, Network, P2PNetwork


[docs] @dataclass(frozen=True, eq=False) class AgentMetricsView: """Immutable view of agent that exposes useful properties for calculating metrics.""" id: UUID cost: Cost x_history: AgentHistory n_x_updates: int n_function_calls: float n_gradient_calls: float n_hessian_calls: float n_proximal_calls: float n_sent_messages: float n_received_messages: float n_sent_messages_dropped: float n_times_selected: int
[docs] @staticmethod def from_agent(agent: Agent) -> AgentMetricsView: """Create from agent.""" return AgentMetricsView( id=agent.id, cost=agent.cost, x_history=agent._x_history, # noqa: SLF001 n_x_updates=agent._n_x_updates, # noqa: SLF001 n_function_calls=agent._n_function_calls, # noqa: SLF001 n_gradient_calls=agent._n_gradient_calls, # noqa: SLF001 n_hessian_calls=agent._n_hessian_calls, # noqa: SLF001 n_proximal_calls=agent._n_proximal_calls, # noqa: SLF001 n_sent_messages=agent._n_sent_messages, # noqa: SLF001 n_received_messages=agent._n_received_messages, # noqa: SLF001 n_sent_messages_dropped=agent._n_sent_messages_dropped, # noqa: SLF001 n_times_selected=agent._n_times_selected, # noqa: SLF001 )
[docs] class NetworkType(Enum): """Supported network types for metric views.""" P2P = "p2p" FEDERATED = "federated"
[docs] @dataclass(frozen=True) class NetworkMetricsView: """ Immutable view of a network that exposes useful properties for calculating metrics. The underlying data structure is a frozen ``nx.Graph`` whose nodes are ``AgentMetricsView`` objects. The object is created using ``from_network`` passing a ``FedNetwork`` or ``P2PNetwork``. Available methods are: - ``agents()`` and ``connected_agents(agent)`` - Fed-only: ``clients()``, ``server()``, and ``coordinator()`` - P2P-only: ``neighbors(agent)`` """ graph: nx.Graph[AgentMetricsView] network_type: NetworkType _server: AgentMetricsView | None = None
[docs] @staticmethod def from_network(network: Network) -> NetworkMetricsView: """Create a network metrics view from a network.""" snapshot_agents = network.snapshot_agents() agent_views = [AgentMetricsView.from_agent(agent) for agent in snapshot_agents] agent_map = dict(zip(snapshot_agents, agent_views, strict=True)) relabeled_graph = nx.relabel_nodes(network.graph, agent_map, copy=True) frozen_graph = nx.freeze(relabeled_graph.copy()) if isinstance(network, FedNetwork): server_view = agent_map[network.server()] return NetworkMetricsView( graph=frozen_graph, network_type=NetworkType.FEDERATED, _server=server_view, ) if isinstance(network, P2PNetwork): return NetworkMetricsView(graph=frozen_graph, network_type=NetworkType.P2P) raise ValueError(f"Unsupported network type: {type(network)!r}")
[docs] def agents(self) -> list[AgentMetricsView]: """Return agents exposed by network semantics (clients for federated, all nodes for P2P).""" if self.network_type is NetworkType.FEDERATED: return [agent for agent in self.graph.nodes if agent is not self._server] return list(self.graph.nodes)
[docs] def clients(self) -> list[AgentMetricsView]: """Return clients in a federated network (alias of agents()).""" if self.network_type is not NetworkType.FEDERATED: raise ValueError("clients() is only available for federated networks") return self.agents()
[docs] def server(self) -> AgentMetricsView: """Return the server node in a federated network.""" if self.network_type is not NetworkType.FEDERATED or self._server is None: raise ValueError("server() is only available for federated networks") return self._server
[docs] def coordinator(self) -> AgentMetricsView: """Alias for server().""" return self.server()
[docs] def connected_agents(self, agent: AgentMetricsView) -> list[AgentMetricsView]: """Return agents in the network connected to an agent.""" if agent not in self.graph: raise ValueError("agent is not in the network") return list(self.graph.neighbors(agent))
[docs] def neighbors(self, agent: AgentMetricsView) -> list[AgentMetricsView]: """Return neighbors in a peer-to-peer network.""" if self.network_type is not NetworkType.P2P: raise ValueError("neighbors() is only available for p2p networks") return self.connected_agents(agent)
@property def iterations(self) -> list[int]: """List of iterations reached by any agent (plus server) in the network.""" return utils.all_sorted_iterations(list(self.graph.nodes))