Source code for decent_bench.algorithms.p2p._lt_admm_vr
from dataclasses import dataclass
from typing import TYPE_CHECKING
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.utils import initial_states
from decent_bench.costs import EmpiricalRiskCost
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from ._lt_admm import LT_ADMM
[docs]
@tags("peer-to-peer", "gradient-based", "dual method", "ADMM", "variance-reduction")
@dataclass(eq=False)
class LT_ADMM_VR(LT_ADMM): # noqa: N801
"""
Local Training ADMM with Variance Reduction (LT-ADMM-VR) :footcite:p:`Alg_LT_ADMM_VR`.
Extends LT-ADMM with variance reduction techniques for improved convergence.
This variant implements additional gradient variance reduction mechanisms
during the local training phase.
Args:
iterations: Total number of communication rounds (K)
num_local_steps: Number of local training steps (tau)
step_size: Local step size (gamma)
aux_step_size: Local step size (beta)
penalty: Penalty parameter (rho)
alpha: Relaxation parameter (alpha)
x0: Initial parameters (optional)
v2: Whether to use the LT-ADMM-VR-2 variant with improved
variance reduction techniques which is less computational heavy (default True).
name: Algorithm name (default "LT-ADMM-VR")
Raises:
TypeError: If any agent's cost function is not an instance of EmpiricalRiskCost.
.. footbibliography::
"""
v2: bool = True # Whether to use the LT-ADMM-VR-2 variant
name: str = "LT-ADMM-VR"
def initialize(self, network: P2PNetwork) -> None:
self.x0 = initial_states(self.x0, network)
# Initialize agents with auxiliary variables
for i in network.agents():
if not isinstance(i.cost, EmpiricalRiskCost):
raise TypeError("LT-ADMM-VR is only compatible with EmpiricalRiskCost.")
neighbors = network.neighbors(i)
z_i = iop.zeros(
shape=(len(neighbors), *iop.shape(self.x0[i])),
framework=i.cost.framework,
device=i.cost.device,
)
neighbor_to_idx: dict[Agent, int] = {} # Mapping from neighbor to index in z_i array
for idx, j in enumerate(neighbors):
z_i[idx] = iop.copy(self.x0[i])
neighbor_to_idx[j] = idx
r_grads = i.cost.gradient(self.x0[i], indices="all", reduction=None) if self.v2 else None
# Initialize auxiliary variables for LT-ADMM
aux_vars = {
"phi": self.x0[i], # phi_i,k - model parameters
"r_grads": r_grads, # shape (m_i, dim) - nabla f_{i,h}(r_{i,h,k})
"z_i": z_i, # z_ij,k+1 - auxiliary consensus variable
"neighbor_to_idx": neighbor_to_idx,
}
i.initialize(x=self.x0[i], aux_vars=aux_vars)
def _local_training(self, agent: Agent, network: P2PNetwork) -> None:
"""
Enhanced local training with variance reduction.
Raises:
TypeError: If the agent's cost is not an instance of EmpiricalRiskCost, as LT-ADMM-VR is only compatible
with EmpiricalRiskCost.
"""
if TYPE_CHECKING:
if not isinstance(agent.cost, EmpiricalRiskCost):
raise TypeError("LT-ADMM-VR is only compatible with EmpiricalRiskCost.")
agent.aux_vars["phi"] = iop.copy(agent.x)
z_sum = iop.sum(agent.aux_vars["z_i"], dim=0)
# Always use the number of neighbors for the penalty term to ensure proper scaling
multiplier = self.penalty * len(network.neighbors(agent))
correction = self.aux_step_size * (multiplier * agent.x - z_sum)
if not self.v2:
r_grads = agent.cost.gradient(agent.x, indices="all", reduction=None)
agent.aux_vars["r_grads"] = r_grads
for _ in range(self.num_local_steps):
batch_grad = agent.cost.gradient(agent.aux_vars["phi"])
batch_used = agent.cost.batch_used
r_grads = iop.mean(agent.aux_vars["r_grads"][batch_used], dim=0)
current_gradient = (batch_grad - r_grads) + iop.mean(agent.aux_vars["r_grads"], dim=0)
step = self.step_size * current_gradient + correction
agent.aux_vars["phi"] -= step
r_grads = agent.cost.gradient(agent.aux_vars["phi"], indices=batch_used, reduction=None)
agent.aux_vars["r_grads"][batch_used] = r_grads
# Update agent's main parameter (line 10)
agent.x = agent.aux_vars["phi"]