Source code for decent_bench.costs._empirical_risk._empirical_regularized_cost

from __future__ import annotations

from functools import cached_property
from typing import Any

import numpy as np

import decent_bench.utils.interoperability as iop
from decent_bench.costs._base._cost import Cost
from decent_bench.costs._base._regularizer_costs import BaseRegularizerCost
from decent_bench.utils.array import Array
from decent_bench.utils.types import (
    Dataset,
    EmpiricalRiskIndices,
    EmpiricalRiskReduction,
    SupportedDevices,
    SupportedFrameworks,
)

from ._empirical_risk_cost import EmpiricalRiskCost


[docs] class EmpiricalRegularizedCost(EmpiricalRiskCost): """ Composite objective of an empirical risk term plus a regularizer. This wrapper preserves empirical-risk-specific behavior from the empirical term, including :meth:`predict`, dataset access, and batch metadata, while combining function, gradient, and Hessian values with the regularizer. When :meth:`gradient` is called with ``reduction=None``, the regularizer gradient is broadcast across the leading sample dimension so that averaging over samples recovers the composite mean gradient. A generic proximal is intentionally not implemented. Instances keep references to the wrapped cost objects. No implicit copying is performed; use :func:`copy.deepcopy` explicitly if independent objects are required. """
[docs] def __init__(self, empirical_cost: EmpiricalRiskCost, regularizer: BaseRegularizerCost): empirical_cost._validate_cost_operation(regularizer, check_framework=True, check_device=True) # noqa: SLF001 self.empirical_cost = empirical_cost self.regularizer = regularizer
@property def shape(self) -> tuple[int, ...]: return self.empirical_cost.shape @property def framework(self) -> SupportedFrameworks: return self.empirical_cost.framework @property def device(self) -> SupportedDevices: return self.empirical_cost.device @property def n_samples(self) -> int: return self.empirical_cost.n_samples @property def batch_size(self) -> int: return self.empirical_cost.batch_size @property def batch_used(self) -> list[int]: return self.empirical_cost.batch_used @property def dataset(self) -> Dataset: return self.empirical_cost.dataset
[docs] @cached_property def m_smooth(self) -> float: m_smooth_vals = [self.empirical_cost.m_smooth, self.regularizer.m_smooth] return np.nan if any(np.isnan(v) for v in m_smooth_vals) else float(sum(m_smooth_vals))
[docs] @cached_property def m_cvx(self) -> float: m_cvx_vals = [self.empirical_cost.m_cvx, self.regularizer.m_cvx] return np.nan if any(np.isnan(v) for v in m_cvx_vals) else float(sum(m_cvx_vals))
[docs] def predict(self, x: Array, data: list[Array]) -> Array: """Predictions are determined by the empirical term.""" return self.empirical_cost.predict(x, data)
[docs] def function(self, x: Array, indices: EmpiricalRiskIndices = "batch", **kwargs: Any) -> float: # noqa: ANN401 return self.empirical_cost.function(x, indices=indices, **kwargs) + self.regularizer.function(x)
[docs] def gradient( self, x: Array, indices: EmpiricalRiskIndices = "batch", reduction: EmpiricalRiskReduction = "mean", **kwargs: Any, # noqa: ANN401 ) -> Array: """ Gradient of the empirical objective plus regularizer. When ``reduction="mean"``, this returns the mean empirical gradient over the selected samples plus the regularizer gradient. When ``reduction=None``, this returns one gradient per selected sample with the regularizer gradient broadcast along the leading sample dimension. Averaging the result over that leading dimension recovers the composite gradient returned by ``reduction="mean"``. """ empirical_gradient = self.empirical_cost.gradient(x, indices=indices, reduction=reduction, **kwargs) regularizer_gradient = self.regularizer.gradient(x) if reduction is None: batch_count = iop.shape(empirical_gradient)[0] regularizer_gradients = [regularizer_gradient for _ in range(batch_count)] return empirical_gradient + iop.stack(regularizer_gradients) return empirical_gradient + regularizer_gradient
[docs] def hessian(self, x: Array, indices: EmpiricalRiskIndices = "batch", **kwargs: Any) -> Array: # noqa: ANN401 return self.empirical_cost.hessian(x, indices=indices, **kwargs) + self.regularizer.hessian(x)
[docs] def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ANN401 """ Raise ``NotImplementedError`` for the generic proximal of an empirical cost plus regularizer. This wrapper preserves evaluation, gradient, and Hessian structure, but does not imply a closed-form proximal. Use a specialized composite cost if one exists, or use :func:`decent_bench.centralized_algorithms.proximal_solver` when its assumptions are satisfied. Raises: NotImplementedError: Always, because no generic closed-form proximal is provided. """ raise NotImplementedError( "EmpiricalRegularizedCost does not implement a generic proximal operator. Use a specialized proximal if " "available, or use centralized_algorithms.proximal_solver when applicable." )
[docs] def _sample_batch_indices(self, indices: EmpiricalRiskIndices = "batch") -> list[int]: return self.empirical_cost._sample_batch_indices(indices) # noqa: SLF001
[docs] def _get_batch_data(self, indices: EmpiricalRiskIndices = "batch") -> Any: # noqa: ANN401 return self.empirical_cost._get_batch_data(indices) # noqa: SLF001
[docs] def __add__(self, other: Cost) -> Cost: if isinstance(other, BaseRegularizerCost): return EmpiricalRegularizedCost(self.empirical_cost, self.regularizer + other) return super().__add__(other)