Source code for decent_bench.algorithms.p2p._admm
from dataclasses import dataclass
import decent_bench.utils.interoperability as iop
from decent_bench.algorithms.utils import initial_states
from decent_bench.networks import P2PNetwork
from decent_bench.utils._tags import tags
from decent_bench.utils.types import InitialStates
from ._p2p_algorithm import P2PAlgorithm
[docs]
@tags("peer-to-peer", "dual method", "ADMM")
@dataclass(eq=False)
class ADMM(P2PAlgorithm):
r"""
Distributed Alternating Direction Method of Multipliers characterized by the update step below.
.. math::
\mathbf{x}_{i, k+1} = \operatorname{prox}_{\frac{1}{\rho N_i} f_i}
\left(\sum_j \mathbf{z}_{ij, k} \frac{1}{\rho N_i} \right)
.. math::
\mathbf{z}_{ij, k+1} = (1-\alpha) \mathbf{z}_{ij, k} - \alpha (\mathbf{z}_{ji, k} - 2 \rho \mathbf{x}_{j, k+1})
where
:math:`\mathbf{x}_{i, k}` is agent i's local optimization variable at iteration k,
:math:`\operatorname{prox}` is the proximal operator described in :meth:`Cost.proximal()
<decent_bench.costs.Cost.proximal>`,
:math:`\rho > 0` is the Lagrangian penalty parameter (the corresponding argument is ``penalty``),
:math:`N_i` is the number of neighbors of i,
:math:`f_i` is i's local cost function,
j is a neighbor of i,
and :math:`\alpha \in (0, 1)` is the relaxation parameter (the corresponding argument is ``relaxation``).
Note:
``x0`` and ``z0`` follow the :obj:`~decent_bench.utils.types.InitialStates` convention and are resolved
per agent during ``initialize`` via
:func:`~decent_bench.algorithms.utils.initial_states`.
If ``x0`` is ``None`` and ``z0`` is provided, each agent initializes ``x0`` from ``z0`` with one proximal
update:
.. math::
x_{i,0} = \operatorname{prox}_{\frac{1}{\rho N_i} f_i}\left(\frac{z_{i,0}}{\rho}\right)
The :math:`\mathbf{z}_{ij}` variables of an agent are all initialized to
the same value specified in ``z0`` (if any).
"""
iterations: int = 100
penalty: float = 1
relaxation: float = 0.5
x0: InitialStates = None
z0: InitialStates = None
name: str = "ADMM"
def __post_init__(self) -> None:
"""
Validate hyperparameters.
Raises:
ValueError: if hyperparameters are invalid.
"""
if self.penalty <= 0:
raise ValueError("`penalty` must be positive")
if not (0 < self.relaxation < 1):
raise ValueError("`relaxation` must be in (0, 1)")
def initialize(self, network: P2PNetwork) -> None:
self.rho_i = {i: 1 / (self.penalty * len(network.neighbors(i))) for i in network.agents()}
x_from_z = self.x0 is None and self.z0 is not None # if x0 needs to be created from z0
self.x0 = initial_states(self.x0, network)
self.z0 = initial_states(self.z0, network)
for i in network.agents():
if x_from_z:
self.x0[i] = i.cost.proximal(self.z0[i] / self.penalty, penalty=self.rho_i[i])
z0 = iop.stack([self.z0[i] for _ in network.neighbors(i)])
neighbor_to_idx = {j: idx for idx, j in enumerate(network.neighbors(i))}
i.initialize(x=self.x0[i], aux_vars={"z": z0, "neighbor_to_idx": neighbor_to_idx})
def step(self, network: P2PNetwork, _: int) -> None:
for i in network.active_agents():
i.x = i.cost.proximal(iop.sum(i.aux_vars["z"], dim=0) * self.rho_i[i], penalty=self.rho_i[i])
for i in network.active_agents():
for j in network.active_neighbors(i):
idx = i.aux_vars["neighbor_to_idx"][j]
network.send(i, j, i.aux_vars["z"][idx] - 2 * self.penalty * i.x)
for i in network.active_agents():
for j, msg in i.messages().items():
idx = i.aux_vars["neighbor_to_idx"][j]
i.aux_vars["z"][idx] = (1 - self.relaxation) * i.aux_vars["z"][idx] - self.relaxation * (msg)