Source code for decent_bench.utils.types

"""Type definitions for optimization variables."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Literal, SupportsIndex, TypeAlias, TypeVar, Union

if TYPE_CHECKING:
    import jax
    import numpy
    import tensorflow as tf
    import torch

    from decent_bench.agents import Agent
    from decent_bench.networks import Network
    from decent_bench.utils.array import Array

ArrayLike: TypeAlias = Union["numpy.ndarray", "torch.Tensor", "tf.Tensor", "jax.Array"]  # noqa: UP040
"""
Type alias for array-like types supported in decent-bench, including NumPy arrays,
PyTorch tensors, TensorFlow tensors, and JAX arrays.
"""

SupportedArrayTypes: TypeAlias = ArrayLike | float | int  # noqa: UP040
"""
Type alias for supported types for optimization variables in decent-bench,
including array-like types and scalars.
"""

if TYPE_CHECKING:
    NetworkT = TypeVar("NetworkT", bound=Network)
else:
    NetworkT = TypeVar("NetworkT")
"""
Type variable for algorithms operating on a :class:`~decent_bench.networks.Network`.
"""

type InitialStates = Union["Array", "dict[Agent, Array]", None]  # noqa: UP007
"""
Type alias for what can be passed to
:func:`~decent_bench.algorithms.utils.initial_states`.
"""

type LocalSteps = int | dict["Agent", int]
"""
Type alias for specifying local step counts in federated algorithms.

Can be a single positive integer shared by all clients, or a dictionary
mapping each client agent to its own positive local step count.
"""

ArrayKey: TypeAlias = SupportsIndex | slice | tuple[SupportsIndex | slice, ...]  # noqa: UP040
"""
Type alias for valid keys used to index into supported array types.
Includes single indices, tuples of indices, slices, and tuples of slices.
"""

type EmpiricalRiskIndices = list[int] | Literal["all", "batch"] | int
"""
Type alias for specifying indices in empirical risk computations.
Can be a list of integers, the string "all" for full dataset, the string "batch" for a mini-batch,
or an integer specifying a single datapoint.
"""

type EmpiricalRiskReduction = Literal["mean"] | None
"""
Type alias for specifying reduction methods in empirical risk computations.
Can be "mean" to average over samples or None for no reduction and the result
is returned as a list of gradients for each sample.
"""

type EmpiricalRiskBatchSize = int | Literal["all"]
"""
Type alias for specifying batch size in empirical risk initialization.
Can be an integer for mini-batch size or the string "all" for full dataset.
"""

type Datapoint = tuple["Array", "Array"]
"""Tuple of (x, y) representing one datapoint where x are features and y is the target."""

type Dataset = list[Datapoint]
"""
List of datapoints, where each datapoint is a tuple of (features, targets).

In decentralized optimization each agent has their own local dataset. This
type alias represents such datasets. This local dataset can be a subset of a larger
global dataset or the entire dataset itself. These subsets can be obtained
by using the :class:`~decent_bench.datasets.DatasetHandler` class, specifically the
:meth:`~decent_bench.datasets.DatasetHandler.get_partitions` method.

Features and targets are represented as :class:`~decent_bench.utils.array.Array`
objects or framework-specific tensor objects in special cases. For unsupervised learning,
targets are usually None.

The expected shapes depend on the specific dataset and cost function requirements,
but typically it is:

- Features: 1-dimensional vector (n_features,)
- Targets: 1-dimensional vector (n_targets,), or None for unsupervised learning.
"""


[docs] class SupportedFrameworks(Enum): """Enum for supported frameworks in decent-bench.""" NUMPY = "numpy" PYTORCH = "pytorch" TENSORFLOW = "tensorflow" JAX = "jax"
[docs] class SupportedDevices(Enum): """Enum for supported devices in decent-bench.""" CPU = "cpu" GPU = "gpu" MPS = "mps"