Source code for decent_bench.algorithms.p2p._dinno
from dataclasses import dataclass
from typing import TYPE_CHECKING
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates
from ._p2p_algorithm import P2PAlgorithm
if TYPE_CHECKING:
from decent_bench.utils.array import Array
[docs]
@tags("peer-to-peer", "gradient-based")
@dataclass(eq=False)
class DiNNO(P2PAlgorithm):
r"""
Distributed Neural Network Optimization (DiNNO) algorithm :footcite:p:`Alg_DiNNO`.
Each iteration, each agent approximately optimizes an augmented Lagrangian function which
is then communicated to its neighbors in order to update the dual variables. This is
then repeated for a number of iterations.
Args:
iterations: Total number of outer iterations (K)
step_size: Step size for primal updates
num_local_steps: Number of inner iterations (B) for approximate primal update
penalty: Penalty parameter (rho) for augmented Lagrangian
x0: Initial parameters (optional)
name: Algorithm name (default "DiNNO")
.. footbibliography::
"""
iterations: int = 100 # Total number of outer iterations (K)
step_size: float = 0.01
num_local_steps: int = 5 # Number of inner iterations (B) for approximate primal update
penalty: float = 0.5 # Penalty parameter (rho) for augmented Lagrangian
x0: InitialStates = None # Initial parameters (optional)
name: str = "DiNNO"
def __post_init__(self) -> None:
"""
Validate parameters.
Raises:
ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps,
step_size, penalty, or alpha).
"""
if self.num_local_steps <= 0:
raise ValueError("local_steps must be positive")
if isinstance(self.step_size, float) and self.step_size <= 0:
raise ValueError("step_size must be positive")
if self.penalty <= 0:
raise ValueError("penalty must be positive")
def initialize(self, network: P2PNetwork) -> None:
# Initialize agents (lines 2-5)
self.x0 = initial_states(self.x0, network)
for i in network.agents():
# Initialize dual variable p_i = theta (line 3)
p_0 = iop.zeros_like(self.x0[i])
i.initialize(
x=self.x0[i], # theta_i^theta = theta_initial (line 4)
aux_vars={"p": p_0},
)
def step(self, network: P2PNetwork, _: int) -> None:
# Main optimization loop (line 7)
# Step 1: Communication - send θ_i^k to neighbors (line 8)
for i in network.active_agents():
network.broadcast(i, i.x)
# Step 2: Dual variable update (line 10) - Equation (4a)
for i in network.active_agents():
self._auxiliary_update(i)
# Step 3: Approximate primal update (lines 11-15) - Equation (4b)
for i in network.active_agents():
self._local_training(i)
def _auxiliary_update(self, agent: Agent) -> None:
# p_i^(k+1) = p_i^k + rho * sum_{j in N_i}(theta_i^k - theta_j^k)
s = None
for val in agent.messages().values():
if s is None:
s = val
else:
s += val
if s is not None:
consensus_error = agent.x * len(agent.messages()) - s
agent.aux_vars["p"] += self.penalty * consensus_error
def _local_training(self, agent: Agent) -> None:
# Initialize psi^0 = theta_i^k (line 11)
psi = iop.copy(agent.x)
neighbor_thetas_sum: Array | None = None
for val in agent.messages().values():
if neighbor_thetas_sum is None:
neighbor_thetas_sum = val
else:
neighbor_thetas_sum += val
if neighbor_thetas_sum is not None:
neighbor_thetas_sum /= 2.0
# Approximate the primal update with B iterations (lines 12-14)
for _ in range(self.num_local_steps):
# Term 1: grad l(psi; D_i)
grad_loss = agent.cost.gradient(psi)
# Term 2: grad(theta^T p_i^(k+1)) = p_i^(k+1)
grad_dual = agent.aux_vars["p"]
# Term 3: grad[rho sum_{j in N_i} ||theta - (theta_i^k + theta_j^k)/2||^2]
# = 2 rho sum_{j in N_i} (theta - (theta_i^k + theta_j^k)/2)
if neighbor_thetas_sum is not None:
consensus_term = (psi - agent.x / 2) * len(agent.messages()) - neighbor_thetas_sum
# Note: factor of 2 from derivative of squared norm
grad_consensus: Array | float = 2.0 * self.penalty * consensus_term
else:
grad_consensus = 0.0
# Gradient step: psi^(tau+1) = psi^tau - step_size * grad L_augmented
total_gradient = grad_loss + grad_dual + grad_consensus
psi -= self.step_size * total_gradient
# Update primal variable theta_i^(k+1) = psi^B (line 15)
agent.x = psi