Source code for decent_bench.utils.network_utils
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Literal
import matplotlib.axes
import networkx as nx
from decent_bench.networks import Network
_LAYOUT_FUNCS: dict[Literal["spring", "kamada_kawai", "circular", "random", "shell"], Any] = {
"spring": nx.drawing.layout.spring_layout,
"kamada_kawai": nx.drawing.layout.kamada_kawai_layout,
"circular": nx.drawing.layout.circular_layout,
"random": nx.drawing.layout.random_layout,
"shell": nx.drawing.layout.shell_layout,
}
[docs]
def plot_network(
network: Network,
*,
ax: matplotlib.axes.Axes | None = None,
layout: Literal["spring", "kamada_kawai", "circular", "random", "shell"] = "spring",
**draw_kwargs: Mapping[str, object],
) -> matplotlib.axes.Axes:
"""
Plot a Network using NetworkX drawing utilities.
Args:
network: Network to be plotted.
ax: optional :class:`matplotlib.axes.Axes` to draw on. If ``None`` a new figure is created.
layout: layout algorithm to position nodes (e.g. :func:`networkx.drawing.layout.spring_layout`,
:func:`networkx.drawing.layout.kamada_kawai_layout`,
:func:`networkx.drawing.layout.circular_layout`,
:func:`networkx.drawing.layout.random_layout`,
:func:`networkx.drawing.layout.shell_layout`).
draw_kwargs: forwarded to :func:`networkx.drawing.nx_pylab.draw_networkx`.
Returns:
The matplotlib :class:`matplotlib.axes.Axes` containing the plot.
Raises:
RuntimeError: if matplotlib is not available.
ValueError: if an unsupported layout is requested.
"""
try:
import matplotlib.pyplot as plt # noqa: PLC0415
except Exception as exc: # pragma: no cover - runtime dependency guard
raise RuntimeError("matplotlib is required for plotting the network") from exc
layout_func = _LAYOUT_FUNCS.get(layout)
if layout_func is None:
supported = ", ".join(sorted(_LAYOUT_FUNCS))
raise ValueError(f"Unsupported layout '{layout}'. Supported layouts: {supported}")
pos = layout_func(network.graph)
if ax is None:
_, ax = plt.subplots()
draw_kwargs_dict: dict[str, Any] = dict(draw_kwargs)
# use agents' indices within the network as labels (if custom labels are not specified by the user)
if "labels" not in draw_kwargs_dict:
draw_kwargs_dict["labels"] = {a: a.index for a in network.graph}
nx.drawing.nx_pylab.draw_networkx(
network.graph,
pos=pos,
ax=ax,
**draw_kwargs_dict,
)
return ax