Source code for decent_bench.centralized_algorithms

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING, final

import numpy as np
from rich.progress import track

import decent_bench.utils.interoperability as iop
from decent_bench.utils import logger
from decent_bench.utils.array import Array
from decent_bench.utils.logger import LOGGER

if TYPE_CHECKING:
    from decent_bench.costs import Cost


[docs] def solve( cost: "Cost", max_iter: int = 100, stop_tol: float | None = None, max_tol: float | None = None, show_progress: bool = True, ) -> Array: """ Minimize a cost function using a suitable solver. Applies :func:`~numpy.linalg.solve` to quadratic costs, accelerated gradient descent to smooth and (strongly) convex costs, (sub)gradient descent to any other cost. Args: cost: cost function to minimize. max_iter: maximum number of iterations to run. Defaults to 100. stop_tol: optional early stopping tolerance; stops if ``||x_new - x_old||^2`` drops below this value. max_tol: optional final tolerance; RuntimeError is raised if ``||x_new - x_old||^2`` exceeds this value after max_iter iterations. show_progress: whether to display a progress bar during iterative solves. Defaults to ``True``. Returns: Approximate minimizer or stationary point. Raises: ValueError: when the cost has m_smooth = 0. """ if not LOGGER.handlers: logger.start_logger() LOGGER.info("Finding the optimal solution to the problem ...") stop_criteria = f"Stopping after {max_iter} iterations" if stop_tol is not None: stop_criteria += f" or when ||x_new - x_old||^2 <= {stop_tol}" stop_criteria += "." if max_tol is not None: stop_criteria += f" Will raise if ||x_new - x_old||^2 > {max_tol} at the end." # quadratic from decent_bench.costs import LinearRegressionCost, QuadraticCost, SumCost # noqa: PLC0415 if isinstance(cost, QuadraticCost): x_optimal = Array(np.linalg.solve(cost.A, -cost.b)) # linear regression elif isinstance(cost, SumCost) and all(isinstance(c, LinearRegressionCost) for c in cost.costs): z = iop.zeros(framework=cost.costs[0].framework, device=cost.costs[0].device, shape=cost.costs[0].shape) Q = np.asarray(sum(c.hessian(z, indices="all") for c in cost.costs)) # noqa: N806 r = np.asarray(sum(c.gradient(z, indices="all") for c in cost.costs)) try: x_optimal_np = np.linalg.solve(Q, -r) except np.linalg.LinAlgError: x_optimal_np = np.linalg.lstsq(Q, -r, rcond=None)[0] x_optimal = Array(x_optimal_np) # exclude costs with m_smooth = 0 elif np.isfinite(cost.m_smooth) and np.isfinite(cost.m_cvx) and cost.m_smooth == 0: raise ValueError("Costs with m_smooth = 0 are not supported.") # smooth and convex/strongly convex elif np.isfinite(cost.m_smooth) and np.isfinite(cost.m_cvx) and cost.m_smooth > 0: LOGGER.info(f"{stop_criteria}") x_optimal = AcceleratedGradientDescent(cost).run( max_iter=max_iter, stop_tol=stop_tol, max_tol=max_tol, show_progress=show_progress, ) # non-smooth or non-convex else: LOGGER.info(f"{stop_criteria}") x_optimal = GradientDescent(cost).run( max_iter=max_iter, stop_tol=stop_tol, max_tol=max_tol, show_progress=show_progress, ) LOGGER.info("... done!") return x_optimal
[docs] class Solver(ABC): """ Base class for centralized solvers. Initializes iterate (x) and previous iterate (x_old), validates domain shape, and stores hyperparameters. Subclasses must implement the step method to define one iteration of their algorithm. """ def __init__(self, cost: "Cost", x0: Array | None = None): if x0 is None: x0 = iop.zeros(shape=cost.shape, framework=cost.framework, device=cost.device) if iop.shape(x0) != cost.shape: raise ValueError("x0 and cost function domain must have same shape") self.x = x0 self.x_old = iop.copy(self.x) self.cost = cost
[docs] @abstractmethod def step(self, iteration: int) -> None: """ Perform one iteration of the solver. Subclasses must update self.x exactly once per step. Use the iteration counter for algorithms with iteration-dependent parameters (e.g., step schedules). Args: iteration: current iteration number. """
[docs] @final def run( self, max_iter: int = 100, stop_tol: float | None = None, max_tol: float | None = None, check_frequency: float = 0.01, show_progress: bool = True, ) -> Array: """ Run the solver. Executes :meth:`step` for up to max_iter iterations. Stops early if the squared norm of the iterate change drops below stop_tol. After completion, verifies that the final iterate change is at most max_tol. Args: max_iter: maximum number of iterations; must be positive. Defaults to 100. stop_tol: optional early stopping tolerance; stops if ``||x_new - x_old||^2 <= stop_tol``. Must be positive if provided. max_tol: optional final tolerance; raises RuntimeError if ``||x_new - x_old||^2 > max_tol`` after max_iter iterations. Must be positive if provided. check_frequency: float in (0, 1] defining how often the early stopping condition should be checked. A smaller value means that the stopping condition is checked more often. This applies only if ``stop_tol`` is not None. show_progress: whether to display a progress bar during iteration. Defaults to True. Returns: Final iterate x as an Array. Raises: ValueError: if max_iter < 1, or if stop_tol or max_tol are provided and non-positive. RuntimeError: if max_tol is provided and the final iterate change exceeds max_tol. Warning: Do not override this method. Instead, override :meth:`step` to define solver behavior. """ if max_iter < 1: raise ValueError("`max_iter` must be positive") if stop_tol is not None and stop_tol <= 0: raise ValueError("`stop_tol` must be positive or None") if max_tol is not None and max_tol <= 0: raise ValueError("`max_tol` must be positive or None") if check_frequency <= 0 or check_frequency > 1: raise ValueError("`check_frequency` must be a float in (0, 1]") check_every = max(1, int(check_frequency * max_iter)) for k in track( range(max_iter), description="Solving...", disable=not show_progress, update_period=0.0, ): self.x_old = iop.copy(self.x) self.step(k) if stop_tol is not None and k % check_every == 0: d = self.x - self.x_old delta = float(iop.transpose(d) @ d) if delta <= stop_tol: break if max_tol is not None: if stop_tol is None or k % check_every != 0: d = self.x - self.x_old delta = float(iop.transpose(d) @ d) if delta > max_tol: raise RuntimeError( f"Solver failed to converge within {max_iter} iterations: delta {delta} > max delta {max_tol}." ) return self.x
[docs] class GradientDescent(Solver): r""" Gradient descent solver. If step_size is not provided, defaults to: - Non-smooth or non-convex: :math:`1/\sqrt{k+1}` - Strongly convex: :math:`2/(L+mu)` - Convex: step_size = 1/m_smooth """ def __init__(self, cost: "Cost", step_size: float | Callable[[int], float] | None = None, x0: Array | None = None): if callable(step_size): step_size_k: Callable[[int], float] = step_size elif isinstance(step_size, float): step_size_k = lambda _: float(step_size) # noqa: E731 elif np.isnan(cost.m_smooth) or np.isinf(cost.m_smooth) or np.isnan(cost.m_cvx): # non-smooth or non-convex step_size_k = lambda k: float(1 / np.sqrt(k + 1)) # noqa: E731 elif cost.m_cvx > 0: # strongly convex step_size_k = lambda _: 2 / (cost.m_smooth + cost.m_cvx) # noqa: E731 else: # convex step_size_k = lambda _: 1 / cost.m_smooth # noqa: E731 super().__init__(cost, x0) self.step_size = step_size_k
[docs] def step(self, iteration: int) -> None: """Perform one iteration of the solver.""" self.x -= self.step_size(iteration) * self.cost.gradient(self.x)
[docs] class AcceleratedGradientDescent(Solver): r""" Accelerated gradient descent (Nesterov momentum) solver. If step_size is not provided, defaults to: :math:`1/L`. If momentum is not provided, defaults to: - Strongly convex: :math:`(\sqrt(L)-\sqrt(mu)) / (\sqrt(L)+\sqrt(mu))` - Otherwise: :math:`k / (k+3)` """ def __init__( self, cost: "Cost", step_size: float | None = None, momentum: float | Callable[[int], float] | None = None, x0: Array | None = None, ): step_size = float(step_size) if isinstance(step_size, float) else 1 / cost.m_smooth if callable(momentum): momentum_k: Callable[[int], float] = momentum elif isinstance(momentum, float): momentum_k = lambda _: float(momentum) # noqa: E731 elif cost.m_cvx > 0: # strongly convex momentum_k = lambda _: float( # noqa: E731 (np.sqrt(cost.m_smooth) - np.sqrt(cost.m_cvx)) / (np.sqrt(cost.m_smooth) + np.sqrt(cost.m_cvx)) ) else: momentum_k = lambda k: k / (k + 3) # noqa: E731 super().__init__(cost, x0) self.step_size = step_size self.momentum = momentum_k self.y = iop.copy(self.x)
[docs] def step(self, iteration: int) -> None: """Perform one iteration of the solver.""" self.x = self.y - self.step_size * self.cost.gradient(self.y) self.y = self.x + self.momentum(iteration) * (self.x - self.x_old)
[docs] def proximal_solver(cost: "Cost", y: Array, penalty: float, max_iter: int = 100) -> Array: """ Approximate the cost's proximal at y using accelerated gradient descent. This is an approximate solution to the proximal operator defined as: .. include:: snippets/proximal_operator.rst The cost must be differentiable, L-smooth, and convex. Args: cost: cost function to compute the proximal of. y: point at which to evaluate the proximal. penalty: penalty parameter. max_iter: maximum number of iterations of the solver. Returns: Approximate proximal at `y`. Raises: ValueError: if cost's domain and `y` do not have the same shape, or if `penalty` is not positive. NotImplementedError: if the cost is not differentiable, L-smooth, and convex. """ if cost.shape != iop.shape(y): raise ValueError("Cost function domain and y need to have the same shape") if penalty <= 0: raise ValueError("Penalty term `penalty` must be greater than 0") from decent_bench.costs import QuadraticCost # noqa: PLC0415 proximal_cost = QuadraticCost(A=iop.eye_like(y) / penalty, b=-y / penalty) + cost if proximal_cost.m_smooth == np.inf or np.isnan(proximal_cost.m_smooth) or np.isnan(proximal_cost.m_cvx): raise NotImplementedError("Proximal solver requires the cost to be differentiable, L-smooth, and convex.") return AcceleratedGradientDescent(proximal_cost, x0=y).run( max_iter=max_iter, stop_tol=1e-10, max_tol=None, show_progress=False )