Source code for decent_bench.costs._base._sum_cost
from __future__ import annotations
from functools import cached_property
from typing import Any
import numpy as np
import decent_bench.utils.interoperability as iop
from decent_bench import centralized_algorithms as ca
from decent_bench.costs._base._cost import Cost
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks
[docs]
class SumCost(Cost):
"""
Generic additive fallback for cost composition.
``SumCost`` is returned when two costs can be added but no more specialized composite is available. It preserves
the core :class:`~decent_bench.costs.Cost` interface, but does not preserve regularizer-specific or
empirical-risk-specific behavior.
Instances keep references to the wrapped cost objects. No implicit copying is performed; use
:func:`copy.deepcopy` explicitly if independent objects are required.
"""
[docs]
def __init__(self, costs: list[Cost]):
if len(costs) == 0:
raise ValueError("SumCost must contain at least one cost function.")
self.costs: list[Cost] = []
for cf in costs:
if isinstance(cf, SumCost):
self.costs.extend(cf.costs)
else:
self.costs.append(cf)
first = self.costs[0]
for cf in self.costs[1:]:
first._validate_cost_operation(cf) # noqa: SLF001
@property
def shape(self) -> tuple[int, ...]:
return self.costs[0].shape
@property
def framework(self) -> SupportedFrameworks:
return self.costs[0].framework
@property
def device(self) -> SupportedDevices:
return self.costs[0].device
[docs]
@cached_property
def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride]
r"""
The cost function's smoothness constant.
.. math::
\sum m_{\text{smooth}, k}
where :math:`m_{\text{smooth}, k}` is the smoothness constant of each individual cost function :math:`f_k`.
If any :math:`m_{\text{smooth}, k} = \text{NaN}`, the result is :math:`\text{NaN}`.
For the general definition, see
:attr:`Cost.m_smooth <decent_bench.costs.Cost.m_smooth>`.
"""
m_smooth_vals = [cf.m_smooth for cf in self.costs]
return np.nan if any(np.isnan(v) for v in m_smooth_vals) else sum(m_smooth_vals)
[docs]
@cached_property
def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride]
r"""
The cost function's convexity constant.
.. math::
\sum m_{\text{cvx}, k}
where :math:`m_{\text{cvx}, k}` is the convexity constant of each individual cost function :math:`f_k`.
If any :math:`m_{\text{cvx}, k} = \text{NaN}`, the result is :math:`\text{NaN}`.
For the general definition, see
:attr:`Cost.m_cvx <decent_bench.costs.Cost.m_cvx>`.
"""
m_cvx_vals = [cf.m_cvx for cf in self.costs]
return np.nan if any(np.isnan(v) for v in m_cvx_vals) else sum(m_cvx_vals)
[docs]
def function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401
"""Sum the :meth:`Cost.function <decent_bench.costs.Cost.function>` of each cost function."""
return sum(cf.function(x, *args, **kwargs) for cf in self.costs)
[docs]
def gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
"""Sum the :meth:`Cost.gradient <decent_bench.costs.Cost.gradient>` of each cost function."""
return iop.sum(iop.stack([cf.gradient(x, *args, **kwargs) for cf in self.costs]), dim=0)
[docs]
def hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
"""Sum the :meth:`Cost.hessian <decent_bench.costs.Cost.hessian>` of each cost function."""
return iop.sum(iop.stack([cf.hessian(x, *args, **kwargs) for cf in self.costs]), dim=0)
[docs]
def proximal(self, x: Array, penalty: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
"""
Approximate the proximal of the full summed objective.
``SumCost`` computes its proximal through
:func:`decent_bench.centralized_algorithms.proximal_solver`, which solves the proximal subproblem for the full
summed objective using accelerated gradient descent. Extra ``args`` and ``kwargs`` are ignored.
Raises:
NotImplementedError: If the accelerated-gradient backend assumptions are not satisfied.
"""
del args, kwargs
try:
return ca.proximal_solver(self, x, penalty)
except NotImplementedError as exc:
raise NotImplementedError(
"SumCost.proximal uses centralized_algorithms.proximal_solver and requires the summed objective to be "
"differentiable, globally L-smooth, and convex under the accelerated-gradient backend."
) from exc
[docs]
def __add__(self, other: Cost) -> SumCost:
"""Add another cost function."""
return SumCost([self, other])