Source code for decent_bench.algorithms._algorithm
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, final
from decent_bench.networks import Network
if TYPE_CHECKING:
from decent_bench.agents import Agent
[docs]
class Algorithm[NetworkT: Network](ABC):
"""Base class for decentralized algorithms."""
def __post_init__(self) -> None:
"""Optional hook to be called by dataclasses after __init__.""" # noqa: D401
return
def __init_subclass__(cls, **kwargs: dict[str, Any]) -> None:
"""Validate `iterations` for all subclasses."""
super().__init_subclass__(**kwargs)
# override __post_init__ to inject `iterations` validation
original_post_init: Callable[[Algorithm[NetworkT]], None] | None = getattr(cls, "__post_init__", None)
def __post_init__(self: "Algorithm[NetworkT]") -> None: # noqa: N807
# inject `iterations` validation
if self.iterations <= 0:
raise ValueError("`iterations` must be positive")
# add subclass's __post_init__ if any
if original_post_init:
original_post_init(self)
setattr(cls, "__post_init__", __post_init__) # noqa: B010
iterations: int
"""Number of iterations to run the algorithm for."""
@property
@abstractmethod
def name(self) -> str:
"""Name of the algorithm."""
[docs]
@abstractmethod
def initialize(self, network: NetworkT) -> None:
"""
Initialize the algorithm.
Args:
network: provides the agents and topology for this algorithm.
"""
[docs]
@abstractmethod
def step(self, network: NetworkT, iteration: int) -> None:
"""
Perform one iteration of the algorithm.
Args:
network: provides the agents and topology for this algorithm.
iteration: current iteration number.
"""
[docs]
@abstractmethod
def cleanup_agents(self, network: NetworkT) -> Iterable["Agent"]:
"""
Return the agents whose auxiliary variables should be cleared.
Args:
network: provides the agents and topology for this algorithm.
"""
[docs]
def cleanup(self, network: NetworkT) -> None:
"""
Clean up the algorithm state by clearing auxiliary variables from agents.
This method is used to free up memory used by auxiliary variables that are not needed after training.
Can be overridden to control what gets cleaned up.
Note:
Override :meth:`~decent_bench.algorithms.Algorithm.cleanup_agents` to control which
agents are cleaned up.
Args:
network: provides the agents and topology for this algorithm.
"""
for agent in self.cleanup_agents(network):
if agent.aux_vars is not None:
agent.aux_vars.clear()
@final
def _snapshot_agents(self, network: NetworkT, iteration: int) -> None:
for i in network.snapshot_agents():
# Forcefully save a snapshot on the final iteration
i._snapshot(iteration=iteration, force=iteration == self.iterations) # noqa: SLF001
[docs]
@final
def run(
self,
network: NetworkT,
start_iteration: int = 0,
progress_callback: Callable[[int], None] | None = None,
) -> None:
"""
Run the algorithm.
This method first calls :meth:`initialize`, then :meth:`step` for the specified number of iterations.
Optionally call :meth:`cleanup` after :meth:`run` to clear auxiliary variables
and free up memory.
Args:
network: provides the agents and topology for this algorithm.
start_iteration: iteration number to start from, used when resuming from a checkpoint. If greater than 0,
:meth:`initialize` will be skipped.
progress_callback: optional callback to report progress after each iteration.
Raises:
ValueError: if start_iteration is not in [0, iterations]
Warning:
Do not override this method. Instead, override :meth:`initialize` and :meth:`step` as needed.
Note:
The algorithm saves the agents' states every :attr:`~decent_bench.agents.Agent.state_snapshot_period`.
"""
if start_iteration < 0 or start_iteration > self.iterations:
raise ValueError(
f"Invalid start_iteration {start_iteration} for algorithm with {self.iterations} iterations"
)
if start_iteration == 0:
self.initialize(network)
for k in range(start_iteration, self.iterations):
network._step(k) # noqa: SLF001
self.step(network, k)
# Already completed the iteration, so snapshot with k+1 to indicate the state after iteration k
self._snapshot_agents(network, k + 1)
if progress_callback is not None:
progress_callback(k)