Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,34 @@
from .run_context import TContext
from .tool import FunctionTool


class _McpSseNoiseFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
try:
msg = record.getMessage().strip()
except Exception:
return True
if msg.startswith("Unknown SSE event:"):
event_name = msg.split(":", 1)[1].strip()
if event_name in {"stream", "connection"}:
return False
return True


def _install_mcp_noise_filters() -> None:
for logger_name in ("mcp.client.streamable_http", "mcp.client.sse"):
log = logging.getLogger(logger_name)
if any(isinstance(f, _McpSseNoiseFilter) for f in log.filters):
continue
log.addFilter(_McpSseNoiseFilter())


try:
import anyio
import mcp
from mcp.client.sse import sse_client

_install_mcp_noise_filters()
except (ModuleNotFoundError, ImportError):
logger.warning(
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
Expand All @@ -47,6 +71,8 @@ def _prepare_config(config: dict) -> dict:

async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"""Quick test MCP server connectivity"""
import json

import aiohttp

cfg = _prepare_config(config.copy())
Expand All @@ -55,6 +81,40 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
headers = cfg.get("headers", {})
timeout = cfg.get("timeout", 10)

async def _format_http_error(response: aiohttp.ClientResponse) -> str:
reason = response.reason or ""
detail = ""
try:
raw = await response.content.read(2048)
if raw:
text = raw.decode(errors="replace").strip()
if text:
try:
data = json.loads(text)
except Exception:
detail = text
else:
if isinstance(data, dict):
msg = (
data.get("message")
or data.get("error")
or data.get("detail")
)
code = data.get("code")
if msg is not None:
detail = (
f"{code}: {msg}" if code is not None else str(msg)
)
else:
detail = text
else:
detail = text
except Exception:
detail = ""
if detail:
return f"HTTP {response.status}: {reason} ({detail})"
return f"HTTP {response.status}: {reason}"

try:
if "transport" in cfg:
transport_type = cfg["transport"]
Expand All @@ -70,7 +130,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
"method": "initialize",
"id": 0,
"params": {
"protocolVersion": "2024-11-05",
"protocolVersion": mcp.types.LATEST_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.2.3"},
},
Expand All @@ -87,7 +147,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
return False, await _format_http_error(response)
else:
async with session.get(
url,
Expand All @@ -99,7 +159,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
) as response:
if response.status == 200:
return True, ""
return False, f"HTTP {response.status}: {response.reason}"
return False, await _format_http_error(response)

except asyncio.TimeoutError:
return False, f"Connection timeout: {timeout} seconds"
Expand Down Expand Up @@ -146,6 +206,11 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:

def logging_callback(msg: str) -> None:
# Handle MCP service error logs
normalized = msg.strip()
if normalized.startswith("Unknown SSE event:"):
event_name = normalized.split(":", 1)[1].strip()
if event_name in {"stream", "connection"}:
return
Comment on lines +210 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这段过滤SSE噪音事件的逻辑与文件顶部的 _McpSseNoiseFilter 类中的逻辑重复了。为了提高代码的可维护性和复用性,建议将这部分逻辑提取到一个独立的辅助函数中,供两处调用。

另外,当前的 split 操作没有处理不包含':'的字符串,可能会导致 IndexError。建议增加对 split 结果的长度检查以增强代码的健壮性,并可以简化逻辑。

Suggested change
if normalized.startswith("Unknown SSE event:"):
event_name = normalized.split(":", 1)[1].strip()
if event_name in {"stream", "connection"}:
return
if normalized.startswith("Unknown SSE event:"):
parts = normalized.split(":", 1)
if len(parts) > 1 and parts[1].strip() in {"stream", "connection"}:
return

print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The server_errlogs list grows indefinitely as error messages are received from the MCP server. A malicious or malfunctioning server could flood the bot with messages, leading to memory exhaustion and a Denial of Service (DoS). Consider using a fixed-size buffer like collections.deque(maxlen=100) to store error logs.


Expand Down
Loading