from __future__ import annotations
import contextlib
import random
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
import numpy as np
from decent_bench.utils.array import Array
from decent_bench.utils.types import SupportedDevices, SupportedFrameworks
from ._helpers import _return_array, device_to_framework_device
jax = None
jnp = None
tf = None
torch = None
with contextlib.suppress(ImportError, ModuleNotFoundError):
import torch as _torch
torch = _torch
with contextlib.suppress(ImportError, ModuleNotFoundError):
import tensorflow as _tf
tf = _tf
with contextlib.suppress(ImportError, ModuleNotFoundError):
import jax.numpy as _jnp
jnp = _jnp
with contextlib.suppress(ImportError, ModuleNotFoundError):
import jax as _jax
jax = _jax
if TYPE_CHECKING:
from jax import Array as JaxArray
from tensorflow.random import Generator as TensorflowGenerator
from torch import Generator as TorchGenerator
@dataclass
class _RngState:
global_seed: int | None = None
numpy_rng: np.random.Generator = field(default_factory=np.random.default_rng)
jax_key: JaxArray | None = None
tf_generator: TensorflowGenerator | None = None
torch_generators: dict[SupportedDevices, TorchGenerator] = field(default_factory=dict)
_STATE = _RngState(
numpy_rng=np.random.default_rng(),
jax_key=(jax.random.key(random.randint(0, 2**32 - 1)) if jax else None),
tf_generator=(tf.random.Generator.from_non_deterministic_state() if tf else None),
torch_generators={},
)
def _selected_frameworks(frameworks: Iterable[SupportedFrameworks] | None) -> set[SupportedFrameworks]:
if frameworks is None:
return set(SupportedFrameworks)
return set(frameworks)
[docs]
def set_seed(
seed: int,
frameworks: Iterable[SupportedFrameworks] | None = None,
) -> None:
"""
Set random seeds across supported frameworks.
Args:
seed: Base seed to use.
frameworks: Optional subset of frameworks to seed. If ``None``, all are seeded.
"""
_set_seed(seed=seed, frameworks=frameworks, set_global_seed=True)
def _set_seed(
seed: int,
frameworks: Iterable[SupportedFrameworks] | None = None,
*,
set_global_seed: bool = True,
) -> None:
"""
Set random seeds across supported frameworks.
Args:
seed: Base seed to use.
frameworks: Optional subset of frameworks to seed. If ``None``, all are seeded.
set_global_seed: Whether to update the globally tracked seed returned by
:func:`get_seed`. Set this to ``False`` for trial-local reseeding where
preserving the external base seed is required.
"""
selected = _selected_frameworks(frameworks)
random.seed(seed)
if SupportedFrameworks.NUMPY in selected:
# If a user uses legacy np.random functions
np.random.seed(seed) # noqa: NPY002
_STATE.numpy_rng = np.random.default_rng(seed)
if torch and SupportedFrameworks.PYTORCH in selected:
torch.manual_seed(seed)
_STATE.torch_generators.clear()
if tf and SupportedFrameworks.TENSORFLOW in selected:
tf.random.set_seed(seed)
_STATE.tf_generator = tf.random.Generator.from_seed(seed, alg="philox")
if jax and SupportedFrameworks.JAX in selected:
_STATE.jax_key = jax.random.key(seed)
if set_global_seed:
_STATE.global_seed = seed
[docs]
def get_seed() -> int | None:
"""Return the current global seed if one was set explicitly."""
return _STATE.global_seed
[docs]
def rng_numpy() -> np.random.Generator:
"""Return the shared NumPy generator used by interoperability random functions."""
return _STATE.numpy_rng
[docs]
def rng_jax() -> JaxArray:
"""
Split and return the next JAX sub-key while advancing global JAX RNG state.
Raises:
RuntimeError: if JAX is not installed.
"""
if not jax or _STATE.jax_key is None:
raise RuntimeError("JAX is not installed.")
_STATE.jax_key, sub_key = jax.random.split(_STATE.jax_key)
return cast("JaxArray", sub_key)
[docs]
def rng_torch(device: SupportedDevices = SupportedDevices.CPU) -> TorchGenerator:
"""
Return a torch.Generator for a given device.
Raises:
RuntimeError: if PyTorch is not installed.
"""
if not torch:
raise RuntimeError("PyTorch is not installed.")
if device in _STATE.torch_generators:
return _STATE.torch_generators[device]
framework_device = device_to_framework_device(device, SupportedFrameworks.PYTORCH)
generator: TorchGenerator = torch.Generator(device=framework_device)
if _STATE.global_seed is not None:
generator.manual_seed(torch.initial_seed())
_STATE.torch_generators[device] = generator
return generator
[docs]
def rng_tensorflow() -> TensorflowGenerator:
"""
Return a TensorFlow random generator.
Raises:
RuntimeError: if TensorFlow is not installed.
"""
if not tf:
raise RuntimeError("TensorFlow is not installed.")
if _STATE.tf_generator is None:
# Only for type chekcing, in practice _tf_generator should always be initialized if tf is available
raise RuntimeError("TensorFlow random generator is not initialized.")
return _STATE.tf_generator
[docs]
def get_rng_state(frameworks: Iterable[SupportedFrameworks] | None = None) -> dict[str, Any]:
"""
Return a picklable snapshot of all managed RNG states.
Args:
frameworks: Optional subset of frameworks to seed. If ``None``, all are seeded.
"""
selected = _selected_frameworks(frameworks)
state: dict[str, Any] = {
"seed": _STATE.global_seed,
"python_random_state": random.getstate(),
"numpy_bit_generator_state": deepcopy(_STATE.numpy_rng.bit_generator.state),
"numpy_rng_state": np.random.get_state(), # noqa: NPY002 # Include legacy state for users who use legacy np.random functions
}
if torch and SupportedFrameworks.PYTORCH in selected:
state["torch_rng_state"] = torch.random.get_rng_state()
if torch.cuda.is_available():
state["torch_cuda_rng_state"] = torch.cuda.get_rng_state_all()
state["torch_generators"] = {device: gen.get_state() for device, gen in _STATE.torch_generators.items()}
if tf and _STATE.tf_generator is not None and SupportedFrameworks.TENSORFLOW in selected:
state["tf_generator_state"] = _STATE.tf_generator.state.numpy()
if jax and _STATE.jax_key is not None and SupportedFrameworks.JAX in selected:
state["jax_key"] = jax.random.key_data(_STATE.jax_key)
return state
[docs]
def set_rng_state(state: dict[str, Any]) -> None:
"""Restore a RNG snapshot created by ``get_rng_state``."""
if "seed" in state:
_STATE.global_seed = state["seed"]
if "python_random_state" in state:
random.setstate(state["python_random_state"])
if "numpy_bit_generator_state" in state:
_STATE.numpy_rng = np.random.default_rng()
_STATE.numpy_rng.bit_generator.state = state["numpy_bit_generator_state"]
if "numpy_rng_state" in state:
np.random.set_state(state["numpy_rng_state"]) # noqa: NPY002
if torch and "torch_rng_state" in state:
torch.random.set_rng_state(state["torch_rng_state"])
if torch and "torch_cuda_rng_state" in state and torch.cuda.is_available():
torch.cuda.set_rng_state_all(state["torch_cuda_rng_state"])
if torch and "torch_generators" in state:
_STATE.torch_generators.clear()
for device, generator_state in state["torch_generators"].items():
framework_device = device_to_framework_device(device, SupportedFrameworks.PYTORCH)
generator = torch.Generator(device=framework_device)
generator.set_state(generator_state)
_STATE.torch_generators[device] = generator
if tf and "tf_generator_state" in state:
_STATE.tf_generator = tf.random.Generator.from_state(state["tf_generator_state"], alg="philox")
if jax and "jax_key" in state:
_STATE.jax_key = jax.random.wrap_key_data(state["jax_key"])
[docs]
def normal(
framework: SupportedFrameworks,
device: SupportedDevices,
mean: float = 0.0,
std: float = 1.0,
shape: tuple[int, ...] = (),
) -> Array:
"""
Create an array of random values with the specified shape and framework.
Values are drawn from a normal distribution with mean `mean` and standard deviation `std`.
Args:
framework (SupportedFrameworks): Target framework type.
device (SupportedDevices): Target device.
mean (float): Mean of the normal distribution.
std (float): Standard deviation of the normal distribution.
shape (tuple[int, ...]): Shape of the output array.
Returns:
Array: Array of random values in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
framework_device = device_to_framework_device(device, framework)
if framework == SupportedFrameworks.NUMPY:
random_array = rng_numpy().normal(loc=mean, scale=std, size=shape)
return _return_array(random_array)
if torch and framework == SupportedFrameworks.PYTORCH:
return _return_array(torch.normal(mean=mean, std=std, size=shape, device=framework_device))
if tf and framework == SupportedFrameworks.TENSORFLOW:
with tf.device(framework_device):
return _return_array(rng_tensorflow().normal(shape=shape, mean=mean, stddev=std))
if jax and framework == SupportedFrameworks.JAX:
sub_key = rng_jax()
return _return_array(mean + std * jax.random.normal(sub_key, shape=shape).to_device(framework_device))
raise TypeError(f"Unsupported framework type: {framework}")
[docs]
def normal_like(array: Array, mean: float = 0.0, std: float = 1.0) -> Array:
"""
Create an array of random values with the same shape and type as the input.
Values are drawn from a normal distribution with mean `mean` and standard deviation `std`.
Args:
array (Array): Input array.
mean (float): Mean of the normal distribution.
std (float): Standard deviation of the normal distribution.
Returns:
Array: Array of random values in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
random_array = rng_numpy().normal(loc=mean, scale=std, size=value.shape)
return _return_array(random_array)
if torch and isinstance(value, torch.Tensor):
return _return_array(torch.normal(mean=mean, std=std, size=value.shape, dtype=value.dtype, device=value.device))
if tf and isinstance(value, tf.Tensor):
shape = tf.shape(value)
return _return_array(rng_tensorflow().normal(shape=shape, mean=mean, stddev=std, dtype=value.dtype))
if jnp and jax and isinstance(value, jnp.ndarray | jnp.generic):
sub_key = rng_jax()
return _return_array(mean + std * jax.random.normal(sub_key, shape=value.shape, dtype=value.dtype))
raise TypeError(f"Unsupported framework type: {type(value)}")
[docs]
def choice(array: Array, size: int, replace: bool = True) -> Array:
"""
Randomly sample elements from an array.
Args:
array (Array): Input array to sample from.
size (int): Number of samples to draw.
replace (bool): Whether to sample with replacement.
Returns:
Array: Sampled values in the same framework type as the input.
Raises:
TypeError: if the framework type of `array` is unsupported.
"""
value = array.value if isinstance(array, Array) else array
if isinstance(value, np.ndarray | np.generic):
random_array = rng_numpy().choice(value, size=size, replace=replace)
return _return_array(random_array)
if torch and isinstance(value, torch.Tensor):
unif = torch.ones(value.shape[0], device=value.device)
indices = unif.multinomial(num_samples=size, replacement=replace)
return _return_array(value[indices])
if tf and isinstance(value, tf.Tensor):
if replace:
indices = rng_tensorflow().uniform(shape=(size,), minval=0, maxval=value.shape[0], dtype=tf.int32)
else:
scores = rng_tensorflow().uniform(shape=(value.shape[0],), dtype=tf.float32)
indices = tf.math.top_k(scores, k=size).indices
return _return_array(tf.gather(value, indices))
if jnp and jax and isinstance(value, jnp.ndarray | jnp.generic):
sub_key = rng_jax()
indices = jax.random.choice(sub_key, a=value.shape[0], shape=(size,), replace=replace)
return _return_array(value[indices])
raise TypeError(f"Unsupported framework type: {type(value)}")