From 06f5e11b7b0edddfa25f781a971d42da32b5da38 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Thu, 26 Feb 2026 20:06:47 -0800 Subject: [PATCH 1/4] update: add interruptible execution of loops decorator --- utils/api.py | 30 +++--- utils/interrupts.py | 216 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+), 16 deletions(-) create mode 100644 utils/interrupts.py diff --git a/utils/api.py b/utils/api.py index d8ada34d..903f255e 100644 --- a/utils/api.py +++ b/utils/api.py @@ -1,7 +1,6 @@ import datetime import json import os -import time import urllib.request from typing import List, Optional @@ -12,6 +11,8 @@ from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints from tabulate import tabulate +from utils.interrupts import interruptible_polling_loop + def save_files(job_id: str, job_endpoint: JobEndpoints, filename_on_cloud: str, filename_on_disk: str) -> None: """ @@ -57,7 +58,8 @@ def get_jobs_statuses_by_ids(endpoint: JobEndpoints, job_ids: List[str]) -> List return [job["status"] for job in jobs] -def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval: int = 10) -> None: +@interruptible_polling_loop() +def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: list) -> bool: """ Waits for jobs to finish and prints their statuses. A job is considered finished if it is not in "pre-submission", "submitted", or, "active" status. @@ -67,23 +69,19 @@ def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval job_ids (list): list of job IDs to wait for poll_interval (int): poll interval for job information in seconds. Defaults to 10. """ - print("Wait for jobs to finish, poll interval: {0} sec".format(poll_interval)) - while True: - statuses = get_jobs_statuses_by_ids(endpoint, job_ids) + statuses = get_jobs_statuses_by_ids(endpoint, job_ids) - errored_jobs = len([status for status in statuses if status == "error"]) - active_jobs = len([status for status in statuses if status == "active"]) - finished_jobs = len([status for status in statuses if status == "finished"]) - submitted_jobs = len([status for status in statuses if status == "submitted"]) + errored_jobs = sum(status == "error" for status in statuses) + active_jobs = sum(status == "active" for status in statuses) + finished_jobs = sum(status == "finished" for status in statuses) + submitted_jobs = sum(status == "submitted" for status in statuses) - headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] - now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") - row = [now, submitted_jobs, active_jobs, finished_jobs, errored_jobs] - print(tabulate([row], headers, tablefmt="grid", stralign="center")) + headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] + now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + row = [now, submitted_jobs, active_jobs, finished_jobs, errored_jobs] + print(tabulate([row], headers, tablefmt="grid", stralign="center")) - if all([status not in ["pre-submission", "submitted", "active"] for status in statuses]): - break - time.sleep(poll_interval) + return any(status in ["pre-submission", "submitted", "active"] for status in statuses) def copy_bank_workflow_by_system_name(endpoint: BankWorkflowEndpoints, system_name: str, account_id: str) -> dict: diff --git a/utils/interrupts.py b/utils/interrupts.py new file mode 100644 index 00000000..f7435bf3 --- /dev/null +++ b/utils/interrupts.py @@ -0,0 +1,216 @@ +import asyncio +import inspect +import sys +import uuid +from dataclasses import dataclass +from functools import wraps +from typing import Any, Awaitable, Callable + +try: + from IPython.display import HTML, display # type: ignore +except Exception: + HTML = None + display = None + + +class UserAbortError(RuntimeError): + pass + + +def display_abort_controls_in_current_cell_output( + channel_name: str = "mat3ra_abort_channel", + abort_button_text: str = "Abort", +) -> None: + """ + Shows: + [Abort] Press ESC to abort + + Only works in notebook frontends that support HTML output. + Safe no-op otherwise. + """ + if HTML is None or display is None: + return + + element_id = f"abort_controls_{uuid.uuid4().hex}" + + display( + HTML( + f""" +
+ + + Press ESC to abort + +
+ + + """ + ) + ) + + +@dataclass +class BroadcastChannelAbortController: + """ + WebWorker-side receiver. Works only in pyodide (emscripten). + In regular Python: start() does nothing and is_aborted stays False. + """ + + channel_name: str = "mat3ra_abort_channel" + is_aborted: bool = False + + def __post_init__(self) -> None: + self._broadcast_channel = None + self._on_message_proxy = None + + def start(self) -> None: + if sys.platform != "emscripten": + return + if self._broadcast_channel is not None: + return + + import js # type: ignore + from pyodide.ffi import create_proxy # type: ignore + + self._broadcast_channel = js.BroadcastChannel.new(self.channel_name) + + def on_message(event) -> None: + message = getattr(event, "data", None) + if message and getattr(message, "type", None) == "abort": + self.is_aborted = True + + self._on_message_proxy = create_proxy(on_message) + self._broadcast_channel.onmessage = self._on_message_proxy + + def stop(self) -> None: + if self._broadcast_channel is None: + return + + self._broadcast_channel.close() + self._broadcast_channel = None + + if self._on_message_proxy is not None: + self._on_message_proxy.destroy() + self._on_message_proxy = None + + +async def run_interruptible_loop_async( + loop_body: Callable[[], Awaitable[bool]], + poll_interval_seconds: float, + *, + channel_name: str = "mat3ra_abort_channel", + check_interval_seconds: float = 0.05, + show_controls: bool = True, +) -> None: + """ + Minimal wrapper. + + loop_body(): + - do one "poll" iteration + - return True to keep looping, False to stop normally + + Between iterations we sleep in small slices so: + - pyodide: ESC/button can be received and stop the loop + - regular Python: yields control (Ctrl+C/Stop works where supported) + """ + broadcast_channel_abort_controller = BroadcastChannelAbortController(channel_name=channel_name) + broadcast_channel_abort_controller.start() + + if show_controls and sys.platform == "emscripten": + display_abort_controls_in_current_cell_output(channel_name=channel_name, abort_button_text="Abort") + + try: + while True: + should_continue = await loop_body() + if not should_continue: + return + + remaining_seconds = float(poll_interval_seconds) + while remaining_seconds > 0: + if broadcast_channel_abort_controller.is_aborted: + raise UserAbortError("Aborted by user.") + await asyncio.sleep(min(check_interval_seconds, remaining_seconds)) + remaining_seconds -= check_interval_seconds + + finally: + broadcast_channel_abort_controller.stop() + + +def interruptible_polling_loop(poll_interval_kwarg_name: str = "poll_interval"): + """ + Decorator for single-iteration polling functions. + + The decorated function must: + - return True to continue + - return False to stop + + poll_interval is passed at call time. + """ + + def decorator(poll_step_function: Callable[..., Any]) -> Callable[..., Any]: + @wraps(poll_step_function) + async def wrapped(*args: Any, **kwargs: Any) -> None: + poll_interval_seconds = kwargs.pop(poll_interval_kwarg_name) + + broadcast_channel_abort_controller = BroadcastChannelAbortController() + broadcast_channel_abort_controller.start() + + if sys.platform == "emscripten": + display_abort_controls_in_current_cell_output() + + try: + while True: + result = poll_step_function(*args, **kwargs) + should_continue = await result if inspect.isawaitable(result) else result + + if not should_continue: + return + + remaining_seconds = float(poll_interval_seconds) + while remaining_seconds > 0: + if broadcast_channel_abort_controller.is_aborted: + raise UserAbortError("Aborted by user.") + await asyncio.sleep(min(0.05, remaining_seconds)) + remaining_seconds -= 0.05 + + finally: + broadcast_channel_abort_controller.stop() + + return wrapped + + return decorator From 337824442afb04e9d446b0accc310e227c66bd62 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Fri, 27 Feb 2026 11:05:10 -0800 Subject: [PATCH 2/4] update: cleanup --- utils/interrupts.py | 62 ++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/utils/interrupts.py b/utils/interrupts.py index f7435bf3..fa2dbfcf 100644 --- a/utils/interrupts.py +++ b/utils/interrupts.py @@ -1,11 +1,12 @@ import asyncio import inspect -import sys import uuid from dataclasses import dataclass from functools import wraps from typing import Any, Awaitable, Callable +from mat3ra.utils.jupyterlite.environment import ENVIRONMENT, EnvironmentsEnum + try: from IPython.display import HTML, display # type: ignore except Exception: @@ -99,7 +100,7 @@ def __post_init__(self) -> None: self._on_message_proxy = None def start(self) -> None: - if sys.platform != "emscripten": + if ENVIRONMENT != EnvironmentsEnum.PYODIDE: return if self._broadcast_channel is not None: return @@ -115,7 +116,7 @@ def on_message(event) -> None: self.is_aborted = True self._on_message_proxy = create_proxy(on_message) - self._broadcast_channel.onmessage = self._on_message_proxy + self._broadcast_channel.onmessage = self._on_message_proxy # type: ignore def stop(self) -> None: if self._broadcast_channel is None: @@ -151,7 +152,7 @@ async def run_interruptible_loop_async( broadcast_channel_abort_controller = BroadcastChannelAbortController(channel_name=channel_name) broadcast_channel_abort_controller.start() - if show_controls and sys.platform == "emscripten": + if show_controls and ENVIRONMENT != EnvironmentsEnum.PYODIDE: display_abort_controls_in_current_cell_output(channel_name=channel_name, abort_button_text="Abort") try: @@ -171,7 +172,14 @@ async def run_interruptible_loop_async( broadcast_channel_abort_controller.stop() -def interruptible_polling_loop(poll_interval_kwarg_name: str = "poll_interval"): +def interruptible_polling_loop( + poll_interval_kwarg_name: str = "poll_interval", + *, + default_poll_interval_seconds: float = 10.0, + channel_name: str = "mat3ra_abort_channel", + check_interval_seconds: float = 0.05, + show_controls: bool = True, +): """ Decorator for single-iteration polling functions. @@ -179,37 +187,29 @@ def interruptible_polling_loop(poll_interval_kwarg_name: str = "poll_interval"): - return True to continue - return False to stop - poll_interval is passed at call time. + The decorated function becomes an `async def` that runs the polling loop until completion. + + The polling interval can be passed at call time using `poll_interval_kwarg_name` (defaults to + `"poll_interval"`). If not provided, `default_poll_interval_seconds` is used. """ def decorator(poll_step_function: Callable[..., Any]) -> Callable[..., Any]: @wraps(poll_step_function) async def wrapped(*args: Any, **kwargs: Any) -> None: - poll_interval_seconds = kwargs.pop(poll_interval_kwarg_name) - - broadcast_channel_abort_controller = BroadcastChannelAbortController() - broadcast_channel_abort_controller.start() - - if sys.platform == "emscripten": - display_abort_controls_in_current_cell_output() - - try: - while True: - result = poll_step_function(*args, **kwargs) - should_continue = await result if inspect.isawaitable(result) else result - - if not should_continue: - return - - remaining_seconds = float(poll_interval_seconds) - while remaining_seconds > 0: - if broadcast_channel_abort_controller.is_aborted: - raise UserAbortError("Aborted by user.") - await asyncio.sleep(min(0.05, remaining_seconds)) - remaining_seconds -= 0.05 - - finally: - broadcast_channel_abort_controller.stop() + poll_interval_seconds = float(kwargs.pop(poll_interval_kwarg_name, default_poll_interval_seconds)) + + async def loop_body() -> bool: + result = poll_step_function(*args, **kwargs) + should_continue = await result if inspect.isawaitable(result) else result + return bool(should_continue) + + await run_interruptible_loop_async( + loop_body, + poll_interval_seconds, + channel_name=channel_name, + check_interval_seconds=check_interval_seconds, + show_controls=show_controls, + ) return wrapped From f90b3187602cbcab2616a44d24efddf44ac37eeb Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Fri, 27 Feb 2026 11:09:04 -0800 Subject: [PATCH 3/4] update: cleanup 2 --- utils/api.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/utils/api.py b/utils/api.py index 903f255e..be3a2e13 100644 --- a/utils/api.py +++ b/utils/api.py @@ -2,6 +2,7 @@ import json import os import urllib.request +from collections import Counter from typing import List, Optional from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints @@ -58,8 +59,21 @@ def get_jobs_statuses_by_ids(endpoint: JobEndpoints, job_ids: List[str]) -> List return [job["status"] for job in jobs] +def _print_jobs_status_table(counts: Counter) -> None: + headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] + now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + row = [ + now, + counts.get("submitted", 0), + counts.get("active", 0), + counts.get("finished", 0), + counts.get("error", 0), + ] + print(tabulate([row], headers, tablefmt="grid", stralign="center")) + + @interruptible_polling_loop() -def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: list) -> bool: +def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: List[str]) -> bool: """ Waits for jobs to finish and prints their statuses. A job is considered finished if it is not in "pre-submission", "submitted", or, "active" status. @@ -67,21 +81,13 @@ def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: list) -> bool Args: endpoint (JobEndpoints): Job endpoint object from the Exabyte API Client job_ids (list): list of job IDs to wait for - poll_interval (int): poll interval for job information in seconds. Defaults to 10. """ statuses = get_jobs_statuses_by_ids(endpoint, job_ids) + counts = Counter(statuses) + _print_jobs_status_table(counts) - errored_jobs = sum(status == "error" for status in statuses) - active_jobs = sum(status == "active" for status in statuses) - finished_jobs = sum(status == "finished" for status in statuses) - submitted_jobs = sum(status == "submitted" for status in statuses) - - headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] - now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") - row = [now, submitted_jobs, active_jobs, finished_jobs, errored_jobs] - print(tabulate([row], headers, tablefmt="grid", stralign="center")) - - return any(status in ["pre-submission", "submitted", "active"] for status in statuses) + active_statuses = {"pre-submission", "submitted", "active"} + return any(status in active_statuses for status in statuses) def copy_bank_workflow_by_system_name(endpoint: BankWorkflowEndpoints, system_name: str, account_id: str) -> dict: From 3335ee12512cb76b0ea02fd219b0aa4cdd2709a8 Mon Sep 17 00:00:00 2001 From: VsevolodX Date: Fri, 27 Feb 2026 11:21:14 -0800 Subject: [PATCH 4/4] update: use interupt in nb --- other/materials_designer/workflows/band_gap.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/other/materials_designer/workflows/band_gap.ipynb b/other/materials_designer/workflows/band_gap.ipynb index 5febaf86..83cf69b3 100644 --- a/other/materials_designer/workflows/band_gap.ipynb +++ b/other/materials_designer/workflows/band_gap.ipynb @@ -589,9 +589,9 @@ "metadata": {}, "outputs": [], "source": [ - "from utils.api import wait_for_jobs_to_finish\n", + "from utils.api import wait_for_jobs_to_finish_async\n", "\n", - "wait_for_jobs_to_finish(client.jobs, [job_id], poll_interval=POLL_INTERVAL)" + "await wait_for_jobs_to_finish_async(client.jobs, [job_id], poll_interval=POLL_INTERVAL)" ] }, {