Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion veadk/toolkits/apps/reverse_mcp/client_with_reverse_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,24 @@


class ClientWithReverseMCP:
def __init__(self, ws_url: str, mcp_server_url: str, client_id: str):
def __init__(
self,
ws_url: str,
mcp_server_url: str,
client_id: str,
filters: list[str] | None = None,
):
"""Start a client with reverse mcp,

Args:
ws_url: The url of the websocket server (cloud). Like example.com:8000
mcp_server_url: The url of the mcp server (local).
client_id: The client id for the websocket connection.
filters: Optional list of tool names to filter (whitelist). If None, all tools are available.
"""
self.ws_url = f"ws://{ws_url}/ws?id={client_id}"
if filters:
self.ws_url += f"&filters={','.join(filters)}"
self.mcp_server_url = mcp_server_url

# set timeout for httpx client
Expand Down
228 changes: 210 additions & 18 deletions veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
import asyncio
import json
import uuid
from typing import TYPE_CHECKING

from fastapi import FastAPI, Request, Response, WebSocket
from typing import TYPE_CHECKING, Any, Optional

from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
from fastapi.responses import StreamingResponse
from google.adk.agents.run_config import StreamingMode
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.adk_web_server import RunAgentRequest
from google.adk.runners import Runner as GoogleRunner, RunConfig
from google.adk.sessions import InMemorySessionService, Session
from google.adk.tools.mcp_tool.mcp_session_manager import (
StreamableHTTPConnectionParams,
)
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
from google.adk.utils.context_utils import Aclosing
from pydantic import BaseModel

