diff --git a/other/materials_designer/workflows/band_gap.ipynb b/other/materials_designer/workflows/band_gap.ipynb index 5febaf86..83cf69b3 100644 --- a/other/materials_designer/workflows/band_gap.ipynb +++ b/other/materials_designer/workflows/band_gap.ipynb @@ -589,9 +589,9 @@ "metadata": {}, "outputs": [], "source": [ - "from utils.api import wait_for_jobs_to_finish\n", + "from utils.api import wait_for_jobs_to_finish_async\n", "\n", - "wait_for_jobs_to_finish(client.jobs, [job_id], poll_interval=POLL_INTERVAL)" + "await wait_for_jobs_to_finish_async(client.jobs, [job_id], poll_interval=POLL_INTERVAL)" ] }, { diff --git a/utils/api.py b/utils/api.py index d8ada34d..be3a2e13 100644 --- a/utils/api.py +++ b/utils/api.py @@ -1,8 +1,8 @@ import datetime import json import os -import time import urllib.request +from collections import Counter from typing import List, Optional from mat3ra.api_client.endpoints.bank_workflows import BankWorkflowEndpoints @@ -12,6 +12,8 @@ from mat3ra.api_client.endpoints.workflows import WorkflowEndpoints from tabulate import tabulate +from utils.interrupts import interruptible_polling_loop + def save_files(job_id: str, job_endpoint: JobEndpoints, filename_on_cloud: str, filename_on_disk: str) -> None: """ @@ -57,7 +59,21 @@ def get_jobs_statuses_by_ids(endpoint: JobEndpoints, job_ids: List[str]) -> List return [job["status"] for job in jobs] -def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval: int = 10) -> None: +def _print_jobs_status_table(counts: Counter) -> None: + headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] + now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + row = [ + now, + counts.get("submitted", 0), + counts.get("active", 0), + counts.get("finished", 0), + counts.get("error", 0), + ] + print(tabulate([row], headers, tablefmt="grid", stralign="center")) + + +@interruptible_polling_loop() +def wait_for_jobs_to_finish_async(endpoint: JobEndpoints, job_ids: List[str]) -> bool: """ Waits for jobs to finish and prints their statuses. A job is considered finished if it is not in "pre-submission", "submitted", or, "active" status. @@ -65,25 +81,13 @@ def wait_for_jobs_to_finish(endpoint: JobEndpoints, job_ids: list, poll_interval Args: endpoint (JobEndpoints): Job endpoint object from the Exabyte API Client job_ids (list): list of job IDs to wait for - poll_interval (int): poll interval for job information in seconds. Defaults to 10. """ - print("Wait for jobs to finish, poll interval: {0} sec".format(poll_interval)) - while True: - statuses = get_jobs_statuses_by_ids(endpoint, job_ids) - - errored_jobs = len([status for status in statuses if status == "error"]) - active_jobs = len([status for status in statuses if status == "active"]) - finished_jobs = len([status for status in statuses if status == "finished"]) - submitted_jobs = len([status for status in statuses if status == "submitted"]) - - headers = ["TIME", "SUBMITTED-JOBS", "ACTIVE-JOBS", "FINISHED-JOBS", "ERRORED-JOBS"] - now = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") - row = [now, submitted_jobs, active_jobs, finished_jobs, errored_jobs] - print(tabulate([row], headers, tablefmt="grid", stralign="center")) + statuses = get_jobs_statuses_by_ids(endpoint, job_ids) + counts = Counter(statuses) + _print_jobs_status_table(counts) - if all([status not in ["pre-submission", "submitted", "active"] for status in statuses]): - break - time.sleep(poll_interval) + active_statuses = {"pre-submission", "submitted", "active"} + return any(status in active_statuses for status in statuses) def copy_bank_workflow_by_system_name(endpoint: BankWorkflowEndpoints, system_name: str, account_id: str) -> dict: diff --git a/utils/interrupts.py b/utils/interrupts.py new file mode 100644 index 00000000..fa2dbfcf --- /dev/null +++ b/utils/interrupts.py @@ -0,0 +1,216 @@ +import asyncio +import inspect +import uuid +from dataclasses import dataclass +from functools import wraps +from typing import Any, Awaitable, Callable + +from mat3ra.utils.jupyterlite.environment import ENVIRONMENT, EnvironmentsEnum + +try: + from IPython.display import HTML, display # type: ignore +except Exception: + HTML = None + display = None + + +class UserAbortError(RuntimeError): + pass + + +def display_abort_controls_in_current_cell_output( + channel_name: str = "mat3ra_abort_channel", + abort_button_text: str = "Abort", +) -> None: + """ + Shows: + [Abort] Press ESC to abort + + Only works in notebook frontends that support HTML output. + Safe no-op otherwise. + """ + if HTML is None or display is None: + return + + element_id = f"abort_controls_{uuid.uuid4().hex}" + + display( + HTML( + f""" +