Source code for decent_bench.costs._empirical_risk._pytorch_cost

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, cast, override

import decent_bench.utils.interoperability as iop
from decent_bench.costs._base._cost import Cost
from decent_bench.costs._empirical_risk._empirical_risk_cost import EmpiricalRiskCost
from decent_bench.utils._tags import tags
from decent_bench.utils.logger import LOGGER
from decent_bench.utils.types import (
    Dataset,
    EmpiricalRiskBatchSize,
    EmpiricalRiskIndices,
    EmpiricalRiskReduction,
    SupportedDevices,
    SupportedFrameworks,
)

if TYPE_CHECKING:
    import torch

    from decent_bench.agents import Agent

try:
    import torch

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False


class _IndexDataset:
    """A simple dataset wrapper to handle indexing when using a PyTorch dataloader."""

    def __init__(self, dataset: Dataset):
        self.dataset = dataset

        if not hasattr(self.dataset, "__len__"):
            raise ValueError("Dataset must implement __len__ method.")

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, int]:
        return cast("tuple[torch.Tensor, torch.Tensor, int]", (*self.dataset[idx], idx))


[docs] @tags("regression", "classification", "empirical-risk") class PyTorchCost(EmpiricalRiskCost): """ Cost function wrapper for PyTorch neural networks that integrates with the decentralized optimization framework. Supports batch-based training and gradient computation for decentralized learning scenarios. Note: It is generally recommended to set `agent_state_snapshot_period` to a value greater than 1 when using PyTorchCost, as recording the full model parameters at every iteration can be expensive. """ _NON_PICKLABLE_STATE_KEYS = ( "_ft_compute_grad", "_ft_compute_sample_grad", "_last_batch_x", "_last_batch_y", "_params_list", "param_shapes", "param_sizes", "total_params", "param_names", "_dataloader_iter", )
[docs] def __init__( self, dataset: Dataset, model: torch.nn.Module, loss_fn: torch.nn.Module, final_activation: torch.nn.Module | None = None, *, batch_size: EmpiricalRiskBatchSize = "all", max_batch_size: int | None = None, device: SupportedDevices = SupportedDevices.CPU, use_dataloader: bool = False, dataloader_kwargs: dict[str, Any] | None = None, load_dataset: bool = True, compile_model: bool = False, compile_kwargs: dict[str, Any] | None = None, ): """ Initialize the PyTorch cost function. Args: dataset (Dataset): Dataset partition containing features and targets. Transformations should be applied beforehand such as converting to tensors. See torch.utils.data.Dataset for details. model (torch.nn.Module): PyTorch neural network model. loss_fn: (torch.nn.Module): PyTorch loss function. final_activation (torch.nn.Module | None): Optional final activation layer to apply after model output when predicting targets using :meth:`predict`. E.g., argmax if classification and model outputs logits. batch_size (EmpiricalRiskBatchSize): Size of mini-batches for stochastic methods, or "all" for full-batch. max_batch_size (int | None): Optional maximum batch size to perform computations in, which can be used to avoid out-of-memory errors for large models/datasets. If specified, computations will be calculated in chunks of size at most max_batch_size. This limit will be applied to all computations irregardless of the batch_size or indices parameters; the result will still be the same. This is especially useful for when `indices` is set to "all" but the dataset is too large to fit in memory at once. If not specified, it will default to the batch_size (if batch_size is an int) or the total number of samples (if batch_size is "all"). device (SupportedDevices): Device to run computations on. Make sure to test CPU vs GPU performance for your specific model and dataset, as it can vary. use_dataloader (bool): Whether to use DataLoader for batching. Can be beneficial for large datasets which can't fit into memory or when using an accelerator. Dataloaders cannot be pickled so resumption of iterrupted runs will start with a new random batch order. dataloader_kwargs (dict | None): Additional arguments for the DataLoader. load_dataset (bool): If True, loads the entire dataset into memory to optimize data access. This may lead to major speedups if the dataset is lazily loaded (e.g., loading data from disk), but it might increase memory usage so set to False if memory is an issue. Setting this to False might break checkpointing if the underlying dataset is not pickleable. compile_model (bool): Whether to compile the model using torch.compile for performance. May improve speed after warm-up. Might need to try different modes based on the model and OS, use compile_kwargs. compile_kwargs (dict | None): Additional arguments for torch.compile. Commonly used mode is "reduce_overhead" for performance optimization. See https://pytorch.org/docs/stable/generated/torch.compile.html for details. Raises: ImportError: If PyTorch is not available ValueError: If batch_size is larger than the number of samples in the dataset """ if not TORCH_AVAILABLE: raise ImportError("PyTorch is not available. Please install PyTorch to use PyTorchCost.") if isinstance(batch_size, int) and (batch_size <= 0 or batch_size > len(dataset)): raise ValueError( f"Batch size must be positive and at most the number of samples, " f"got: {batch_size} and number of samples is: {len(dataset)}." ) if isinstance(batch_size, str) and batch_size != "all": raise ValueError(f"Invalid batch size string. Supported value is 'all', got {batch_size}.") if load_dataset: # Loads the dataset into memory in case it is lazily loaded self._dataset = _IndexDataset([(x, y) for x, y in dataset]) else: self._dataset = _IndexDataset(dataset) self.model = model self.loss_fn = loss_fn self.final_activation = final_activation if final_activation is not None else torch.nn.Identity() self._batch_size = self.n_samples if batch_size == "all" else batch_size self._max_batch_size = max_batch_size if max_batch_size is not None else self.n_samples self._device = device self._use_dataloader = use_dataloader self._dataloader_kwargs = dataloader_kwargs if dataloader_kwargs is not None else {} self._load_dataset = load_dataset self._compile_model = compile_model self._compile_kwargs = compile_kwargs if compile_kwargs is not None else {} self._optimizer: torch.optim.Optimizer | None = None self._scheduler: torch.optim.lr_scheduler.LRScheduler | None = None self._pytorch_device: str = iop.device_to_framework_device(device, framework=self.framework) self.model = self.model.to(self._pytorch_device) self.loss_fn = self.loss_fn.to(self._pytorch_device) if self._compile_model: torch.set_float32_matmul_precision("high") # For better torch.compile performance self.model = cast("torch.nn.Module", torch.compile(self.model, **self._compile_kwargs)) self._dataloader: torch.utils.data.DataLoader[Any] | None = None self._last_batch_used: list[int] = [] self._remaining_batch_indices: list[int] = [] # Used for tracking remaining indices when not using dataloader self._init_param_caches()
def __getstate__(self) -> dict[str, Any]: """Return a checkpoint-safe state for pickling.""" try: state = self.__dict__.copy() for key in self._NON_PICKLABLE_STATE_KEYS: state.pop(key, None) state["model"] = self.model.to("cpu") state["loss_fn"] = self.loss_fn.to("cpu") finally: self.model = self.model.to(self._pytorch_device) self.loss_fn = self.loss_fn.to(self._pytorch_device) return state def __setstate__(self, state: dict[str, Any]) -> None: """Restore state and clear transient runtime caches.""" for key, value in state.items(): setattr(self, key, value) self.model = self.model.to(self._pytorch_device) self.loss_fn = self.loss_fn.to(self._pytorch_device) if self._dataloader is not None: self._dataloader_iter = iter(self._dataloader) self._init_param_caches() @property def shape(self) -> tuple[int, ...]: return (self.total_params,) @property def framework(self) -> SupportedFrameworks: return SupportedFrameworks.PYTORCH @property def device(self) -> SupportedDevices: return self._device @property def n_samples(self) -> int: return len(self.dataset) @property def batch_size(self) -> int: return self._batch_size @property def dataset(self) -> Dataset: return self._dataset.dataset @property def m_smooth(self) -> float: return float("nan") @property def m_cvx(self) -> float: return float("nan") def _set_model_parameters(self, x: torch.Tensor) -> None: """ Set model parameters from a tensor. Args: x (torch.Tensor): Flattened parameter tensor. Raises: ValueError: If the size of x does not match the total number of model parameters. """ if x.numel() != self.total_params: raise ValueError( f"Parameter vector size {x.numel()} does not match total model parameters {self.total_params}" ) # Move to target device once (assumes all params on same device) x = x.to(self._pytorch_device) # Split tensor efficiently and unflatten param_chunks = torch.split(x, self.param_sizes) with torch.no_grad(): for param, chunk, shape in zip(self._params_list, param_chunks, self.param_shapes, strict=True): param.data = chunk.view(shape) def _get_model_parameters(self) -> torch.Tensor: """Get model parameters as a flattened tensor.""" params = [p.detach().flatten() for p in self._params_list] return torch.cat(params).to(self._pytorch_device)
[docs] @iop.autodecorate_cost_method(EmpiricalRiskCost.predict) def predict(self, x: torch.Tensor, data: list[torch.Tensor]) -> list[torch.Tensor]: """ Make predictions at x on the given data. Args: x: Point to make predictions at. data: List of torch.Tensor containing features to make predictions on. Returns: Predicted targets as an array Raises: TypeError: If data is not a list of torch.Tensor or a single torch.Tensor. """ self._set_model_parameters(x) if self.model.training: self.model.eval() with torch.no_grad(): if isinstance(data, list): inputs = torch.stack(data) elif isinstance(data, torch.Tensor): inputs = data else: raise TypeError(f"Data must be a list of torch.Tensor or a single torch.Tensor, got {type(data)}.") final_outputs: list[torch.Tensor] = [] for i in range(0, inputs.shape[0], self._max_batch_size): inputs_chunk = inputs[i : i + self._max_batch_size].to(self._pytorch_device) outputs: torch.Tensor = self.model(inputs_chunk) outputs = self.final_activation(outputs) final_outputs.extend(outputs.detach().cpu().tolist()) return final_outputs
[docs] @iop.autodecorate_cost_method(EmpiricalRiskCost.function) def function(self, x: torch.Tensor, indices: EmpiricalRiskIndices = "batch") -> float: self._set_model_parameters(x) if self.model.training: self.model.eval() batches = self._get_batch_data(indices) total_loss = 0.0 with torch.no_grad(): for batch_x, batch_y in batches: outputs = self.model(batch_x) loss: torch.Tensor = self.loss_fn(outputs, batch_y) total_loss += loss.item() * batch_x.shape[0] return float(total_loss / len(self.batch_used))
[docs] @iop.autodecorate_cost_method(EmpiricalRiskCost.gradient) def gradient( self, x: torch.Tensor, indices: EmpiricalRiskIndices = "batch", reduction: EmpiricalRiskReduction = "mean", ) -> torch.Tensor: if reduction is None: return self._per_sample_gradients(x, indices) self._set_model_parameters(x) if not self.model.training: self.model.train() batches = self._get_batch_data(indices) if len(batches) == 1: # Slightly faster path for single batch batch_x, batch_y = batches[0] outputs = self.model(batch_x) loss = self.loss_fn(outputs, batch_y) gradients = torch.autograd.grad(loss, self._params_list) return torch.cat([g.flatten() for g in gradients]) # Accumulate chunk gradients with sample-count weighting so the final # result is the mean gradient over all selected samples. inv_total_samples = 1.0 / len(self.batch_used) grad_acc = [torch.zeros_like(param) for param in self._params_list] for batch_x, batch_y in batches: outputs = self.model(batch_x) loss = self.loss_fn(outputs, batch_y) gradients = torch.autograd.grad(loss, self._params_list) weight = batch_x.shape[0] * inv_total_samples for acc, grad in zip(grad_acc, gradients, strict=True): acc.add_(grad, alpha=weight) return torch.cat([g.flatten() for g in grad_acc])
def _per_sample_gradients(self, x: torch.Tensor, indices: EmpiricalRiskIndices = "batch") -> torch.Tensor: """Compute per-sample gradients for the specified indices.""" # Credit: https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html self._init_per_sample_grad() self._set_model_parameters(x) if not self.model.training: self.model.train() batches = self._get_batch_data(indices) params = {k: v.detach() for k, v in self.model.named_parameters()} buffers = {k: v.detach() for k, v in self.model.named_buffers()} if len(batches) == 1: batch_x, batch_y = batches[0] ft_per_sample_grads = self._ft_compute_sample_grad(params, buffers, batch_x, batch_y) batch_size = batch_x.shape[0] grad_list = [ ft_per_sample_grads[name].view(batch_size, size) for name, size in zip(self.param_names, self.param_sizes, strict=True) ] return torch.cat(grad_list, dim=1) param_offsets = [0] for size in self.param_sizes: param_offsets.append(param_offsets[-1] + size) all_grads = torch.empty( (len(self.batch_used), self.total_params), dtype=self._params_list[0].dtype, device=self._pytorch_device, ) row_start = 0 for batch_x, batch_y in batches: ft_per_sample_grads = self._ft_compute_sample_grad(params, buffers, batch_x, batch_y) batch_size = batch_x.shape[0] row_end = row_start + batch_size for name, size, col_start, col_end in zip( self.param_names, self.param_sizes, param_offsets[:-1], param_offsets[1:], strict=True, ): all_grads[row_start:row_end, col_start:col_end] = ft_per_sample_grads[name].view(batch_size, size) row_start = row_end return all_grads
[docs] @iop.autodecorate_cost_method(EmpiricalRiskCost.hessian) def hessian(self, x: torch.Tensor, indices: EmpiricalRiskIndices = "batch") -> torch.Tensor: """ Compute the Hessian matrix. Note: This is computationally expensive for neural networks and typically not used. Raises: NotImplementedError: Always raised to indicate Hessian computation is not implemented. """ raise NotImplementedError("Hessian computation is not implemented for PyTorchCost.")
[docs] @iop.autodecorate_cost_method(EmpiricalRiskCost.proximal) def proximal(self, x: torch.Tensor, penalty: float) -> torch.Tensor: """ Compute the proximal operator. Note: This is computationally expensive for neural networks and typically not used. Raises: NotImplementedError: Always raised to indicate proximal computation is not implemented. """ raise NotImplementedError("Proximal operator is not implemented for PyTorchCost.")
[docs] def init_local_training( self, opt_cls: type[torch.optim.Optimizer], opt_kwargs: dict[str, Any] | None = None, sched_cls: type[torch.optim.lr_scheduler.LRScheduler] | None = None, sched_kwargs: dict[str, Any] | None = None, ) -> None: """ Initialize the optimizer and scheduler for local training. This method is required to be called before using :meth:`local_training` to set up the optimizer and scheduler. Args: opt_cls (type[torch.optim.Optimizer]): PyTorch optimizer class to use for local training. opt_kwargs (dict[str, Any] | None): Keyword arguments for initializing the optimizer. The model parameters will be passed as the first argument, so do not include them in opt_kwargs. sched_cls (type[torch.optim.lr_scheduler.LRScheduler] | None): Optional PyTorch learning rate scheduler class to use for local training. The scheduler will be stepped once at the end of each call to :meth:`local_training`. sched_kwargs (dict[str, Any] | None): Keyword arguments for initializing the scheduler. The optimizer will be passed as the first argument, so do not include it in sched_kwargs. Raises: RuntimeError: If the optimizer is already initialized. This method is intended to be called only once to set the optimizer for local training. """ if self._optimizer is not None: raise RuntimeError( "Optimizer is already initialized. This method is intended to be called " "only once to set the optimizer for local training." ) self._optimizer = opt_cls(self._params_list, **(opt_kwargs or {})) if sched_cls is not None: self._scheduler = sched_cls(self._optimizer, **(sched_kwargs or {}))
[docs] def local_training( self, x: torch.Tensor, iterations: int, agent: Agent, regularization: torch.Tensor | Callable[[torch.Tensor], torch.Tensor] | None, indices: EmpiricalRiskIndices = "batch", ) -> torch.Tensor: r""" Perform local training steps using the provided optimizer. Note: This method is intended to be used in decentralized algorithms that support local training. Args: x (torch.Tensor): Initial parameters to start local training from. iterations (int): Number of local training iterations to perform. agent (Agent): The agent performing the local training. regularization (torch.Tensor | Callable[[torch.Tensor], torch.Tensor] | None): Optional regularization. Two forms are supported: - Scalar tensor (or callable returning a scalar): interpreted as an additive loss penalty. - Flat tensor with the same number of elements as the flattened parameter vector: interpreted as a parameter-space correction step applied after each optimizer step (i.e., :math:`x \leftarrow x - r`). indices (EmpiricalRiskIndices): Indices of the samples to use for local training. Returns: torch.Tensor: Updated parameters after local training. Raises: RuntimeError: If no optimizer was provided during initialization. ValueError: If `regularization` is a non-scalar tensor but does not have the same number of elements as the flattened parameter vector. TypeError: If `regularization` is not a torch.Tensor or a callable returning a torch.Tensor. """ if self._optimizer is None: raise RuntimeError( "Local training is not available because no optimizer was provided." " Please call init_local_training to set up the optimizer before using local_training." ) self._set_model_parameters(x) if not self.model.training: self.model.train() for _ in range(iterations): batches = self._get_batch_data(indices) for batch_x, batch_y in batches: self._optimizer.zero_grad() outputs = self.model(batch_x) loss = self.loss_fn(outputs, batch_y) reg_value: torch.Tensor | None = None if regularization is not None: reg_value = ( regularization(torch.cat([p.flatten() for p in self._params_list]).to(self._pytorch_device)) if callable(regularization) else regularization ) if not isinstance(reg_value, torch.Tensor): raise TypeError( "`regularization` must be a torch.Tensor or a callable returning a torch.Tensor, " f"got {type(reg_value)}." ) reg_value = reg_value.to( device=self._pytorch_device, dtype=self._params_list[0].dtype, ) # Two supported forms: # 1) Scalar tensor: interpreted as an additive loss penalty. # 2) Flat vector of model-parameter shape: interpreted as an additive parameter update step. # This is useful for algorithms that add a correction term directly in parameter space. if reg_value.ndim == 0 or reg_value.numel() == 1: loss += reg_value.reshape(()) elif reg_value.numel() != self.total_params: raise ValueError( "If `regularization` is non-scalar, it must have the same number of elements as the " f"flattened parameter vector (expected {self.total_params}, got {reg_value.numel()})." ) loss.backward() self._optimizer.step() if reg_value is not None and reg_value.ndim != 0 and reg_value.numel() != 1: # Apply parameter-space correction step: x <- x - reg_value reg_value = reg_value.detach().flatten() reg_chunks = torch.split(reg_value, self.param_sizes) with torch.no_grad(): for param, chunk, shape in zip( self._params_list, reg_chunks, self.param_shapes, strict=True, ): param.data.sub_(chunk.view(shape)) # Since we are not calling the gradient method, # we need to manually update the agent's gradient call count for benchmarking purposes. agent._n_gradient_calls += len(self.batch_used) # noqa: SLF001 if self._scheduler is not None: self._scheduler.step() return self._get_model_parameters()
[docs] @override def _sample_batch_indices(self, indices: EmpiricalRiskIndices = "batch") -> list[int]: """Not used in PyTorchCost, implemented in _get_batch_data.""" raise NotImplementedError("_sample_batch_indices is not used in PyTorchCost, implemented in _get_batch_data.")
def _collate_xy_idx( self, batch: list[tuple[torch.Tensor, torch.Tensor, int]], ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: xs, ys, idx = zip(*batch, strict=True) return torch.stack(xs), torch.stack(ys), list(idx) def _init_dataloader(self) -> torch.utils.data.DataLoader[Any]: self._dataloader_kwargs.setdefault("shuffle", True) return torch.utils.data.DataLoader( cast("torch.utils.data.Dataset[Any]", self._dataset), batch_size=self.batch_size, generator=iop.rng_torch(SupportedDevices.CPU), # DataLoader shuffling must be done on CPU collate_fn=self._collate_xy_idx, **self._dataloader_kwargs, ) def _handle_dataloader(self) -> tuple[torch.Tensor, torch.Tensor, list[int]]: if self._dataloader is None: self._dataloader = self._init_dataloader() self._dataloader_iter = iter(self._dataloader) try: batch_x, batch_y, batch_idx = next(self._dataloader_iter) except StopIteration: # Restart the iterator if we reach the end self._dataloader_iter = iter(self._dataloader) batch_x, batch_y, batch_idx = next(self._dataloader_iter) return batch_x, batch_y, batch_idx
[docs] def _get_batch_data(self, indices: EmpiricalRiskIndices = "batch") -> list[tuple[torch.Tensor, torch.Tensor]]: """ Get a list of batch data for the specified indices, each list item contains a tuple of (batch_x, batch_y). The max size of each batch is determined by self._max_batch_size. Raises: RuntimeError: If batch data could not be retrieved, which should not happen under normal circumstances """ batch_x, batch_y, batch_idx = self._handle_indices(indices) if batch_x is None or batch_y is None or batch_idx is None: raise RuntimeError("Batch data could not be retrieved. Please report this error.") self._last_batch_used = list(batch_idx) self._last_batch_x = batch_x.to(self._pytorch_device) self._last_batch_y = batch_y.to(self._pytorch_device) batches = [] for i in range(0, batch_x.shape[0], self._max_batch_size): batch_x_chunk = self._last_batch_x[i : i + self._max_batch_size] batch_y_chunk = self._last_batch_y[i : i + self._max_batch_size] batches.append((batch_x_chunk, batch_y_chunk)) return batches
def _handle_indices( self, indices: EmpiricalRiskIndices, ) -> tuple[torch.Tensor | None, torch.Tensor | None, list[int] | None]: batch_x: torch.Tensor | None = None batch_y: torch.Tensor | None = None batch_idx: list[int] | None = None if isinstance(indices, str): if indices == "batch": if self.batch_size < self.n_samples: if self._use_dataloader: batch_x, batch_y, batch_idx = self._handle_dataloader() else: if len(self._remaining_batch_indices) < self.batch_size: # Refill remaining indices if not enough left for a full batch self._remaining_batch_indices = torch.randperm(self.n_samples).tolist() indices = self._remaining_batch_indices[: self.batch_size] self._remaining_batch_indices = self._remaining_batch_indices[self.batch_size :] else: # Use full dataset indices = list(range(self.n_samples)) elif indices == "all": indices = list(range(self.n_samples)) else: raise ValueError(f"Invalid indices string: {indices}. Only 'all' and 'batch' are supported.") elif isinstance(indices, int): indices = [indices] if isinstance(indices, list): if len(indices) == len(self.batch_used) and indices == self.batch_used and hasattr(self, "_last_batch_x"): # Use cached batch so we don't have to re-stack return self._last_batch_x, self._last_batch_y, self._last_batch_used batch = ( self._dataset if len(indices) == len(self._dataset) and self._load_dataset else [self._dataset[i] for i in indices] ) list_batch_x, list_batch_y, batch_idx = zip(*batch, strict=True) # type: ignore[misc, assignment] batch_x = torch.stack(list_batch_x) batch_y = torch.stack(list_batch_y) return batch_x, batch_y, batch_idx def _init_per_sample_grad(self) -> None: """Initialize per-sample gradient function using functorch.""" if hasattr(self, "_ft_compute_sample_grad"): return # Already initialized def compute_loss( params: dict[str, torch.Tensor], buffers: dict[str, torch.Tensor], sample: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: batch = sample.unsqueeze(0) targets = target.unsqueeze(0) predictions = torch.func.functional_call(self.model, (params, buffers), (batch,)) loss: torch.Tensor = self.loss_fn(predictions, targets) return loss self._ft_compute_grad = torch.func.grad(compute_loss) self._ft_compute_sample_grad = torch.func.vmap(self._ft_compute_grad, in_dims=(None, None, 0, 0)) if not self._compile_model: return try: self._ft_compute_sample_grad = torch.compile(self._ft_compute_sample_grad, **self._compile_kwargs) except Exception as e: LOGGER.warning(f"Error compiling per-sample gradient function: {e}\n\nContinuing without compilation.") def _init_param_caches(self) -> None: """Initialize parameter caches for flattening/unflattening.""" # Store parameter shapes for flattening/unflattening self._params_list = list(self.model.parameters()) # Cache for faster access self.param_shapes = [p.shape for p in self._params_list] self.param_sizes = [p.numel() for p in self._params_list] self.total_params = sum(self.param_sizes) self.param_names = [n for n, _ in self.model.named_parameters()]
[docs] def __add__(self, other: Cost) -> Cost: return super().__add__(other)