Source code for decent_bench.costs._base._regularizer_costs

from __future__ import annotations

from functools import cached_property
from typing import Any, overload

import numpy as np

import decent_bench.utils.interoperability as iop
from decent_bench.costs._base._cost import Cost
from decent_bench.utils._tags import tags
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks

__all__ = [
    "BaseRegularizerCost",
    "FractionalQuadraticRegularizerCost",
    "L1RegularizerCost",
    "L2RegularizerCost",
]


[docs] class BaseRegularizerCost(Cost): """ Base class for regularizers with regularizer-preserving arithmetic. Adding, subtracting, negating, scaling, or dividing regularizers returns another regularizer subclass instead of falling back immediately to generic :class:`~decent_bench.costs.SumCost` or :class:`~decent_bench.costs.ScaledCost`. This preserves regularizer-specific structure and can improve performance. Mixing a regularizer with an arbitrary non-regularizer still falls back to generic cost composition. """
[docs] def __init__( self, shape: tuple[int, ...], *, framework: SupportedFrameworks = SupportedFrameworks.NUMPY, device: SupportedDevices = SupportedDevices.CPU, ): if len(shape) == 0: raise ValueError("Regularizer shape must be non-empty.") if any(dim <= 0 for dim in shape): raise ValueError(f"Regularizer shape must be positive, got {shape}.") self._shape = shape self._dim = int(np.prod(shape)) self._framework = framework self._device = device self._hessian_cache: Array | None = None
@property def shape(self) -> tuple[int, ...]: return self._shape @property def framework(self) -> SupportedFrameworks: return self._framework @property def device(self) -> SupportedDevices: return self._device @overload def __add__(self, other: BaseRegularizerCost) -> BaseRegularizerCost: ... @overload def __add__(self, other: Cost) -> Cost: ...
[docs] def __add__(self, other: Cost) -> Cost: """Add another cost, preserving the regularizer abstraction when possible.""" self._validate_cost_operation(other) if isinstance(other, BaseRegularizerCost): return _CompositeRegularizerCost([self, other]) return super().__add__(other)
def __mul__(self, other: float) -> Cost: """ Multiply by a scalar while preserving the regularizer abstraction. Raises: TypeError: If other is not a real scalar. """ if not self._is_valid_scalar(other): raise TypeError(f"Cost can only be multiplied by a real number, got {type(other)}.") return _CompositeRegularizerCost([self], weights=[float(other)]) def __truediv__(self, other: float) -> Cost: """ Divide by a scalar while preserving the regularizer abstraction. Raises: TypeError: If other is not a real scalar. ZeroDivisionError: If other is zero. """ if not self._is_valid_scalar(other): raise TypeError(f"Cost can only be divided by a real number, got {type(other)}.") if other == 0: raise ZeroDivisionError("Division by zero is not allowed for Cost objects.") return self.__mul__(1.0 / float(other)) def __neg__(self) -> Cost: """Negate this regularizer while preserving the regularizer abstraction.""" return self.__mul__(-1.0) def __sub__(self, other: Cost) -> Cost: """Subtract another cost, preserving the regularizer abstraction when possible.""" self._validate_cost_operation(other) if isinstance(other, BaseRegularizerCost): return _CompositeRegularizerCost([self, other], weights=[1.0, -1.0]) return super().__sub__(other)
class _CompositeRegularizerCost(BaseRegularizerCost): """ Weighted combination of regularizers that preserves the regularizer abstraction. This wrapper represents sums and scalar rescalings of regularizers while keeping the :class:`BaseRegularizerCost` interface. It combines function, gradient, and Hessian termwise. A generic proximal is intentionally not implemented except for the single positively scaled regularizer case. Instances keep references to the wrapped cost objects. No implicit copying is performed; use :func:`copy.deepcopy` explicitly if independent objects are required. """ def __init__(self, regularizers: list[BaseRegularizerCost], weights: list[float] | None = None): if len(regularizers) == 0: raise ValueError("Composite regularizer must contain at least one regularizer.") first = regularizers[0] super().__init__(first.shape, framework=first.framework, device=first.device) if weights is None: weights = [1.0] * len(regularizers) if len(regularizers) != len(weights): raise ValueError("Composite regularizer weights must match the number of regularizers.") self._terms: list[tuple[BaseRegularizerCost, float]] = [] for regularizer, weight in zip(regularizers, weights, strict=True): if not isinstance(regularizer, BaseRegularizerCost): raise TypeError(f"Composite regularizer can only contain regularizers, got {type(regularizer)}.") self._validate_cost_operation(regularizer, check_framework=True, check_device=True) if isinstance(regularizer, _CompositeRegularizerCost): for inner_regularizer, inner_weight in regularizer._terms: # noqa: SLF001 self._terms.append((inner_regularizer, float(weight) * inner_weight)) else: self._terms.append((regularizer, float(weight))) @cached_property def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] m_smooth_vals = [abs(weight) * regularizer.m_smooth for regularizer, weight in self._terms] return np.nan if any(np.isnan(v) for v in m_smooth_vals) else float(sum(m_smooth_vals)) @cached_property def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] if any(weight < 0 for _, weight in self._terms): return np.nan m_cvx_vals = [weight * regularizer.m_cvx for regularizer, weight in self._terms] return np.nan if any(np.isnan(v) for v in m_cvx_vals) else float(sum(m_cvx_vals)) def function(self, x: Array, **kwargs: Any) -> float: # noqa: ANN401 return float(sum(weight * regularizer.function(x, **kwargs) for regularizer, weight in self._terms)) def gradient(self, x: Array, **kwargs: Any) -> Array: # noqa: ANN401 return iop.sum( iop.stack([regularizer.gradient(x, **kwargs) * weight for regularizer, weight in self._terms]), dim=0, ) def hessian(self, x: Array, **kwargs: Any) -> Array: # noqa: ANN401 return iop.sum( iop.stack([regularizer.hessian(x, **kwargs) * weight for regularizer, weight in self._terms]), dim=0, ) def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ANN401 """ Proximal is only supported for a single positively scaled regularizer term. For sums of regularizers or negative scaling, composing or summing individual proximal operators is not mathematically valid in general. Raises: ValueError: If penalty is not positive. NotImplementedError: If the composite is not a single positively scaled term. """ if penalty <= 0: raise ValueError("The penalty parameter penalty must be positive.") if len(self._terms) == 1: regularizer, weight = self._terms[0] if weight > 0: return regularizer.proximal(x, penalty * weight, **kwargs) raise NotImplementedError( "Composite regularizers do not implement a generic proximal operator because sums of regularizers and " "negative scaling do not admit a proximal from simple composition in general. Use a specialized proximal " "if available." )
[docs] @tags("regularizer") class L1RegularizerCost(BaseRegularizerCost): r""" L1 regularizer cost. .. math:: f(\mathbf{x}) = \|\mathbf{x}\|_1 = \sum_i |x_i| """
[docs] @cached_property def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] return np.nan
[docs] @cached_property def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] return 0.0
[docs] def function(self, x: Array, **kwargs: Any) -> float: # noqa: ARG002, ANN401 return float(iop.astype(iop.sum(iop.absolute(x)), float))
[docs] def gradient(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 return iop.sign(x)
[docs] def hessian(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 if self._hessian_cache is None: self._hessian_cache = iop.zeros(shape=(self._dim, self._dim), framework=self.framework, device=self.device) return self._hessian_cache
[docs] def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 if penalty <= 0: raise ValueError("The penalty parameter penalty must be positive.") shrink = iop.maximum(iop.absolute(x) - penalty, 0.0) return iop.sign(x) * shrink
[docs] @tags("regularizer") class L2RegularizerCost(BaseRegularizerCost): r""" L2 regularizer cost. .. math:: f(\mathbf{x}) = \frac{1}{2}\|\mathbf{x}\|_2^2 """
[docs] @cached_property def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] return 1.0
[docs] @cached_property def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] return 1.0
[docs] def function(self, x: Array, **kwargs: Any) -> float: # noqa: ARG002, ANN401 return float(iop.astype(0.5 * iop.sum(iop.mul(x, x)), float))
[docs] def gradient(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 return x
[docs] def hessian(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 if self._hessian_cache is None: self._hessian_cache = iop.eye(n=self._dim, framework=self.framework, device=self.device) return self._hessian_cache
[docs] def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 if penalty <= 0: raise ValueError("The penalty parameter penalty must be positive.") return x / (1.0 + penalty)
[docs] @tags("regularizer") class FractionalQuadraticRegularizerCost(BaseRegularizerCost): r""" Nonconvex fractional quadratic regularizer. .. math:: f(\mathbf{x}) = \sum_i \frac{x_i^2}{1 + x_i^2} """
[docs] def __init__( self, shape: tuple[int, ...], *, framework: SupportedFrameworks = SupportedFrameworks.NUMPY, device: SupportedDevices = SupportedDevices.CPU, prox_max_iter: int = 100, prox_tol: float | None = 1e-10, ): super().__init__(shape, framework=framework, device=device) if prox_max_iter <= 0: raise ValueError("prox_max_iter must be positive.") self._prox_max_iter = prox_max_iter self._prox_tol = prox_tol
[docs] @cached_property def m_smooth(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] return 2.0
[docs] @cached_property def m_cvx(self) -> float: # pyright: ignore[reportIncompatibleMethodOverride] return np.nan
[docs] def function(self, x: Array, **kwargs: Any) -> float: # noqa: ARG002, ANN401 x2 = x * x return float(iop.astype(iop.sum(x2 / (1.0 + x2)), float))
[docs] def gradient(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 x2 = x * x denom = (1.0 + x2) ** 2 return 2.0 * x / denom
[docs] def hessian(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 x2 = x * x denom = (1.0 + x2) ** 3 second = 2.0 * (1.0 - 3.0 * x2) / denom return iop.diag(iop.reshape(second, (self._dim,)))
[docs] def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ARG002, ANN401 if penalty <= 0: raise ValueError("The penalty parameter penalty must be positive.") step_size = 1.0 / (2.0 + 1.0 / penalty) current = iop.copy(x) for _ in range(self._prox_max_iter): x2 = current * current denom = (1.0 + x2) ** 2 grad = 2.0 * current / denom + (current - x) / penalty next_x = current - step_size * grad if self._prox_tol is not None and iop.astype(iop.norm(next_x - current), float) <= self._prox_tol: return next_x current = next_x return current