Source code for decent_bench.algorithms.p2p._kgt
from dataclasses import dataclass
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates
from ._p2p_algorithm import P2PAlgorithm
_STATE_CHANNEL = "state"
_GRADIENT_TRACKER_CHANNEL = "gradient_tracker"
[docs]
@tags("peer-to-peer", "gradient-based")
@dataclass(eq=False)
class KGT(P2PAlgorithm):
"""
K-GT: Gradient Sum Tracking algorithm :footcite:p:`Alg_K_GT`.
Args:
iterations: Total number of communication rounds (T)
num_local_steps: Number of local gradient steps (K)
step_size: Local step size (eta_c)
aux_step_size: Communication step size (eta_s)
x0: Initial parameters (optional)
name: Algorithm name (default "K-GT")
.. footbibliography::
"""
iterations: int = 100 # Total number of communication rounds (T)
num_local_steps: int = 5 # Number of local gradient steps (K)
step_size: float = 0.01 # Local step size (eta_c)
aux_step_size: float = 0.01 # Communication step size (eta_s)
x0: InitialStates = None # Initial parameters (optional)
name: str = "K-GT"
def __post_init__(self) -> None:
"""
Validate parameters.
Raises:
ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps,
step_size, penalty, or alpha).
"""
if self.num_local_steps <= 0:
raise ValueError("local_steps must be positive")
if isinstance(self.step_size, float) and self.step_size <= 0:
raise ValueError("step_size must be positive")
if isinstance(self.aux_step_size, float) and self.aux_step_size <= 0:
raise ValueError("aux_step_size must be positive")
def initialize(self, network: P2PNetwork) -> None:
"""
Initialize agents with x_i^(0) and c_i^(0).
Algorithm 1, line 2.
"""
x0 = initial_states(self.x0, network)
# Get mixing matrix weights
self.W = network.weights
for i in network.agents():
# Initialize c_i^(0) according to line 2
# In practice, we can initialize c_i^(0) = 0 (as noted in the paper)
c_0 = iop.zeros_like(x0[i])
i.initialize(
x=x0[i],
aux_vars={
"c": c_0,
"x_before_local": x0[i],
"z_i": x0[i],
},
)
def step(self, network: P2PNetwork, _: int) -> None:
# Main algorithm loop (lines 4-12)
# Step 1: Local training phase (lines 5-7)
for i in network.active_agents():
self._local_training(i)
# Step 2: Compute z_i and store in aux_vars
multiplier = self.num_local_steps * self.aux_step_size * self.step_size
for i in network.active_agents():
# Compute z_i^(t) = (1/K eta_c)(x_i^(t) - x_i^(t+K))
i.aux_vars["z_i"] = (i.aux_vars["x_before_local"] - i.x) / (self.num_local_steps * self.step_size)
msg = i.aux_vars["x_before_local"] - multiplier * i.aux_vars["z_i"]
i.aux_vars["msg"] = msg
# Step 3: Communication phase
network.broadcast(i, msg, channel=_STATE_CHANNEL)
network.broadcast(i, i.aux_vars["z_i"], channel=_GRADIENT_TRACKER_CHANNEL)
# Step 5: Update tracking variable and model parameters (lines 9-10)
for i in network.active_agents():
self._update_tracking_and_params(i)
def _local_training(self, agent: Agent) -> None:
"""
Perform K local gradient steps.
Algorithm 1, lines 5-7.
"""
# Store x_i^(t) before local steps for z_i computation
agent.aux_vars["x_before_local"] = agent.x
x_k = iop.copy(agent.x)
# Perform K local gradient steps (line 6)
# x_i^(t)+k+1 = x_i^(t)+k - eta_c(grad F_i(x_i^(t)+k; xi_i^(t)+k) + c_i^(t))
for _ in range(self.num_local_steps):
gradient = agent.cost.gradient(x_k)
x_k -= self.step_size * (gradient + agent.aux_vars["c"])
agent.x = x_k
def _update_tracking_and_params(
self,
agent: Agent,
) -> None:
"""
Update tracking variable c_i and model parameters x_i.
Algorithm 1, lines 9-10.
"""
# Get z_i (already computed)
z_i = agent.aux_vars["z_i"]
# Line 9: Update tracking variable
# c_i^(t+1) = c_i^(t) - z_i^(t) + ∑_j w_ij z_j^(t)
# Line 10: Update model parameters
weighted_neighbor_z = self.W[agent, agent] * z_i
weighted_sum = self.W[agent, agent] * agent.aux_vars["msg"]
received_state_updates = agent.messages(_STATE_CHANNEL)
received_tracking_updates = agent.messages(_GRADIENT_TRACKER_CHANNEL)
received_both = set(received_state_updates).intersection(received_tracking_updates)
for j in received_both:
weighted_sum += self.W[agent, j] * received_state_updates[j]
weighted_neighbor_z += self.W[agent, j] * received_tracking_updates[j]
agent.aux_vars["c"] = agent.aux_vars["c"] - z_i + weighted_neighbor_z
agent.x = weighted_sum