decent_bench.utils.pytorch_utils#
- class decent_bench.utils.pytorch_utils.SimpleLinearModel(input_size: int, hidden_sizes: list[int], output_size: int, activation: Literal['relu', 'tanh', 'sigmoid'] | None = 'relu', output_activation: Literal['relu', 'tanh', 'sigmoid'] | None = None)[source]#
Bases:
ModuleSimple feedforward neural network model with linear layers and optional activations.
- Parameters:
input_size (int) – The size of the input features.
hidden_sizes (list[int]) – A list of sizes for the hidden layers.
output_size (int) – The size of the output layer.
activation (Literal["relu", "tanh", "sigmoid"] | None) – The activation function to use for hidden layers.
output_activation (Literal["relu", "tanh", "sigmoid"] | None) – The final activation after the output layer.
- Raises:
ImportError – If PyTorch is not installed.
- forward(x: torch.Tensor) torch.Tensor[source]#
Forward pass through the network.
- class decent_bench.utils.pytorch_utils.ArgmaxActivation(dim: int = 1)[source]#
Bases:
ModuleApplies the argmax function as an activation.
- Parameters:
dim (int) – The dimension along which to compute the argmax.
- Raises:
ImportError – If PyTorch is not installed.
- forward(x: torch.Tensor) torch.Tensor[source]#
Forward pass applying argmax.
- class decent_bench.utils.pytorch_utils.ArgminActivation(dim: int = 1)[source]#
Bases:
ModuleApplies the argmin function as an activation.
- Parameters:
dim (int) – The dimension along which to compute the argmin.
- Raises:
ImportError – If PyTorch is not installed.
- forward(x: torch.Tensor) torch.Tensor[source]#
Forward pass applying argmin.