Source code for decent_bench.costs._base._sum_cost

from __future__ import annotations

from functools import cached_property
from typing import Any

import numpy as np

import decent_bench.utils.interoperability as iop
from decent_bench import centralized_algorithms as ca
from decent_bench.costs._base._cost import Cost
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks


[docs] class SumCost(Cost): """ Generic additive fallback for cost composition. ``SumCost`` is returned when two costs can be added but no more specialized composite is available. It preserves the core :class:`~decent_bench.costs.Cost` interface, but does not preserve regularizer-specific or empirical-risk-specific behavior. Instances keep references to the wrapped cost objects. No implicit copying is performed; use :func:`copy.deepcopy` explicitly if independent objects are required. """
[docs] def __init__(self, costs: list[Cost]): if len(costs) == 0: raise ValueError("SumCost must contain at least one cost function.") self.costs: list[Cost] = [] for cf in costs: if isinstance(cf, SumCost): self.costs.extend(cf.costs) else: self.costs.append(cf) first = self.costs[0] for cf in self.costs[1:]: first._validate_cost_operation(cf) # noqa: SLF001
@property def shape(self) -> tuple[int, ...]: return self.costs[0].shape @property def framework(self) -> SupportedFrameworks: return self.costs[0].framework @property def device(self) -> SupportedDevices: return self.costs[0].device
[docs] @cached_property def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's smoothness constant. .. math:: \sum m_{\text{smooth}, k} where :math:`m_{\text{smooth}, k}` is the smoothness constant of each individual cost function :math:`f_k`. If any :math:`m_{\text{smooth}, k} = \text{NaN}`, the result is :math:`\text{NaN}`. For the general definition, see :attr:`Cost.m_smooth <decent_bench.costs.Cost.m_smooth>`. """ m_smooth_vals = [cf.m_smooth for cf in self.costs] return np.nan if any(np.isnan(v) for v in m_smooth_vals) else sum(m_smooth_vals)
[docs] @cached_property def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] r""" The cost function's convexity constant. .. math:: \sum m_{\text{cvx}, k} where :math:`m_{\text{cvx}, k}` is the convexity constant of each individual cost function :math:`f_k`. If any :math:`m_{\text{cvx}, k} = \text{NaN}`, the result is :math:`\text{NaN}`. For the general definition, see :attr:`Cost.m_cvx <decent_bench.costs.Cost.m_cvx>`. """ m_cvx_vals = [cf.m_cvx for cf in self.costs] return np.nan if any(np.isnan(v) for v in m_cvx_vals) else sum(m_cvx_vals)
[docs] def function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401 """Sum the :meth:`Cost.function <decent_bench.costs.Cost.function>` of each cost function.""" return sum(cf.function(x, *args, **kwargs) for cf in self.costs)
[docs] def gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 """Sum the :meth:`Cost.gradient <decent_bench.costs.Cost.gradient>` of each cost function.""" return iop.sum(iop.stack([cf.gradient(x, *args, **kwargs) for cf in self.costs]), dim=0)
[docs] def hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 """Sum the :meth:`Cost.hessian <decent_bench.costs.Cost.hessian>` of each cost function.""" return iop.sum(iop.stack([cf.hessian(x, *args, **kwargs) for cf in self.costs]), dim=0)
[docs] def proximal(self, x: Array, penalty: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 """ Approximate the proximal of the full summed objective. ``SumCost`` computes its proximal through :func:`decent_bench.centralized_algorithms.proximal_solver`, which solves the proximal subproblem for the full summed objective using accelerated gradient descent. Extra ``args`` and ``kwargs`` are ignored. Raises: NotImplementedError: If the accelerated-gradient backend assumptions are not satisfied. """ del args, kwargs try: return ca.proximal_solver(self, x, penalty) except NotImplementedError as exc: raise NotImplementedError( "SumCost.proximal uses centralized_algorithms.proximal_solver and requires the summed objective to be " "differentiable, globally L-smooth, and convex under the accelerated-gradient backend." ) from exc
[docs] def __add__(self, other: Cost) -> SumCost: """Add another cost function.""" return SumCost([self, other])