from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from math import ceil
from multiprocessing.managers import SyncManager
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Any
from rich.progress import (
BarColumn,
Progress,
ProgressColumn,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
)
from rich.table import Column, Table
from rich.text import Text
from decent_bench.algorithms import Algorithm
if TYPE_CHECKING:
from rich.progress import Task, TaskID
else:
TaskID = int
@dataclass(eq=False)
class _ProgressRecord:
"""Record of progress bar update to be sent to the progress listener."""
progress_bar_id: TaskID
increment: int
trial: int | None
class TrialColumn(ProgressColumn):
"""Safe column that shows 'Trial [X/total]' without KeyError if field missing."""
def __init__(self, n_trials: int, style: str = "", finished_style: str = "") -> None:
super().__init__()
self.n_trials = n_trials
self.style = style
self.finished_style = finished_style
def render(self, task: "Task") -> Text: # noqa: D102
trial = self.n_trials if task.finished else task.fields.get("fields", task.fields).get("trial", "?")
return Text(f"{trial}/{self.n_trials}", style=self.finished_style if task.finished else self.style)
class SpeedColumn(ProgressColumn):
"""Column that shows speed in iterations/second."""
def __init__(self, progress_step: int | None) -> None:
super().__init__()
self.progress_step = progress_step
def render(self, task: "Task") -> Text: # noqa: D102
if task.speed is None and task.finished_speed is None:
return Text("--.-- it/s", style="progress.percentage", justify="right")
speed = task.finished_speed or task.speed
if speed is not None and self.progress_step is not None:
# Normalize speed to iterations/second depending on progress step
speed *= self.progress_step
text = TaskProgressColumn.render_speed(speed)
text.justify = "right"
return text
class ProgressWithHeader(Progress):
"""Custom Progress display with column headers."""
def __init__( # type: ignore[no-untyped-def]
self,
*columns: str | ProgressColumn,
headers: Iterable[Text] | None = None,
**kwargs, # noqa: ANN003
) -> None:
self.headers = headers
super().__init__(*columns, **kwargs)
def make_tasks_table(self, tasks: Iterable["Task"]) -> Table:
"""Override to add header row to the table."""
if not tasks or not self.headers:
return super().make_tasks_table(tasks)
# Mimic super() implementation but render headers of columns
table_columns = [
(Column(no_wrap=True) if isinstance(col, str) else col.get_table_column().copy()) for col in self.columns
]
for col, header in zip(table_columns, self.headers, strict=False):
col.header = header
table = Table(
*table_columns,
padding=(0, 1),
expand=self.expand,
show_header=True,
box=None,
collapse_padding=True,
show_footer=False,
show_edge=False,
pad_edge=False,
)
# Add each task as a row
for task in tasks:
if task.visible:
table.add_row(
*(
(column.format(task=task) if isinstance(column, str) else column(task))
for column in self.columns
)
)
return table
[docs]
@dataclass
class ProgressBarHandle:
"""
A picklable handle for worker processes to update :class:`ProgressBarController`.
This class contains only the picklable parts needed by worker processes,
separating them from the unpicklable Thread components in ProgressBarController.
"""
_progress_increment_queue: Queue[Any]
_progress_bar_ids: dict[Algorithm[Any], TaskID]
_progress_step: int | None
[docs]
def start_progress_bar(self, algorithm: Algorithm[Any], trial: int, initial_progress: int) -> None:
"""
Start the clock of *algorithm*'s progress bar without incrementing it.
Internally, this is done through sending an increment of 0 to the progress listener. The progress listener
recognizes that the algorithm's execution just started and resets its clock, which started when the progress bar
was first rendered.
"""
progress_bar_id = self._progress_bar_ids[algorithm]
self._progress_increment_queue.put(
_ProgressRecord(
progress_bar_id,
initial_progress // self._progress_step if self._progress_step else 0,
trial + 1,
)
)
[docs]
def advance_progress_bar(self, algorithm: Algorithm[Any], iteration: int) -> None:
"""Advance *algorithm*'s progress bar by an amount (units)."""
if self._progress_step is None:
if (iteration + 1) < algorithm.iterations:
return
elif (iteration + 1) % self._progress_step != 0 and (iteration + 1) < algorithm.iterations:
return
progress_bar_id = self._progress_bar_ids[algorithm]
self._progress_increment_queue.put(_ProgressRecord(progress_bar_id, 1, None))
[docs]
class ProgressBarController:
"""
Controller of progress bars showing how far each algorithm has progressed and the estimated time remaining.
Args:
manager: A multiprocessing :class:`~multiprocessing.managers.SyncManager` instance used to create a shared queue
for coordinating progress updates across multiple processes. This enables thread-safe communication between
worker processes and the progress bar listener thread. If ``None``, a local in-process queue is used.
algorithms: algorithms that will be run, each gets its own bar
n_trials: number of trials the algorithms will run
progress_step: if provided, the progress bar will step every `progress_step`.
When provided, each algorithm's task total becomes `n_trials * ceil(algorithm.iterations / progress_step)`.
If `None`, the progress bar uses 1 unit per trial.
Note:
If `progress_step` is too small performance may degrade due to the
overhead of updating the progress bar too often.
"""
def __init__(
self,
manager: SyncManager | None,
algorithms: Sequence[Algorithm[Any]],
n_trials: int,
progress_step: int | None,
show_speed: bool = False,
show_trial: bool = False,
):
# Use a local queue for single-process runs to avoid multiprocessing manager overhead.
self._progress_increment_queue: Queue[_ProgressRecord | None] = (
manager.Queue() if manager is not None else Queue()
)
self.progress_step = progress_step
p_cols = [
(
TextColumn("{task.description}", table_column=Column(no_wrap=True, max_width=24)),
Text("Algorithm", style="bold"),
),
(BarColumn(finished_style="bold green", pulse_style="none"), Text("Progress Bar", style="bold")),
(TaskProgressColumn(), Text("", style="bold")), # Skip % Completed header as it's part of progress bar
*([(SpeedColumn(progress_step), Text("Speed", style="bold"))] if show_speed else []),
(TimeRemainingColumn(elapsed_when_finished=True), Text("Time", style="bold")),
*(
[
(
TrialColumn(n_trials=n_trials, style="progress.remaining", finished_style="progress.elapsed"),
Text("Active Trials", style="bold"),
)
]
if show_trial
else []
),
]
cols, headers = zip(*p_cols, strict=True)
orchestrator = ProgressWithHeader(*cols, headers=headers)
if progress_step is None:
self._progress_bar_ids = {alg: orchestrator.add_task(alg.name, total=n_trials) for alg in algorithms}
else:
self.steps_per_trial = {alg: max(1, ceil(alg.iterations / progress_step)) for alg in algorithms}
self._progress_bar_ids = {
alg: orchestrator.add_task(alg.name, total=n_trials * self.steps_per_trial[alg]) for alg in algorithms
}
orchestrator.start()
self.listener_thread = Thread(
target=self._progress_listener, args=(orchestrator, self._progress_increment_queue)
)
self.listener_thread.start()
self._handle = ProgressBarHandle(
_progress_increment_queue=self._progress_increment_queue,
_progress_bar_ids=self._progress_bar_ids,
_progress_step=self.progress_step,
)
[docs]
def mark_one_trial_as_complete(self, algorithm: Algorithm[Any], trial: int) -> None:
"""Mark a trial of *algorithm* as complete in the progress bar."""
progress_bar_id = self._progress_bar_ids[algorithm]
increment = 1 if self.progress_step is None else self.steps_per_trial[algorithm]
self._progress_increment_queue.put(_ProgressRecord(progress_bar_id, increment, trial + 1))
[docs]
def get_handle(self) -> ProgressBarHandle:
"""
Get a picklable handle for worker processes.
Returns a handle containing only the queue and metadata needed by worker
processes, without the unpicklable Thread component.
"""
return self._handle
[docs]
def stop(self) -> None:
"""Stop the progress bar and wait for the listener thread to finish."""
# Signal the listener thread to stop if it is still running
# because an exception occurred in one algorithm's execution.
self._progress_increment_queue.put(None)
if hasattr(self, "listener_thread"):
self.listener_thread.join(timeout=2.0)
@staticmethod
def _visible_capacity(orchestrator: Progress) -> int:
return max(1, orchestrator.console.size.height - 5)
@staticmethod
def _progress_listener(orchestrator: Progress, queue: Queue[_ProgressRecord | None]) -> None:
started_progress_bar_ids = set()
started_order = []
height = ProgressBarController._visible_capacity(orchestrator)
while not orchestrator.finished:
progress_record = queue.get()
if progress_record is None:
break
if progress_record.progress_bar_id not in started_progress_bar_ids:
orchestrator.reset(progress_record.progress_bar_id)
started_progress_bar_ids.add(progress_record.progress_bar_id)
started_order.append(progress_record.progress_bar_id)
# If we have more progress bars than visible capacity, keep the most recently started ones visible
while len(started_order) > height:
orchestrator.update(started_order.pop(0), visible=False)
if progress_record.trial is not None:
orchestrator.update(progress_record.progress_bar_id, fields={"trial": str(progress_record.trial)})
orchestrator.advance(progress_record.progress_bar_id, progress_record.increment)
orchestrator.stop()