from __future__ import annotations
import importlib
from numbers import Real
from typing import TYPE_CHECKING, cast
import numpy as np
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedArrayTypes
from ._helpers import _return_array
from ._imports_types import _jnp_types, _np_types, _tf_types, _torch_types, jax, jnp, tf, torch
if TYPE_CHECKING:
from torch import Tensor as TorchTensor
[docs]
def add(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Element-wise addition of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise addition in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, _np_types):
return _return_array(value1 + value2)
if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types):
return _return_array(torch.add(value1, value2))
if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types):
return _return_array(tf.add(value1, value2))
if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types):
return _return_array(jnp.add(value1, value2))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
def iadd[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T:
"""
Element-wise in-place addition of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise in-place addition in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, np.ndarray | np.generic):
value1 += value2
return cast("T", _return_array(value1))
if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types):
value1 += value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types):
value1 += value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types):
value1 += value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
[docs]
def sub(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Element-wise subtraction of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise subtraction in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, _np_types):
return _return_array(value1 - value2)
if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types):
return _return_array(torch.sub(value1, value2))
if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types):
return _return_array(tf.subtract(value1, value2))
if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types):
return _return_array(jnp.subtract(value1, value2))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
def isub[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T:
"""
Element-wise in-place subtraction of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise in-place subtraction in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, np.ndarray | np.generic):
value1 -= value2
return cast("T", _return_array(value1))
if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types):
value1 -= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types):
value1 -= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types):
value1 -= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
[docs]
def mul(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Element-wise multiplication of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise multiplication in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, _np_types):
return _return_array(value1 * value2) # pyright: ignore[reportOperatorIssue]
if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types):
return _return_array(torch.mul(value1, value2))
if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types):
return _return_array(tf.multiply(value1, value2))
if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types):
return _return_array(jnp.multiply(value1, value2))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
def imul[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T:
"""
Element-wise in-place multiplication of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise in-place multiplication in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, np.ndarray | np.generic):
value1 *= value2
return cast("T", _return_array(value1))
if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types):
value1 *= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types):
value1 *= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types):
value1 *= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
[docs]
def div(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Element-wise division of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise division in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, _np_types):
return _return_array(value1 / value2)
if torch and isinstance(value1, _torch_types) and isinstance(value2, _torch_types):
return _return_array(torch.div(value1, value2))
if tf and isinstance(value1, _tf_types) and isinstance(value2, _tf_types):
return _return_array(tf.divide(value1, value2))
if jnp and isinstance(value1, _jnp_types) and isinstance(value2, _jnp_types):
return _return_array(jnp.divide(value1, value2))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
def idiv[T: Array](array1: T, array2: Array | SupportedArrayTypes) -> T:
"""
Element-wise in-place division of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise in-place division in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, np.ndarray | np.generic):
value1 /= value2
return cast("T", _return_array(value1))
if torch and isinstance(value1, torch.Tensor) and isinstance(value2, _torch_types):
value1 /= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if tf and isinstance(value1, tf.Tensor) and isinstance(value2, _tf_types):
value1 /= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, _jnp_types):
value1 /= value2 # pyright: ignore[reportOperatorIssue]
return cast("T", _return_array(value1))
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
[docs]
def matmul(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Matrix multiplication of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of matrix multiplication in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, np.ndarray | np.generic):
return _return_array(value1 @ value2)
if torch and isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
return _return_array(value1 @ value2) # pyright: ignore[reportOperatorIssue]
if tf and isinstance(value1, tf.Tensor) and isinstance(value2, tf.Tensor):
return _return_array(value1 @ value2) # pyright: ignore[reportOperatorIssue]
if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, jnp.ndarray | jnp.generic):
return _return_array(value1 @ value2) # pyright: ignore[reportOperatorIssue]
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
[docs]
def dot(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Dot product of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of the dot product in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
if isinstance(value1, np.ndarray | np.generic):
return _return_array(value1.dot(value2))
if torch and isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
return _return_array(value1.dot(value2)) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
if tf and isinstance(value1, tf.Tensor) and isinstance(value2, tf.Tensor):
return _return_array(value1.dot(value2)) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
if jnp and isinstance(value1, jnp.ndarray | jnp.generic) and isinstance(value2, jnp.ndarray | jnp.generic):
return _return_array(value1.dot(value2)) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
[docs]
def power(array: Array | SupportedArrayTypes, p: float) -> Array:
"""
Raise array to p power.
Args:
array (Array | SupportedArrayTypes): The tensor.
p (float): The power.
Returns:
Array: The result of the operation.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.power(value, p))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.pow(value, p))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.pow(value, p))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.power(value, p))
raise TypeError(f"Unsupported type: {type(value)}")
def ipow[T: Array](array: T, p: float) -> T:
"""
Element-wise in-place power of an array.
Args:
array (Array | SupportedArrayTypes): Input array.
p (float): The power.
Returns:
Array: Result of element-wise in-place power in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
value **= p
return cast("T", _return_array(value))
if torch and isinstance(value, torch.Tensor):
value **= p
return cast("T", _return_array(value))
if tf and isinstance(value, tf.Tensor):
value **= p
return cast("T", _return_array(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
value **= p
return cast("T", _return_array(value))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def negative(array: Array | SupportedArrayTypes) -> Array:
"""
Negate array.
Args:
array (Array | SupportedArrayTypes): The tensor.
Returns:
Array: The negated tensor.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.negative(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.neg(value))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.negative(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.negative(value))
raise TypeError(f"Unsupported type: {type(value)}")
[docs]
def absolute(array: Array | SupportedArrayTypes) -> Array:
"""
Return the absolute value of a tensor.
Args:
array (Array | SupportedArrayTypes): The tensor.
Returns:
Array: The absolute value tensor.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.abs(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.abs(value))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.abs(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.abs(value))
raise TypeError(f"Unsupported type: {type(value)}")
[docs]
def sqrt(array: Array | SupportedArrayTypes) -> Array:
"""
Return the square root of a tensor.
Args:
array (Array | SupportedArrayTypes): The tensor.
Returns:
Array: The square root tensor.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.sqrt(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.sqrt(value))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.sqrt(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.sqrt(value))
raise TypeError(f"Unsupported type: {type(value)}")
[docs]
def sign(array: Array | SupportedArrayTypes) -> Array:
"""
Return the sign of a tensor.
Args:
array (Array | SupportedArrayTypes): The tensor.
Returns:
Array: The sign tensor.
Raises:
TypeError: If the type is not supported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
return _return_array(np.sign(value))
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.sign(value))
if tf and isinstance(value, tf.Tensor):
return _return_array(tf.sign(value))
if jnp and isinstance(value, jnp.ndarray | jnp.generic):
return _return_array(jnp.sign(value))
raise TypeError(f"Unsupported type: {type(value)}")
[docs]
def maximum(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) -> Array:
"""
Element-wise maximum of two arrays.
Args:
array1 (Array | SupportedArrayTypes): First input array.
array2 (Array | SupportedArrayTypes): Second input array.
Returns:
Array: Result of element-wise maximum in the same framework type as the inputs.
Raises:
TypeError: if the framework type of the input arrays is unsupported
or if the input arrays are not of the same framework type.
"""
value1 = array1.value if isinstance(array1, Array) else array1
value2 = array2.value if isinstance(array2, Array) else array2
def _is_scalar(value: object) -> bool:
return isinstance(value, Real) and not isinstance(value, bool)
def _is_jax_array(value: object) -> bool:
return (
isinstance(value, _jnp_types)
or (jax is not None and hasattr(jax, "Array") and isinstance(value, jax.Array))
or type(value).__module__.startswith(("jax", "jaxlib"))
)
result = None
if (isinstance(value1, _np_types) and (isinstance(value2, _np_types) or _is_scalar(value2))) or (
isinstance(value2, _np_types) and _is_scalar(value1)
):
result = np.maximum(value1, value2)
elif torch and (isinstance(value1, torch.Tensor) or isinstance(value2, torch.Tensor)):
tensor = value1 if isinstance(value1, torch.Tensor) else value2
other = value2 if tensor is value1 else value1
tensor_t = cast("TorchTensor", tensor)
if isinstance(other, torch.Tensor):
result = torch.maximum(tensor_t, other)
elif _is_scalar(other):
result = torch.maximum(tensor_t, torch.tensor(other, device=tensor_t.device, dtype=tensor_t.dtype))
elif tf and (isinstance(value1, tf.Tensor) or isinstance(value2, tf.Tensor)):
tensor = value1 if isinstance(value1, tf.Tensor) else value2
other = value2 if tensor is value1 else value1
if isinstance(other, _tf_types) or _is_scalar(other):
result = tf.maximum(tensor, other)
elif _is_jax_array(value1) or _is_jax_array(value2):
jnp_module = jnp or importlib.import_module("jax.numpy")
tensor = value1 if not _is_scalar(value1) else value2
other = value2 if tensor is value1 else value1
if _is_jax_array(other) or _is_scalar(other):
result = jnp_module.maximum(tensor, other)
if result is None:
raise TypeError(f"Unsupported framework type: {type(value1)} and {type(value2)}")
return _return_array(result)