decent_bench.utils.interoperability#

Utilities for operating on arrays from different deep learning and linear algebra frameworks.

Mirrors NumPy’s functionality for interoperability across frameworks.

decent_bench.utils.interoperability.argmax(array: Array, dim: int | None = None, keepdims: bool = False) Array[source]#

Compute index of maximum value.

Parameters:
  • array (Array) – Input array.

  • dim (int | None) – Dimension along which to find maximum. If None, finds maximum over flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

Indices of maximum values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.argmin(array: Array, dim: int | None = None, keepdims: bool = False) Array[source]#

Compute index of minimum value.

Parameters:
  • array (Array) – Input array.

  • dim (int | None) – Dimension along which to find minimum. If None, finds minimum over flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

Indices of minimum values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.astype(array: Array, dtype: type[float | int | bool]) float | int | bool[source]#

Cast a single-element array to a Python scalar of the specified type.

Parameters:
Returns:

The casted scalar value.

Return type:

float | int | bool

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.copy(array: Array) Array[source]#

Create a copy of the input array.

Parameters:

array (Array) – Input array.

Returns:

A copy of the input array in the same framework type.

Return type:

Array

decent_bench.utils.interoperability.diag(array: Array) Array[source]#

Create a diagonal matrix from a vector or extract a diagonal from a matrix.

Parameters:

array (Array) – Input array.

Returns:

Diagonal matrix or diagonal vector in the same framework type as the input.

Return type:

Array

Raises:
  • TypeError – if the framework type of array is unsupported.

  • ValueError – if the input does not have rank 1 or 2.

decent_bench.utils.interoperability.eye(framework: SupportedFrameworks, device: SupportedDevices, n: int) Array[source]#

Create an identity matrix of size n x n in the specified framework.

Parameters:
  • framework (SupportedFrameworks) – Target framework type (e.g., “torch”, “tf”).

  • device (SupportedDevices) – Target device (“cpu” or “gpu”).

  • n (int) – Size of the identity matrix.

Returns:

Identity matrix in the specified framework type.

Return type:

Array

Raises:

TypeError – if the framework type of framework is unsupported.

decent_bench.utils.interoperability.eye_like(array: Array) Array[source]#

Create an identity matrix with the same shape as the input.

Parameters:

array (Array) – Input array.

Returns:

Identity matrix in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.get_item(array: Array, key: SupportsIndex | slice | tuple[SupportsIndex | slice, ...]) Array[source]#

Get the item at the specified index of the array.

Parameters:
  • array (Array) – The tensor.

  • key (ArrayKey) – The key or index to get.

Returns:

The item at the specified index.

Return type:

Array

decent_bench.utils.interoperability.max(array: Array, dim: int | tuple[int, ...] | None = None, keepdims: bool = False) Array[source]#

Compute maximum of array elements.

Parameters:
  • array (Array) – Input array.

  • dim (int | tuple[int, ...] | None) – Dimension or dimensions along which to compute maximum. If None, finds maximum over flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

Maximum value in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.mean(array: Array, dim: int | tuple[int, ...] | None = None, keepdims: bool = False) Array[source]#

Compute mean of array elements.

Parameters:
  • array (Array) – Input array.

  • dim (int | tuple[int, ...] | None) – Dimension or dimensions along which to compute the mean. If None, computes mean of flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

Mean value in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.min(array: Array, dim: int | tuple[int, ...] | None = None, keepdims: bool = False) Array[source]#

Compute minimum of array elements.

Parameters:
  • array (Array) – Input array.

  • dim (int | tuple[int, ...] | None) – Dimension or dimensions along which to compute minimum. If None, finds minimum over flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

Minimum value in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.norm(array: Array, p: float = 2, dim: int | tuple[int, ...] | None = None, keepdims: bool = False) Array[source]#

Compute the norm of an array.

Parameters:
  • array (Array) – The tensor.

  • p (float) – The order of the norm.

  • dim (int | tuple[int, ...] | None) – Dimension or dimensions along which to compute the norm. If None, computes norm over flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

The norm of the tensor.

Return type:

Array

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.ones_like(array: Array) Array[source]#

Create an array of ones with the same shape and type as the input.

Parameters:

array (Array) – Input array.

Returns:

Array of ones in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.reshape(array: Array, shape: tuple[int, ...]) Array[source]#

Reshape an array to the specified shape.

Parameters:
  • array (Array) – Input array.

  • shape (tuple[int, ...]) – Desired shape for the output array.

Returns:

