Source code for decent_bench.metrics._runtime_metric
import contextlib
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
from decent_bench.agents import Agent
if TYPE_CHECKING:
import queue
from decent_bench.benchmark import BenchmarkProblem
[docs]
class RuntimeMetric(ABC):
"""
Abstract base class for runtime metrics.
Runtime metrics are computed during algorithm execution to provide live feedback
for early stopping or monitoring. Unlike post-hoc metrics, they don't store historical
data and are designed to be lightweight.
To create a new runtime metric, subclass this class and implement :meth:`description`,
:meth:`x_log`, :meth:`y_log`, and :meth:`compute`.
Args:
update_interval: Number of iterations between metric updates, do not update more frequently than necessary as
this can slow down the algorithm.
save_path: Path to save the plot when the metric is updated, if None, the plot will not be saved
Note:
The :meth:`compute` method should be efficient as it's called during algorithm execution.
Avoid expensive computations or operations that might significantly slow down
the algorithm.
"""
def __init__(self, update_interval: int, save_path: str | Path | None = None) -> None:
"""
Initialize runtime metric.
Args:
update_interval: Number of iterations between metric updates, do not update more frequently than necessary
as this can slow down the algorithm.
save_path: Path to save the plot when the metric is updated, if None, the plot will not be saved
"""
self._update_interval = update_interval
self._save_path = Path(save_path) if save_path is not None else None
self._queue: queue.Queue[Any] | None = None
self._metric_id: str = ""
self._algorithm_name: str = ""
self._trial: int = 0
@property
@abstractmethod
def description(self) -> str:
"""Description of the metric, used as the y-axis label."""
@property
@abstractmethod
def x_log(self) -> bool:
"""Whether the x-axis should be logarithmic."""
@property
@abstractmethod
def y_log(self) -> bool:
"""Whether the y-axis should be logarithmic."""
@property
def update_interval(self) -> int:
"""
Number of iterations between metric updates.
Returns:
Number of iterations between updates.
"""
return self._update_interval
[docs]
@abstractmethod
def compute(self, problem: "BenchmarkProblem", agents: Sequence["Agent"], iteration: int) -> float:
"""
Compute the metric value for the current iteration.
Args:
problem: benchmark problem being solved
agents: sequence of agents with their current state
iteration: current iteration number
Returns:
The computed metric value as a float.
"""
[docs]
def initialize_plot(self, algorithm_name: str, trial: int, queue: "queue.Queue[Any]") -> None:
"""
Initialize the plot for this metric.
Sends initialization message to plotter process to create the figure.
Args:
algorithm_name: name of the algorithm being run
trial: trial number (0-indexed)
queue: multiprocessing queue for sending data to the plotter
"""
self._algorithm_name = algorithm_name
self._trial = trial
self._queue = queue
# Use class name as metric_id so all instances of the same metric type share one figure
self._metric_id = self.__class__.__name__
# Send initialization message to plotter process to create figure
# The plotter will handle deduplication (won't create duplicate figures)
if self._queue is not None:
with contextlib.suppress(Exception):
self._queue.put(
("init", self._metric_id, self.description, self.x_log, self.y_log, self._save_path), block=False
)
[docs]
def update_plot(self, problem: "BenchmarkProblem", agents: Sequence["Agent"], iteration: int) -> None:
"""
Update the plot with a new data point.
Computes the metric value and sends it to the centralized plotter via queue.
Args:
problem: benchmark problem being solved
agents: sequence of agents with their current state
iteration: current iteration number
"""
# Compute metric value without counting calls made by the metric
with Agent.no_count(agents):
value = self.compute(problem, agents, iteration)
# Send data to plotter process queue
if self._queue is not None:
with contextlib.suppress(Exception):
self._queue.put((self._metric_id, self._algorithm_name, self._trial, iteration, value), block=False)
[docs]
def should_update(self, iteration: int) -> bool:
"""
Check if the metric should be updated at this iteration.
Args:
iteration: current iteration number
Returns:
True if the metric should be updated, False otherwise.
"""
return iteration % self.update_interval == 0 or iteration == 0