from veadk import Runner
Expand Down Expand Up @@ -93,11 +100,15 @@ def __init__(
self.port = port

self.app = FastAPI()

self.artifact_service = InMemoryArtifactService()

# build routes for self.app
self.build()

self.ws_session_mgr = WebsocketSessionManager()
self.ws_agent_mgr: dict[str, "Agent"] = {}
self.ws_session_service_mgr: dict[str, "InMemorySessionService"] = {}

def build(self):
logger.info("Build routes for server with reverse mcp")
Expand Down Expand Up @@ -126,19 +137,6 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:

agent = self.ws_agent_mgr[payload.websocket_id]

if not agent.tools:
logger.debug("Mount fake MCPToolset to agent")

# we hard code the mcp url with `/mcp` to obey the mcp protocol
agent.tools.append(
MCPToolset(
connection_params=StreamableHTTPConnectionParams(
url=f"http://127.0.0.1:{self.port}/mcp",
headers={REVERSE_MCP_HEADER_KEY: payload.websocket_id},
),
)
)

runner = Runner(app_name=payload.app_name, agent=agent)
response = await runner.run(
messages=[prompt],
Expand All @@ -152,18 +150,43 @@ async def invoke(payload: InvokeRequest) -> InvokeResponse:
@self.app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
client_id = ws.query_params.get("id")

if not client_id:
await ws.close(
code=400,
reason="WebSocket `id` is required like `/ws?id=my_id`",
)
return

# Parse filters from query params, comma-separated string
filters_str = ws.query_params.get("filters")
filters = None
if filters_str:
filters = [t.strip() for t in filters_str.split(",") if t.strip()]

logger.info(f"Register websocket {client_id} to session manager.")
self.ws_session_mgr.connections[client_id] = ws

logger.info(f"Fork agent for websocket {client_id}")
self.ws_agent_mgr[client_id] = self.agent.clone()
agent = self.agent.clone()

# Mount MCPToolset when creating agent
mcp_toolset_url = f"http://127.0.0.1:{self.port}/mcp"
mcp_toolset_headers = {REVERSE_MCP_HEADER_KEY: client_id}
logger.debug(f"Mount MCPToolset to agent for websocket {client_id}")
agent.tools.append(
MCPToolset(
connection_params=StreamableHTTPConnectionParams(
url=mcp_toolset_url,
headers=mcp_toolset_headers,
),
tool_filter=filters,
)
)
self.ws_agent_mgr[client_id] = agent

logger.info(f"Create session service for websocket {client_id}")
self.ws_session_service_mgr[client_id] = InMemorySessionService()

await ws.accept()
logger.info(f"Websocket {client_id} connected")
Expand All @@ -172,8 +195,160 @@ async def ws_endpoint(ws: WebSocket):
raw = await ws.receive_text()
await self.ws_session_mgr.handle_ws_message(client_id, raw)

class CreateSessionRequest(BaseModel):
state: Optional[dict[str, Any]] = None
session_id: Optional[str] = None
websocket_id: str

class RunAgentRequestWithWsId(RunAgentRequest):
websocket_id: str

def _get_session_service(websocket_id: str) -> InMemorySessionService:
"""Get session service for the websocket client."""
if websocket_id not in self.ws_session_service_mgr:
raise HTTPException(
status_code=404, detail=f"WebSocket client {websocket_id} not found"
)
return self.ws_session_service_mgr[websocket_id]

@self.app.post(
"/apps/{app_name}/users/{user_id}/sessions",
response_model_exclude_none=True,
)
async def create_session(
app_name: str,
user_id: str,
req: CreateSessionRequest,
) -> Session:
"""Create a new session."""
session_id = req.session_id if req.session_id else str(uuid.uuid4())
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=req.state if req.state else {},
)
session_service = _get_session_service(req.websocket_id)
await session_service.create_session(
app_name=app_name,
user_id=user_id,
session_id=session_id,
state=req.state if req.state else {},
)
logger.info(
f"Created session: {session_id} for user {user_id} in app {app_name}"
)
return session

@self.app.post(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
async def create_session_with_id(
app_name: str,
user_id: str,
session_id: str,
req: CreateSessionRequest,
) -> Session:
"""Create a session with specific ID."""
session_service = _get_session_service(req.websocket_id)
await session_service.create_session(
app_name=app_name,
user_id=user_id,
session_id=session_id,
state=req.state if req.state else {},
)
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=req.state if req.state else {},
)
logger.info(f"Created session with ID: {session_id} for user {user_id}")
return session

@self.app.post("/run_sse")
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
"""Run agent with SSE streaming."""
session_service = _get_session_service(req.websocket_id)

# Get session
session = await session_service.get_session(
app_name=req.app_name,
user_id=req.user_id,
session_id=req.session_id,
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")

# Get agent for this websocket
if req.websocket_id in self.ws_agent_mgr:
agent = self.ws_agent_mgr[req.websocket_id]
logger.debug(f"Using agent from websocket {req.websocket_id}")
else:
raise HTTPException(
status_code=404,
detail=f"WebSocket client {req.websocket_id} not found",
)

# Create runner
runner = GoogleRunner(
agent=agent,
app_name=req.app_name,
session_service=session_service,
artifact_service=self.artifact_service,
)

# Determine streaming mode from request
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE

async def event_generator():
try:
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
invocation_id=req.invocation_id,
)
) as agen:
async for event in agen:
# ADK Web renders artifacts from `actions.artifactDelta`
# during part processing *and* during action processing
# 1) the original event with `artifactDelta` cleared (content)
# 2) a content-less "action-only" event carrying `artifactDelta`
events_to_stream = [event]
if (
event.actions.artifact_delta
and event.content
and event.content.parts
):
content_event = event.model_copy(deep=True)
content_event.actions.artifact_delta = {}
artifact_event = event.model_copy(deep=True)
artifact_event.content = None
events_to_stream = [content_event, artifact_event]

for event_to_stream in events_to_stream:
sse_event = event_to_stream.model_dump_json(
exclude_none=True, by_alias=True
)
logger.debug(f"SSE event: {sse_event}")
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception(f"Error in event_generator: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
)

# build the fake MPC server,
# and intercept all requests to the client websocket client.
# NOTE: This catch-all route must be defined LAST
@self.app.api_route("/{path:path}", methods=["GET", "POST"])
async def mcp_proxy(path: str, request: Request):
client_id = request.headers.get(REVERSE_MCP_HEADER_KEY)
Expand Down Expand Up @@ -202,10 +377,27 @@ async def mcp_proxy(path: str, request: Request):

logger.debug(f"[Reverse mcp proxy] Response from local: {resp}")

# Filter hop-by-hop headers to avoid Content-Length mismatch
headers = resp["payload"]["headers"]
hop_by_hop_headers = {
"content-length",
"transfer-encoding",
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailers",
"upgrade",
}
filtered_headers = {
k: v for k, v in headers.items() if k.lower() not in hop_by_hop_headers
}

return Response(
content=resp["payload"]["body"], # type: ignore
status_code=resp["payload"]["status"], # type: ignore
headers=resp["payload"]["headers"], # type: ignore
headers=filtered_headers, # type: ignore
)

def run(self):
Expand Down