Reshaped array in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.set_item(array: Array | SupportedArrayTypes, key: ArrayKey, value: Array | SupportedArrayTypes) None[source]#

Set the item at the specified index of the array to the given value.

Parameters:
Raises:
decent_bench.utils.interoperability.shape(array: Array) tuple[int, ...][source]#

Get the shape of an array.

Parameters:

array (Array) – Input array.

Returns:

Shape of the input array.

Return type:

tuple[int, …]

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.squeeze(array: Array, dim: int | tuple[int, ...] | None = None) Array[source]#

Remove single-dimensional entries from the shape of an array.

Parameters:
  • array (Array) – Input array.

  • dim (int | tuple[int, ...] | None) – Dimension or dimensions to squeeze. If None, squeezes all single-dimensional entries.

Returns:

Squeezed array in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.stack(arrays: Sequence[Array], dim: int = 0) Array[source]#

Stack a sequence of arrays along a new dimension.

Parameters:
  • arrays (Sequence[Array]) – Sequence of input arrays. or nested containers (list, tuple).

  • dim (int) – Dimension along which to stack the arrays.

Returns:

Stacked array in the same framework type as the inputs.

Return type:

Array

Raises:
  • TypeError – if the framework type of the input arrays is unsupported.

  • ValueError – if the input sequence is empty.

decent_bench.utils.interoperability.sum(array: Array, dim: int | tuple[int, ...] | None = None, keepdims: bool = False) Array[source]#

Sum elements of an array.

Parameters:
  • array (Array) – Input array.

  • dim (int | tuple[int, ...] | None) – Dimension or dimensions along which to sum. If None, sums over flattened array.

  • keepdims (bool) – If True, retains reduced dimensions with length 1.

Returns:

Summed value in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.to_array(array: Array | SupportedArrayTypes, framework: SupportedFrameworks, device: SupportedDevices, dtype: Any | None = None) Array[source]#

Convert an array to the specified framework type.

See decent_bench.utils.interoperability.to_array_like() if you want to convert an array to match the framework and device of another array.

Parameters:
  • array (Array | SupportedArrayTypes) – Input array.

  • framework (SupportedFrameworks) – Target framework type (e.g., “torch”, “tf”).

  • device (SupportedDevices) – Target device (“cpu” or “gpu”).

  • dtype (Any | None) – Optional data type for the converted array. If None, the data type of the input array is preserved.

Returns:

Converted array in the specified framework type.

Return type:

Array

Raises:

TypeError – if the framework type of framework is unsupported.

decent_bench.utils.interoperability.to_array_like(array: Array | SupportedArrayTypes, like: Array) Array[source]#

Convert an array to the framework/device of like.

Parameters:
Returns:

Converted array in the specified framework type.

Return type:

Array

decent_bench.utils.interoperability.to_numpy(array: Array | SupportedArrayTypes, device: SupportedDevices = SupportedDevices.CPU, dtype: Any | None = None) NDArray[Any][source]#

Convert input array to a NumPy array.

Parameters:
  • array (Array | SupportedArrayTypes) – Input Array

  • device (SupportedDevices) – Device of the input array.

  • dtype (Any | None) – Optional data type for the converted array. If None, the data type of the input array is preserved.

Returns:

Converted NumPy array.

Return type:

NDArray

Note

The device parameter is currently not used in this function but is included for API consistency.

decent_bench.utils.interoperability.to_torch(array: Array | SupportedArrayTypes, device: SupportedDevices, dtype: Any | None = None) TorchTensor[source]#

Convert input array to a PyTorch tensor.

Parameters:
  • array (Array | SupportedArrayTypes) – Input Array

  • device (SupportedDevices) – Device of the input array.

  • dtype (Any | None) – Optional data type for the converted array. If None, the data type of the input array is preserved.

Returns:

Converted PyTorch tensor.

Return type:

torch.Tensor

Raises:

ImportError – if PyTorch is not installed.

decent_bench.utils.interoperability.to_tensorflow(array: Array | SupportedArrayTypes, device: SupportedDevices, dtype: Any | None = None) TensorflowTensor[source]#

Convert input array to a TensorFlow tensor.

Parameters:
  • array (Array | SupportedArrayTypes) – Input Array

  • device (SupportedDevices) – Device of the input array.

  • dtype (Any | None) – Optional data type for the converted array. If None, the data type of the input array is preserved.

Returns:

Converted TensorFlow tensor.

Return type:

tf.Tensor

Raises:

ImportError – if TensorFlow is not installed.

