Source code for decent_bench.costs._base._cost

from __future__ import annotations

from abc import ABC, abstractmethod
from functools import cached_property
from math import prod
from numbers import Real
from typing import Any

from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks


[docs] class Cost(ABC): """Used by agents to evaluate the cost and its derivatives at a certain x.""" def _validate_cost_operation( self, other: object, *, check_framework: bool = True, check_device: bool = True, ) -> None: """ Validate that another object can participate in a binary cost operation. Raises: TypeError: If other is not a Cost. ValueError: If shapes, frameworks, or devices are mismatched. """ if not isinstance(other, Cost): raise TypeError(f"Cost can only be combined with another Cost, got {type(other)}.") if self.shape != other.shape: raise ValueError(f"Mismatched domain shapes: {self.shape} vs {other.shape}") if check_framework and self.framework != other.framework: raise ValueError(f"Mismatching frameworks: {self.framework} vs {other.framework}") if check_device and self.device != other.device: raise ValueError(f"Mismatching devices: {self.device} vs {other.device}") @property @abstractmethod def shape(self) -> tuple[int, ...]: """Required shape of x.""" @property def domain_shape(self) -> tuple[int, ...]: """Alias for :attr:`shape`.""" return self.shape
[docs] @cached_property def size(self) -> int: """Number of elements in x.""" return prod(self.shape)
@property @abstractmethod def framework(self) -> SupportedFrameworks: """ The framework used by this cost function. Make sure that all :class:`decent_bench.utils.array.Array` objects returned by this cost function's methods use this framework. """ @property @abstractmethod def device(self) -> SupportedDevices: """ The device used by this cost function. Make sure that all :class:`decent_bench.utils.array.Array` objects returned by this cost function's methods use this device. """ @property @abstractmethod def m_smooth(self) -> float: r""" Lipschitz constant of the cost function's gradient. The gradient's Lipschitz constant m_smooth is the smallest value such that .. math:: \| \nabla f(\mathbf{x_1}) - \nabla f(\mathbf{x_2}) \| \leq m_{\text{smooth}} \cdot \|\mathbf{x_1} - \mathbf{x_2}\| for all :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`. Returns: - non-negative finite number if function is L-smooth - ``np.inf`` if function is differentiable everywhere but not L-smooth - ``np.nan`` if function is not differentiable everywhere """ @property @abstractmethod def m_cvx(self) -> float: r""" Convexity constant of the cost function. The convexity constant m_cvx is the largest value such that .. math:: f(\mathbf{x_1}) \geq f(\mathbf{x_2}) + \nabla f(\mathbf{x_2})^T (\mathbf{x_1} - \mathbf{x_2}) + \frac{m_{\text{cvx}}}{2} \|\mathbf{x_1} - \mathbf{x_2}\|^2 for all :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`. Returns: - positive finite number if function is strongly convex - ``0`` if function is convex but not strongly convex - ``np.nan`` if function is not guaranteed to be convex """
[docs] @abstractmethod def function(self, x: Array, **kwargs: Any) -> float: # noqa: ANN401 """Evaluate function at x."""
[docs] def evaluate(self, x: Array, **kwargs: Any) -> float: # noqa: ANN401 """Alias for :meth:`function`.""" return self.function(x, **kwargs)
[docs] def loss(self, x: Array, **kwargs: Any) -> float: # noqa: ANN401 """Alias for :meth:`function`.""" return self.function(x, **kwargs)
[docs] def f(self, x: Array, **kwargs: Any) -> float: # noqa: ANN401 """Alias for :meth:`function`.""" return self.function(x, **kwargs)
[docs] @abstractmethod def gradient(self, x: Array, **kwargs: Any) -> Array: # noqa: ANN401 """Gradient at x."""
[docs] @abstractmethod def hessian(self, x: Array, **kwargs: Any) -> Array: # noqa: ANN401 """Hessian at x."""
[docs] @abstractmethod def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ANN401 r""" Proximal at x. The proximal operator is defined as: .. include:: snippets/proximal_operator.rst If the cost function's proximal does not have a closed form solution, it can be solved iteratively using :meth:`~decent_bench.centralized_algorithms.proximal_solver`. """
[docs] def __add__(self, other: Cost) -> Cost: """ Add another cost function to create a new one. The generic fallback returns ``SumCost([self, other])``. Subclasses can override this to preserve specialized structure when the result remains in the same abstraction. For example, the addition of two :class:`~decent_bench.costs.QuadraticCost` objects benefits from returning a new :class:`~decent_bench.costs.QuadraticCost` instead of a :class:`~decent_bench.costs.SumCost` as this preserves the closed form proximal solution and only requires one evaluation instead of two when calling :meth:`function`, :meth:`gradient`, and :meth:`hessian`. """ self._validate_cost_operation(other) from decent_bench.costs._base._sum_cost import SumCost # noqa: PLC0415 return SumCost([self, other])
def __mul__(self, other: float) -> Cost: """ Multiply by a scalar to create a weighted cost. Raises: TypeError: If other is not a real scalar. """ if not self._is_valid_scalar(other): raise TypeError(f"Cost can only be multiplied by a real number, got {type(other)}.") from decent_bench.costs._base._scaled_cost import ScaledCost # noqa: PLC0415 return ScaledCost(self, float(other)) def __rmul__(self, other: float) -> Cost: """Right-side scalar multiplication.""" return self.__mul__(other) def __truediv__(self, other: float) -> Cost: """ Divide by a scalar. Raises: TypeError: If other is not a real scalar. ZeroDivisionError: If other is zero. """ if not self._is_valid_scalar(other): raise TypeError(f"Cost can only be divided by a real number, got {type(other)}.") if other == 0: raise ZeroDivisionError("Division by zero is not allowed for Cost objects.") return self.__mul__(1.0 / float(other)) def __rtruediv__(self, other: float) -> Cost: """ Right-side scalar division is not supported. Raises: TypeError: Always, since scalar / cost is not supported. """ raise TypeError("Right-side division is not supported for Cost objects.") def __neg__(self) -> Cost: """Negate this cost function.""" return self.__mul__(-1.0) def __sub__(self, other: Cost) -> Cost: """Subtract another cost function as sum with its negation.""" self._validate_cost_operation(other) return self + (-other) def __radd__(self, other: object) -> Cost: """ Right-side addition, used to make sum(costs) work. Raises: TypeError: If other is neither 0 nor a Cost. """ if other == 0: return self if isinstance(other, Cost): return other + self raise TypeError(f"Cost can only be added to another Cost, got {type(other)}.") @staticmethod def _is_valid_scalar(value: object) -> bool: """Return True if value is a real scalar and not bool.""" return isinstance(value, Real) and not isinstance(value, bool)