Source code for decent_bench.utils.agent_utils
from typing import TYPE_CHECKING
from decent_bench.costs import EmpiricalRiskCost
if TYPE_CHECKING:
from decent_bench.agents import Agent
[docs]
def infer_client_data_size(client: "Agent") -> float:
"""
Infer a client's local data size from an empirical-risk cost.
Args:
client: client agent whose data size should be inferred.
Raises:
ValueError: if the client cost is not an empirical-risk cost.
"""
cost = client.cost
if isinstance(cost, EmpiricalRiskCost):
return float(cost.n_samples)
raise ValueError(
"Cannot infer client data size. Use an EmpiricalRiskCost to provide the number of local samples.",
)