@@ -629,7 +629,8 @@ def __init__(
629629 if app :
630630 if app_name :
631631 raise ValueError (
632- "When app is provided, app_name should not be provided."
632+ "When app is provided, app_name should not be provided, "
633+ "since it will be derived from app.name."
633634 )
634635 if agent :
635636 raise ValueError ("When app is provided, agent should not be provided." )
@@ -656,6 +657,11 @@ def __init__(
656657 ),
657658 }
658659
660+ def _app_name (self ) -> str :
661+ """Returns the app name."""
662+ app = self ._tmpl_attrs .get ("app" )
663+ return app .name if app else self ._tmpl_attrs .get ("app_name" )
664+
659665 async def _init_session (
660666 self ,
661667 session_service : "BaseSessionService" ,
@@ -672,9 +678,8 @@ async def _init_session(
672678 auth = _Authorization (** auth )
673679 session_state [auth_id ] = auth .access_token
674680
675- app = self ._tmpl_attrs .get ("app" )
676681 session = await session_service .create_session (
677- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
682+ app_name = self ._app_name ( ),
678683 user_id = request .user_id ,
679684 state = session_state ,
680685 )
@@ -694,7 +699,6 @@ async def _save_artifacts(
694699 request : _StreamRunRequest ,
695700 ):
696701 """Saves the artifacts."""
697- app = self ._tmpl_attrs .get ("app" )
698702 if request .artifacts :
699703 for artifact in request .artifacts :
700704 artifact = _Artifact (** artifact )
@@ -703,7 +707,7 @@ async def _save_artifacts(
703707 ):
704708 version_data = _ArtifactVersion (** version_data )
705709 saved_version = await artifact_service .save_artifact (
706- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
710+ app_name = self ._app_name ( ),
707711 user_id = request .user_id ,
708712 session_id = session_id ,
709713 filename = artifact .file_name ,
@@ -749,7 +753,7 @@ async def _convert_response_events(
749753 _ArtifactVersion (
750754 version = version ,
751755 data = await artifact_service .load_artifact (
752- app_name = self ._tmpl_attrs . get ( "app_name" ),
756+ app_name = self ._app_name ( ),
753757 user_id = user_id ,
754758 session_id = session_id ,
755759 filename = key ,
@@ -1206,7 +1210,6 @@ async def streaming_agent_run_with_events(self, request_json: str):
12061210 )
12071211 ):
12081212 self .set_up ()
1209- app = self ._tmpl_attrs .get ("app" )
12101213
12111214 # Try to get the session, if it doesn't exist, create a new one.
12121215 if request .session_id :
@@ -1216,7 +1219,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
12161219 session = None
12171220 try :
12181221 session = await session_service .get_session (
1219- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1222+ app_name = self ._app_name ( ),
12201223 user_id = request .user_id ,
12211224 session_id = request .session_id ,
12221225 )
@@ -1269,7 +1272,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
12691272 if session and not request .session_id :
12701273 app = self ._tmpl_attrs .get ("app" )
12711274 await session_service .delete_session (
1272- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1275+ app_name = self ._app_name ( ),
12731276 user_id = request .user_id ,
12741277 session_id = session .id ,
12751278 )
@@ -1306,9 +1309,8 @@ async def async_get_session(
13061309 """
13071310 if not self ._tmpl_attrs .get ("session_service" ):
13081311 self .set_up ()
1309- app = self ._tmpl_attrs .get ("app" )
13101312 session = await self ._tmpl_attrs .get ("session_service" ).get_session (
1311- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1313+ app_name = self ._app_name ( ),
13121314 user_id = user_id ,
13131315 session_id = session_id ,
13141316 ** kwargs ,
@@ -1384,9 +1386,8 @@ async def async_list_sessions(self, *, user_id: str, **kwargs):
13841386 """
13851387 if not self ._tmpl_attrs .get ("session_service" ):
13861388 self .set_up ()
1387- app = self ._tmpl_attrs .get ("app" )
13881389 return await self ._tmpl_attrs .get ("session_service" ).list_sessions (
1389- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1390+ app_name = self ._app_name ( ),
13901391 user_id = user_id ,
13911392 ** kwargs ,
13921393 )
@@ -1457,9 +1458,8 @@ async def async_create_session(
14571458 """
14581459 if not self ._tmpl_attrs .get ("session_service" ):
14591460 self .set_up ()
1460- app = self ._tmpl_attrs .get ("app" )
14611461 session = await self ._tmpl_attrs .get ("session_service" ).create_session (
1462- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1462+ app_name = self ._app_name ( ),
14631463 user_id = user_id ,
14641464 session_id = session_id ,
14651465 state = state ,
@@ -1539,9 +1539,8 @@ async def async_delete_session(
15391539 """
15401540 if not self ._tmpl_attrs .get ("session_service" ):
15411541 self .set_up ()
1542- app = self ._tmpl_attrs .get ("app" )
15431542 await self ._tmpl_attrs .get ("session_service" ).delete_session (
1544- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1543+ app_name = self ._app_name ( ),
15451544 user_id = user_id ,
15461545 session_id = session_id ,
15471546 ** kwargs ,
@@ -1630,9 +1629,8 @@ async def async_search_memory(self, *, user_id: str, query: str):
16301629 """
16311630 if not self ._tmpl_attrs .get ("memory_service" ):
16321631 self .set_up ()
1633- app = self ._tmpl_attrs .get ("app" )
16341632 return await self ._tmpl_attrs .get ("memory_service" ).search_memory (
1635- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1633+ app_name = self ._app_name ( ),
16361634 user_id = user_id ,
16371635 query = query ,
16381636 )
0 commit comments