import logging
from json import JSONDecodeError
from typing import TYPE_CHECKING, Literal
from rich.status import Status
from decent_bench.algorithms import Algorithm
from decent_bench.benchmark._benchmark_result import BenchmarkResult
from decent_bench.benchmark._compute.compute_plots import (
aggregate_plot_metrics,
compute_plot_metrics,
)
from decent_bench.benchmark._compute.compute_tables import aggregate_table_metrics, compute_table_metrics
from decent_bench.benchmark._metric_result import MetricResult
from decent_bench.metrics import Metric, utils
from decent_bench.metrics import metric_library as ml
from decent_bench.metrics._metrics_view import NetworkMetricsView
from decent_bench.networks import Network
from decent_bench.utils._metric_helpers import _find_duplicates
from decent_bench.utils.logger import LOGGER, start_logger
if TYPE_CHECKING:
from decent_bench.benchmark import BenchmarkProblem
if TYPE_CHECKING:
from decent_bench.benchmark import BenchmarkProblem
from decent_bench.utils.checkpoint_manager import CheckpointManager
[docs]
def compute_metrics(
benchmark_result: BenchmarkResult | None = None,
checkpoint_manager: "CheckpointManager | None" = None,
*,
table_metrics: list[Metric] | None = None,
plot_metrics: list[Metric] | None = None,
statistics_across_agents: list[str] | None = None,
log_level: int = logging.INFO,
) -> MetricResult:
"""
Compute metrics from a benchmark result.
Args:
benchmark_result: result of a benchmark execution. If not provided, the result will be loaded from the
checkpoint manager
checkpoint_manager: if provided, will be used to save results of metrics computation and/or load benchmark
result.
table_metrics: metrics to be displayed in a table of results. Table metrics are computed only at the
*recentmost* iteration reached during benchmarking. If ``None``, all table
metrics available for the benchmark problem will be used. For example, federated-only metrics are removed
when a non-federated network is passed.
plot_metrics: metrics to be plotted over algorithm iterations. Plot metrics are computed at *all* the
iterations reached during benchmarking. If ``None``, all plot metrics available for the
benchmark problem will be used.
statistics_across_agents: statistics to compute across agents for metrics that return one value per agent
(like ``ConsensusError`` or ``Accuracy``). Available statistics are "mean" (aliases "average", "avg"),
"std", "max" (alias "maximum"), "min" (alias "minimum"), and "median" (alias "mdn"). If ``None``, "mean"
and "std" are used.
log_level: minimum level to log, e.g. :data:`logging.INFO`
Returns:
MetricsResult containing the computed metrics.
Raises:
ValueError: If neither ``benchmark_result`` nor ``checkpoint_manager`` is provided, or
if the checkpoint manager does not contain a valid benchmark result to load.
ValueError: If duplicate metrics (i.e. with same ``description``) are provided
in ``table_metrics`` or ``plot_metrics``.
Note:
If ``benchmark_result`` is not provided, it will be loaded from the checkpoint manager. If both are provided,
then the results from the provided ``benchmark_result`` will be used and the checkpoint manager will only be
used to save the computed metrics result. If neither is provided, an error will be raised.
All used table- and plot-metrics will be saved to the checkpoints' metadata if a checkpoint manager is provided,
in order to know which metrics were computed and can be displayed later.
Metrics that return ``False`` from :meth:`~decent_bench.metrics.Metric.is_available` for the given problem are
filtered out from the returned metric lists. Warnings are emitted with the omitted metric names.
Plot metrics can still be available even when their final table value is ``inf/nan``: plot computation keeps the
finite part of a trajectory, while table metrics are evaluated at the final iteration.
"""
start_logger(log_level=log_level)
LOGGER.info("Starting metrics computation")
# 1) user input validation
if benchmark_result is None:
if checkpoint_manager is None:
raise ValueError(
"If ``benchmark_result`` is not provided, ``checkpoint_manager`` must be provided "
"to load the benchmark result from."
)
try:
benchmark_result = checkpoint_manager.load_benchmark_result()
except (FileNotFoundError, KeyError) as e:
raise ValueError(f"Invalid checkpoint directory: missing or corrupted metadata - {e}") from e
except JSONDecodeError as e:
raise ValueError(f"Invalid checkpoint directory: metadata is not valid JSON - {e}") from e
if len(benchmark_result.states) == 0:
raise ValueError("No benchmark result found in checkpoint manager to compute metrics")
table_metrics, plot_metrics = _resolve_default_metrics(table_metrics, plot_metrics)
# check metrics are unique
_validate_unique_descriptions(table_metrics, "table")
_validate_unique_descriptions(plot_metrics, "plot")
# remove unavailable metrics
table_metrics = _remove_unavailable(table_metrics, benchmark_result.problem, "table")
plot_metrics = _remove_unavailable(plot_metrics, benchmark_result.problem, "plot")
# 2) compute table and plot metrics
network_views: dict[Algorithm[Network], list[NetworkMetricsView]] = {}
for alg, networks in benchmark_result.states.items():
network_views[alg] = [NetworkMetricsView.from_network(nw) for nw in networks]
iterations = _all_iterations(network_views)
# compute metrics
raw_plot_results = compute_plot_metrics(network_views, benchmark_result.problem, plot_metrics, iterations)
raw_table_results = compute_table_metrics(
network_views, benchmark_result.problem, table_metrics, iterations, raw_plot_results
)
# aggregate metrics
aggregated_plot_metrics = aggregate_plot_metrics(raw_plot_results)
aggregated_table_metrics = aggregate_table_metrics(raw_table_results, statistics_across_agents)
# 2) create MetricResult
result = MetricResult(
network_views=network_views,
raw_table_results=raw_table_results,
raw_plot_results=raw_plot_results,
table_results=aggregated_table_metrics,
plot_results=aggregated_plot_metrics,
)
if checkpoint_manager is not None:
with Status("Saving computed metrics..."):
metadata = {
"table_metrics": [metric.description for metric in table_metrics],
"plot_metrics": [metric.description for metric in plot_metrics],
}
checkpoint_manager.save_metrics_result(result)
checkpoint_manager.append_metadata(metadata)
utils._clear_caches() # noqa: SLF001
return result
def _resolve_default_metrics(
table_metrics: list[Metric] | None,
plot_metrics: list[Metric] | None,
) -> tuple[list[Metric], list[Metric]]:
if table_metrics is None:
table_metrics = ml._DEFAULT_TABLE_METRICS # noqa: SLF001
if plot_metrics is None:
plot_metrics = ml._DEFAULT_PLOT_METRICS # noqa: SLF001
return table_metrics, plot_metrics
def _validate_unique_descriptions(metrics: list[Metric], type_: Literal["table", "plot"]) -> None:
duplicate_metric_descriptions = _find_duplicates([metric.description for metric in metrics])
if duplicate_metric_descriptions:
duplicates = ", ".join(duplicate_metric_descriptions)
raise ValueError(f"{type_.capitalize()} metric descriptions must be unique, duplicates found: {duplicates}")
def _remove_unavailable(
metrics: list[Metric], problem: "BenchmarkProblem", type_: Literal["table", "plot"]
) -> list[Metric]:
available_metrics: list[Metric] = []
for metric in metrics:
available, reason = metric.is_available(problem)
if not available:
LOGGER.warning(f"Skipping {type_} metric '{metric.description}' because it is unavailable: {reason}")
continue
available_metrics.append(metric)
return available_metrics
def _all_iterations(network_views: dict[Algorithm[Network], list[NetworkMetricsView]]) -> list[int]:
"""Find all the iterations that were reached in at least one trial by at least one algorithm."""
iterations: list[int] = []
for network_views_by_trial in network_views.values():
for network_view in network_views_by_trial:
iterations += network_view.iterations
return sorted(set(iterations))