Skip to content

Commit 6fff187

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add a keep_alive endpoint in Fast api server template and implement it for Adk template
PiperOrigin-RevId: 879663369
1 parent 2b0a98c commit 6fff187

2 files changed

Lines changed: 179 additions & 0 deletions

File tree

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,106 @@ def test_register_operations(self):
368368
for operation in operations:
369369
assert operation in dir(app)
370370

371+
@mock.patch("os.rename")
372+
@mock.patch("tempfile.NamedTemporaryFile")
373+
@mock.patch("time.time", return_value=1000.0)
374+
@mock.patch("os.getpid", return_value=12345)
375+
@mock.patch("os.path.exists", return_value=True)
376+
@mock.patch("os.path.isdir", return_value=True)
377+
def test__update_keep_alive_timestamp(
378+
self,
379+
isdir_mock,
380+
exists_mock,
381+
getpid_mock,
382+
time_mock,
383+
tempfile_mock,
384+
rename_mock,
385+
):
386+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
387+
388+
mock_tf_ret = mock.MagicMock() # this is result of NamedTemporaryFile call
389+
mock_tf_ret.name = "/dev/shm/tmp_xyz123"
390+
mock_tf_ret.__enter__.return_value = mock_tf_ret
391+
mock_tf_ret.__exit__.return_value = (None, None, None)
392+
tempfile_mock.return_value = mock_tf_ret
393+
394+
app._update_keep_alive_timestamp()
395+
396+
tempfile_mock.assert_called_once_with(
397+
"w",
398+
dir="/dev/shm",
399+
delete=False,
400+
prefix="tmp_keep_alive_timestamp_12345_",
401+
)
402+
pid = 12345
403+
lease = 60 * 60
404+
expected_timestamp = str(1000.0 + lease)
405+
mock_tf_ret.write.assert_called_once_with(expected_timestamp)
406+
rename_mock.assert_called_once_with(
407+
"/dev/shm/tmp_xyz123", f"/dev/shm/keep_alive_timestamp_{pid}"
408+
)
409+
410+
@mock.patch("glob.glob")
411+
@mock.patch("os.kill")
412+
@mock.patch("os.remove")
413+
@mock.patch("time.time")
414+
def test_keep_alive_no_files(self, time_mock, remove_mock, kill_mock, glob_mock):
415+
glob_mock.return_value = []
416+
time_mock.return_value = 1000.0
417+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
418+
assert not app.keep_alive()
419+
420+
@mock.patch("glob.glob")
421+
@mock.patch("os.kill")
422+
@mock.patch("os.remove")
423+
@mock.patch("time.time")
424+
def test_keep_alive_one_file_busy(
425+
self, time_mock, remove_mock, kill_mock, glob_mock
426+
):
427+
pid = 12345
428+
glob_mock.return_value = [f"/dev/shm/keep_alive_timestamp_{pid}"]
429+
time_mock.return_value = 1500.0
430+
mock_read_data = str(2000.0) # Timestamp in file is in future
431+
with mock.patch("builtins.open", mock.mock_open(read_data=mock_read_data)):
432+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
433+
assert app.keep_alive()
434+
kill_mock.assert_called_once_with(pid, 0)
435+
remove_mock.assert_not_called()
436+
437+
@mock.patch("glob.glob")
438+
@mock.patch("os.kill")
439+
@mock.patch("os.remove")
440+
@mock.patch("time.time")
441+
def test_keep_alive_one_file_not_busy(
442+
self, time_mock, remove_mock, kill_mock, glob_mock
443+
):
444+
pid = 12345
445+
glob_mock.return_value = [f"/dev/shm/keep_alive_timestamp_{pid}"]
446+
time_mock.return_value = 2500.0
447+
mock_read_data = str(2000.0) # Timestamp in file is in past
448+
with mock.patch("builtins.open", mock.mock_open(read_data=mock_read_data)):
449+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
450+
assert not app.keep_alive()
451+
kill_mock.assert_called_once_with(pid, 0)
452+
remove_mock.assert_not_called()
453+
454+
@mock.patch("glob.glob")
455+
@mock.patch("os.kill")
456+
@mock.patch("os.remove")
457+
@mock.patch("time.time")
458+
def test_keep_alive_stale_file_process_dead(
459+
self, time_mock, remove_mock, kill_mock, glob_mock
460+
):
461+
pid = 12345
462+
stale_file = f"/dev/shm/keep_alive_timestamp_{pid}"
463+
glob_mock.return_value = [stale_file]
464+
kill_mock.side_effect = ProcessLookupError()
465+
time_mock.return_value = 1000.0
466+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
467+
assert not app.keep_alive()
468+
kill_mock.assert_called_once_with(pid, 0)
469+
remove_mock.assert_called_once_with(stale_file)
470+
371471
def test_stream_query(
372472
self,
373473
default_instrumentor_builder_mock: mock.Mock,

vertexai/agent_engines/templates/adk.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@
9595

9696
_DEFAULT_APP_NAME = "default-app-name"
9797
_DEFAULT_USER_ID = "default-user-id"
98+
_KEEP_ALIVE_DIR = "/dev/shm"
99+
_KEEP_ALIVE_FILENAME_PREFIX = "keep_alive_timestamp"
100+
_KEEP_ALIVE_TEMP_FILENAME_PREFIX = "tmp_keep_alive_timestamp"
101+
_KEEP_ALIVE_LEASE_SECONDS = 60 #TODO: Change to 60 * 60 for 1 hour.
98102
_TELEMETRY_API_DISABLED_WARNING = (
99103
"Tracing integration for Agent Engine has migrated to a new API.\n"
100104
"The 'telemetry.googleapis.com' has not been enabled in project %s. \n"
@@ -992,6 +996,77 @@ def set_up(self):
992996
memory_service=self._tmpl_attrs.get("in_memory_memory_service"),
993997
)
994998

999+
def _update_keep_alive_timestamp(self):
1000+
"""Updates the keep-alive timestamp.
1001+
1002+
It writes the current timestamp to a file
1003+
/dev/shm/keep_alive_timestamp_{pid} where pid is the process id.
1004+
This is done atomically by writing to a temporary file and then renaming it.
1005+
This file can be checked by other processes to see if this agent process
1006+
is still alive and processing requests.
1007+
"""
1008+
import os
1009+
import tempfile
1010+
import time
1011+
1012+
try:
1013+
pid = os.getpid()
1014+
timestamp = str(time.time() + _KEEP_ALIVE_LEASE_SECONDS)
1015+
filename = f"{_KEEP_ALIVE_DIR}/{_KEEP_ALIVE_FILENAME_PREFIX}_{pid}"
1016+
tmp_dir = _KEEP_ALIVE_DIR
1017+
if not os.path.exists(tmp_dir) or not os.path.isdir(tmp_dir):
1018+
return
1019+
with tempfile.NamedTemporaryFile(
1020+
"w",
1021+
dir=tmp_dir,
1022+
delete=False,
1023+
prefix=f"{_KEEP_ALIVE_TEMP_FILENAME_PREFIX}_{pid}_",
1024+
) as fp:
1025+
fp.write(timestamp)
1026+
tmp_path = fp.name
1027+
os.rename(tmp_path, filename)
1028+
except Exception as e:
1029+
# If there's any issue writing the timestamp, we log a warning
1030+
# and ignore it.
1031+
_warn(f"Failed to update keep-alive timestamp: {e}")
1032+
1033+
def keep_alive(self) -> bool:
1034+
"""Checks if the agent is busy."""
1035+
import glob
1036+
import os
1037+
import time
1038+
1039+
max_timestamp = -1.0
1040+
try:
1041+
timestamp_files = glob.glob(
1042+
f"{_KEEP_ALIVE_DIR}/{_KEEP_ALIVE_FILENAME_PREFIX}_*"
1043+
)
1044+
for timestamp_file in timestamp_files:
1045+
try:
1046+
# Extract PID from filename (e.g., keep_alive_timestamp_1234)
1047+
basename = os.path.basename(timestamp_file)
1048+
pid_str = basename[len(_KEEP_ALIVE_FILENAME_PREFIX) + 1 :]
1049+
pid = int(pid_str)
1050+
1051+
# Check if the process that created the file is still running
1052+
os.kill(pid, 0)
1053+
1054+
with open(timestamp_file, "r") as f:
1055+
timestamp = float(f.read())
1056+
if timestamp > max_timestamp:
1057+
max_timestamp = timestamp
1058+
except (ProcessLookupError, ValueError, FileNotFoundError):
1059+
# If process is dead or file is missing/corrupt, remove the stale file
1060+
try:
1061+
os.remove(timestamp_file)
1062+
except FileNotFoundError:
1063+
pass
1064+
continue
1065+
except Exception as e:
1066+
_warn(f"Failed to read timestamp files: {e}")
1067+
1068+
return time.time() <= max_timestamp
1069+
9951070
async def async_stream_query(
9961071
self,
9971072
*,
@@ -1035,6 +1110,7 @@ async def async_stream_query(
10351110
from vertexai.agent_engines import _utils
10361111
from google.genai import types
10371112

1113+
self._update_keep_alive_timestamp()
10381114
if isinstance(message, Dict):
10391115
content = types.Content.model_validate(message)
10401116
elif isinstance(message, str):
@@ -1140,6 +1216,7 @@ def stream_query(
11401216
from vertexai.agent_engines import _utils
11411217
from google.genai import types
11421218

1219+
self._update_keep_alive_timestamp()
11431220
if isinstance(message, Dict):
11441221
content = types.Content.model_validate(message)
11451222
elif isinstance(message, str):
@@ -1191,6 +1268,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
11911268
from google.genai import types
11921269
from google.genai.errors import ClientError
11931270

1271+
self._update_keep_alive_timestamp()
11941272
request = _StreamRunRequest(**json.loads(request_json))
11951273
if not any(
11961274
self._tmpl_attrs.get(service)
@@ -1659,6 +1737,7 @@ def register_operations(self) -> Dict[str, List[str]]:
16591737
"async_stream_query",
16601738
"streaming_agent_run_with_events",
16611739
],
1740+
"keep_alive": ["keep_alive"],
16621741
}
16631742

16641743
def _telemetry_enabled(self) -> Optional[bool]:

0 commit comments

Comments
 (0)