diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py index a47199a..09ad42d 100644 --- a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -298,6 +298,32 @@ def enable_port_forwarding( return (port_available, port_forward_process) +def stream_pod_logs(pod_name: str) -> subprocess.Popen[str]: + """Streams logs from the given pod. + + Args: + pod_name: The name of the pod. + + Returns: + The process for streaming the logs. + + Raises: + Exception: If the log streaming fails. + """ + command = ["kubectl", "logs", "-f", pod_name] + try: + return subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, # Line buffered + ) + except Exception as _: + _logger.exception("Error streaming logs for pod %s", pod_name) + raise + + def delete_gke_job(job_name: str) -> None: """Deletes the given job from the GKE cluster. diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 5e3c618..3c3b95a 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -9,6 +9,7 @@ import random import string import subprocess +import threading from typing import Any import jax @@ -123,6 +124,50 @@ def _deploy_pathways_proxy_server( _logger.info("Successfully deployed Pathways proxy.") +def _wait_for_placement( + pod_name: str, + num_slices: int, + stream_logs_func=gke_utils.stream_pod_logs, +) -> None: + """Waits for the placement to be complete by checking proxy logs.""" + _logger.info("Streaming proxy logs until the placement is complete...") + with stream_logs_func(pod_name) as log_process: + keywords = [ + "placement", + "Signaling to RM", + "Transition slice", + "FAILED_PRECONDITION", + ] + end_phrase = "unplaced -> placed" + placement_count = 0 + + if not log_process.stdout: + _logger.error("Log streaming process stdout is empty. Terminating.") + log_process.terminate() + _, stderr = log_process.communicate() + raise RuntimeError( + "Failed to stream proxy logs: stdout not available.\n" + f"STDERR: {stderr}" + ) + + for line in log_process.stdout: + line_lower = line.lower() + if any(keyword.lower() in line_lower for keyword in keywords): + _logger.info("Proxy log: %s", line.strip()) + + if end_phrase.lower() in line_lower: + placement_count += 1 + if placement_count < num_slices: + _logger.info( + "TPU slice %d/%d placed!", + placement_count, + num_slices, + ) + else: + _logger.info("TPU placement for %d slice(s) complete!", num_slices) + break + + def _restore_env_var(key: str, original_value: str | None) -> None: """Restores an environment variable to its original value or unsets it.""" if original_value is None: @@ -147,6 +192,7 @@ class _ISCPathways: expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. proxy_job_name: The name to use for the deployed proxy. + proxy_pod_name: The name of the proxy pod, assigned during deployment. proxy_server_image: The image to use for the proxy server. proxy_options: Configuration options for the Pathways proxy. """ @@ -171,6 +217,7 @@ def __init__( self.pathways_service = pathways_service self.expected_tpu_instances = expected_tpu_instances self._proxy_job_name = proxy_job_name + self.proxy_pod_name: str = "" self._port_forward_process = None self._proxy_port = None self.proxy_server_image = proxy_server_image @@ -220,9 +267,11 @@ def __enter__(self): ) _logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link) - proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name) + self.proxy_pod_name = gke_utils.wait_for_pod(self._proxy_job_name) self._proxy_port, self._port_forward_process = ( - gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT) + gke_utils.enable_port_forwarding( + self.proxy_pod_name, PROXY_SERVER_PORT + ) ) # Update the JAX backend to use the proxy. @@ -351,4 +400,20 @@ def connect( proxy_server_image=proxy_server_image, proxy_options=proxy_options, ) as t: + if t.proxy_pod_name: + num_slices = sum(t.expected_tpu_instances.values()) + placement_thread = threading.Thread( + target=_wait_for_placement, + args=( + t.proxy_pod_name, + num_slices, + ), + daemon=True, + ) + placement_thread.start() + else: + _logger.warning( + "proxy_pod_name not set on _ISCPathways instance, skipping background" + " _wait_for_placement." + ) yield t