Source code for decent_bench.costs._base._zero_cost
from __future__ import annotations
from functools import cached_property
from typing import Any
import decent_bench.utils.interoperability as iop
from decent_bench.costs._base._cost import Cost
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks
[docs]
class ZeroCost(Cost):
"""
A cost function that is identically zero.
This function is used as default for the server in :class:`~decent_bench.networks.FedNetwork`.
"""
[docs]
def __init__(
self,
shape: tuple[int, ...],
framework: SupportedFrameworks = SupportedFrameworks.NUMPY,
device: SupportedDevices = SupportedDevices.CPU,
):
if not all(isinstance(d, int) and d >= 0 for d in shape):
raise ValueError("shape must be a tuple of non-negative integers")
self._shape = shape
self._framework = framework
self._device = device
@property
def shape(self) -> tuple[int, ...]:
return self._shape
@property
def framework(self) -> SupportedFrameworks:
return self._framework
@property
def device(self) -> SupportedDevices:
return self._device
[docs]
@cached_property
def m_smooth(self) -> float:
return 0.0
[docs]
@cached_property
def m_cvx(self) -> float:
return 0.0
def _check_shape(self, x: Array) -> None:
if iop.shape(x) != self.shape:
raise ValueError(f"Mismatching domain shapes: {iop.shape(x)} vs {self.shape}")
[docs]
def function(self, x: Array, **kwargs: Any) -> float: # noqa: ARG002, ANN401
self._check_shape(x)
return 0.0
[docs]
def gradient(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401
self._check_shape(x)
return iop.zeros(shape=self.shape, framework=self.framework, device=self.device)
[docs]
def hessian(self, x: Array, **kwargs: Any) -> Array: # noqa: ARG002, ANN401
self._check_shape(x)
return iop.zeros(shape=self.shape + self.shape, framework=self.framework, device=self.device)
[docs]
def proximal(self, x: Array, penalty: float, **kwargs: Any) -> Array: # noqa: ARG002, ANN401
"""
Return ``x`` unchanged.
Since :class:`ZeroCost` is identically zero, its proximal operator is the identity map. The method still
validates that ``penalty`` is positive and that ``x`` has the expected shape.
Raises:
ValueError: if ``penalty`` is not positive or ``x`` has the wrong shape.
"""
if penalty <= 0:
raise ValueError("The penalty parameter penalty must be positive.")
self._check_shape(x)
return x
[docs]
def __add__(self, other: Cost) -> Cost:
self._validate_cost_operation(other)
if isinstance(other, ZeroCost):
return self
return other