decent_bench.utils.pytorch_utils#

class decent_bench.utils.pytorch_utils.SimpleLinearModel(input_size: int, hidden_sizes: list[int], output_size: int, activation: Literal['relu', 'tanh', 'sigmoid'] | None = 'relu', output_activation: Literal['relu', 'tanh', 'sigmoid'] | None = None)[source]#

Bases: Module

Simple feedforward neural network model with linear layers and optional activations.

Parameters:
  • input_size (int) – The size of the input features.

  • hidden_sizes (list[int]) – A list of sizes for the hidden layers.

  • output_size (int) – The size of the output layer.

  • activation (Literal["relu", "tanh", "sigmoid"] | None) – The activation function to use for hidden layers.

  • output_activation (Literal["relu", "tanh", "sigmoid"] | None) – The final activation after the output layer.

Raises:

ImportError – If PyTorch is not installed.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass through the network.

class decent_bench.utils.pytorch_utils.ArgmaxActivation(dim: int = 1)[source]#

Bases: Module

Applies the argmax function as an activation.

Parameters:

dim (int) – The dimension along which to compute the argmax.

Raises:

ImportError – If PyTorch is not installed.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass applying argmax.

class decent_bench.utils.pytorch_utils.ArgminActivation(dim: int = 1)[source]#

Bases: Module

Applies the argmin function as an activation.

Parameters:

dim (int) – The dimension along which to compute the argmin.

Raises:

ImportError – If PyTorch is not installed.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass applying argmin.