"""Collection of pre-defined table and plot metrics."""
from typing import TYPE_CHECKING
import numpy as np
import decent_bench.utils.interoperability as iop
from decent_bench.costs import Cost, EmpiricalRiskCost
from decent_bench.metrics import utils
from decent_bench.metrics._metric import Metric
from decent_bench.metrics._metrics_view import NetworkMetricsView
from decent_bench.networks import FedNetwork
if TYPE_CHECKING:
from decent_bench.benchmark import BenchmarkProblem
[docs]
class Regret(Metric):
r"""
Global regret.
Table:
Global regret using the agents'/clients' final x.
Plot:
Global regret (y-axis) per iteration (x-axis).
Global regret is defined as:
.. include:: snippets/global_cost_error.rst
Note:
Available only when ``problem.x_optimal`` is provided.
"""
description: str = "regret"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
if getattr(problem, "x_optimal", None) is None:
return False, "requires problem.x_optimal"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return [utils._regret(network.agents(), problem, iteration)] # noqa: SLF001
[docs]
class GradientNorm(Metric):
r"""
Global gradient norm.
Table:
Gradient norm using the agents'/clients' final x.
Plot:
Gradient norm (y-axis) per iteration (x-axis).
Gradient norm is defined as:
.. include:: snippets/global_gradient_optimality.rst
"""
description: str = "gradient norm"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return [utils._gradient_norm(network.agents(), iteration)] # noqa: SLF001
[docs]
class XError(Metric):
r"""
Distance to optimal solution.
Table:
Distance to optimal solution using the mean of the agents'/clients' final x.
Plot:
Distance to optimal solution (y-axis) per iteration (x-axis).
X error is defined as:
.. math::
\|\mathbf{\bar{x}} - \mathbf{x}^\star\|
where :math:`\mathbf{\bar{x}}` is the mean x across all agents/clients, and
:math:`\mathbf{x}^\star` is the optimal x defined in the *problem*.
Note:
Available only when ``problem.x_optimal`` is provided.
"""
description: str = "x error"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
if getattr(problem, "x_optimal", None) is None:
return False, "requires problem.x_optimal"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return [utils._x_error(network.agents(), problem, iteration)] # noqa: SLF001
[docs]
class ConsensusError(Metric):
r"""
Distance to consensus.
Table:
Distance of the agents'/clients' states from their current average.
Plot:
Distance to consensus (y-axis) per iteration (x-axis).
The consensus error per agent/client is defined as:
.. math::
\{ \|\mathbf{x}_i - \bar{\mathbf{x}}\|, \|\mathbf{x}_j - \bar{\mathbf{x}}\|, ... \}
where :math:`\mathbf{x}_i` is agent/client i's current state,
:math:`\bar{\mathbf{x}}` is the average of all agents'/clients' states, and :math:`\| \cdot \|` is the 2-norm.
.. seealso::
:class:`~decent_bench.metrics.runtime_library.RuntimeConsensusError` for the runtime version.
"""
description: str = "consensus error"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
iteration: int,
) -> list[float]:
agent_views = network.agents()
x_mean = utils.x_mean(tuple(agent_views), iteration)
return [float(iop.norm(x_mean - a.x_history[iteration])) for a in agent_views]
[docs]
class XUpdates(Metric):
r"""
Number of x iterations/updates.
Table:
Number of x iterations/updates per agent.
Plot:
Number of x iterations/updates (y-axis) per iteration (x-axis).
Will be a flat line as the number of x iterations/updates is only calculated at the end of the trial,
not per iteration.
"""
description: str = "nr x updates"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[int]:
return [a.n_x_updates for a in network.agents()]
[docs]
class FunctionCalls(Metric):
r"""
Number of function calls.
Table:
Number of function calls per agent.
Plot:
Number of function calls (y-axis) per iteration (x-axis).
Will be a flat line as the number of function calls is only calculated at the end of the trial,
not per iteration.
Note:
Can be a floating point number if :class:`~decent_bench.costs.EmpiricalRiskCost` is used and a
batch size other than the full dataset size is used.
"""
description: str = "nr function calls"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_function_calls for a in network.agents()]
[docs]
class GradientCalls(Metric):
r"""
Number of gradient calls.
Table:
Number of gradient calls per agent.
Plot:
Number of gradient calls (y-axis) per iteration (x-axis).
Will be a flat line as the number of gradient calls is only calculated at the end of the trial,
not per iteration.
Note:
Can be a floating point number if :class:`~decent_bench.costs.EmpiricalRiskCost` is used and a
batch size other than the full dataset size is used.
"""
description: str = "nr gradient calls"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_gradient_calls for a in network.agents()]
[docs]
class HessianCalls(Metric):
r"""
Number of Hessian calls.
Table:
Number of Hessian calls per agent.
Plot:
Number of Hessian calls (y-axis) per iteration (x-axis).
Will be a flat line as the number of Hessian calls is only calculated at the end of the trial,
not per iteration.
Note:
Can be a floating point number if :class:`~decent_bench.costs.EmpiricalRiskCost` is used and a
batch size other than the full dataset size is used.
"""
description: str = "nr Hessian calls"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_hessian_calls for a in network.agents()]
[docs]
class ProximalCalls(Metric):
r"""
Number of proximal calls.
Table:
Number of proximal calls per agent.
Plot:
Number of proximal calls (y-axis) per iteration (x-axis).
Will be a flat line as the number of proximal calls is only calculated at the end of the trial,
not per iteration.
"""
description: str = "nr proximal calls"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_proximal_calls for a in network.agents()]
[docs]
class SentMessages(Metric):
r"""
Number of sent messages.
Table:
Number of sent messages per agent. For federated networks, this includes the server.
Plot:
Number of sent messages (y-axis) per iteration (x-axis).
Will be a flat line as the number of sent messages is calculated at the end of the trial,
not per iteration.
"""
description: str = "nr sent messages"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_sent_messages for a in network.graph.nodes]
[docs]
class ReceivedMessages(Metric):
r"""
Number of received messages.
Table:
Number of received messages per agent. For federated networks, this includes the server.
Plot:
Number of received messages (y-axis) per iteration (x-axis).
Will be a flat line as the number of received messages are calculated at the end of the trial,
not per iteration.
"""
description: str = "nr received messages"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_received_messages for a in network.graph.nodes]
[docs]
class SentMessagesDropped(Metric):
r"""
Number of sent messages dropped.
Table:
Number of sent messages dropped per agent. For federated networks, this includes the server.
Plot:
Number of sent messages dropped (y-axis) per iteration (x-axis).
Will be a flat line as the number of sent messages dropped is calculated at the end of the trial,
not per iteration.
"""
description: str = "nr sent messages dropped"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
return [a.n_sent_messages_dropped for a in network.graph.nodes]
[docs]
class Accuracy(Metric):
r"""
Accuracy of the agents'/clients' predictions.
Table:
Accuracy of the agents'/clients' final x.
Plot:
Accuracy (y-axis) per iteration (x-axis).
Accuracy is calculated as the mean accuracy across agents/clients, where each agent's/client's accuracy is
calculated using its recorded x at that iteration.
Only available for :class:`~decent_bench.costs.EmpiricalRiskCost` and integer targets.
Accuracy measures the proportion of correct predictions:
.. math::
\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}
where TP, TN, FP, and FN are true positives, true negatives, false positives, and false negatives, respectively.
Note:
Available only when:
- ``problem.test_data`` is provided,
- all agent costs are :class:`~decent_bench.costs.EmpiricalRiskCost`,
- target labels are integer-valued.
"""
description: str = "accuracy"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
if getattr(problem, "test_data", None) is None:
return False, "requires problem.test_data"
if not all(isinstance(a.cost, EmpiricalRiskCost) for a in problem.network.agents()):
return False, "accuracy only applies if all agents have EmpiricalRiskCost"
_, test_y = utils._split_dataset(problem.test_data) # type: ignore[arg-type] # noqa: SLF001
if test_y.dtype.kind not in {"i", "u"}:
return False, f"accuracy only applies for integer targets, dtype {test_y.dtype} found"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return utils._accuracy(network.agents(), problem, iteration) # noqa: SLF001
[docs]
class MSE(Metric):
r"""
Mean squared error of the agents'/clients' predictions.
Table:
Mean squared error of the agents'/clients' final x.
Plot:
Mean Squared Error (MSE) (y-axis) per iteration (x-axis).
MSE is calculated as the mean MSE across agents/clients, where each agent's/client's MSE is calculated using
its recorded x at that iteration.
Only available for :class:`~decent_bench.costs.EmpiricalRiskCost`.
MSE measures the average squared difference between predictions and true values:
.. math::
\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (\hat{y}_i - y_i)^2
where :math:`\hat{y}_i` are the predicted values, :math:`y_i` are the true values, and :math:`n` is
the number of samples.
Note:
Available only when ``problem.test_data`` is provided and all agent costs are
:class:`~decent_bench.costs.EmpiricalRiskCost`.
"""
description: str = "mse"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
if getattr(problem, "test_data", None) is None:
return False, "requires problem.test_data"
if not all(isinstance(a.cost, EmpiricalRiskCost) for a in problem.network.agents()):
return False, "MSE only applies if all agents have EmpiricalRiskCost"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return utils._mse(network.agents(), problem, iteration) # noqa: SLF001
[docs]
class Precision(Metric):
r"""
Precision of the agents'/clients' predictions.
Table:
Precision of the agents'/clients' final x.
Plot:
Precision (y-axis) per iteration (x-axis).
Precision is calculated as the mean precision across agents/clients, where each agent's/client's precision is
calculated using its recorded x at that iteration.
Only available for :class:`~decent_bench.costs.EmpiricalRiskCost` and integer targets.
Precision measures the proportion of positive predictions that are correct:
.. math::
\text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}}
where TP is the number of true positives and FP is the number of false positives.
Note:
Available only when:
- ``problem.test_data`` is provided,
- all agent costs are :class:`~decent_bench.costs.EmpiricalRiskCost`,
- target labels are integer-valued.
"""
description: str = "precision"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
if getattr(problem, "test_data", None) is None:
return False, "requires problem.test_data"
if not all(isinstance(a.cost, EmpiricalRiskCost) for a in problem.network.agents()):
return False, "precision only applies if all agents have EmpiricalRiskCost"
_, test_y = utils._split_dataset(problem.test_data) # type: ignore[arg-type] # noqa: SLF001
if test_y.dtype.kind not in {"i", "u"}:
return False, f"precision only applies for integer targets, dtype {test_y.dtype} found"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return utils._precision(network.agents(), problem, iteration) # noqa: SLF001
[docs]
class Recall(Metric):
r"""
Recall of the agents'/clients' predictions.
Table:
Recall of the agents'/clients' final x.
Plot:
Recall (y-axis) per iteration (x-axis).
Recall is calculated as the mean recall across agents/clients, where each agent's/client's recall is calculated
using its recorded x at that iteration.
Only available for :class:`~decent_bench.costs.EmpiricalRiskCost` and integer targets.
Recall measures the proportion of actual positives that are correctly identified:
.. math::
\text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}}
where TP is the number of true positives and FN is the number of false negatives.
Note:
Available only when:
- ``problem.test_data`` is provided,
- all agent costs are :class:`~decent_bench.costs.EmpiricalRiskCost`,
- target labels are integer-valued.
"""
description: str = "recall"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
if getattr(problem, "test_data", None) is None:
return False, "requires problem.test_data"
if not all(isinstance(a.cost, EmpiricalRiskCost) for a in problem.network.agents()):
return False, "recall only applies if all agents have EmpiricalRiskCost"
_, test_y = utils._split_dataset(problem.test_data) # type: ignore[arg-type] # noqa: SLF001
if test_y.dtype.kind not in {"i", "u"}:
return False, f"recall only applies for integer targets, dtype {test_y.dtype} found"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return utils._recall(network.agents(), problem, iteration) # noqa: SLF001
[docs]
class Loss(Metric):
r"""
Loss of the agents'/clients' predictions.
Table:
Loss of the agents'/clients' final x.
Plot:
Loss (y-axis) per iteration (x-axis).
Loss is calculated as the mean loss across agents/clients, where each agent's/client's loss is calculated using
its recorded x at that iteration.
"""
description: str = "loss"
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
iteration: int,
) -> list[float]:
return utils._losses(network.agents(), iteration) # noqa: SLF001
def _requires_fednetwork(problem: "BenchmarkProblem", metric_name: str) -> tuple[bool, str | None]:
if not isinstance(problem.network, FedNetwork):
return False, f"{metric_name} only applies to FedNetwork"
return True, None
def _server_metric_cost(network: NetworkMetricsView, metric_name: str) -> Cost:
agent_views = network.agents()
if not agent_views:
raise ValueError(f"{metric_name} requires at least one client metrics view")
return agent_views[0].cost
[docs]
class ClientDriftFromServer(Metric):
r"""
Distance between client local models and the server model.
Table:
Distance of the clients' final states from the final server state.
Plot:
Client drift from server (y-axis) per iteration (x-axis).
The client drift per client is defined as:
.. math::
\{ \|\mathbf{x}_i - \mathbf{x}_s\|, \|\mathbf{x}_j - \mathbf{x}_s\|, ... \}
where :math:`\mathbf{x}_s` is the current server state.
Note:
Available only for :class:`~decent_bench.networks.FedNetwork`.
"""
description: str = "client drift from server"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
return _requires_fednetwork(problem, self.description)
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem", # noqa: ARG002
iteration: int,
) -> list[float]:
return utils._drifts(network.clients(), network.server(), iteration) # noqa: SLF001
[docs]
class FractionSelectedClients(Metric):
r"""
Fraction of clients selected by the federated algorithm to perform local training.
Table:
Fraction of selected clients over the algorithm run.
Note:
Available only for :class:`~decent_bench.networks.FedNetwork`.
"""
description: str = "fraction selected clients"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
return _requires_fednetwork(problem, self.description)
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
_: "BenchmarkProblem",
__: int,
) -> list[float]:
agent_views = network.clients()
n_rounds = utils._observed_rounds(agent_views) # noqa: SLF001
if n_rounds == 0 or not agent_views:
return [np.nan]
return [sum(agent.n_times_selected for agent in agent_views) / (n_rounds * len(agent_views))]
[docs]
class ServerMSE(Metric):
r"""
Mean squared error of the server model's predictions.
Table:
Mean squared error of the final server x.
Plot:
Server MSE (y-axis) per iteration (x-axis).
Note:
Available only for :class:`~decent_bench.networks.FedNetwork` with ``problem.test_data`` and empirical-risk
client costs.
"""
description: str = "server mse"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
available, reason = _requires_fednetwork(problem, self.description)
if not available:
return False, reason
if getattr(problem, "test_data", None) is None:
return False, "requires problem.test_data"
if not all(isinstance(a.cost, EmpiricalRiskCost) for a in problem.network.agents()):
return False, "server MSE only applies if all clients have EmpiricalRiskCost"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
cost = _server_metric_cost(network, self.description)
return [utils._mse_at_x(cost, network.server().x_history[iteration], problem)] # noqa: SLF001
[docs]
class ServerAccuracy(Metric):
r"""
Accuracy of the server model's predictions.
Table:
Accuracy of the final server x.
Plot:
Server accuracy (y-axis) per iteration (x-axis).
Note:
Available only for :class:`~decent_bench.networks.FedNetwork` with ``problem.test_data``, empirical-risk
client costs, and integer-valued targets.
"""
description: str = "server accuracy"
[docs]
def is_available( # noqa: D102
self,
problem: "BenchmarkProblem",
) -> tuple[bool, str | None]:
available, reason = _requires_fednetwork(problem, self.description)
if not available:
return False, reason
if getattr(problem, "test_data", None) is None:
return False, "requires problem.test_data"
if not all(isinstance(a.cost, EmpiricalRiskCost) for a in problem.network.agents()):
return False, "server accuracy only applies if all clients have EmpiricalRiskCost"
_, test_y = utils._split_dataset(problem.test_data) # type: ignore[arg-type] # noqa: SLF001
if test_y.dtype.kind not in {"i", "u"}:
return False, f"server accuracy only applies for integer targets, dtype {test_y.dtype} found"
return True, None
[docs]
def compute( # noqa: D102
self,
network: NetworkMetricsView,
problem: "BenchmarkProblem",
iteration: int,
) -> list[float]:
cost = _server_metric_cost(network, self.description)
return [utils._accuracy_at_x(cost, network.server().x_history[iteration], problem)] # noqa: SLF001
_BASE_TABLE_METRICS: list[Metric] = [
Regret(),
GradientNorm(),
XError(),
ConsensusError(),
Loss(),
XUpdates(),
FunctionCalls(),
GradientCalls(),
HessianCalls(),
ProximalCalls(),
SentMessages(),
ReceivedMessages(),
SentMessagesDropped(),
]
"""
- :class:`Regret`
- :class:`GradientNorm`
- :class:`XError`
- :class:`ConsensusError`
- :class:`Loss`
- :class:`XUpdates`
- :class:`FunctionCalls`
- :class:`GradientCalls`
- :class:`HessianCalls`
- :class:`ProximalCalls`
- :class:`SentMessages`
- :class:`ReceivedMessages`
- :class:`SentMessagesDropped`
:meta hide-value:
"""
_REGRESSION_TABLE_METRICS: list[Metric] = [
MSE(x_log=False, y_log=True),
]
"""
- :class:`MSE` - :func:`min`, :func:`~numpy.average`, :func:`max`
:meta hide-value:
"""
_CLASSIFICATION_TABLE_METRICS: list[Metric] = [
Accuracy(fmt=".2%", x_log=False, y_log=False),
Precision(fmt=".2%", x_log=False, y_log=False),
Recall(fmt=".2%", x_log=False, y_log=False),
]
"""
- :class:`Accuracy` - with percentage format
- :class:`Precision` - with percentage format
- :class:`Recall` - with percentage format
:meta hide-value:
"""
# No need to specify statistics for plot metrics as they are only
# used for table metrics, if you were to use the same Metric object
# for both, you would need to specify statistics
_BASE_PLOT_METRICS: list[Metric] = [
Regret(x_log=False, y_log=True),
GradientNorm(x_log=False, y_log=True),
ConsensusError(x_log=False, y_log=True),
Loss(x_log=False, y_log=False),
]
"""
- :class:`Regret` (semi-log)
- :class:`GradientNorm` (semi-log)
- :class:`ConsensusError` (semi-log)
- :class:`Loss` (linear)
:meta hide-value:
"""
_REGRESSION_PLOT_METRICS: list[Metric] = _REGRESSION_TABLE_METRICS
"""
- :class:`MSE` (semi-log)
:meta hide-value:
"""
_CLASSIFICATION_PLOT_METRICS: list[Metric] = _CLASSIFICATION_TABLE_METRICS
"""
- :class:`Accuracy` (linear)
- :class:`Precision` (linear)
- :class:`Recall` (linear)
:meta hide-value:
"""
_FEDERATED_TABLE_METRICS: list[Metric] = [
ClientDriftFromServer(),
FractionSelectedClients(fmt=".2%", x_log=False, y_log=False),
]
"""
- :class:`ClientDriftFromServer`
- :class:`FractionSelectedClients` - with percentage format
:meta hide-value:
"""
_FEDERATED_PLOT_METRICS: list[Metric] = [
ClientDriftFromServer(x_log=False, y_log=True),
]
"""
- :class:`ClientDriftFromServer` (semi-log)
:meta hide-value:
"""
_FEDERATED_REGRESSION_TABLE_METRICS: list[Metric] = [
ServerMSE(x_log=False, y_log=True),
]
"""
- :class:`ServerMSE`
:meta hide-value:
"""
_FEDERATED_REGRESSION_PLOT_METRICS: list[Metric] = [
ServerMSE(x_log=False, y_log=True),
]
"""
- :class:`ServerMSE` (semi-log)
:meta hide-value:
"""
_FEDERATED_CLASSIFICATION_TABLE_METRICS: list[Metric] = [
ServerAccuracy(fmt=".2%", x_log=False, y_log=False),
]
"""
- :class:`ServerAccuracy` - with percentage format
:meta hide-value:
"""
_FEDERATED_CLASSIFICATION_PLOT_METRICS: list[Metric] = [
ServerAccuracy(fmt=".2%", x_log=False, y_log=False),
]
"""
- :class:`ServerAccuracy` (linear)
:meta hide-value:
"""
_DEFAULT_TABLE_METRICS: list[Metric] = [
*_BASE_TABLE_METRICS,
*_REGRESSION_TABLE_METRICS,
*_CLASSIFICATION_TABLE_METRICS,
*_FEDERATED_TABLE_METRICS,
*_FEDERATED_REGRESSION_TABLE_METRICS,
*_FEDERATED_CLASSIFICATION_TABLE_METRICS,
]
_DEFAULT_PLOT_METRICS: list[Metric] = [
*_BASE_PLOT_METRICS,
*_REGRESSION_PLOT_METRICS,
*_CLASSIFICATION_PLOT_METRICS,
*_FEDERATED_PLOT_METRICS,
*_FEDERATED_REGRESSION_PLOT_METRICS,
*_FEDERATED_CLASSIFICATION_PLOT_METRICS,
]