Skip to content

Commit 3015b8d

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
fix: save artifact in streaming agent run with events when multiturn
PiperOrigin-RevId: 873078156
1 parent 5705565 commit 3015b8d

2 files changed

Lines changed: 44 additions & 5 deletions

File tree

  • vertexai
    • agent_engines/templates
    • preview/reasoning_engines/templates

vertexai/agent_engines/templates/adk.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,20 @@ async def _init_session(
683683
if request.events:
684684
for event in request.events:
685685
await session_service.append_event(session, Event(**event))
686+
if request.artifacts:
687+
await self._save_artifacts(
688+
session.id, artifact_service, request
689+
)
690+
return session
691+
692+
async def _save_artifacts(
693+
self,
694+
session_id: str,
695+
artifact_service: "BaseArtifactService",
696+
request: _StreamRunRequest,
697+
):
698+
"""Saves the artifacts."""
699+
app = self._tmpl_attrs.get("app")
686700
if request.artifacts:
687701
for artifact in request.artifacts:
688702
artifact = _Artifact(**artifact)
@@ -693,7 +707,7 @@ async def _init_session(
693707
saved_version = await artifact_service.save_artifact(
694708
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
695709
user_id=request.user_id,
696-
session_id=session.id,
710+
session_id=session_id,
697711
filename=artifact.file_name,
698712
artifact=version_data.data,
699713
)
@@ -707,7 +721,6 @@ async def _init_session(
707721
saved_version,
708722
version_data.version,
709723
)
710-
return session
711724

712725
async def _convert_response_events(
713726
self,
@@ -1209,6 +1222,11 @@ async def streaming_agent_run_with_events(self, request_json: str):
12091222
user_id=request.user_id,
12101223
session_id=request.session_id,
12111224
)
1225+
await self._save_artifacts(
1226+
session_id=request.session_id,
1227+
artifact_service=artifact_service,
1228+
request=request,
1229+
)
12121230
except ClientError:
12131231
pass
12141232
if not session:

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,23 @@ async def _init_session(
616616
if request.events:
617617
for event in request.events:
618618
await session_service.append_event(session, Event(**event))
619+
if request.artifacts:
620+
await self._save_artifacts(
621+
session_id=session.id,
622+
artifact_service=artifact_service,
623+
request=request,
624+
)
625+
626+
return session
627+
628+
async def _save_artifacts(
629+
self,
630+
session_id: str,
631+
artifact_service: "BaseArtifactService",
632+
request: _StreamRunRequest,
633+
):
634+
"""Saves the artifacts."""
635+
app = self._tmpl_attrs.get("app")
619636
if request.artifacts:
620637
for artifact in request.artifacts:
621638
artifact = _Artifact(**artifact)
@@ -624,9 +641,9 @@ async def _init_session(
624641
):
625642
version_data = _ArtifactVersion(**version_data)
626643
saved_version = await artifact_service.save_artifact(
627-
app_name=self._tmpl_attrs.get("app_name"),
644+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
628645
user_id=request.user_id,
629-
session_id=session.id,
646+
session_id=session_id,
630647
filename=artifact.file_name,
631648
artifact=version_data.data,
632649
)
@@ -640,7 +657,6 @@ async def _init_session(
640657
saved_version,
641658
version_data.version,
642659
)
643-
return session
644660

645661
async def _convert_response_events(
646662
self,
@@ -1054,6 +1070,11 @@ async def _invoke_agent_async():
10541070
artifact_service=artifact_service,
10551071
request=request,
10561072
)
1073+
await self._save_artifacts(
1074+
session_id=request.session_id,
1075+
artifact_service=artifact_service,
1076+
request=request,
1077+
)
10571078
else:
10581079
# Not providing a session ID will create a new in-memory session.
10591080
session_service = self._tmpl_attrs.get("in_memory_session_service")

0 commit comments

Comments
 (0)