Source code for decent_bench.costs._base._scaled_cost

from __future__ import annotations

from functools import cached_property
from typing import Any

import numpy as np

from decent_bench.costs._base._cost import Cost
from decent_bench.costs._base._sum_cost import SumCost
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks


[docs] class ScaledCost(Cost): """ Generic scalar wrapper for arbitrary costs. ``ScaledCost`` is the fallback result of scalar arithmetic when no more specialized wrapper is available. It delegates evaluation, gradient, Hessian, and metadata to the wrapped cost, and preserves proximal support only for nonnegative scalars. Instances keep references to the wrapped cost objects. No implicit copying is performed; use :func:`copy.deepcopy` explicitly if independent objects are required. """
[docs] def __init__(self, cost: Cost, scalar: float): self.cost: Cost self.scalar: float if isinstance(cost, ScaledCost): self.cost = cost.cost self.scalar = scalar * cost.scalar else: self.cost = cost self.scalar = scalar
@property def shape(self) -> tuple[int, ...]: return self.cost.shape @property def framework(self) -> SupportedFrameworks: return self.cost.framework @property def device(self) -> SupportedDevices: return self.cost.device
[docs] @cached_property def m_smooth(self) -> float: if self.scalar == 0: return 0.0 return float(abs(self.scalar) * self.cost.m_smooth)
[docs] @cached_property def m_cvx(self) -> float: if self.scalar > 0: return float(self.scalar * self.cost.m_cvx) if self.scalar == 0: return 0.0 return np.nan
[docs] def function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401 return float(self.scalar * self.cost.function(x, *args, **kwargs))
[docs] def gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 return self.cost.gradient(x, *args, **kwargs) * self.scalar
[docs] def hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 return self.cost.hessian(x, *args, **kwargs) * self.scalar
[docs] def proximal(self, x: Array, penalty: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 if penalty <= 0: raise ValueError("The penalty parameter penalty must be positive.") if self.scalar < 0: raise ValueError("The proximal operator is not defined for negative scaling.") if self.scalar == 0: return x return self.cost.proximal(x, penalty * self.scalar, *args, **kwargs)
[docs] def __add__(self, other: Cost) -> Cost: self._validate_cost_operation(other) return SumCost([self, other])