@@ -683,6 +683,20 @@ async def _init_session(
683683 if request .events :
684684 for event in request .events :
685685 await session_service .append_event (session , Event (** event ))
686+ if request .artifacts :
687+ await self ._save_artifacts (
688+ session .id , artifact_service , request
689+ )
690+ return session
691+
692+ async def _save_artifacts (
693+ self ,
694+ session_id : str ,
695+ artifact_service : "BaseArtifactService" ,
696+ request : _StreamRunRequest ,
697+ ):
698+ """Saves the artifacts."""
699+ app = self ._tmpl_attrs .get ("app" )
686700 if request .artifacts :
687701 for artifact in request .artifacts :
688702 artifact = _Artifact (** artifact )
@@ -693,7 +707,7 @@ async def _init_session(
693707 saved_version = await artifact_service .save_artifact (
694708 app_name = app .name if app else self ._tmpl_attrs .get ("app_name" ),
695709 user_id = request .user_id ,
696- session_id = session . id ,
710+ session_id = session_id ,
697711 filename = artifact .file_name ,
698712 artifact = version_data .data ,
699713 )
@@ -707,7 +721,6 @@ async def _init_session(
707721 saved_version ,
708722 version_data .version ,
709723 )
710- return session
711724
712725 async def _convert_response_events (
713726 self ,
@@ -1209,6 +1222,11 @@ async def streaming_agent_run_with_events(self, request_json: str):
12091222 user_id = request .user_id ,
12101223 session_id = request .session_id ,
12111224 )
1225+ self ._save_artifacts (
1226+ session_id = request .session_id ,
1227+ artifact_service = artifact_service ,
1228+ request = request
1229+ )
12121230 except ClientError :
12131231 pass
12141232 if not session :
0 commit comments