From 2cb53bbe918851ee9769e7678cc8d768890b44e9 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Tue, 7 Apr 2026 13:18:58 -0700 Subject: [PATCH] Add background log streaming to detect TPU placement completion A background thread watches for specific log messages indicating that the proxy pod is waiting for placement until the TPU placement process has finished. This allows for better tracking of the Pathways service readiness. Continued "waiting" messages from proxy might indicate that the Pathways service doesn't have enough TPU availability to process the request. PiperOrigin-RevId: 896055974 --- .../shared_pathways_service/gke_utils.py | 26 +++++++ .../shared_pathways_service/isc_pathways.py | 69 ++++++++++++++++++- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py index a47199a..09ad42d 100644 --- a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -298,6 +298,32 @@ def enable_port_forwarding( return (port_available, port_forward_process) +def stream_pod_logs(pod_name: str) -> subprocess.Popen[str]: + """Streams logs from the given pod. + + Args: + pod_name: The name of the pod. + + Returns: + The process for streaming the logs. + + Raises: + Exception: If the log streaming fails. + """ + command = ["kubectl", "logs", "-f", pod_name] + try: + return subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, # Line buffered + ) + except Exception as _: + _logger.exception("Error streaming logs for pod %s", pod_name) + raise + + def delete_gke_job(job_name: str) -> None: """Deletes the given job from the GKE cluster. diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 5e3c618..3c3b95a 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -9,6 +9,7 @@ import random import string import subprocess +import threading from typing import Any import jax @@ -123,6 +124,50 @@ def _deploy_pathways_proxy_server( _logger.info("Successfully deployed Pathways proxy.") +def _wait_for_placement( + pod_name: str, + num_slices: int, + stream_logs_func=gke_utils.stream_pod_logs, +) -> None: + """Waits for the placement to be complete by checking proxy logs.""" + _logger.info("Streaming proxy logs until the placement is complete...") + with stream_logs_func(pod_name) as log_process: + keywords = [ + "placement", + "Signaling to RM", + "Transition slice", + "FAILED_PRECONDITION", + ] + end_phrase = "unplaced -> placed" + placement_count = 0 + + if not log_process.stdout: + _logger.error("Log streaming process stdout is empty. Terminating.") + log_process.terminate() + _, stderr = log_process.communicate() + raise RuntimeError( + "Failed to stream proxy logs: stdout not available.\n" + f"STDERR: {stderr}" + ) + + for line in log_process.stdout: + line_lower = line.lower() + if any(keyword.lower() in line_lower for keyword in keywords): + _logger.info("Proxy log: %s", line.strip()) + + if end_phrase.lower() in line_lower: + placement_count += 1 + if placement_count < num_slices: + _logger.info( + "TPU slice %d/%d placed!", + placement_count, + num_slices, + ) + else: + _logger.info("TPU placement for %d slice(s) complete!", num_slices) + break + + def _restore_env_var(key: str, original_value: str | None) -> None: """Restores an environment variable to its original value or unsets it.""" if original_value is None: @@ -147,6 +192,7 @@ class _ISCPathways: expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. proxy_job_name: The name to use for the deployed proxy. + proxy_pod_name: The name of the proxy pod, assigned during deployment. proxy_server_image: The image to use for the proxy server. proxy_options: Configuration options for the Pathways proxy. """ @@ -171,6 +217,7 @@ def __init__( self.pathways_service = pathways_service self.expected_tpu_instances = expected_tpu_instances self._proxy_job_name = proxy_job_name + self.proxy_pod_name: str = "" self._port_forward_process = None self._proxy_port = None self.proxy_server_image = proxy_server_image @@ -220,9 +267,11 @@ def __enter__(self): ) _logger.info("View proxy logs in Cloud Logging: %s", cloud_logging_link) - proxy_pod = gke_utils.wait_for_pod(self._proxy_job_name) + self.proxy_pod_name = gke_utils.wait_for_pod(self._proxy_job_name) self._proxy_port, self._port_forward_process = ( - gke_utils.enable_port_forwarding(proxy_pod, PROXY_SERVER_PORT) + gke_utils.enable_port_forwarding( + self.proxy_pod_name, PROXY_SERVER_PORT + ) ) # Update the JAX backend to use the proxy. @@ -351,4 +400,20 @@ def connect( proxy_server_image=proxy_server_image, proxy_options=proxy_options, ) as t: + if t.proxy_pod_name: + num_slices = sum(t.expected_tpu_instances.values()) + placement_thread = threading.Thread( + target=_wait_for_placement, + args=( + t.proxy_pod_name, + num_slices, + ), + daemon=True, + ) + placement_thread.start() + else: + _logger.warning( + "proxy_pod_name not set on _ISCPathways instance, skipping background" + " _wait_for_placement." + ) yield t