Source code for decent_bench.utils.interoperability._helpers

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedArrayTypes, SupportedDevices, SupportedFrameworks

from ._imports_types import _jnp_types, _np_types, _tf_types, _torch_types, jax, jnp, tf, torch


[docs] def device_to_framework_device(device: SupportedDevices, framework: SupportedFrameworks) -> Any: # noqa: ANN401 """ Convert SupportedDevices literal to framework-specific device representation. Args: device (SupportedDevices): Device literal ("cpu" or "gpu"). framework (SupportedFrameworks): Framework literal ("numpy", "torch", "tensorflow", "jax"). Returns: Any: Framework-specific device representation. Raises: ValueError: If the framework is unsupported. """ if framework == SupportedFrameworks.NUMPY: return device # NumPy does not have explicit device management if torch and framework == SupportedFrameworks.PYTORCH: return _identify_torch_device(device) if tf and framework == SupportedFrameworks.TENSORFLOW: return f"/{device.value}:0" if jax and framework == SupportedFrameworks.JAX: if device == SupportedDevices.CPU: return jax.devices("cpu")[0] return jax.devices("gpu")[0] raise ValueError(f"Unsupported framework: {framework}")
def _identify_torch_device(device: SupportedDevices) -> str: if device == SupportedDevices.CPU: return "cpu" if device == SupportedDevices.GPU: return "cuda" if device == SupportedDevices.MPS: return "mps" raise ValueError(f"Unsupported device: {device}") def _return_array(array: SupportedArrayTypes) -> Array: """ Wrap a framework-native array in an `Array` wrapper. This helper standardizes return types across interoperability functions, returning the same framework-native object at runtime, while providing a typed `Array` during static type checking. Args: array (SupportedArrayTypes): Input array (NumPy, torch, tf, jax). Returns: Array: Wrapped array (type-only during static analysis; at runtime this returns the original framework-native value). """ if not TYPE_CHECKING: return array return Array(array)
[docs] def framework_device_of_array(array: Array | SupportedArrayTypes) -> tuple[SupportedFrameworks, SupportedDevices]: """ Determine the framework and device of the given Array. Args: array (Array | SupportedArrayTypes): Input array. Returns: tuple[SupportedFrameworks, SupportedDevices]: Framework and device of the array. Raises: TypeError: if the framework type of `array` is unsupported. """ value = array.value if isinstance(array, Array) else array if isinstance(value, _np_types): return SupportedFrameworks.NUMPY, SupportedDevices.CPU if torch and isinstance(value, _torch_types): if value.device.type == "mps": # type: ignore[union-attr] device_type = SupportedDevices.MPS elif value.device.type == "cuda": # type: ignore[union-attr] device_type = SupportedDevices.GPU elif value.device.type == "cpu": # type: ignore[union-attr] device_type = SupportedDevices.CPU else: raise TypeError(f"Unsupported PyTorch device type: {value.device.type}") # type: ignore[union-attr] return SupportedFrameworks.PYTORCH, device_type if tf and isinstance(value, _tf_types): device_str = value.device.lower() # type: ignore[union-attr] device_type = SupportedDevices.GPU if "gpu" in device_str or "cuda" in device_str else SupportedDevices.CPU return SupportedFrameworks.TENSORFLOW, device_type if jnp and isinstance(value, _jnp_types): backend = value.device.platform device_type = SupportedDevices.GPU if backend == "gpu" else SupportedDevices.CPU return SupportedFrameworks.JAX, device_type raise TypeError(f"Unsupported framework type: {type(value)}")
[docs] def is_supported_array_type(array: Any) -> bool: # noqa: ANN401 """ Check if the given array is of a supported type. Args: array (Any): Input array. Returns: bool: True if the array is of a supported type, False otherwise. """ value = array.value if isinstance(array, Array) else array return bool( isinstance(value, _np_types) or (torch and isinstance(value, _torch_types)) or (tf and isinstance(value, _tf_types)) or (jnp and isinstance(value, _jnp_types)) )