Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions other/materials_designer/workflows/band_gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@
"metadata": {},
"outputs": [],
"source": [
"from utils.api import wait_for_jobs_to_finish\n",
"from utils.api import wait_for_jobs_to_finish_async\n",
"\n",
"wait_for_jobs_to_finish(client.jobs, [job_id], poll_interval=POLL_INTERVAL)"
"await wait_for_jobs_to_finish_async(client.jobs, [job_id], poll_interval=POLL_INTERVAL)"
]
},
{
Expand Down
42 changes: 23 additions & 19 deletions utils/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import datetime
import json
import os
import time
import urllib.request
from collections import Counter
from typing import List, Optional

from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints
Expand All @@ -12,6 +12,8 @@
from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints
from tabulate import tabulate

from utils.interrupts import interruptible_polling_loop


def save_files(job_id: str, job_endpoint: JobEndpoints, filename_on_cloud: str, filename_on_disk: str) -> None:
"""
Expand Down Expand Up @@ -57,33 +59,35 @@ def get_jobs_statuses_by_ids(endpoint: JobEndpoints, job_ids: List[str]) -> List
return [job["status"] for job in jobs]


def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval: int = 10) -> None:
def _print_jobs_status_table(counts: Counter) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"]
now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
row = [
now,
counts.get("submitted", 0),
counts.get("active", 0),
counts.get("finished", 0),
counts.get("error", 0),
]
print(tabulate([row], headers, tablefmt="grid", stralign="center"))


@interruptible_polling_loop()
def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: List[str]) -> bool:
"""
Waits for jobs to finish and prints their statuses.
A job is considered finished if it is not in "pre-submission", "submitted", or, "active" status.

Args:
endpoint (JobEndpoints): Job endpoint object from the Exabyte API Client
job_ids (list): list of job IDs to wait for
poll_interval (int): poll interval for job information in seconds. Defaults to 10.
"""
print("Wait for jobs to finish, poll interval: {0} sec".format(poll_interval))
while True:
statuses = get_jobs_statuses_by_ids(endpoint, job_ids)

errored_jobs = len([status for status in statuses if status == "error"])
active_jobs = len([status for status in statuses if status == "active"])
finished_jobs = len([status for status in statuses if status == "finished"])
submitted_jobs = len([status for status in statuses if status == "submitted"])

headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"]
now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
row = [now, submitted_jobs, active_jobs, finished_jobs, errored_jobs]
print(tabulate([row], headers, tablefmt="grid", stralign="center"))
statuses = get_jobs_statuses_by_ids(endpoint, job_ids)
counts = Counter(statuses)
_print_jobs_status_table(counts)

if all([status not in ["pre-submission", "submitted", "active"] for status in statuses]):
break
time.sleep(poll_interval)
active_statuses = {"pre-submission", "submitted", "active"}
return any(status in active_statuses for status in statuses)


def copy_bank_workflow_by_system_name(endpoint: BankWorkflowEndpoints, system_name: str, account_id: str) -> dict:
Expand Down
216 changes: 216 additions & 0 deletions utils/interrupts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import asyncio
import inspect
import uuid
from dataclasses import dataclass
from functools import wraps
from typing import Any, Awaitable, Callable

from mat3ra.utils.jupyterlite.environment import ENVIRONMENT, EnvironmentsEnum
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put this to utils


try:
from IPython.display import HTML, display # type: ignore
except Exception:
HTML = None
display = None


class UserAbortError(RuntimeError):
pass


def display_abort_controls_in_current_cell_output(
channel_name: str = "mat3ra_abort_channel",
abort_button_text: str = "Abort",
) -> None:
"""
Shows:
[Abort] Press ESC to abort

Only works in notebook frontends that support HTML output.
Safe no-op otherwise.
"""
if HTML is None or display is None:
return

element_id = f"abort_controls_{uuid.uuid4().hex}"

display(
HTML(
f"""
<div style="display:flex; align-items:center; gap:12px; margin:8px 0;">
<button
id="{element_id}_button"
style="
background:#d32f2f; color:white; border:none; padding:8px 14px;
border-radius:6px; cursor:pointer; font-weight:600;
"
>{abort_button_text}</button>

<span style="font-family:monospace; opacity:0.85;">Press ESC to abort</span>
<span id="{element_id}_status" style="font-family:monospace; opacity:0.85;"></span>
</div>

<script>
(function() {{
const channelName = {channel_name!r};

// Install ESC broadcaster once per page
if (!window.__mat3raEscapeAbortInstalled) {{
window.__mat3raEscapeAbortInstalled = true;
const escChannel = new BroadcastChannel(channelName);
document.addEventListener("keydown", (event) => {{
if (event.key === "Escape") {{
escChannel.postMessage({{ type: "abort", source: "escape" }});
}}
}}, true);
}}

// Button broadcaster (this output)
const buttonChannel = new BroadcastChannel(channelName);
const buttonElement = document.getElementById("{element_id}_button");
const statusElement = document.getElementById("{element_id}_status");
if (!buttonElement) return;

buttonElement.addEventListener("click", () => {{
buttonChannel.postMessage({{ type: "abort", source: "button" }});
if (statusElement) {{
statusElement.textContent = "Abort sent";
statusElement.style.color = "#d32f2f";
}}
}});
}})();
</script>
"""
)
)


@dataclass
class BroadcastChannelAbortController:
"""
WebWorker-side receiver. Works only in pyodide (emscripten).
In regular Python: start() does nothing and is_aborted stays False.
"""

channel_name: str = "mat3ra_abort_channel"
is_aborted: bool = False

def __post_init__(self) -> None:
self._broadcast_channel = None
self._on_message_proxy = None

def start(self) -> None:
if ENVIRONMENT != EnvironmentsEnum.PYODIDE:
return
if self._broadcast_channel is not None:
return

import js # type: ignore
from pyodide.ffi import create_proxy # type: ignore

self._broadcast_channel = js.BroadcastChannel.new(self.channel_name)

def on_message(event) -> None:
message = getattr(event, "data", None)
if message and getattr(message, "type", None) == "abort":
self.is_aborted = True

self._on_message_proxy = create_proxy(on_message)
self._broadcast_channel.onmessage = self._on_message_proxy # type: ignore

def stop(self) -> None:
if self._broadcast_channel is None:
return

self._broadcast_channel.close()
self._broadcast_channel = None

if self._on_message_proxy is not None:
self._on_message_proxy.destroy()
self._on_message_proxy = None


async def run_interruptible_loop_async(
loop_body: Callable[[], Awaitable[bool]],
poll_interval_seconds: float,
*,
channel_name: str = "mat3ra_abort_channel",
check_interval_seconds: float = 0.05,
show_controls: bool = True,
) -> None:
"""
Minimal wrapper.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?


loop_body():
- do one "poll" iteration
- return True to keep looping, False to stop normally

Between iterations we sleep in small slices so:
- pyodide: ESC/button can be received and stop the loop
- regular Python: yields control (Ctrl+C/Stop works where supported)
"""
broadcast_channel_abort_controller = BroadcastChannelAbortController(channel_name=channel_name)
broadcast_channel_abort_controller.start()

if show_controls and ENVIRONMENT != EnvironmentsEnum.PYODIDE:
display_abort_controls_in_current_cell_output(channel_name=channel_name, abort_button_text="Abort")

try:
while True:
should_continue = await loop_body()
if not should_continue:
return

remaining_seconds = float(poll_interval_seconds)
while remaining_seconds > 0:
if broadcast_channel_abort_controller.is_aborted:
raise UserAbortError("Aborted by user.")
await asyncio.sleep(min(check_interval_seconds, remaining_seconds))
remaining_seconds -= check_interval_seconds

finally:
broadcast_channel_abort_controller.stop()


def interruptible_polling_loop(
poll_interval_kwarg_name: str = "poll_interval",
*,
default_poll_interval_seconds: float = 10.0,
channel_name: str = "mat3ra_abort_channel",
check_interval_seconds: float = 0.05,
show_controls: bool = True,
):
"""
Decorator for single-iteration polling functions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not clear


The decorated function must:
- return True to continue
- return False to stop

The decorated function becomes an `async def` that runs the polling loop until completion.

The polling interval can be passed at call time using `poll_interval_kwarg_name` (defaults to
`"poll_interval"`). If not provided, `default_poll_interval_seconds` is used.
"""

def decorator(poll_step_function: Callable[..., Any]) -> Callable[..., Any]:
@wraps(poll_step_function)
async def wrapped(*args: Any, **kwargs: Any) -> None:
poll_interval_seconds = float(kwargs.pop(poll_interval_kwarg_name, default_poll_interval_seconds))

async def loop_body() -> bool:
result = poll_step_function(*args, **kwargs)
should_continue = await result if inspect.isawaitable(result) else result
return bool(should_continue)

await run_interruptible_loop_async(
loop_body,
poll_interval_seconds,
channel_name=channel_name,
check_interval_seconds=check_interval_seconds,
show_controls=show_controls,
)

return wrapped

return decorator