Source code for decent_bench.utils.types
"""Type definitions for optimization variables."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Literal, SupportsIndex, TypeAlias, TypeVar, Union
if TYPE_CHECKING:
import jax
import numpy
import tensorflow as tf
import torch
from decent_bench.agents import Agent
from decent_bench.networks import Network
from decent_bench.utils.array import Array
ArrayLike: TypeAlias = Union["numpy.ndarray", "torch.Tensor", "tf.Tensor", "jax.Array"] # noqa: UP040
"""
Type alias for array-like types supported in decent-bench, including NumPy arrays,
PyTorch tensors, TensorFlow tensors, and JAX arrays.
"""
SupportedArrayTypes: TypeAlias = ArrayLike | float | int # noqa: UP040
"""
Type alias for supported types for optimization variables in decent-bench,
including array-like types and scalars.
"""
if TYPE_CHECKING:
NetworkT = TypeVar("NetworkT", bound=Network)
else:
NetworkT = TypeVar("NetworkT")
"""
Type variable for algorithms operating on a :class:`~decent_bench.networks.Network`.
"""
type InitialStates = Union["Array", "dict[Agent, Array]", None] # noqa: UP007
"""
Type alias for what can be passed to
:func:`~decent_bench.algorithms.utils.initial_states`.
"""
type LocalSteps = int | dict["Agent", int]
"""
Type alias for specifying local step counts in federated algorithms.
Can be a single positive integer shared by all clients, or a dictionary
mapping each client agent to its own positive local step count.
"""
ArrayKey: TypeAlias = SupportsIndex | slice | tuple[SupportsIndex | slice, ...] # noqa: UP040
"""
Type alias for valid keys used to index into supported array types.
Includes single indices, tuples of indices, slices, and tuples of slices.
"""
type EmpiricalRiskIndices = list[int] | Literal["all", "batch"] | int
"""
Type alias for specifying indices in empirical risk computations.
Can be a list of integers, the string "all" for full dataset, the string "batch" for a mini-batch,
or an integer specifying a single datapoint.
"""
type EmpiricalRiskReduction = Literal["mean"] | None
"""
Type alias for specifying reduction methods in empirical risk computations.
Can be "mean" to average over samples or None for no reduction and the result
is returned as a list of gradients for each sample.
"""
type EmpiricalRiskBatchSize = int | Literal["all"]
"""
Type alias for specifying batch size in empirical risk initialization.
Can be an integer for mini-batch size or the string "all" for full dataset.
"""
type Datapoint = tuple["Array", "Array"]
"""Tuple of (x, y) representing one datapoint where x are features and y is the target."""
type Dataset = list[Datapoint]
"""
List of datapoints, where each datapoint is a tuple of (features, targets).
In decentralized optimization each agent has their own local dataset. This
type alias represents such datasets. This local dataset can be a subset of a larger
global dataset or the entire dataset itself. These subsets can be obtained
by using the :class:`~decent_bench.datasets.DatasetHandler` class, specifically the
:meth:`~decent_bench.datasets.DatasetHandler.get_partitions` method.
Features and targets are represented as :class:`~decent_bench.utils.array.Array`
objects or framework-specific tensor objects in special cases. For unsupervised learning,
targets are usually None.
The expected shapes depend on the specific dataset and cost function requirements,
but typically it is:
- Features: 1-dimensional vector (n_features,)
- Targets: 1-dimensional vector (n_targets,), or None for unsupervised learning.
"""
[docs]
class SupportedFrameworks(Enum):
"""Enum for supported frameworks in decent-bench."""
NUMPY = "numpy"
PYTORCH = "pytorch"
TENSORFLOW = "tensorflow"
JAX = "jax"
[docs]
class SupportedDevices(Enum):
"""Enum for supported devices in decent-bench."""
CPU = "cpu"
GPU = "gpu"
MPS = "mps"