diff --git a/veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py b/veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py index b6e97ccc..5d9a8b7c 100644 --- a/veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py +++ b/veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py @@ -23,14 +23,24 @@ class ClientWithReverseMCP: - def __init__(self, ws_url: str, mcp_server_url: str, client_id: str): + def __init__( + self, + ws_url: str, + mcp_server_url: str, + client_id: str, + filters: list[str] | None = None, + ): """Start a client with reverse mcp, Args: ws_url: The url of the websocket server (cloud). Like example.com:8000 mcp_server_url: The url of the mcp server (local). + client_id: The client id for the websocket connection. + filters: Optional list of tool names to filter (whitelist). If None, all tools are available. """ self.ws_url = f"ws://{ws_url}/ws?id={client_id}" + if filters: + self.ws_url += f"&filters={','.join(filters)}" self.mcp_server_url = mcp_server_url # set timeout for httpx client diff --git a/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py b/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py index 442ce6bf..defdb4ba 100644 --- a/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py +++ b/veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py @@ -15,13 +15,20 @@ import asyncio import json import uuid -from typing import TYPE_CHECKING - -from fastapi import FastAPI, Request, Response, WebSocket +from typing import TYPE_CHECKING, Any, Optional + +from fastapi import FastAPI, HTTPException, Request, Response, WebSocket +from fastapi.responses import StreamingResponse +from google.adk.agents.run_config import StreamingMode +from google.adk.artifacts import InMemoryArtifactService +from google.adk.cli.adk_web_server import RunAgentRequest +from google.adk.runners import Runner as GoogleRunner, RunConfig +from google.adk.sessions import InMemorySessionService, Session from google.adk.tools.mcp_tool.mcp_session_manager import ( StreamableHTTPConnectionParams, ) from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from google.adk.utils.context_utils import Aclosing from pydantic import BaseModel from veadk import Runner @@ -93,11 +100,15 @@ def __init__( self.port = port self.app = FastAPI() + + self.artifact_service = InMemoryArtifactService() + # build routes for self.app self.build() self.ws_session_mgr = WebsocketSessionManager() self.ws_agent_mgr: dict[str, "Agent"] = {} + self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {} def build(self): logger.info("Build routes for server with reverse mcp") @@ -126,19 +137,6 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse: agent = self.ws_agent_mgr[payload.websocket_id] - if not agent.tools: - logger.debug("Mount fake MCPToolset to agent") - - # we hard code the mcp url with `/mcp` to obey the mcp protocol - agent.tools.append( - MCPToolset( - connection_params=StreamableHTTPConnectionParams( - url=f"http://127.0.0.1:{self.port}/mcp", - headers={REVERSE_MCP_HEADER_KEY: payload.websocket_id}, - ), - ) - ) - runner = Runner(app_name=payload.app_name, agent=agent) response = await runner.run( messages=[prompt], @@ -152,6 +150,7 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse: @self.app.websocket("/ws") async def ws_endpoint(ws: WebSocket): client_id = ws.query_params.get("id") + if not client_id: await ws.close( code=400, @@ -159,11 +158,35 @@ async def ws_endpoint(ws: WebSocket): ) return + # Parse filters from query params, comma-separated string + filters_str = ws.query_params.get("filters") + filters = None + if filters_str: + filters = [t.strip() for t in filters_str.split(",") if t.strip()] + logger.info(f"Register websocket {client_id} to session manager.") self.ws_session_mgr.connections[client_id] = ws logger.info(f"Fork agent for websocket {client_id}") - self.ws_agent_mgr[client_id] = self.agent.clone() + agent = self.agent.clone() + + # Mount MCPToolset when creating agent + mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp" + mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id} + logger.debug(f"Mount MCPToolset to agent for websocket {client_id}") + agent.tools.append( + MCPToolset( + connection_params=StreamableHTTPConnectionParams( + url=mcp_toolset_url, + headers=mcp_toolset_headers, + ), + tool_filter=filters, + ) + ) + self.ws_agent_mgr[client_id] = agent + + logger.info(f"Create session service for websocket {client_id}") + self.ws_session_service_mgr[client_id] = InMemorySessionService() await ws.accept() logger.info(f"Websocket {client_id} connected") @@ -172,8 +195,160 @@ async def ws_endpoint(ws: WebSocket): raw = await ws.receive_text() await self.ws_session_mgr.handle_ws_message(client_id, raw) + class CreateSessionRequest(BaseModel): + state: Optional[dict[str, Any]] = None + session_id: Optional[str] = None + websocket_id: str + + class RunAgentRequestWithWsId(RunAgentRequest): + websocket_id: str + + def _get_session_service(websocket_id: str) -> InMemorySessionService: + """Get session service for the websocket client.""" + if websocket_id not in self.ws_session_service_mgr: + raise HTTPException( + status_code=404, detail=f"WebSocket client {websocket_id} not found" + ) + return self.ws_session_service_mgr[websocket_id] + + @self.app.post( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def create_session( + app_name: str, + user_id: str, + req: CreateSessionRequest, + ) -> Session: + """Create a new session.""" + session_id = req.session_id if req.session_id else str(uuid.uuid4()) + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=req.state if req.state else {}, + ) + session_service = _get_session_service(req.websocket_id) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state=req.state if req.state else {}, + ) + logger.info( + f"Created session: {session_id} for user {user_id} in app {app_name}" + ) + return session + + @self.app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def create_session_with_id( + app_name: str, + user_id: str, + session_id: str, + req: CreateSessionRequest, + ) -> Session: + """Create a session with specific ID.""" + session_service = _get_session_service(req.websocket_id) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state=req.state if req.state else {}, + ) + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=req.state if req.state else {}, + ) + logger.info(f"Created session with ID: {session_id} for user {user_id}") + return session + + @self.app.post("/run_sse") + async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse: + """Run agent with SSE streaming.""" + session_service = _get_session_service(req.websocket_id) + + # Get session + session = await session_service.get_session( + app_name=req.app_name, + user_id=req.user_id, + session_id=req.session_id, + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # Get agent for this websocket + if req.websocket_id in self.ws_agent_mgr: + agent = self.ws_agent_mgr[req.websocket_id] + logger.debug(f"Using agent from websocket {req.websocket_id}") + else: + raise HTTPException( + status_code=404, + detail=f"WebSocket client {req.websocket_id} not found", + ) + + # Create runner + runner = GoogleRunner( + agent=agent, + app_name=req.app_name, + session_service=session_service, + artifact_service=self.artifact_service, + ) + + # Determine streaming mode from request + stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE + + async def event_generator(): + try: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + invocation_id=req.invocation_id, + ) + ) as agen: + async for event in agen: + # ADK Web renders artifacts from `actions.artifactDelta` + # during part processing *and* during action processing + # 1) the original event with `artifactDelta` cleared (content) + # 2) a content-less "action-only" event carrying `artifactDelta` + events_to_stream = [event] + if ( + event.actions.artifact_delta + and event.content + and event.content.parts + ): + content_event = event.model_copy(deep=True) + content_event.actions.artifact_delta = {} + artifact_event = event.model_copy(deep=True) + artifact_event.content = None + events_to_stream = [content_event, artifact_event] + + for event_to_stream in events_to_stream: + sse_event = event_to_stream.model_dump_json( + exclude_none=True, by_alias=True + ) + logger.debug(f"SSE event: {sse_event}") + yield f"data: {sse_event}\n\n" + except Exception as e: + logger.exception(f"Error in event_generator: {e}") + yield f"data: {json.dumps({'error': str(e)})}\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + ) + # build the fake MPC server, # and intercept all requests to the client websocket client. + # NOTE: This catch-all route must be defined LAST @self.app.api_route("/{path:path}", methods=["GET", "POST"]) async def mcp_proxy(path: str, request: Request): client_id = request.headers.get(REVERSE_MCP_HEADER_KEY) @@ -202,10 +377,27 @@ async def mcp_proxy(path: str, request: Request): logger.debug(f"[Reverse mcp proxy] Response from local: {resp}") + # Filter hop-by-hop headers to avoid Content-Length mismatch + headers = resp["payload"]["headers"] + hop_by_hop_headers = { + "content-length", + "transfer-encoding", + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "upgrade", + } + filtered_headers = { + k: v for k, v in headers.items() if k.lower() not in hop_by_hop_headers + } + return Response( content=resp["payload"]["body"], # type: ignore status_code=resp["payload"]["status"], # type: ignore - headers=resp["payload"]["headers"], # type: ignore + headers=filtered_headers, # type: ignore ) def run(self):