"""Utilities for algorithm initialization and general helpers."""
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.networks import FedNetwork, Network
from decent_bench.utils.array import Array
from decent_bench.utils.types import InitialStates
__all__ = [
"initial_states",
"normal_initialization",
"pytorch_initialization",
"uniform_initialization",
]
[docs]
def initial_states(x0: InitialStates, network: Network) -> "dict[Agent, Array]": # noqa: PLR0912
"""
Build per-agent initial states, for use in :meth:`~decent_bench.algorithms.Algorithm.initialize`.
Args:
x0 (InitialStates):
- ``None``: initialize all agents to zeros, using each agent's native shape/framework/device.
- ``Array``: apply the same state to all agents.
- ``dict[Agent, Array]``: explicit per-agent states.
network (Network): network instance containing the target agents.
Returns:
dict[Agent, Array]: mapping from each network agent to its initial state.
Raises:
ValueError: if ``x0`` is missing required agent entries.
TypeError: if ``x0`` has an invalid type.
Notes:
For :class:`~decent_bench.networks.FedNetwork`, explicit ``x0`` dictionaries must provide client entries.
If the server entry is missing, it is inferred as the average of client initial states.
Keys in ``x0`` not referring to agents in the network are silently ignored.
"""
if x0 is None:
x0s = {a: iop.zeros(a.cost.framework, a.cost.device, a.cost.shape) for a in network.graph}
elif isinstance(x0, dict):
# match by agent.id to handle deep-copied dicts whose keys are different instances
x0_by_id = {}
for a, v in x0.items():
if not isinstance(a, Agent):
raise TypeError(f"``x0`` must have keys of type Agent, got {type(a)}")
if iop.shape(v) != a.cost.shape:
raise ValueError(
f"Shape of initial state for agent {a} does not match cost shape: "
f"expected {a.cost.shape}, got {iop.shape(v)}"
)
x0_by_id[a.id] = v
x0s = {}
for a in network.agents():
if a.id not in x0_by_id:
raise ValueError(f"x0 not provided for agent {a}")
x0s[a] = x0_by_id[a.id]
if isinstance(network, FedNetwork):
server = network.server()
if server.id not in x0_by_id:
client_initial_states = [x0s[a] for a in network.clients()]
client_state_sum = iop.zeros_like(client_initial_states[0])
for client_initial_state in client_initial_states:
client_state_sum += client_initial_state
x0s[server] = client_state_sum / len(client_initial_states)
else:
x0s[server] = x0_by_id[server.id]
elif iop.is_supported_array_type(x0):
for a in network.graph:
if iop.shape(x0) != a.cost.shape:
raise ValueError(
f"Shape of initial state for agent {a} does not match cost shape: "
f"expected {a.cost.shape}, got {iop.shape(x0)}"
)
x0s = dict.fromkeys(network.graph, x0)
else:
raise ValueError(f"Invalid x0: expected None, Array, or dict[Agent, Array], got {type(x0)}")
# ignore keys that are not network agents and normalize to the target framework/device
return {a: iop.to_array(x0s[a], framework=a.cost.framework, device=a.cost.device) for a in network.graph}
[docs]
def normal_initialization(
network: Network,
mean: float = 0.0,
std: float = 1.0,
) -> "dict[Agent, Array]":
"""
Build per-agent initial states sampled from a normal distribution.
Args:
network (Network): network instance containing the target agents.
mean (float): mean of the normal distribution used to sample each state entry.
std (float): standard deviation of the normal distribution used to sample each state entry.
Returns:
dict[Agent, Array]: mapping from each agent to an independently sampled random initial state.
Notes:
The states are created using each agent's own ``cost.shape``, ``cost.framework``, and ``cost.device``.
"""
return {
a: iop.normal(shape=a.cost.shape, framework=a.cost.framework, device=a.cost.device, mean=mean, std=std)
for a in network.graph
}
[docs]
def pytorch_initialization(
network: Network,
all_same: bool = False,
) -> "dict[Agent, Array] | Array":
"""
Build per-agent initial states using ``PyTorchCost.model`` initialization routine.
Gets the initialized parameter tensor for every agent from
``PyTorchCost.model`` (via :meth:`torch.nn.Module.parameters`), and flattens it.
The returned dict is compatible with :func:`initial_states` and can be passed
directly as ``x0`` to any algorithm.
Args:
network (Network): network instance containing the target agents.
All agents must have a :class:`~decent_bench.costs.PyTorchCost`.
all_same (bool): if ``True``, use the first agent's initialized state for all agents.
Returns:
dict[Agent, Array] | Array: mapping from each network agent to its initial state,
as a flattened parameter vector extracted from the initialized model.
Raises:
TypeError: if any agent's cost is not a :class:`~decent_bench.costs.PyTorchCost`.
"""
from decent_bench.costs import PyTorchCost # noqa: PLC0415
if all_same:
# use the first agent's initialized model as the shared initial state for all agents
first_agent = next(iter(network.agents()))
if not isinstance(first_agent.cost, PyTorchCost):
raise TypeError(
f"Agent {first_agent} has cost of type {type(first_agent.cost).__name__!r}, expected PyTorchCost."
)
return iop.to_array(
first_agent.cost._get_model_parameters(), # noqa: SLF001
framework=first_agent.cost.framework,
device=first_agent.cost.device,
)
x0s = {}
for a in network.graph:
if not isinstance(a.cost, PyTorchCost):
raise TypeError(f"Agent {a} has cost of type {type(a.cost).__name__!r}, expected PyTorchCost.")
val = iop.to_array(
a.cost._get_model_parameters(), # noqa: SLF001
framework=a.cost.framework,
device=a.cost.device,
)
x0s[a] = val
return x0s