from __future__ import annotations
import contextlib
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from numpy.typing import NDArray
from decent_bench.utils.array import Array
from decent_bench.utils.types import ArrayKey, SupportedArrayTypes, SupportedDevices, SupportedFrameworks
from ._helpers import _return_array, device_to_framework_device, framework_device_of_array
from ._imports_types import (
_jnp_types,
_np_types,
_tf_types,
_torch_types,
)
jax = None
jnp = None
tf = None
torch = None
with contextlib.suppress(ImportError, ModuleNotFoundError):
import torch as _torch
torch = _torch
with contextlib.suppress(ImportError, ModuleNotFoundError):
import tensorflow as _tf
tf = _tf
with contextlib.suppress(ImportError, ModuleNotFoundError):
import jax.numpy as _jnp
jnp = _jnp
with contextlib.suppress(ImportError, ModuleNotFoundError):
import jax as _jax
jax = _jax
if TYPE_CHECKING:
from jax import Array as JaxArray
from tensorflow import Tensor as TensorflowTensor
from torch import Tensor as TorchTensor
[docs]
def to_numpy(
array: Array | SupportedArrayTypes,
device: SupportedDevices = SupportedDevices.CPU, # noqa: ARG001
dtype: Any | None = None, # noqa: ANN401
) -> NDArray[Any]:
"""
Convert input array to a NumPy array.
Args:
array (Array | SupportedArrayTypes): Input Array
device (SupportedDevices): Device of the input array.
dtype (Any | None): Optional data type for the converted array. If None,
the data type of the input array is preserved.
Returns:
NDArray: Converted NumPy array.
Note:
The `device` parameter is currently not used in this function but is included for API consistency.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return value
if torch and isinstance(value, torch.Tensor):
res: np.ndarray[Any] = value.cpu().numpy()
if dtype:
res = res.astype(dtype)
return res
if tf and isinstance(value, tf.Tensor):
res = value.numpy()
if dtype:
res = res.astype(dtype)
return res
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return np.array(value, dtype=dtype)
return np.array(value, dtype=dtype)
[docs]
def to_torch(array: Array | SupportedArrayTypes, device: SupportedDevices, dtype: Any | None = None) -> TorchTensor: # noqa: ANN401
"""
Convert input array to a PyTorch tensor.
Args:
array (Array | SupportedArrayTypes): Input Array
device (SupportedDevices): Device of the input array.
dtype (Any | None): Optional data type for the converted array. If None,
the data type of the input array is preserved.
Returns:
torch.Tensor: Converted PyTorch tensor.
Raises:
ImportError: if PyTorch is not installed.
"""
if not torch:
raise ImportError("PyTorch is not installed.")
value = array.value if isinstance(array, Array) else array
framework_device = device_to_framework_device(device, SupportedFrameworks.PYTORCH)
if dtype == np.float64:
# PyTorch does not support float64 on some devices, so we convert to float32 if float64 is requested
dtype = torch.float32
if isinstance(value, torch.Tensor):
return cast("TorchTensor", value.to(device=framework_device, dtype=dtype))
if isinstance(value, np.ndarray | np.generic):
return cast("TorchTensor", torch.tensor(value, dtype=dtype, device=framework_device))
if tf and isinstance(value, tf.Tensor):
return cast("TorchTensor", torch.tensor(value.cpu(), dtype=dtype, device=framework_device))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return cast("TorchTensor", torch.tensor(np.array(value), dtype=dtype, device=framework_device))
return cast("TorchTensor", torch.tensor(value, dtype=dtype, device=framework_device))
[docs]
def to_tensorflow(
array: Array | SupportedArrayTypes,
device: SupportedDevices,
dtype: Any | None = None, # noqa: ANN401
) -> TensorflowTensor:
"""
Convert input array to a TensorFlow tensor.
Args:
array (Array | SupportedArrayTypes): Input Array
device (SupportedDevices): Device of the input array.
dtype (Any | None): Optional data type for the converted array. If None,
the data type of the input array is preserved.
Returns:
tf.Tensor: Converted TensorFlow tensor.
Raises:
ImportError: if TensorFlow is not installed.
"""
if not tf:
raise ImportError("TensorFlow is not installed.")
value = array.value if isinstance(array, Array) else array
framework_device = device_to_framework_device(device, SupportedFrameworks.TENSORFLOW)
if isinstance(value, tf.Tensor):
with tf.device(framework_device):
return cast("TensorflowTensor", value)
if isinstance(value, np.ndarray | np.generic):
with tf.device(framework_device):
return cast("TensorflowTensor", tf.convert_to_tensor(value, dtype=dtype))
if torch and isinstance(value, torch.Tensor):
with tf.device(framework_device):
return cast("TensorflowTensor", tf.convert_to_tensor(value.cpu(), dtype=dtype))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
with tf.device(framework_device):
return cast("TensorflowTensor", tf.convert_to_tensor(value, dtype=dtype))
with tf.device(framework_device):
return cast("TensorflowTensor", tf.convert_to_tensor(value, dtype=dtype))
[docs]
def to_jax(array: Array | SupportedArrayTypes, device: SupportedDevices, dtype: Any | None = None) -> JaxArray: # noqa: ANN401
"""
Convert input array to a JAX array.
Args:
array (Array | SupportedArrayTypes): Input Array
device (SupportedDevices): Device of the input array.
dtype (Any | None): Optional data type for the converted array. If None,
the data type of the input array is preserved.
Returns:
jax.Array: Converted JAX array.
Raises:
ImportError: if JAX is not installed.
"""
if not jnp:
raise ImportError("JAX is not installed.")
value = array.value if isinstance(array, Array) else array
framework_device = device_to_framework_device(device, SupportedFrameworks.JAX)
if isinstance(value, jnp.ndarray | jnp.generic):
return cast("JaxArray", value.to_device(framework_device))
if isinstance(value, np.ndarray | np.generic):
return cast("JaxArray", jnp.array(value, dtype=dtype, device=framework_device))
if torch and isinstance(value, torch.Tensor):
return cast("JaxArray", jnp.array(value, dtype=dtype, device=framework_device))
if tf and isinstance(value, tf.Tensor):
return cast("JaxArray", jnp.array(value, dtype=dtype, device=framework_device))
return cast("JaxArray", jnp.array(value, dtype=dtype, device=framework_device))
[docs]
def to_array(
array: Array | SupportedArrayTypes,
framework: SupportedFrameworks,
device: SupportedDevices,
dtype: Any | None = None, # noqa: ANN401
) -> Array:
"""
Convert an array to the specified framework type.
See :func:`decent_bench.utils.interoperability.to_array_like` if you want to convert an array to match
the framework and device of another array.
Args:
array (Array | SupportedArrayTypes): Input array.
framework (SupportedFrameworks): Target framework type (e.g., "torch", "tf").
device (SupportedDevices): Target device ("cpu" or "gpu").
dtype (Any | None): Optional data type for the converted array. If None,
the data type of the input array is preserved.
Returns:
Array: Converted array in the specified framework type.
Raises:
TypeError: if the framework type of `framework` is unsupported.
"""
if framework == SupportedFrameworks.NUMPY:
return _return_array(to_numpy(array, device, dtype))
if torch and framework == SupportedFrameworks.PYTORCH:
return _return_array(to_torch(array, device, dtype))
if tf and framework == SupportedFrameworks.TENSORFLOW:
return _return_array(to_tensorflow(array, device, dtype))
if jnp and framework == SupportedFrameworks.JAX:
return _return_array(to_jax(array, device, dtype))
raise TypeError(f"Unsupported framework type: {framework}")
[docs]
def to_array_like(array: Array | SupportedArrayTypes, like: Array) -> Array:
"""
Convert an array to the framework/device of `like`.
Args:
array (Array | SupportedArrayTypes): Input array.
like (Array): Array whose framework and device to match.
Returns:
Array: Converted array in the specified framework type.
"""
value = like.value if isinstance(like, Array) else like
framework, device = framework_device_of_array(value)
dtype = None
if hasattr(value, "dtype"):
dtype = value.dtype
return to_array(array, framework, device, dtype)
[docs]
def sum( # noqa: A001
array: Array,
dim: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
"""
Sum elements of an array.
Args:
array (Array): Input array.
dim (int | tuple[int, ...] | None): Dimension or dimensions along which to sum.
If None, sums over flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: Summed value in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.sum(value, axis=dim, keepdims=keepdims))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.sum(value, dim=dim, keepdim=keepdims))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.reduce_sum(value, axis=dim, keepdims=keepdims))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.sum(value, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def mean(
array: Array,
dim: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
"""
Compute mean of array elements.
Args:
array (Array): Input array.
dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute the mean.
If None, computes mean of flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: Mean value in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.mean(value, axis=dim, keepdims=keepdims))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.mean(value, dim=dim, keepdim=keepdims))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.reduce_mean(value, axis=dim, keepdims=keepdims))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.mean(value, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def min( # noqa: A001
array: Array,
dim: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
"""
Compute minimum of array elements.
Args:
array (Array): Input array.
dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute minimum.
If None, finds minimum over flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: Minimum value in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.min(value, axis=dim, keepdims=keepdims))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.amin(value, dim=dim, keepdim=keepdims)) # pyright: ignore[reportArgumentType]
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.reduce_min(value, axis=dim, keepdims=keepdims))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.min(value, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def max( # noqa: A001
array: Array,
dim: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
"""
Compute maximum of array elements.
Args:
array (Array): Input array.
dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute maximum.
If None, finds maximum over flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: Maximum value in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.max(value, axis=dim, keepdims=keepdims))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.amax(value, dim=dim, keepdim=keepdims)) # pyright: ignore[reportArgumentType]
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.reduce_max(value, axis=dim, keepdims=keepdims))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.max(value, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def argmax(array: Array, dim: int | None = None, keepdims: bool = False) -> Array:
"""
Compute index of maximum value.
Args:
array (Array): Input array.
dim (int | None): Dimension along which to find maximum. If None, finds maximum over flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: Indices of maximum values in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.argmax(value, axis=dim, keepdims=keepdims))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.argmax(value, dim=dim, keepdim=keepdims))
if tf and isinstance(value, tf.Tensor):
if dim is None:
# TensorFlow's argmax does not support dim=None directly
dims = value.ndim if value.ndim is not None else 0
reshaped_array = tf.reshape(value, [-1])
amax = tf.math.argmax(reshaped_array, axis=0)
ret = _return_array(amax) if not keepdims else _return_array(tf.reshape(amax, [1] * dims))
else:
ret = (
_return_array(tf.math.argmax(value, axis=dim))
if not keepdims
else _return_array(tf.expand_dims(tf.math.argmax(value, axis=dim), axis=dim))
)
return ret
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.argmax(value, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def argmin(array: Array, dim: int | None = None, keepdims: bool = False) -> Array:
"""
Compute index of minimum value.
Args:
array (Array): Input array.
dim (int | None): Dimension along which to find minimum. If None, finds minimum over flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: Indices of minimum values in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.argmin(value, axis=dim, keepdims=keepdims))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.argmin(value, dim=dim, keepdim=keepdims))
if tf and isinstance(value, tf.Tensor):
ret = None
if dim is None:
# TensorFlow's argmin does not support dim=None directly
dims = value.ndim if value.ndim is not None else 0
tf_array = tf.reshape(value, [-1])
amin = tf.math.argmin(tf_array, axis=0)
ret = _return_array(amin) if not keepdims else _return_array(tf.reshape(amin, [1] * dims))
else:
ret = (
_return_array(tf.math.argmin(value, axis=dim))
if not keepdims
else _return_array(tf.expand_dims(tf.math.argmin(value, axis=dim), axis=dim))
)
return ret
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.argmin(value, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def copy(array: Array) -> Array:
"""
Create a copy of the input array.
Args:
array (Array): Input array.
Returns:
Array: A copy of the input array in the same framework type.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.copy(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.clone(value.detach()))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.identity(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.array(value, copy=True))
return deepcopy(array)
[docs]
def stack(arrays: Sequence[Array], dim: int = 0) -> Array:
"""
Stack a sequence of arrays along a new dimension.
Args:
arrays (Sequence[Array]): Sequence of input arrays.
or nested containers (list, tuple).
dim (int): Dimension along which to stack the arrays.
Returns:
Array: Stacked array in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported.
ValueError: if the input sequence is empty.
"""
if len(arrays) == 0:
raise ValueError("Input sequence is empty.")
arrs = [arr.value for arr in arrays] if isinstance(arrays[0], Array) else arrays
if isinstance(arrs[0], np.ndarray | np.generic):
return _return_array(np.stack(arrs, axis=dim)) # pyright: ignore[reportArgumentType, reportCallIssue]
if torch and isinstance(arrs[0], torch.Tensor):
return _return_array(torch.stack(arrs, dim=dim)) # pyright: ignore[reportArgumentType]
if tf and isinstance(arrs[0], tf.Tensor):
return _return_array(tf.stack(arrs, axis=dim))
if jnp and isinstance(arrs[0], jnp.ndarray | jnp.generic):
return _return_array(jnp.stack(arrs, axis=dim)) # pyright: ignore[reportArgumentType]
raise TypeError(f"Unsupported framework type or mixed types: {[type(arr) for arr in arrs]}")
[docs]
def reshape(array: Array, shape: tuple[int, ...]) -> Array:
"""
Reshape an array to the specified shape.
Args:
array (Array): Input array.
shape (tuple[int, ...]): Desired shape for the output array.
Returns:
Array: Reshaped array in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.reshape(value, shape))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.reshape(value, shape))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.reshape(value, shape))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.reshape(value, shape))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def zeros_like(array: Array) -> Array:
"""
Create an array of zeros with the same shape and type as the input.
Args:
array (Array): Input array.
Returns:
Array: Array of zeros in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.zeros_like(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.zeros_like(value))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.zeros_like(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.zeros_like(value))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def ones_like(array: Array) -> Array:
"""
Create an array of ones with the same shape and type as the input.
Args:
array (Array): Input array.
Returns:
Array: Array of ones in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.ones_like(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.ones_like(value))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.ones_like(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.ones_like(value))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def eye_like(array: Array) -> Array:
"""
Create an identity matrix with the same shape as the input.
Args:
array (Array): Input array.
Returns:
Array: Identity matrix in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.eye(*value.shape[-2:], dtype=value.dtype))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.eye(*value.shape[-2:], dtype=value.dtype, device=value.device))
if tf and isinstance(value, tf.Tensor):
shape = tf.shape(value)
return _return_array(tf.eye(*shape[-2:], dtype=value.dtype))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.eye(*value.shape[-2:], dtype=value.dtype, device=value.device))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def eye(framework: SupportedFrameworks, device: SupportedDevices, n: int) -> Array:
"""
Create an identity matrix of size n x n in the specified framework.
Args:
framework (SupportedFrameworks): Target framework type (e.g., "torch", "tf").
device (SupportedDevices): Target device ("cpu" or "gpu").
n (int): Size of the identity matrix.
Returns:
Array: Identity matrix in the specified framework type.
Raises:
TypeError: if the framework type of `framework` is unsupported.
"""
if framework == SupportedFrameworks.NUMPY:
return _return_array(np.eye(n))
framework_device = device_to_framework_device(device, framework)
if torch and framework == SupportedFrameworks.PYTORCH:
return _return_array(torch.eye(n, device=framework_device))
if tf and framework == SupportedFrameworks.TENSORFLOW:
with tf.device(framework_device):
return _return_array(tf.eye(n))
if jnp and framework == SupportedFrameworks.JAX:
return _return_array(jnp.eye(n, device=framework_device))
raise TypeError(f"Unsupported framework type: {framework}")
[docs]
def transpose(array: Array, dim: tuple[int, ...] | None = None) -> Array:
"""
Transpose an array.
Args:
array (Array): Input array.
dim (tuple[int, ...] | None): Desired dim order. If None, reverses the dimensions.
Returns:
Array: Transposed array in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.transpose(value, axes=dim))
if torch and isinstance(value, torch.Tensor):
# Handle None case for PyTorch
return (
_return_array(torch.permute(value, dims=dim))
if dim
else _return_array(torch.permute(value, dims=list(reversed(range(value.ndim)))))
)
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.transpose(value, perm=dim))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.transpose(value, axes=dim))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def shape(array: Array) -> tuple[int, ...]:
"""
Get the shape of an array.
Args:
array (Array): Input array.
Returns:
tuple[int, ...]: Shape of the input array.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return value.shape
if torch and isinstance(value, torch.Tensor):
return tuple(value.shape)
if tf and isinstance(value, tf.Tensor):
tf_shape = tuple(value.shape)
return cast("tuple[int, ...]", tf_shape)
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return cast("tuple[int, ...]", value.shape)
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def zeros(framework: SupportedFrameworks, device: SupportedDevices, shape: tuple[int, ...]) -> Array:
"""
Create a Array of zeros.
Args:
framework (SupportedFrameworks): The framework to use.
device (SupportedDevices): The device to place the tensor on.
shape (tuple[int, ...]): Shape of the output array.
Returns:
Array: Array of zeros.
Raises:
TypeError: If the framework type of `framework` is unsupported.
"""
framework_device = device_to_framework_device(device, framework)
if framework == SupportedFrameworks.NUMPY:
return _return_array(np.zeros(shape))
if torch and framework == SupportedFrameworks.PYTORCH:
return _return_array(torch.zeros(shape, device=framework_device))
if tf and framework == SupportedFrameworks.TENSORFLOW:
with tf.device(framework_device):
return _return_array(tf.zeros(shape))
if jnp and framework == SupportedFrameworks.JAX:
return _return_array(jnp.zeros(shape, device=framework_device))
raise TypeError(f"Unsupported framework type: {framework}")
[docs]
def set_item(
array: Array | SupportedArrayTypes,
key: ArrayKey,
value: Array | SupportedArrayTypes,
) -> None:
"""
Set the item at the specified index of the array to the given value.
Args:
array (Array | SupportedArrayTypes): The tensor.
key (ArrayKey): The key or index to set.
value (Array | SupportedArrayTypes): The value to set.
Raises:
TypeError: If the type is not supported.
NotImplementedError: If the operation is not supported due to immutability.
"""
array_value = array.value if isinstance(array, Array) else array
value_value = value.value if isinstance(value, Array) else value
if isinstance(array_value, np.ndarray | np.generic):
array_value[key] = value_value
return
if torch and isinstance(array_value, torch.Tensor) and isinstance(value_value, _torch_types):
array_value[key] = value_value
return
if tf and isinstance(array_value, tf.Tensor) and isinstance(value_value, _tf_types):
raise NotImplementedError("Setting items in TensorFlow tensors is not supported due to immutability.")
if jnp and isinstance(array_value, jnp.ndarray | jnp.generic) and isinstance(value_value, _jnp_types):
raise NotImplementedError("Setting items in JAX arrays is not supported due to immutability.")
raise TypeError(f"Unsupported type: {type(array_value)} with value: {type(value_value)}")
[docs]
def get_item(array: Array, key: ArrayKey) -> Array:
"""
Get the item at the specified index of the array.
Args:
array (Array): The tensor.
key (ArrayKey): The key or index to get.
Returns:
Array: The item at the specified index.
"""
value = array.value if isinstance(array, Array) else array
return _return_array(value[key]) # type: ignore[index]
[docs]
def astype(array: Array, dtype: type[float | int | bool]) -> float | int | bool:
"""
Cast a single-element array to a Python scalar of the specified type.
Args:
array (Array): The tensor.
dtype (float | int | bool): The target data type.
Returns:
float | int | bool: The casted scalar value.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, _np_types):
return dtype(value.item() if hasattr(value, "item") else value) # pyright: ignore[reportAttributeAccessIssue]
if torch and isinstance(value, torch.Tensor):
return dtype(value.item())
if tf and isinstance(value, tf.Tensor):
return dtype(to_numpy(value).item())
if jnp and isinstance(value, _jnp_types):
return dtype(value.item())
raise TypeError(f"Unsupported type: {type(value)}")
[docs]
def norm(
array: Array,
p: float = 2,
dim: int | tuple[int, ...] | None = None,
keepdims: bool = False,
) -> Array:
"""
Compute the norm of an array.
Args:
array (Array): The tensor.
p (float): The order of the norm.
dim (int | tuple[int, ...] | None): Dimension or dimensions along which to compute the norm.
If None, computes norm over flattened array.
keepdims (bool): If True, retains reduced dimensions with length 1.
Returns:
Array: The norm of the tensor.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(cast("SupportedArrayTypes", np.linalg.norm(value, ord=p, axis=dim, keepdims=keepdims)))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.linalg.norm(value, ord=p, dim=dim, keepdim=keepdims))
if tf and isinstance(value, tf.Tensor):
if dim is None and value.ndim == 2:
dim = (-2, -1)
return _return_array(tf.norm(value, ord=p, axis=dim, keepdims=keepdims))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.linalg.norm(value, ord=p, axis=dim, keepdims=keepdims))
raise TypeError(f"Unsupported type: {type(value)}")
[docs]
def squeeze(array: Array, dim: int | tuple[int, ...] | None = None) -> Array:
"""
Remove single-dimensional entries from the shape of an array.
Args:
array (Array): Input array.
dim (int | tuple[int, ...] | None): Dimension or dimensions to squeeze.
If None, squeezes all single-dimensional entries.
Returns:
Array: Squeezed array in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.squeeze(value, axis=dim))
if torch and isinstance(value, torch.Tensor):
if dim is None: # Bug where dim=None is not supported in torch.squeeze
return _return_array(torch.squeeze(value))
return _return_array(torch.squeeze(value, dim=dim))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.squeeze(value, axis=dim))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.squeeze(value, axis=dim))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def diag(array: Array) -> Array:
"""
Create a diagonal matrix from a vector or extract a diagonal from a matrix.
Args:
array (Array): Input array.
Returns:
Array: Diagonal matrix or diagonal vector in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
ValueError: if the input does not have rank 1 or 2.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.diag(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.diag(value))
if tf and isinstance(value, tf.Tensor):
if value.shape.ndims == 1:
return _return_array(tf.linalg.diag(value))
if value.shape.ndims == 2:
return _return_array(tf.linalg.diag_part(value))
if value.shape.ndims is not None:
raise ValueError("Input must be 1- or 2-d for diag.")
rank = tf.rank(value)
tf.debugging.assert_rank_in(value, [1, 2])
return _return_array(
tf.cond(
rank == 1,
lambda: tf.linalg.diag(value),
lambda: tf.linalg.diag_part(value),
)
)
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.diag(value))
raise TypeError(f"Unsupported framework type: {type(value)}")