Source code for decent_bench.algorithms.p2p._gt_sarah

from dataclasses import dataclass
from typing import TYPE_CHECKING

import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.costs import EmpiricalRiskCost
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates

from ._p2p_algorithm import P2PAlgorithm

_GRADIENT_TRACKER_CHANNEL = "gradient_tracker"
_STATE_CHANNEL = "state"


[docs] @tags("peer-to-peer", "gradient-tracking") @dataclass(eq=False) class GT_SARAH(P2PAlgorithm): # noqa: N801 """ GT-SARAH: Gradient Tracking with SARAH variance reduction :footcite:p:`Alg_GT_SARAH`. Args: iterations: Total number of outer loops (S) num_local_steps: Number of inner loop iterations (q) step_size: Step size (alpha) for updates x0: Initial parameters (optional) name: Algorithm name (default "GT-SARAH") Raises: TypeError: If any agent's cost function is not an instance of EmpiricalRiskCost. .. footbibliography:: """ iterations: int = 100 # S: number of outer loops num_local_steps: int = 5 # q: number of inner loop iterations step_size: float = 0.01 # alpha: step size x0: InitialStates = None # Initial parameters (optional) name: str = "GT-SARAH" def __post_init__(self) -> None: """ Validate parameters. Raises: ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps, step_size, penalty, or alpha). """ if self.num_local_steps <= 0: raise ValueError("local_steps must be positive") if isinstance(self.step_size, float) and self.step_size <= 0: raise ValueError("step_size must be positive") def initialize(self, network: P2PNetwork) -> None: """ Initialize agents with x_i^{0,1}, y_i^{0,1}, v_i^{-1,1}. Raises: TypeError: If any agent's cost function is not an instance of EmpiricalRiskCost. """ self.x0 = initial_states(self.x0, network) self.W = network.weights for i in network.agents(): # Check that cost function supports variance reduction if not isinstance(i.cost, EmpiricalRiskCost): raise TypeError("GT-SARAH only supports EmpiricalRiskCost instances.") # Initialize y_i^{0,1} = 0 and v_i^{-1,1} = 0 y0 = iop.zeros_like(self.x0[i]) v_minus1 = iop.zeros_like(self.x0[i]) # Initialize auxiliary variables aux_vars = { "y": y0, # Gradient tracking variable y_i^{0,1} = 0 "v": v_minus1, # Current SARAH estimator "v_prev": v_minus1, # v_i^{-1,1} = 0 (for outer loop tracking) "x_prev": self.x0[i], # Store x_{t-1} for SARAH } # Estimate received messages using agent's own initial values i.initialize( x=self.x0[i], aux_vars=aux_vars, ) def step(self, network: P2PNetwork, _: int) -> None: # Step 1: Compute full gradient (batch gradient computation) for i in network.active_agents(): self._compute_batch_grad(i) # Step 2: Update gradient tracker for i in network.active_agents(): network.broadcast(i, i.aux_vars["y"], channel=_GRADIENT_TRACKER_CHANNEL) for i in network.active_agents(): self._update_gradient_tracker(i) # Step 3: Update state for i in network.active_agents(): network.broadcast(i, i.x, channel=_STATE_CHANNEL) for i in network.active_agents(): self._state_update(i) self._inner_loop(network) def _compute_batch_grad(self, agent: Agent) -> None: """ Compute full gradient at the beginning of each outer loop. Algorithm 2.1, line 2: """ agent.aux_vars["v_prev"] = agent.aux_vars["v"] grad = agent.cost.gradient(agent.x, indices="all") # Update v_i^{0,s} = grad f_i(x_i^{0,s}) agent.aux_vars["v"] = grad def _update_sarah_estimator(self, agent: Agent) -> None: """ Update SARAH variance-reduced gradient estimator. Algorithm 2.1, line 8. Raises: TypeError: If the agent's cost function is not an instance of EmpiricalRiskCost. """ if TYPE_CHECKING: if not isinstance(agent.cost, EmpiricalRiskCost): raise TypeError("GT-SAGA is only compatible with EmpiricalRiskCost.") # Store previous inner loop gradient for tracking update agent.aux_vars["v_prev"] = agent.aux_vars["v"] # Compute (1/B) sum_{l=1}^B grad f_{i,tau_l}(x_i^{t,s}) grad_current = agent.cost.gradient(agent.x) batch_used = agent.cost.batch_used # Compute (1/B) sum_{l=1}^B grad f_{i,tau_l}(x_i^{t-1,s}) grad_prev = agent.cost.gradient(agent.aux_vars["x_prev"], indices=batch_used) # SARAH update: v_i^{t,s} = (grad f_i(x_i, xi) - grad f_i(x_prev, xi)) + v_i^{t-1,s} agent.aux_vars["v"] = grad_current - grad_prev + agent.aux_vars["v_prev"] def _update_gradient_tracker(self, agent: Agent) -> None: """ Update gradient tracker at the beginning of outer loop. Algorithm 2.1, line 3 and 9. """ weighted_sum = self.W[agent, agent] * agent.aux_vars["y"] for j, y in agent.messages(_GRADIENT_TRACKER_CHANNEL).items(): weighted_sum += self.W[agent, j] * y agent.aux_vars["y"] = weighted_sum + agent.aux_vars["v"] - agent.aux_vars["v_prev"] def _state_update(self, agent: Agent) -> None: """ Update local estimate via consensus. Algorithm 2.1, lines 4 and 10: """ agent.aux_vars["x_prev"] = agent.x weighted_sum = self.W[agent, agent] * agent.x for j, x in agent.messages(_STATE_CHANNEL).items(): weighted_sum += self.W[agent, j] * x agent.x = weighted_sum - self.step_size * agent.aux_vars["y"] def _inner_loop(self, network: P2PNetwork) -> None: """ Inner loop of GT-SARAH. Algorithm 2.1, lines 7-10. """ for _ in range(self.num_local_steps): # Step 4: SARAH variance reduction for i in network.active_agents(): self._update_sarah_estimator(i) # line 8 # Step 5: Update gradient tracker (inner loop) for i in network.active_agents(): network.broadcast(i, i.aux_vars["y"], channel=_GRADIENT_TRACKER_CHANNEL) for i in network.active_agents(): self._update_gradient_tracker(i) # line 9 # Step 6: Update state (inner loop) for i in network.active_agents(): network.broadcast(i, i.x, channel=_STATE_CHANNEL) for i in network.active_agents(): self._state_update(i)