Source code for decent_bench.algorithms.p2p._prox_skip
import random
from dataclasses import dataclass
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates
from ._p2p_algorithm import P2PAlgorithm
[docs]
@tags("peer-to-peer", "gradient-based")
@dataclass(eq=False)
class ProxSkip(P2PAlgorithm):
"""
Proximal Skip with local gradient steps :footcite:p:`Alg_Prox_Skip`.
Args:
iterations: Total number of iterations (T)
step_size: Step size alpha > 0 for primal updates
aux_step_size: Step size beta > 0 for dual updates
comm_probability: Communication probability 0 < p <= 1 for skipping communication
chi: chi >= 1, averaging weight parameter for weighted averaging during communication
x0: Initial parameters (optional)
name: Algorithm name (default "ProxSkip")
.. footbibliography::
"""
iterations: int = 100 # Total number of iterations (T)
step_size: float = 0.01 # Step size alpha > 0 for primal updates
aux_step_size: float = 0.01 # Step size beta > 0 for dual updates
comm_probability: float = 0.7 # Communication probability 0 < p <= 1
chi: float = 1.0 # chi >= 1, averaging weight parameter
x0: InitialStates = None # Initial parameters (optional)
name: str = "ProxSkip"
def __post_init__(self) -> None:
"""
Validate parameters.
Raises:
ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps,
step_size, penalty, or alpha).
"""
if isinstance(self.step_size, float) and self.step_size <= 0:
raise ValueError("step_size must be positive")
if isinstance(self.aux_step_size, float) and self.aux_step_size <= 0:
raise ValueError("aux_step_size must be positive")
if not 0 < self.comm_probability <= 1:
raise ValueError("comm_probability must be in (0, 1]")
if self.chi < 1:
raise ValueError("chi must be >= 1")
def initialize(self, network: P2PNetwork) -> None:
"""
Initialize agents with x_i^0, y_i^0 = 0, and compute weights W_a.
Algorithm 1, line 1:
"""
x0 = initial_states(self.x0, network)
agents = network.agents()
# Compute weights W_a = I - 1/(2χ)(I - W)
n = len(agents)
W = network.weights # noqa: N806
I = iop.eye(n=n, framework=agents[0].cost.framework, device=agents[0].cost.device) # noqa: E741, N806
self.W_a = I - (1.0 / (2.0 * self.chi)) * (I - W)
for i in agents:
# Initialize y_i^0 = 0 (dual variable)
y_0 = iop.zeros_like(x0[i])
# Initialize auxiliary variables
aux_vars = {
"y": y_0, # Dual/control variable y_i^t
"z": x0[i], # Prediction variable z_i^t
}
i.initialize(x=x0[i], aux_vars=aux_vars)
def step(self, network: P2PNetwork, _: int) -> None:
# Main algorithm loop (line 3)
# Step 1: Sample stochastic gradient and compute prediction (lines 4-5)
for i in network.active_agents():
self._compute_prediction(i)
# Step 2: Flip coins to determine communication (line 2)
# theta_k ~ Bernoulli(p), with P(theta_k = 1) = p
theta_k = random.random() < self.comm_probability
# Step 3: Communication and updates (lines 6-11)
if theta_k: # theta_k = 1 (communicate with probability p)
for i in network.active_agents():
# Line 7: Broadcast z_i^t for weighted averaging
network.broadcast(i, i.aux_vars["z"])
# Update based on communication decision
for i in network.active_agents():
if theta_k: # Line 7-8: communicate and update
self._communication_update(i)
else: # Line 10: skip communication
self._no_communication_update(i)
def _compute_prediction(self, agent: Agent) -> None:
"""
Sample gradient and update prediction variable.
Algorithm 1, lines 4-5:
"""
# Sample stochastic gradient g_i^t = grad F_i(x_i^t, xi_i^t)
gradient = agent.cost.gradient(agent.x)
# Update prediction: z_i^t = x_i^t - alpha * g_i^t - y_i^t
agent.aux_vars["z"] = agent.x - self.step_size * gradient - agent.aux_vars["y"]
def _communication_update(self, agent: Agent) -> None:
"""
Communication and update when θ_i = 1.
Algorithm 1, lines 7-8:
"""
# Compute weighted average: x_i^{t+1} = sum_{j=1}^n W_ij z_j^t
# In practice, we only communicate with neighbors, so:
# x_i^{t+1} = sum_{j in N_i} W_a[i,j] z_j^t
weighted_sum = self.W_a[agent, agent] * agent.aux_vars["z"]
for j, z_j in agent.messages().items():
weighted_sum += self.W_a[agent, j] * z_j
# Update primal: x_i^{t+1} = weighted average
agent.x = weighted_sum
# Update dual: y_i^{t+1} = y_i^t + beta * (z_i^t - x_i^{t+1})
agent.aux_vars["y"] += self.aux_step_size * (agent.aux_vars["z"] - agent.x)
def _no_communication_update(self, agent: Agent) -> None:
"""
Skip communication when θ_i = 0.
Algorithm 1, line 10:
"""
# Update primal: x_i^{t+1} = z_i^t (use prediction)
agent.x = agent.aux_vars["z"]
# Dual variable unchanged: y_i^{t+1} = y_i^t
# (no explicit update needed, aux_vars["y"] stays the same)