Source code for decent_bench.costs._base._scaled_cost
from __future__ import annotations
from functools import cached_property
from typing import Any
import numpy as np
from decent_bench.costs._base._cost import Cost
from decent_bench.costs._base._sum_cost import SumCost
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks
[docs]
class ScaledCost(Cost):
"""
Generic scalar wrapper for arbitrary costs.
``ScaledCost`` is the fallback result of scalar arithmetic when no more specialized wrapper is available. It
delegates evaluation, gradient, Hessian, and metadata to the wrapped cost, and preserves proximal support only
for nonnegative scalars.
Instances keep references to the wrapped cost objects. No implicit copying is performed; use
:func:`copy.deepcopy` explicitly if independent objects are required.
"""
[docs]
def __init__(self, cost: Cost, scalar: float):
self.cost: Cost
self.scalar: float
if isinstance(cost, ScaledCost):
self.cost = cost.cost
self.scalar = scalar * cost.scalar
else:
self.cost = cost
self.scalar = scalar
@property
def shape(self) -> tuple[int, ...]:
return self.cost.shape
@property
def framework(self) -> SupportedFrameworks:
return self.cost.framework
@property
def device(self) -> SupportedDevices:
return self.cost.device
[docs]
@cached_property
def m_smooth(self) -> float:
if self.scalar == 0:
return 0.0
return float(abs(self.scalar) * self.cost.m_smooth)
[docs]
@cached_property
def m_cvx(self) -> float:
if self.scalar > 0:
return float(self.scalar * self.cost.m_cvx)
if self.scalar == 0:
return 0.0
return np.nan
[docs]
def function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401
return float(self.scalar * self.cost.function(x, *args, **kwargs))
[docs]
def gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
return self.cost.gradient(x, *args, **kwargs) * self.scalar
[docs]
def hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
return self.cost.hessian(x, *args, **kwargs) * self.scalar
[docs]
def proximal(self, x: Array, penalty: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
if penalty <= 0:
raise ValueError("The penalty parameter penalty must be positive.")
if self.scalar < 0:
raise ValueError("The proximal operator is not defined for negative scaling.")
if self.scalar == 0:
return x
return self.cost.proximal(x, penalty * self.scalar, *args, **kwargs)
[docs]
def __add__(self, other: Cost) -> Cost:
self._validate_cost_operation(other)
return SumCost([self, other])