Source code for decent_bench.algorithms.p2p._led
from dataclasses import dataclass
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates
from ._p2p_algorithm import P2PAlgorithm
[docs]
@tags("peer-to-peer", "gradient-based")
@dataclass(eq=False)
class LED(P2PAlgorithm):
"""
Local Exact-Diffusion (LED) algorithm :footcite:p:`Alg_LED`.
Args:
iterations: Total number of communication rounds (r)
num_local_steps: Number of local updates (tau)
step_size: Step size alpha for gradient steps
aux_step_size: Step size beta for dual variable
x0: Initial parameters (optional)
name: Algorithm name (default "LED")
.. footbibliography::
"""
iterations: int = 100 # Total number of communication rounds (r)
num_local_steps: int = 5 # Number of local updates (tau)
step_size: float = 0.01 # Step size alpha for gradient steps
aux_step_size: float = 0.01 # Step size beta for dual variable
x0: InitialStates = None # Initial parameters (optional)
name: str = "LED"
def __post_init__(self) -> None:
"""
Validate parameters.
Raises:
ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps,
step_size, penalty, or alpha).
"""
if self.num_local_steps <= 0:
raise ValueError("local_steps must be positive")
if isinstance(self.step_size, float) and self.step_size <= 0:
raise ValueError("step_size must be positive")
if isinstance(self.aux_step_size, float) and self.aux_step_size <= 0:
raise ValueError("aux_step_size must be positive")
def initialize(self, network: P2PNetwork) -> None:
"""Initialize agents with x_i^0, y_i^0, and phi_i,0^r."""
x0 = initial_states(self.x0, network)
self.W = network.weights
for i in network.agents():
# Initialize y_i^0 = 0 (simplified initialization)
y_0 = iop.zeros_like(x0[i])
# Initialize auxiliary variables
aux_vars = {
"y": y_0, # Dual variable y_i^r
"phi": x0[i], # phi_i,tau^r (to be broadcasted)
}
i.initialize(
x=x0[i],
aux_vars=aux_vars,
)
def step(self, network: P2PNetwork, _: int) -> None:
# Step 1: Local primal updates (tau steps)
for i in network.active_agents():
self._local_primal_updates(i)
# Step 2: Diffusion (communication and mixing)
for i in network.active_agents():
network.broadcast(i, i.aux_vars["phi"])
for i in network.active_agents():
self._diffusion(i)
# Step 3: Local dual update
for i in network.active_agents():
self._local_dual_update(i)
def _local_primal_updates(self, agent: Agent) -> None:
"""
Step 1: Local primal updates (tau steps).
Algorithm 1, line 1:
"""
# Set phi_i,0^r = x_i^r (line 1)
agent.aux_vars["phi"] = iop.copy(agent.x)
# Perform tau local updates (Equation 2a)
for _ in range(self.num_local_steps):
gradient = agent.cost.gradient(agent.aux_vars["phi"])
agent.aux_vars["phi"] -= self.step_size * gradient + self.aux_step_size * agent.aux_vars["y"]
def _diffusion(self, agent: Agent) -> None:
"""
Step 2: Diffusion.
Algorithm 1, line 2:
"""
weighted_sum = self.W[agent, agent] * agent.aux_vars["phi"]
for j, phi_j in agent.messages().items():
weighted_sum += self.W[agent, j] * phi_j
agent.x = weighted_sum
def _local_dual_update(self, agent: Agent) -> None:
"""
Step 3: Local dual update.
Algorithm 1, line 3:
Update the dual variable for exact tracking (Equation 2c).
"""
agent.aux_vars["y"] += agent.aux_vars["phi"] - agent.x