Source code for decent_bench.utils.pytorch_utils

from __future__ import annotations

from typing import TYPE_CHECKING, Literal

TORCH_AVAILABLE = True

if TYPE_CHECKING:
    import torch
else:
    try:
        import torch
    except ImportError:
        TORCH_AVAILABLE = False
        # Create mock torch module for documentation
        from types import ModuleType

        torch = ModuleType("torch")  # type: ignore[assignment]
        nn = ModuleType("torch.nn")

        class _MockModule:
            """Mock base class for PyTorch nn.Module."""

        _MockModule.__module__ = "torch.nn"
        _MockModule.__qualname__ = "Module"
        _MockModule.__name__ = "Module"

        nn.Module = _MockModule  # type: ignore[attr-defined]
        torch.nn = nn  # type: ignore[attr-defined]


[docs] class SimpleLinearModel(torch.nn.Module): """ Simple feedforward neural network model with linear layers and optional activations. Args: input_size (int): The size of the input features. hidden_sizes (list[int]): A list of sizes for the hidden layers. output_size (int): The size of the output layer. activation (Literal["relu", "tanh", "sigmoid"] | None): The activation function to use for hidden layers. output_activation (Literal["relu", "tanh", "sigmoid"] | None): The final activation after the output layer. Raises: ImportError: If PyTorch is not installed. """ def __init__( self, input_size: int, hidden_sizes: list[int], output_size: int, activation: Literal["relu", "tanh", "sigmoid"] | None = "relu", output_activation: Literal["relu", "tanh", "sigmoid"] | None = None, ): if not TORCH_AVAILABLE: raise ImportError("PyTorch must be installed to use SimpleLinearModel") super().__init__() layers: list[torch.nn.Module] = [] prev_size = input_size # Hidden layers for hidden_size in hidden_sizes: layers.append(torch.nn.Linear(prev_size, hidden_size)) if activation == "relu": layers.append(torch.nn.ReLU()) elif activation == "tanh": layers.append(torch.nn.Tanh()) elif activation == "sigmoid": layers.append(torch.nn.Sigmoid()) prev_size = hidden_size # Output layer layers.append(torch.nn.Linear(prev_size, output_size)) if output_activation is not None: if output_activation == "relu": layers.append(torch.nn.ReLU()) elif output_activation == "tanh": layers.append(torch.nn.Tanh()) elif output_activation == "sigmoid": layers.append(torch.nn.Sigmoid()) self.network = torch.nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the network.""" res: torch.Tensor = self.network(x) return res
[docs] class ArgmaxActivation(torch.nn.Module): """ Applies the argmax function as an activation. Args: dim (int): The dimension along which to compute the argmax. Raises: ImportError: If PyTorch is not installed. """ def __init__(self, dim: int = 1): if not TORCH_AVAILABLE: raise ImportError("PyTorch must be installed to use ArgmaxActivation") super().__init__() self.dim = dim
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass applying argmax.""" return torch.argmax(x, dim=self.dim)
[docs] class ArgminActivation(torch.nn.Module): """ Applies the argmin function as an activation. Args: dim (int): The dimension along which to compute the argmin. Raises: ImportError: If PyTorch is not installed. """ def __init__(self, dim: int = 1): if not TORCH_AVAILABLE: raise ImportError("PyTorch must be installed to use ArgminActivation") super().__init__() self.dim = dim
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass applying argmin.""" return torch.argmin(x, dim=self.dim)