Source code for decent_bench.utils.interoperability._helpers
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedArrayTypes, SupportedDevices, SupportedFrameworks
from ._imports_types import _jnp_types, _np_types, _tf_types, _torch_types, jax, jnp, tf, torch
[docs]
def device_to_framework_device(device: SupportedDevices, framework: SupportedFrameworks) -> Any: # noqa: ANN401
"""
Convert SupportedDevices literal to framework-specific device representation.
Args:
device (SupportedDevices): Device literal ("cpu" or "gpu").
framework (SupportedFrameworks): Framework literal ("numpy", "torch", "tensorflow", "jax").
Returns:
Any: Framework-specific device representation.
Raises:
ValueError: If the framework is unsupported.
"""
if framework == SupportedFrameworks.NUMPY:
return device # NumPy does not have explicit device management
if torch and framework == SupportedFrameworks.PYTORCH:
return _identify_torch_device(device)
if tf and framework == SupportedFrameworks.TENSORFLOW:
return f"/{device.value}:0"
if jax and framework == SupportedFrameworks.JAX:
if device == SupportedDevices.CPU:
return jax.devices("cpu")[0]
return jax.devices("gpu")[0]
raise ValueError(f"Unsupported framework: {framework}")
def _identify_torch_device(device: SupportedDevices) -> str:
if device == SupportedDevices.CPU:
return "cpu"
if device == SupportedDevices.GPU:
return "cuda"
if device == SupportedDevices.MPS:
return "mps"
raise ValueError(f"Unsupported device: {device}")
def _return_array(array: SupportedArrayTypes) -> Array:
"""
Wrap a framework-native array in an `Array` wrapper.
This helper standardizes return types across interoperability functions,
returning the same framework-native object at runtime, while providing a
typed `Array` during static type checking.
Args:
array (SupportedArrayTypes): Input array (NumPy, torch, tf, jax).
Returns:
Array: Wrapped array (type-only during static analysis; at runtime
this returns the original framework-native value).
"""
if not TYPE_CHECKING:
return array
return Array(array)
[docs]
def framework_device_of_array(array: Array | SupportedArrayTypes) -> tuple[SupportedFrameworks, SupportedDevices]:
"""
Determine the framework and device of the given Array.
Args:
array (Array | SupportedArrayTypes): Input array.
Returns:
tuple[SupportedFrameworks, SupportedDevices]: Framework and device of the array.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, _np_types):
return SupportedFrameworks.NUMPY, SupportedDevices.CPU
if torch and isinstance(value, _torch_types):
if value.device.type == "mps": # type: ignore[union-attr]
device_type = SupportedDevices.MPS
elif value.device.type == "cuda": # type: ignore[union-attr]
device_type = SupportedDevices.GPU
elif value.device.type == "cpu": # type: ignore[union-attr]
device_type = SupportedDevices.CPU
else:
raise TypeError(f"Unsupported PyTorch device type: {value.device.type}") # type: ignore[union-attr]
return SupportedFrameworks.PYTORCH, device_type
if tf and isinstance(value, _tf_types):
device_str = value.device.lower() # type: ignore[union-attr]
device_type = SupportedDevices.GPU if "gpu" in device_str or "cuda" in device_str else SupportedDevices.CPU
return SupportedFrameworks.TENSORFLOW, device_type
if jnp and isinstance(value, _jnp_types):
backend = value.device.platform
device_type = SupportedDevices.GPU if backend == "gpu" else SupportedDevices.CPU
return SupportedFrameworks.JAX, device_type
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def is_supported_array_type(array: Any) -> bool: # noqa: ANN401
"""
Check if the given array is of a supported type.
Args:
array (Any): Input array.
Returns:
bool: True if the array is of a supported type, False otherwise.
"""
value = array.value if isinstance(array, Array) else array
return bool(
isinstance(value, _np_types)
or (torch and isinstance(value, _torch_types))
or (tf and isinstance(value, _tf_types))
or (jnp and isinstance(value, _jnp_types))
)