Source code for decent_bench.datasets._pytorch_handler

from __future__ import annotations

import random
from collections import defaultdict
from functools import cached_property
from typing import TYPE_CHECKING, Any, cast

import decent_bench.utils.interoperability as iop
from decent_bench.utils.logger import LOGGER
from decent_bench.utils.types import Dataset, SupportedDevices

from ._dataset_handler import DatasetHandler

if TYPE_CHECKING:
    import torch

try:
    import torch
    from torch.utils.data import ConcatDataset as TorchConcatDataset
    from torch.utils.data import Subset as TorchSubset
    from torch.utils.data import random_split as torch_random_split

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False


[docs] class PyTorchDatasetHandler(DatasetHandler): def __init__( self, torch_dataset: torch.utils.data.Dataset[Any], n_features: int, n_targets: int, n_partitions: int = 1, *, samples_per_partition: int | None = None, heterogeneity: bool = False, targets_per_partition: int = 1, ) -> None: """ Dataset wrapper for PyTorch datasets which represents datapoints as tuples (features, targets). This class will preserve the properties of the underlying PyTorch dataset such as transforms and lazy-loading. This class can create either random partitions where each partition is drawn uniformly at random from the dataset without replacement (heterogeneity=False), or heterogeneous partitions (heterogeneity=True) where each partition contains unique classes. Heterogeneity only works for datasets where the targets are categorical. Args: torch_dataset: PyTorch dataset to wrap n_features: Number of feature dimensions n_targets: Number of target dimensions n_partitions: Number of partitions to split the dataset into samples_per_partition: Number of samples per partition, if None, will split evenly heterogeneity: Whether to create heterogeneous partitions with unique classes targets_per_partition: Number of unique classes per partition (only if heterogeneity is True) Raises: ImportError: If PyTorch is not installed ValueError: If heterogeneity is True and n_partitions * targets_per_partition > n_targets Note: If heterogeneity is True, each partition will contain unique classes. Ensure that n_partitions * targets_per_partition <= n_targets. Be aware that this may lead to some classes being unused if the condition is not tight, the :meth:`n_targets` attribute will be updated accordingly. If the underlying PyTorch dataset has not implemented __len__, set samples_per_partition to specify the number of samples per partition or set heterogeneity to True. Otherwise, the length of the dataset cannot be determined and an error will be raised. """ if not TORCH_AVAILABLE: raise ImportError("PyTorch is required to use PyTorchWrapper. Install it with: pip install torch") self.torch_dataset = torch_dataset self._n_targets = n_targets self._n_features = n_features self._n_partitions = n_partitions self.samples_per_partition = samples_per_partition self.heterogeneity = heterogeneity self.targets_per_partition = targets_per_partition self._partitions: list[Dataset] | None = None if self.heterogeneity: if (self.n_partitions * self.targets_per_partition) > self.n_targets: raise ValueError( f"n_partitions ({self.n_partitions}) * n_targets per partition ({self.targets_per_partition})" f" must be <= n_targets ({self.n_targets})" ) # Set the new number of used targets self._n_targets = self.n_partitions * self.targets_per_partition
[docs] @cached_property def n_samples(self) -> int: return len(self.get_datapoints())
@property def n_partitions(self) -> int: return self._n_partitions @property def n_features(self) -> int: return self._n_features @property def n_targets(self) -> int: return self._n_targets
[docs] def get_datapoints(self) -> Dataset: """ Return all datapoints in the dataset. Can be used for evaluation on the full dataset or creation of test datasets. """ return cast("Dataset", list(TorchConcatDataset(self.get_partitions()))) # type: ignore[arg-type, call-overload]
[docs] def get_partitions(self) -> list[Dataset]: """ Return the dataset divided into partitions for distribution among agents. This method provides the core partitioning functionality for decentralized optimization. Each partition represents the local dataset of an agent in the network. Each partition is sampled uniformly at random from the dataset without replacement if heterogeneity is False, otherwise each partition contains unique classes (targets_per_partition) with number of datapoints per partition equal to min(samples_per_partition, number of available datapoints for the selected classes). Returns: Sequence[Dataset]: Sequence of Dataset objects, where each partition is a list of (features, targets) tuples. """ if self._partitions is None: if self.heterogeneity: self._partitions = self._heterogeneous_split() else: self._partitions = self._random_split() return self._partitions
def _random_split(self) -> list[Dataset]: torch_dataset_len = len(self.torch_dataset) # type: ignore[arg-type] if self.samples_per_partition is None: parts = [1 / self.n_partitions] * self.n_partitions elif self.samples_per_partition * self.n_partitions <= torch_dataset_len: parts = [self.samples_per_partition] * self.n_partitions # Add the remaining samples to the last partition and remove it # to ensure the sum is equal to the total number of samples for random_split parts.append(torch_dataset_len - sum(parts)) else: raise ValueError( f"samples_per_partition ({self.samples_per_partition}) * n_partitions ({self.n_partitions}) " f"must be <= datapoints in the torch dataset ({torch_dataset_len})" ) partitions = cast( "list[Dataset]", torch_random_split(self.torch_dataset, parts, generator=iop.rng_torch(SupportedDevices.CPU)), ) return partitions[: self.n_partitions] def _heterogeneous_split(self) -> list[Dataset]: """ Split dataset so each partition contains unique classes. Requires that partitions * classes_per_partition <= classes. """ # Group indices by class in a single pass class_to_indices: dict[int, list[int]] = defaultdict(list) for idx, (_, label) in enumerate(cast("list[tuple[torch.Tensor, int | torch.Tensor]]", self.torch_dataset)): label_key = int(label.item()) if isinstance(label, torch.Tensor) else int(label) if label_key in class_to_indices or len(class_to_indices) < ( self.n_partitions * self.targets_per_partition ): class_to_indices[label_key].append(idx) # Create partitions from class-grouped indices idx_partitions = [] min_n_datapoints = int(1e10) class_idxs = sorted(class_to_indices.keys()) # Group classes for each partition class_idxs_groups = [ class_idxs[i : i + self.targets_per_partition] for i in range(0, len(class_idxs), self.targets_per_partition) ] for class_idx_group in class_idxs_groups: indices = [] for class_idx in class_idx_group: indices.extend(class_to_indices[class_idx]) # Shuffle and select subset if needed random.shuffle(indices) if self.samples_per_partition is not None: indices = indices[: self.samples_per_partition] min_n_datapoints = min(min_n_datapoints, len(indices)) idx_partitions.append(indices) if self.samples_per_partition is not None and min_n_datapoints < self.samples_per_partition: LOGGER.warning( f"Warning: Some partitions have less datapoints ({min_n_datapoints}) than " f"samples_per_partition ({self.samples_per_partition}) due to class distribution. " f"All partitions will be truncated to {min_n_datapoints} datapoints." ) partitions = [TorchSubset(self.torch_dataset, idx[:min_n_datapoints]) for idx in idx_partitions] # pyright: ignore[reportPossiblyUnboundVariable] return cast("list[Dataset]", partitions)