Source code for decent_bench.algorithms.p2p._atc_tracking
from dataclasses import dataclass
import decent_bench.utils.interoperability as iop
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates
from ._p2p_algorithm import P2PAlgorithm
_STATE_CHANNEL = "state"
_GRADIENT_TRACKER_CHANNEL = "gradient_tracker"
[docs]
@tags("peer-to-peer", "gradient-tracking")
@dataclass(eq=False)
class ATC_Tracking(P2PAlgorithm): # noqa: N801
r"""
ATC-Tracking :footcite:p:`Alg_ATCT_1, Alg_ATCT_2, Alg_ATCT_3` gradient tracking algorithm.
The algorithm is characterized by the updates below.
.. math::
\mathbf{x}_{i, k+1} = \sum_j \mathbf{W}_{ij} (\mathbf{x}_{j, k} - \rho \mathbf{y}_{j, k})
.. math::
\mathbf{y}_{i, k+1} = \sum_j \mathbf{W}_{ij} \mathbf{y}_{j, k}
+ \nabla f_i(\mathbf{x}_{i,k+1}) - \nabla f_i(\mathbf{x}_{i,k})
where
:math:`\mathbf{x}_{i, k}` is agent i's local optimization variable at iteration k,
:math:`\rho` is the step size,
:math:`f_i` is agent i's local cost function,
j is a neighbor of i or i itself,
and :math:`\mathbf{W}_{ij}` is the metropolis weight between agent i and j.
Aliases: :class:`SONATA`, :class:`NEXT`, :class:`ATCT`
.. footbibliography::
"""
iterations: int = 100
step_size: float = 0.001
x0: InitialStates = None
name: str = "ATC-Tracking"
def __post_init__(self) -> None:
"""
Validate hyperparameters.
Raises:
ValueError: if hyperparameters are invalid.
"""
if self.step_size <= 0:
raise ValueError("`step_size` must be positive")
def initialize(self, network: P2PNetwork) -> None:
self.x0 = initial_states(self.x0, network)
for i in network.agents():
y0 = i.cost.gradient(self.x0[i])
z = iop.zeros_like(self.x0[i])
i.initialize(
x=self.x0[i],
aux_vars={"y": y0, "g": y0, "g_new": z, "s": z},
)
self.W = network.weights
def step(self, network: P2PNetwork, _: int) -> None:
# 1st communication round
# step 1: perform local gradient step and communicate
for i in network.active_agents():
i.aux_vars["s"] = i.x - self.step_size * i.aux_vars["y"]
for i in network.active_agents():
network.broadcast(i, i.aux_vars["s"], channel=_STATE_CHANNEL)
# step 2: update state and compute new local gradient
for i in network.active_agents():
neighborhood_avg = self.W[i, i] * i.aux_vars["s"]
for j, s_j in i.messages(_STATE_CHANNEL).items():
neighborhood_avg += self.W[i, j] * s_j
i.x = neighborhood_avg
i.aux_vars["g_new"] = i.cost.gradient(i.x)
# 2nd communication round
# step 1: transmit local gradient tracker
for i in network.active_agents():
network.broadcast(i, i.aux_vars["y"], channel=_GRADIENT_TRACKER_CHANNEL)
# step 2: update y (global gradient estimator)
for i in network.active_agents():
neighborhood_avg = self.W[i, i] * i.aux_vars["y"]
for j, q_j in i.messages(_GRADIENT_TRACKER_CHANNEL).items():
neighborhood_avg += self.W[i, j] * q_j
i.aux_vars["y"] = neighborhood_avg + i.aux_vars["g_new"] - i.aux_vars["g"]
i.aux_vars["g"] = i.aux_vars["g_new"]
SONATA = ATC_Tracking # alias
NEXT = ATC_Tracking # alias
ATCT = ATC_Tracking # alias