diff --git a/other/materials_designer/workflows/band_gap.ipynb b/other/materials_designer/workflows/band_gap.ipynb index 5febaf86..83cf69b3 100644 --- a/other/materials_designer/workflows/band_gap.ipynb +++ b/other/materials_designer/workflows/band_gap.ipynb @@ -589,9 +589,9 @@ "metadata": {}, "outputs": [], "source": [ - "from utils.api import wait_for_jobs_to_finish\n", + "from utils.api import wait_for_jobs_to_finish_async\n", "\n", - "wait_for_jobs_to_finish(client.jobs, [job_id], poll_interval=POLL_INTERVAL)" + "await wait_for_jobs_to_finish_async(client.jobs, [job_id], poll_interval=POLL_INTERVAL)" ] }, { diff --git a/utils/api.py b/utils/api.py index d8ada34d..be3a2e13 100644 --- a/utils/api.py +++ b/utils/api.py @@ -1,8 +1,8 @@ import datetime import json import os -import time import urllib.request +from collections import Counter from typing import List, Optional from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints @@ -12,6 +12,8 @@ from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints from tabulate import tabulate +from utils.interrupts import interruptible_polling_loop + def save_files(job_id: str, job_endpoint: JobEndpoints, filename_on_cloud: str, filename_on_disk: str) -> None: """ @@ -57,7 +59,21 @@ def get_jobs_statuses_by_ids(endpoint: JobEndpoints, job_ids: List[str]) -> List return [job["status"] for job in jobs] -def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval: int = 10) -> None: +def _print_jobs_status_table(counts: Counter) -> None: + headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] + now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + row = [ + now, + counts.get("submitted", 0), + counts.get("active", 0), + counts.get("finished", 0), + counts.get("error", 0), + ] + print(tabulate([row], headers, tablefmt="grid", stralign="center")) + + +@interruptible_polling_loop() +def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: List[str]) -> bool: """ Waits for jobs to finish and prints their statuses. A job is considered finished if it is not in "pre-submission", "submitted", or, "active" status. @@ -65,25 +81,13 @@ def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval Args: endpoint (JobEndpoints): Job endpoint object from the Exabyte API Client job_ids (list): list of job IDs to wait for - poll_interval (int): poll interval for job information in seconds. Defaults to 10. """ - print("Wait for jobs to finish, poll interval: {0} sec".format(poll_interval)) - while True: - statuses = get_jobs_statuses_by_ids(endpoint, job_ids) - - errored_jobs = len([status for status in statuses if status == "error"]) - active_jobs = len([status for status in statuses if status == "active"]) - finished_jobs = len([status for status in statuses if status == "finished"]) - submitted_jobs = len([status for status in statuses if status == "submitted"]) - - headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] - now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") - row = [now, submitted_jobs, active_jobs, finished_jobs, errored_jobs] - print(tabulate([row], headers, tablefmt="grid", stralign="center")) + statuses = get_jobs_statuses_by_ids(endpoint, job_ids) + counts = Counter(statuses) + _print_jobs_status_table(counts) - if all([status not in ["pre-submission", "submitted", "active"] for status in statuses]): - break - time.sleep(poll_interval) + active_statuses = {"pre-submission", "submitted", "active"} + return any(status in active_statuses for status in statuses) def copy_bank_workflow_by_system_name(endpoint: BankWorkflowEndpoints, system_name: str, account_id: str) -> dict: diff --git a/utils/interrupts.py b/utils/interrupts.py new file mode 100644 index 00000000..fa2dbfcf --- /dev/null +++ b/utils/interrupts.py @@ -0,0 +1,216 @@ +import asyncio +import inspect +import uuid +from dataclasses import dataclass +from functools import wraps +from typing import Any, Awaitable, Callable + +from mat3ra.utils.jupyterlite.environment import ENVIRONMENT, EnvironmentsEnum + +try: + from IPython.display import HTML, display # type: ignore +except Exception: + HTML = None + display = None + + +class UserAbortError(RuntimeError): + pass + + +def display_abort_controls_in_current_cell_output( + channel_name: str = "mat3ra_abort_channel", + abort_button_text: str = "Abort", +) -> None: + """ + Shows: + [Abort] Press ESC to abort + + Only works in notebook frontends that support HTML output. + Safe no-op otherwise. + """ + if HTML is None or display is None: + return + + element_id = f"abort_controls_{uuid.uuid4().hex}" + + display( + HTML( + f""" +
+ + + Press ESC to abort + +
+ + + """ + ) + ) + + +@dataclass +class BroadcastChannelAbortController: + """ + WebWorker-side receiver. Works only in pyodide (emscripten). + In regular Python: start() does nothing and is_aborted stays False. + """ + + channel_name: str = "mat3ra_abort_channel" + is_aborted: bool = False + + def __post_init__(self) -> None: + self._broadcast_channel = None + self._on_message_proxy = None + + def start(self) -> None: + if ENVIRONMENT != EnvironmentsEnum.PYODIDE: + return + if self._broadcast_channel is not None: + return + + import js # type: ignore + from pyodide.ffi import create_proxy # type: ignore + + self._broadcast_channel = js.BroadcastChannel.new(self.channel_name) + + def on_message(event) -> None: + message = getattr(event, "data", None) + if message and getattr(message, "type", None) == "abort": + self.is_aborted = True + + self._on_message_proxy = create_proxy(on_message) + self._broadcast_channel.onmessage = self._on_message_proxy # type: ignore + + def stop(self) -> None: + if self._broadcast_channel is None: + return + + self._broadcast_channel.close() + self._broadcast_channel = None + + if self._on_message_proxy is not None: + self._on_message_proxy.destroy() + self._on_message_proxy = None + + +async def run_interruptible_loop_async( + loop_body: Callable[[], Awaitable[bool]], + poll_interval_seconds: float, + *, + channel_name: str = "mat3ra_abort_channel", + check_interval_seconds: float = 0.05, + show_controls: bool = True, +) -> None: + """ + Minimal wrapper. + + loop_body(): + - do one "poll" iteration + - return True to keep looping, False to stop normally + + Between iterations we sleep in small slices so: + - pyodide: ESC/button can be received and stop the loop + - regular Python: yields control (Ctrl+C/Stop works where supported) + """ + broadcast_channel_abort_controller = BroadcastChannelAbortController(channel_name=channel_name) + broadcast_channel_abort_controller.start() + + if show_controls and ENVIRONMENT != EnvironmentsEnum.PYODIDE: + display_abort_controls_in_current_cell_output(channel_name=channel_name, abort_button_text="Abort") + + try: + while True: + should_continue = await loop_body() + if not should_continue: + return + + remaining_seconds = float(poll_interval_seconds) + while remaining_seconds > 0: + if broadcast_channel_abort_controller.is_aborted: + raise UserAbortError("Aborted by user.") + await asyncio.sleep(min(check_interval_seconds, remaining_seconds)) + remaining_seconds -= check_interval_seconds + + finally: + broadcast_channel_abort_controller.stop() + + +def interruptible_polling_loop( + poll_interval_kwarg_name: str = "poll_interval", + *, + default_poll_interval_seconds: float = 10.0, + channel_name: str = "mat3ra_abort_channel", + check_interval_seconds: float = 0.05, + show_controls: bool = True, +): + """ + Decorator for single-iteration polling functions. + + The decorated function must: + - return True to continue + - return False to stop + + The decorated function becomes an `async def` that runs the polling loop until completion. + + The polling interval can be passed at call time using `poll_interval_kwarg_name` (defaults to + `"poll_interval"`). If not provided, `default_poll_interval_seconds` is used. + """ + + def decorator(poll_step_function: Callable[..., Any]) -> Callable[..., Any]: + @wraps(poll_step_function) + async def wrapped(*args: Any, **kwargs: Any) -> None: + poll_interval_seconds = float(kwargs.pop(poll_interval_kwarg_name, default_poll_interval_seconds)) + + async def loop_body() -> bool: + result = poll_step_function(*args, **kwargs) + should_continue = await result if inspect.isawaitable(result) else result + return bool(should_continue) + + await run_interruptible_loop_async( + loop_body, + poll_interval_seconds, + channel_name=channel_name, + check_interval_seconds=check_interval_seconds, + show_controls=show_controls, + ) + + return wrapped + + return decorator