Source code for decent_bench.metrics._metric
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING
from decent_bench.metrics._metrics_view import NetworkMetricsView
if TYPE_CHECKING:
from decent_bench.benchmark import BenchmarkProblem
Statistic = Callable[[Sequence[float]], float]
[docs]
class Metric(ABC):
"""
Abstract base class for metrics.
In order to create a new metric, subclass this class and implement the abstract methods
:func:`description` and :func:`compute`.
Args:
fmt: format string used to format the values in the table, defaults to ".2e". Common formats include:
- ".2e": scientific notation with 2 decimal places
- ".3f": fixed-point notation with 3 decimal places
- ".4g": general format with 4 significant digits
- ".1%": percentage format with 1 decimal place
Where the integer specifies the precision.
See :meth:`str.format` documentation for details on the format string options.
x_log: whether to apply log scaling to the x-axis in plots.
y_log: whether to apply log scaling to the y-axis in plots.
"""
def __init__(
self,
fmt: str = ".2e",
x_log: bool = False,
y_log: bool = True,
) -> None:
self.x_log = x_log
self.y_log = y_log
self.fmt = fmt
@property
@abstractmethod
def description(self) -> str:
"""Metric description used as the table row label and y-axis label in plots."""
[docs]
def is_available(
self,
problem: "BenchmarkProblem", # noqa: ARG002
) -> tuple[bool, str | None]:
"""
Check whether this metric can be computed for the given problem.
Override in subclasses that have availability preconditions (e.g. requiring
``problem.x_optimal`` or ``problem.test_data``). The default implementation
always returns available.
Args:
problem: the benchmark problem being evaluated
Returns:
A tuple ``(available, reason)``. When *available* is ``True``, *reason* is
``None``. When *available* is ``False``, *reason* contains a human-readable
explanation.
"""
return True, None
[docs]
@abstractmethod
def compute(
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> Sequence[float]:
"""
Evaluate the metric on the results of a trial.
Args:
network: the snapshotted network view being evaluated.
problem: the benchmark problem being evaluated
iteration: the iteration at which to compute the metric, or -1 to use the agents' final x
Returns:
a sequence of metric values
"""