decent_bench.utils.interoperability.to_jax(array: Array | SupportedArrayTypes, device: SupportedDevices, dtype: Any | None = None) JaxArray[source]#

Convert input array to a JAX array.

Parameters:
  • array (Array | SupportedArrayTypes) – Input Array

  • device (SupportedDevices) – Device of the input array.

  • dtype (Any | None) – Optional data type for the converted array. If None, the data type of the input array is preserved.

Returns:

Converted JAX array.

Return type:

jax.Array

Raises:

ImportError – if JAX is not installed.

decent_bench.utils.interoperability.transpose(array: Array, dim: tuple[int, ...] | None = None) Array[source]#

Transpose an array.

Parameters:
  • array (Array) – Input array.

  • dim (tuple[int, ...] | None) – Desired dim order. If None, reverses the dimensions.

Returns:

Transposed array in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.zeros(framework: SupportedFrameworks, device: SupportedDevices, shape: tuple[int, ...]) Array[source]#

Create a Array of zeros.

Parameters:
Returns:

Array of zeros.

Return type:

Array

Raises:

TypeError – If the framework type of framework is unsupported.

decent_bench.utils.interoperability.zeros_like(array: Array) Array[source]#

Create an array of zeros with the same shape and type as the input.

Parameters:

array (Array) – Input array.

Returns:

Array of zeros in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.absolute(array: Array | SupportedArrayTypes) Array[source]#

Return the absolute value of a tensor.

Parameters:

array (Array | SupportedArrayTypes) – The tensor.

Returns:

The absolute value tensor.

Return type:

