Source code for decent_bench.utils.interoperability._decorators

from __future__ import annotations

from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, TypeVar, cast

from decent_bench.utils.array import Array
from decent_bench.utils.logger import LOGGER
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks

from ._functions import to_array_like, to_jax, to_numpy, to_tensorflow, to_torch
from ._helpers import framework_device_of_array

if TYPE_CHECKING:
    from decent_bench.costs import Cost

T = TypeVar("T", bound=Callable[..., Any])
"""A generic callable type variable."""


def _get_converter(framework: SupportedFrameworks) -> Callable[[Array | Any, SupportedDevices], Any]:
    if framework == SupportedFrameworks.NUMPY:
        return to_numpy
    if framework == SupportedFrameworks.PYTORCH:
        return to_torch
    if framework == SupportedFrameworks.TENSORFLOW:
        return to_tensorflow
    if framework == SupportedFrameworks.JAX:
        return to_jax

    raise ValueError(f"Unsupported framework: {framework}")


[docs] def autodecorate_cost_method[T: Callable[..., Any]](superclass_method: T) -> Callable[[Callable[..., Any]], T]: """ Decorate Cost methods to automatically convert :class:`~decent_bench.utils.array.Array` args and return types. It automatically converts input :class:`~decent_bench.utils.array.Array` arguments to the cost's framework-specific array type and wraps the output based on the superclass method's return type annotation. Args: superclass_method: The method from the superclass (e.g., `Cost.function`) that is being overridden. Note: * Only arguments that are instances of :class:`~decent_bench.utils.array.Array` are converted. Other types are passed through unchanged. * The first input argument of the decorated function must be ``x``. This is to determine the target array type for output conversion. Otherwise a :class:`ValueError` is raised. * Emits a warning if an input array's framework differs from the cost's framework. This may lead to unexpected behavior or performance issues. """ def decorator(func: Callable[..., Any]) -> T: # Determine the expected return type from the superclass method's annotations. try: return_type_annotation = superclass_method.__annotations__["return"] except (AttributeError, KeyError): return_type_annotation = None @wraps(func) def wrapper(self: Cost, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 converter = _get_converter(self.framework) if len(args) > 0: x_like = args[0] elif "x" in kwargs: x_like = kwargs["x"] else: raise ValueError("First argument must be 'x' for autodecorate_cost_method to work.") new_args = [] for arg in args: if isinstance(arg, Array): framework, _ = framework_device_of_array(arg) if framework != self.framework: LOGGER.warning( f"Converting array from framework {framework} to {self.framework}" f" in method {func.__name__}. This may lead to unexpected behavior or performance issues." ) new_args.append(converter(arg, self.device)) else: new_args.append(arg) new_kwargs = {} for key, value in kwargs.items(): if isinstance(value, Array): framework, _ = framework_device_of_array(value) if framework != self.framework: LOGGER.warning( f"Converting array from framework {framework} to {self.framework}" f" in method {func.__name__}. This may lead to unexpected behavior or performance issues." ) new_kwargs[key] = converter(value, self.device) else: new_kwargs[key] = value result = func(self, *new_args, **new_kwargs) if return_type_annotation is Array: return to_array_like(result, x_like) return result # Cast the wrapper to the type of the superclass method. # This tells mypy that the decorated method is compatible with the superclass. return cast("T", wrapper) return decorator