Source code for decent_bench.algorithms.p2p._gt_saga

from dataclasses import dataclass
from typing import TYPE_CHECKING

import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.costs import EmpiricalRiskCost
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates

from ._p2p_algorithm import P2PAlgorithm

_GRADIENT_TRACKER_CHANNEL = "gradient_tracker"
_STATE_CHANNEL = "state"


[docs] @tags("peer-to-peer", "gradient-tracking") @dataclass(eq=False) class GT_SAGA(P2PAlgorithm): # noqa: N801 """ Gradient Tracking with SAGA variance reduction :footcite:p:`Alg_GT_SAGA_2020` :footcite:p:`Alg_GT_SAGA_2022`. Args: iterations: Total number of iterations step_size: Step size for local updates x0: Initial parameters (optional) name: Algorithm name (default "GT-SAGA") Raises: TypeError: If any agent's cost function is not an instance of EmpiricalRiskCost. .. footbibliography:: """ iterations: int = 100 step_size: float = 0.01 x0: InitialStates = None # Initial parameters (optional) name: str = "GT-SAGA" def __post_init__(self) -> None: """ Validate parameters. Raises: ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps, step_size, penalty, or alpha). """ if isinstance(self.step_size, float) and self.step_size <= 0: raise ValueError("step_size must be positive") def initialize(self, network: P2PNetwork) -> None: self.x0 = initial_states(self.x0, network) self.W = network.weights for i in network.agents(): # Check that cost function supports SAGA if not isinstance(i.cost, EmpiricalRiskCost): raise TypeError("GT-SAGA only supports EmpiricalRiskCost instances.") # Initialize gradient table: z_{i,j}^0 = x_i^0 for all j z_grads = i.cost.gradient(self.x0[i], indices="all", reduction=None) # Initialize y_i^0 = 0_p and g_i^{-1} = 0_p y0 = iop.zeros_like(self.x0[i]) g_minus1 = iop.zeros_like(self.x0[i]) # Initialize auxiliary variables aux_vars = { "z_grads": z_grads, # Gradient table z_{i,j} "y": y0, # Gradient tracking variable y_i^0 = 0 "g_old": g_minus1, # Previous gradient estimator g_i^{-1} = 0 "g": g_minus1, } i.initialize(x=self.x0[i], aux_vars=aux_vars) def step(self, network: P2PNetwork, _: int) -> None: # Step 1: Select random sample and update local stochastic gradient estimator for i in network.active_agents(): self._update_gradient_estimator(i) # Step 2: Update gradient tracker # y_i^{k+1} = sum_{r=1}^n w_ir (y_r^k + g_r^k - g_r^{k-1}) for i in network.active_agents(): # Broadcast y_i + g_i - g_i^{-1} y_plus_delta_g = i.aux_vars["y"] + i.aux_vars["g"] - i.aux_vars["g_old"] i.aux_vars["y_plus_delta_g"] = y_plus_delta_g network.broadcast(i, y_plus_delta_g, channel=_GRADIENT_TRACKER_CHANNEL) for i in network.active_agents(): self._update_gradient_tracker(i) # Step 3: Update local estimate of the solution # x_i^{k+1} = sum_{r=1}^n w_ir (x_r^k - alpha*y_r^{k+1}) for i in network.active_agents(): # Broadcast x_i - alpha*y_i to reduce communication x_minus_alpha_y = i.x - self.step_size * i.aux_vars["y"] i.aux_vars["x_minus_alpha_y"] = x_minus_alpha_y network.broadcast(i, x_minus_alpha_y, channel=_STATE_CHANNEL) for i in network.active_agents(): self._consensus_update(i) # Step 4: Update gradient table for a select samples for i in network.active_agents(): self._update_gradient_table(i) def _update_gradient_estimator(self, agent: Agent) -> None: """ Update local stochastic gradient estimator using SAGA variance reduction. Raises: TypeError: If the agent's cost function is not an instance of EmpiricalRiskCost. """ if TYPE_CHECKING: if not isinstance(agent.cost, EmpiricalRiskCost): raise TypeError("GT-SAGA is only compatible with EmpiricalRiskCost.") # Store old g_i for gradient tracking update agent.aux_vars["g_old"] = agent.aux_vars["g"] # Compute grad f_{i,tau_i}(x_i^k), gradient at current point for selected sample grad_current = agent.cost.gradient(agent.x) batch_used = agent.cost.batch_used # Get grad f_{i,tau_i}(z_{i,tau_i}^k), gradient at stored point for selected sample z_grads = iop.mean(agent.aux_vars["z_grads"][batch_used], dim=0) # Compute (1/m) sum_{j=1}^m grad f_{i,j}(z_{i,j}^k), average of all gradients in table avg_table_grad = iop.mean(agent.aux_vars["z_grads"], dim=0) # Update SAGA gradient estimator # g_i^k = grad f_{i,tau_i}(x_i) - grad f_{i,tau_i}(z_{i,tau_i}) + (1/m) sum_{j=1}^m grad f_{i,j}(z_{i,j}) agent.aux_vars["g"] = grad_current - z_grads + avg_table_grad def _update_gradient_tracker(self, agent: Agent) -> None: """Update local gradient tracker.""" weighted_sum = self.W[agent, agent] * agent.aux_vars["y_plus_delta_g"] for j, y_plus_delta_g in agent.messages(_GRADIENT_TRACKER_CHANNEL).items(): weighted_sum += self.W[agent, j] * y_plus_delta_g agent.aux_vars["y"] = weighted_sum def _consensus_update(self, agent: Agent) -> None: """Update local estimate via consensus.""" weighted_sum = self.W[agent, agent] * agent.aux_vars["x_minus_alpha_y"] for j, x_minus_alpha_y in agent.messages(_STATE_CHANNEL).items(): weighted_sum += self.W[agent, j] * x_minus_alpha_y agent.x = weighted_sum def _update_gradient_table(self, agent: Agent) -> None: """ Update gradient table for the selected sample. Raises: TypeError: If the agent's cost function is not an instance of EmpiricalRiskCost. """ if TYPE_CHECKING: if not isinstance(agent.cost, EmpiricalRiskCost): raise TypeError("GT-SAGA is only compatible with EmpiricalRiskCost.") z_grads = agent.cost.gradient(agent.x, indices="batch", reduction=None) batch_used = agent.cost.batch_used # Update the gradient table entry for the selected sample agent.aux_vars["z_grads"][batch_used] = z_grads
# All other entries remain unchanged (implicit)