Array

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.add(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Element-wise addition of two arrays.

Parameters:
Returns:

Result of element-wise addition in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.div(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Element-wise division of two arrays.

Parameters:
Returns:

Result of element-wise division in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.dot(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Dot product of two arrays.

Parameters:
Returns:

Result of the dot product in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.matmul(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Matrix multiplication of two arrays.

Parameters:
Returns:

Result of matrix multiplication in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.maximum(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Element-wise maximum of two arrays.

Parameters:
Returns:

Result of element-wise maximum in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.mul(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Element-wise multiplication of two arrays.

Parameters:
Returns:

Result of element-wise multiplication in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.negative(array: Array | SupportedArrayTypes) Array[source]#

Negate array.

Parameters:

array (Array | SupportedArrayTypes) – The tensor.

Returns:

The negated tensor.

Return type:

Array

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.power(array: Array | SupportedArrayTypes, p: float) Array[source]#

Raise array to p power.

Parameters:
Returns:

The result of the operation.

Return type:

Array

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.sign(array: Array | SupportedArrayTypes) Array[source]#

Return the sign of a tensor.

Parameters:

array (Array | SupportedArrayTypes) – The tensor.

Returns:

The sign tensor.

Return type:

Array

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.sqrt(array: Array | SupportedArrayTypes) Array[source]#

Return the square root of a tensor.

Parameters:

array (Array | SupportedArrayTypes) – The tensor.

Returns:

The square root tensor.

Return type:

Array

Raises:

TypeError – If the type is not supported.

decent_bench.utils.interoperability.sub(array1: Array | SupportedArrayTypes, array2: Array | SupportedArrayTypes) Array[source]#

Element-wise subtraction of two arrays.

Parameters:
Returns:

Result of element-wise subtraction in the same framework type as the inputs.

Return type:

Array

Raises:

TypeError – if the framework type of the input arrays is unsupported or if the input arrays are not of the same framework type.

decent_bench.utils.interoperability.device_to_framework_device(device: SupportedDevices, framework: SupportedFrameworks) Any[source]#

Convert SupportedDevices literal to framework-specific device representation.

Parameters:
  • device (SupportedDevices) – Device literal (“cpu” or “gpu”).

  • framework (SupportedFrameworks) – Framework literal (“numpy”, “torch”, “tensorflow”, “jax”).

Returns:

Framework-specific device representation.

Return type:

Any

Raises:

ValueError – If the framework is unsupported.

decent_bench.utils.interoperability.framework_device_of_array(array: Array | SupportedArrayTypes) tuple[SupportedFrameworks, SupportedDevices][source]#

Determine the framework and device of the given Array.

Parameters:

array (Array | SupportedArrayTypes) – Input array.

Returns:

Framework and device of the array.

Return type:

tuple[SupportedFrameworks, SupportedDevices]

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.is_supported_array_type(array: Any) bool[source]#

Check if the given array is of a supported type.

Parameters:

array (Any) – Input array.

Returns:

True if the array is of a supported type, False otherwise.

Return type:

bool

decent_bench.utils.interoperability.autodecorate_cost_method(superclass_method: T) Callable[[Callable[[...], Any]], T][source]#

Decorate Cost methods to automatically convert Array args and return types.

It automatically converts input Array arguments to the cost’s framework-specific array type and wraps the output based on the superclass method’s return type annotation.

Parameters:

superclass_method – The method from the superclass (e.g., Cost.function) that is being overridden.

Note

  • Only arguments that are instances of Array are converted.

    Other types are passed through unchanged.

  • The first input argument of the decorated function must be x.

    This is to determine the target array type for output conversion. Otherwise a ValueError is raised.

  • Emits a warning if an input array’s framework differs from the cost’s framework.

    This may lead to unexpected behavior or performance issues.

decent_bench.utils.interoperability.choice(array: Array, size: int, replace: bool = True) Array[source]#

Randomly sample elements from an array.

Parameters:
  • array (Array) – Input array to sample from.

  • size (int) – Number of samples to draw.

  • replace (bool) – Whether to sample with replacement.

Returns:

Sampled values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.rng_numpy() Generator[source]#

Return the shared NumPy generator used by interoperability random functions.

decent_bench.utils.interoperability.get_rng_state(frameworks: Iterable[SupportedFrameworks] | None = None) dict[str, Any][source]#

Return a picklable snapshot of all managed RNG states.

Parameters:

frameworks – Optional subset of frameworks to seed. If None, all are seeded.

decent_bench.utils.interoperability.get_seed() int | None[source]#

Return the current global seed if one was set explicitly.

decent_bench.utils.interoperability.rng_tensorflow() TensorflowGenerator[source]#

Return a TensorFlow random generator.

Raises:

RuntimeError – if TensorFlow is not installed.

decent_bench.utils.interoperability.rng_torch(device: SupportedDevices = SupportedDevices.CPU) TorchGenerator[source]#

Return a torch.Generator for a given device.

Raises:

RuntimeError – if PyTorch is not installed.

decent_bench.utils.interoperability.set_rng_state(state: dict[str, Any]) None[source]#

Restore a RNG snapshot created by get_rng_state.

decent_bench.utils.interoperability.set_seed(seed: int, frameworks: Iterable[SupportedFrameworks] | None = None) None[source]#

Set random seeds across supported frameworks.

Parameters:
  • seed – Base seed to use.

  • frameworks – Optional subset of frameworks to seed. If None, all are seeded.

decent_bench.utils.interoperability.uniform_like(array: Array, low: float = 0.0, high: float = 1.0) Array[source]#

Create an array of random values with the same shape and type as the input.

Values are drawn uniformly from [low, high).

Parameters:
  • array (Array) – Input array.

  • low (float) – Lower bound of the uniform distribution.

  • high (float) – Upper bound of the uniform distribution.

Returns:

Array of random values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.uniform(framework: SupportedFrameworks, device: SupportedDevices, low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) Array[source]#

Create an array of random values with the specified shape and framework.

Values are drawn uniformly from [low, high).

Parameters:
  • framework (SupportedFrameworks) – Target framework type.

  • device (SupportedDevices) – Target device.

  • low (float) – Lower bound of the uniform distribution.

  • high (float) – Upper bound of the uniform distribution.

  • shape (tuple[int, ...]) – Shape of the output array.

Returns:

Array of random values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.normal(framework: SupportedFrameworks, device: SupportedDevices, mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) Array[source]#

Create an array of random values with the specified shape and framework.

Values are drawn from a normal distribution with mean mean and standard deviation std.

Parameters:
  • framework (SupportedFrameworks) – Target framework type.

  • device (SupportedDevices) – Target device.

  • mean (float) – Mean of the normal distribution.

  • std (float) – Standard deviation of the normal distribution.

  • shape (tuple[int, ...]) – Shape of the output array.

Returns:

Array of random values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.normal_like(array: Array, mean: float = 0.0, std: float = 1.0) Array[source]#

Create an array of random values with the same shape and type as the input.

Values are drawn from a normal distribution with mean mean and standard deviation std.

Parameters:
  • array (Array) – Input array.

  • mean (float) – Mean of the normal distribution.

  • std (float) – Standard deviation of the normal distribution.

Returns:

Array of random values in the same framework type as the input.

Return type:

Array

Raises:

TypeError – if the framework type of array is unsupported.

decent_bench.utils.interoperability.rng_jax() JaxArray[source]#

Split and return the next JAX sub-key while advancing global JAX RNG state.

Raises:

RuntimeError – if JAX is not installed.