PyTorch Optimizer Integration#
Following is a short example of how to use a PyTorch optimizer for local training in the LT-ADMM algorithm. This example defines a new algorithm class LT_ADMM_TORCH that inherits from LTADMM and overrides the local training step to use a PyTorch optimizer. The initialize method sets up the PyTorch optimizer for each agent, and the _local_training method performs local training using the optimizer.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import decent_bench.utils.interoperability as iop
from decent_bench.agents import Agent
from decent_bench.algorithms.p2p import LTADMM
from decent_bench.costs import PyTorchCost
from decent_bench.networks import P2PNetwork
if TYPE_CHECKING:
import torch
try:
import torch
except ImportError as e:
raise ImportError(
"PyTorch is required for LT-ADMM-EMA algorithm, but it is not installed. "
"Please install PyTorch to use this algorithm."
) from e
@dataclass(eq=False)
class LT_ADMM_TORCH(LTADMM):
opt_cls: type[torch.optim.Optimizer] | None = None # PyTorch optimizer class to use for local training
opt_kwargs: dict[str, Any] | None = None # Keyword arguments for PyTorch optimizer
sched_cls: type[torch.optim.lr_scheduler.LRScheduler] | None = None # PyTorch scheduler class for local training
sched_kwargs: dict[str, Any] | None = None # Keyword arguments for PyTorch scheduler
name: str = "LT-ADMM-TORCH"
def initialize(self, network: P2PNetwork) -> None:
super().initialize(network)
for i in network.agents():
if not isinstance(i.cost, PyTorchCost):
raise TypeError(f"LT-ADMM-TORCH requires PyTorchCost, but agent {i} has cost of type {type(i.cost)}")
# Initialize PyTorch optimizer for local training if use_torch_optim is True
if self.opt_cls is not None:
if self.opt_kwargs is None:
self.opt_kwargs = {}
self.opt_kwargs.setdefault("lr", self.step_size)
i.cost.init_local_training(
opt_cls=self.opt_cls,
opt_kwargs=self.opt_kwargs,
sched_cls=self.sched_cls,
sched_kwargs=self.sched_kwargs,
)
def _local_training(self, agent: Agent, network: P2PNetwork) -> None:
if TYPE_CHECKING:
if not isinstance(agent.cost, PyTorchCost):
raise TypeError(
f"LT-ADMM-TORCH requires PyTorchCost, but agent {agent} has cost of type {type(agent.cost)}"
)
agent.aux_vars["phi"] = iop.copy(agent.x)
z_sum = iop.sum(agent.aux_vars["z_i"], dim=0)
multiplier = self.penalty * len(network.neighbors(agent))
correction = self.aux_step_size * (multiplier * agent.x - z_sum)
if self.opt_cls is not None:
agent.aux_vars["phi"] = agent.cost.local_training(
x=agent.aux_vars["phi"],
iterations=self.num_local_steps,
regularization=correction,
agent=agent,
)
else:
for _ in range(self.num_local_steps):
current_gradient = agent.cost.gradient(agent.aux_vars["phi"])
step = self.step_size * current_gradient + correction
# Update phi_i,k according to gradient step (line 7)
agent.aux_vars["phi"] -= step
# Update agent's main parameter (line 10)
agent.x = agent.aux_vars["phi"]