Source code for decent_bench.algorithms.p2p._gt_vr

import random
from dataclasses import dataclass
from typing import TYPE_CHECKING

from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.costs import EmpiricalRiskCost
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates

from ._p2p_algorithm import P2PAlgorithm

_STATE_CHANNEL = "state"
_GRADIENT_TRACKER_CHANNEL = "gradient_tracker"


[docs] @tags("peer-to-peer", "gradient-tracking") @dataclass(eq=False) class GT_VR(P2PAlgorithm): # noqa: N801 """ GT-VR: Gradient Tracking with Variance Reduction algorithm :footcite:p:`Alg_GT_VR`. Args: iterations: Total number of iterations step_size: Step size for primal updates snapshot_prob: Probability of performing a snapshot update (P in the paper) x0: Initial parameters (optional) name: Algorithm name (default "GT-VR") Raises: TypeError: If any agent's cost function is not an instance of EmpiricalRiskCost. .. footbibliography:: """ iterations: int = 100 step_size: float = 0.01 snapshot_prob: float = 0.3 # P in the algorithm x0: InitialStates = None # Initial parameters (optional) name: str = "GT-VR" def __post_init__(self) -> None: """ Validate parameters. Raises: ValueError: If any of the parameters are invalid (e.g., non-positive iterations, local_steps, step_size, penalty, or alpha). """ if isinstance(self.step_size, float) and self.step_size <= 0: raise ValueError("step_size must be positive") if not 0 < self.snapshot_prob <= 1: raise ValueError("snapshot_prob must be in (0, 1]") def initialize(self, network: P2PNetwork) -> None: """ Initialize agents. Algorithm 1, line 1 Raises: TypeError: If any agent's cost function is not an instance of EmpiricalRiskCost, since GT-VR relies on variance reduction techniques that require access to individual sample gradients. Using GT-VR with incompatible cost functions may lead to errors or undefined behavior. """ self.x0 = initial_states(self.x0, network) self.W = network.weights for i in network.agents(): # Check that cost function supports variance reduction if not isinstance(i.cost, EmpiricalRiskCost): raise TypeError("GT-VR only supports EmpiricalRiskCost instances.") # Compute full gradient at initialization: grad f_i(x_i^1) full_grad = i.cost.gradient(self.x0[i], indices="all") # Initialize auxiliary variables according to line 1 aux_vars = { "tau": (self.x0[i]), # tau_i^1 = x_i^1 (for snapshot updates) "full_grad_tau": full_grad, # grad f_i(tau_i) - cached to avoid recomputation "y": full_grad, # y_i^1 = grad f_i(x_i^1) "v": full_grad, # v_i^1 = grad f_i(x_i^1) "v_old": full_grad, # Store v_i^k for gradient tracking update } i.initialize( x=self.x0[i], aux_vars=aux_vars, ) def step(self, network: P2PNetwork, _: int) -> None: # Main algorithm loop (line 2) for i in network.active_agents(): x_minus_eta_y = i.x - self.step_size * i.aux_vars["y"] i.aux_vars["x_minus_eta_y"] = x_minus_eta_y network.broadcast(i, x_minus_eta_y, channel=_STATE_CHANNEL) # Step 1: Update local estimate of the solution (line 3) for i in network.active_agents(): self._consensus_update(i) # Step 2: Probabilistic snapshot update (line 4) # Select l_i^{k+1} ~ Bernoulli(P) for i in network.active_agents(): if random.random() < self.snapshot_prob: # l_i^{k+1} = 1 # Update: tau_i^{k+1} = x_i^{k+1} and recompute full gradient self._snapshot_update(i) # Step 3: Select batch and update local gradient estimator (lines 5-6) for i in network.active_agents(): self._update_gradient_estimator(i) # We broadcast y_i + v_i - v_old to reduce communication for i in network.active_agents(): y_plus_delta_v = i.aux_vars["y"] + i.aux_vars["v"] - i.aux_vars["v_old"] i.aux_vars["y_plus_delta_v"] = y_plus_delta_v network.broadcast(i, y_plus_delta_v, channel=_GRADIENT_TRACKER_CHANNEL) # Step 4: Update gradient tracker (line 7) for i in network.active_agents(): self._update_gradient_tracker(i) def _consensus_update(self, agent: Agent) -> None: """ Update local estimate via consensus. Algorithm 1, line 3. """ weighted_sum = self.W[agent, agent] * agent.aux_vars["x_minus_eta_y"] for j, x_minus_eta_y in agent.messages(_STATE_CHANNEL).items(): weighted_sum += self.W[agent, j] * x_minus_eta_y agent.x = weighted_sum def _snapshot_update(self, agent: Agent) -> None: """ Update snapshot point when l_i^{k+1} = 1. Algorithm 1, line 4. """ agent.aux_vars["tau"] = agent.x # Compute and cache the full gradient at the new snapshot point full_grad_tau = agent.cost.gradient(agent.aux_vars["tau"], indices="all") agent.aux_vars["full_grad_tau"] = full_grad_tau def _update_gradient_estimator(self, agent: Agent) -> None: """ Update local stochastic gradient estimator with variance reduction. Algorithm 1, lines 5-6: This implements the variance reduction technique (Equation 3) Raises: TypeError: If the agent's cost is not an instance of EmpiricalRiskCost. """ if TYPE_CHECKING: if not isinstance(agent.cost, EmpiricalRiskCost): raise TypeError("GT-VR is only compatible with EmpiricalRiskCost.") # Store old v_i for gradient tracking update agent.aux_vars["v_old"] = agent.aux_vars["v"] # Select s_i^{k+1} uniformly at random (this is done by the cost function) # Compute stochastic gradient at current point: grad f_{i,s_i}(x_i^{k+1}) grad_current = agent.cost.gradient(agent.x) batch_indices = agent.cost.batch_used # Compute stochastic gradient at snapshot point: grad f_{i,s_i}(tau_i^{k+1}) grad_snapshot = agent.cost.gradient(agent.aux_vars["tau"], indices=batch_indices) # Use cached full gradient at snapshot point: grad f_i(tau_i^{k+1}) full_grad_snapshot = agent.aux_vars["full_grad_tau"] # Update variance-reduced gradient estimator (Equation 3) # v_i^{k+1} = grad f_{i,s_i}(x_i) - grad f_{i,s_i}(tau_i) + grad f_i(tau_i) agent.aux_vars["v"] = grad_current - grad_snapshot + full_grad_snapshot def _update_gradient_tracker(self, agent: Agent) -> None: """ Update local gradient tracker. Algorithm 1, line 7: Note: We receive y_r + v_r - v_r_old directly to reduce communication. """ weighted_sum = self.W[agent, agent] * agent.aux_vars["y_plus_delta_v"] for j, y_plus_delta_v in agent.messages(_GRADIENT_TRACKER_CHANNEL).items(): weighted_sum += self.W[agent, j] * y_plus_delta_v agent.aux_vars["y"] = weighted_sum