Source code for decent_bench.datasets._pytorch_handler
from __future__ import annotations
import random
from collections import defaultdict
from functools import cached_property
from typing import TYPE_CHECKING, Any, cast
import decent_bench.utils.interoperability as iop
from decent_bench.utils.logger import LOGGER
from decent_bench.utils.types import Dataset, SupportedDevices
from ._dataset_handler import DatasetHandler
if TYPE_CHECKING:
import torch
try:
import torch
from torch.utils.data import ConcatDataset as TorchConcatDataset
from torch.utils.data import Subset as TorchSubset
from torch.utils.data import random_split as torch_random_split
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
[docs]
class PyTorchDatasetHandler(DatasetHandler):
def __init__(
self,
torch_dataset: torch.utils.data.Dataset[Any],
n_features: int,
n_targets: int,
n_partitions: int = 1,
*,
samples_per_partition: int | None = None,
heterogeneity: bool = False,
targets_per_partition: int = 1,
) -> None:
"""
Dataset wrapper for PyTorch datasets which represents datapoints as tuples (features, targets).
This class will preserve the properties of the underlying PyTorch dataset
such as transforms and lazy-loading. This class can create either random partitions where
each partition is drawn uniformly at random from the dataset without replacement (heterogeneity=False),
or heterogeneous partitions (heterogeneity=True) where each partition contains unique classes.
Heterogeneity only works for datasets where the targets are categorical.
Args:
torch_dataset: PyTorch dataset to wrap
n_features: Number of feature dimensions
n_targets: Number of target dimensions
n_partitions: Number of partitions to split the dataset into
samples_per_partition: Number of samples per partition, if None, will split evenly
heterogeneity: Whether to create heterogeneous partitions with unique classes
targets_per_partition: Number of unique classes per partition (only if heterogeneity is True)
Raises:
ImportError: If PyTorch is not installed
ValueError: If heterogeneity is True and n_partitions * targets_per_partition > n_targets
Note:
If heterogeneity is True, each partition will contain unique classes.
Ensure that n_partitions * targets_per_partition <= n_targets. Be aware that
this may lead to some classes being unused if the condition is not tight,
the :meth:`n_targets` attribute will be updated accordingly.
If the underlying PyTorch dataset has not implemented __len__, set samples_per_partition
to specify the number of samples per partition or set heterogeneity to True. Otherwise,
the length of the dataset cannot be determined and an error will be raised.
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch is required to use PyTorchWrapper. Install it with: pip install torch")
self.torch_dataset = torch_dataset
self._n_targets = n_targets
self._n_features = n_features
self._n_partitions = n_partitions
self.samples_per_partition = samples_per_partition
self.heterogeneity = heterogeneity
self.targets_per_partition = targets_per_partition
self._partitions: list[Dataset] | None = None
if self.heterogeneity:
if (self.n_partitions * self.targets_per_partition) > self.n_targets:
raise ValueError(
f"n_partitions ({self.n_partitions}) * n_targets per partition ({self.targets_per_partition})"
f" must be <= n_targets ({self.n_targets})"
)
# Set the new number of used targets
self._n_targets = self.n_partitions * self.targets_per_partition
[docs]
@cached_property
def n_samples(self) -> int:
return len(self.get_datapoints())
@property
def n_partitions(self) -> int:
return self._n_partitions
@property
def n_features(self) -> int:
return self._n_features
@property
def n_targets(self) -> int:
return self._n_targets
[docs]
def get_datapoints(self) -> Dataset:
"""
Return all datapoints in the dataset.
Can be used for evaluation on the full dataset or creation of test datasets.
"""
return cast("Dataset", list(TorchConcatDataset(self.get_partitions()))) # type: ignore[arg-type, call-overload]
[docs]
def get_partitions(self) -> list[Dataset]:
"""
Return the dataset divided into partitions for distribution among agents.
This method provides the core partitioning functionality for decentralized
optimization. Each partition represents the local dataset of an agent in
the network.
Each partition is sampled uniformly at random from the dataset without replacement
if heterogeneity is False, otherwise each partition contains unique classes (targets_per_partition)
with number of datapoints per partition equal to
min(samples_per_partition, number of available datapoints for the selected classes).
Returns:
Sequence[Dataset]: Sequence of Dataset objects, where each partition is a list of
(features, targets) tuples.
"""
if self._partitions is None:
if self.heterogeneity:
self._partitions = self._heterogeneous_split()
else:
self._partitions = self._random_split()
return self._partitions
def _random_split(self) -> list[Dataset]:
torch_dataset_len = len(self.torch_dataset) # type: ignore[arg-type]
if self.samples_per_partition is None:
parts = [1 / self.n_partitions] * self.n_partitions
elif self.samples_per_partition * self.n_partitions <= torch_dataset_len:
parts = [self.samples_per_partition] * self.n_partitions
# Add the remaining samples to the last partition and remove it
# to ensure the sum is equal to the total number of samples for random_split
parts.append(torch_dataset_len - sum(parts))
else:
raise ValueError(
f"samples_per_partition ({self.samples_per_partition}) * n_partitions ({self.n_partitions}) "
f"must be <= datapoints in the torch dataset ({torch_dataset_len})"
)
partitions = cast(
"list[Dataset]",
torch_random_split(self.torch_dataset, parts, generator=iop.rng_torch(SupportedDevices.CPU)),
)
return partitions[: self.n_partitions]
def _heterogeneous_split(self) -> list[Dataset]:
"""
Split dataset so each partition contains unique classes.
Requires that partitions * classes_per_partition <= classes.
"""
# Group indices by class in a single pass
class_to_indices: dict[int, list[int]] = defaultdict(list)
for idx, (_, label) in enumerate(cast("list[tuple[torch.Tensor, int | torch.Tensor]]", self.torch_dataset)):
label_key = int(label.item()) if isinstance(label, torch.Tensor) else int(label)
if label_key in class_to_indices or len(class_to_indices) < (
self.n_partitions * self.targets_per_partition
):
class_to_indices[label_key].append(idx)
# Create partitions from class-grouped indices
idx_partitions = []
min_n_datapoints = int(1e10)
class_idxs = sorted(class_to_indices.keys())
# Group classes for each partition
class_idxs_groups = [
class_idxs[i : i + self.targets_per_partition]
for i in range(0, len(class_idxs), self.targets_per_partition)
]
for class_idx_group in class_idxs_groups:
indices = []
for class_idx in class_idx_group:
indices.extend(class_to_indices[class_idx])
# Shuffle and select subset if needed
random.shuffle(indices)
if self.samples_per_partition is not None:
indices = indices[: self.samples_per_partition]
min_n_datapoints = min(min_n_datapoints, len(indices))
idx_partitions.append(indices)
if self.samples_per_partition is not None and min_n_datapoints < self.samples_per_partition:
LOGGER.warning(
f"Warning: Some partitions have less datapoints ({min_n_datapoints}) than "
f"samples_per_partition ({self.samples_per_partition}) due to class distribution. "
f"All partitions will be truncated to {min_n_datapoints} datapoints."
)
partitions = [TorchSubset(self.torch_dataset, idx[:min_n_datapoints]) for idx in idx_partitions] # pyright: ignore[reportPossiblyUnboundVariable]
return cast("list[Dataset]", partitions)