diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 18f4d47e04..d654592986 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -19,10 +19,34 @@ from .run_context import TContext from .tool import FunctionTool + +class _McpSseNoiseFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + try: + msg = record.getMessage().strip() + except Exception: + return True + if msg.startswith("Unknown SSE event:"): + event_name = msg.split(":", 1)[1].strip() + if event_name in {"stream", "connection"}: + return False + return True + + +def _install_mcp_noise_filters() -> None: + for logger_name in ("mcp.client.streamable_http", "mcp.client.sse"): + log = logging.getLogger(logger_name) + if any(isinstance(f, _McpSseNoiseFilter) for f in log.filters): + continue + log.addFilter(_McpSseNoiseFilter()) + + try: import anyio import mcp from mcp.client.sse import sse_client + + _install_mcp_noise_filters() except (ModuleNotFoundError, ImportError): logger.warning( "Warning: Missing 'mcp' dependency, MCP services will be unavailable." @@ -47,6 +71,8 @@ def _prepare_config(config: dict) -> dict: async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: """Quick test MCP server connectivity""" + import json + import aiohttp cfg = _prepare_config(config.copy()) @@ -55,6 +81,40 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: headers = cfg.get("headers", {}) timeout = cfg.get("timeout", 10) + async def _format_http_error(response: aiohttp.ClientResponse) -> str: + reason = response.reason or "" + detail = "" + try: + raw = await response.content.read(2048) + if raw: + text = raw.decode(errors="replace").strip() + if text: + try: + data = json.loads(text) + except Exception: + detail = text + else: + if isinstance(data, dict): + msg = ( + data.get("message") + or data.get("error") + or data.get("detail") + ) + code = data.get("code") + if msg is not None: + detail = ( + f"{code}: {msg}" if code is not None else str(msg) + ) + else: + detail = text + else: + detail = text + except Exception: + detail = "" + if detail: + return f"HTTP {response.status}: {reason} ({detail})" + return f"HTTP {response.status}: {reason}" + try: if "transport" in cfg: transport_type = cfg["transport"] @@ -70,7 +130,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: "method": "initialize", "id": 0, "params": { - "protocolVersion": "2024-11-05", + "protocolVersion": mcp.types.LATEST_PROTOCOL_VERSION, "capabilities": {}, "clientInfo": {"name": "test-client", "version": "1.2.3"}, }, @@ -87,7 +147,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: ) as response: if response.status == 200: return True, "" - return False, f"HTTP {response.status}: {response.reason}" + return False, await _format_http_error(response) else: async with session.get( url, @@ -99,7 +159,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: ) as response: if response.status == 200: return True, "" - return False, f"HTTP {response.status}: {response.reason}" + return False, await _format_http_error(response) except asyncio.TimeoutError: return False, f"Connection timeout: {timeout} seconds" @@ -146,6 +206,11 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: def logging_callback(msg: str) -> None: # Handle MCP service error logs + normalized = msg.strip() + if normalized.startswith("Unknown SSE event:"): + event_name = normalized.split(":", 1)[1].strip() + if event_name in {"stream", "connection"}: + return print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 068c63c5ad..21853fda25 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -11,14 +11,18 @@ from types import MappingProxyType from typing import Any -import aiohttp - from astrbot import logger from astrbot.core import sp -from astrbot.core.agent.mcp_client import MCPClient, MCPTool +from astrbot.core.agent.mcp_client import ( + MCPClient, + MCPTool, + _quick_test_mcp_connection, +) from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from .mcp_sync_providers import SyncedMcpServer, get_mcp_sync_provider + DEFAULT_MCP_CONFIG = {"mcpServers": {}} DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 20.0 @@ -140,70 +144,6 @@ def _resolve_timeout( FuncTool = FunctionTool -def _prepare_config(config: dict) -> dict: - """准备配置,处理嵌套格式""" - if config.get("mcpServers"): - first_key = next(iter(config["mcpServers"])) - config = config["mcpServers"][first_key] - config.pop("active", None) - return config - - -async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: - """快速测试 MCP 服务器可达性""" - import aiohttp - - cfg = _prepare_config(config.copy()) - - url = cfg["url"] - headers = cfg.get("headers", {}) - timeout = cfg.get("timeout", 10) - - try: - async with aiohttp.ClientSession() as session: - if cfg.get("transport") == "streamable_http": - test_payload = { - "jsonrpc": "2.0", - "method": "initialize", - "id": 0, - "params": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "clientInfo": {"name": "test-client", "version": "1.2.3"}, - }, - } - async with session.post( - url, - headers={ - **headers, - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - }, - json=test_payload, - timeout=aiohttp.ClientTimeout(total=timeout), - ) as response: - if response.status == 200: - return True, "" - return False, f"HTTP {response.status}: {response.reason}" - else: - async with session.get( - url, - headers={ - **headers, - "Accept": "application/json, text/event-stream", - }, - timeout=aiohttp.ClientTimeout(total=timeout), - ) as response: - if response.status == 200: - return True, "" - return False, f"HTTP {response.status}: {response.reason}" - - except asyncio.TimeoutError: - return False, f"连接超时: {timeout}秒" - except Exception as e: - return False, f"{e!s}" - - class FunctionToolManager: def __init__(self) -> None: self.func_list: list[FuncTool] = [] @@ -480,8 +420,19 @@ async def _start_mcp_server( raise MCPInitTimeoutError( f"MCP 服务 {name} 初始化超时({timeout:g} 秒)" ) from exc - except Exception: - logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) + except Exception as e: + msg = str(e).lower() + is_invalid_key = ( + "invalid apikey" in msg or "invalid authorization key" in msg + ) + if is_invalid_key and str(cfg.get("provider", "")).lower() == "mcprouter": + logger.warning( + f"初始化 MCP 客户端 {name} 失败:MCPRouter API Key 无效(请重新同步/更新 API Key): {e!s}", + ) + elif is_invalid_key: + logger.warning(f"初始化 MCP 客户端 {name} 失败: {e!s}") + else: + logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) raise finally: if mcp_client is None: @@ -841,71 +792,117 @@ def save_mcp_config(self, config: dict) -> bool: logger.error(f"保存 MCP 配置失败: {e}") return False - async def sync_modelscope_mcp_servers(self, access_token: str) -> None: - """从 ModelScope 平台同步 MCP 服务器配置""" - base_url = "https://www.modelscope.cn/openapi/v1" - url = f"{base_url}/mcp/servers/operational" - headers = { - "Authorization": f"Bearer {access_token.strip()}", - "Content-Type": "application/json", + async def _enable_mcp_servers_with_concurrency_limit( + self, + server_names: list[str], + config: dict, + *, + max_concurrency: int = 5, + timeout: int = 30, + ) -> tuple[int, dict[str, str]]: + sem = asyncio.Semaphore(max_concurrency) + failures: dict[str, str] = {} + + async def _enable_one(name: str) -> bool: + async with sem: + try: + if name in self.mcp_client_dict: + await self.disable_mcp_server(name, timeout=10) + await self.enable_mcp_server( + name=name, + config=config["mcpServers"][name], + timeout=timeout, + ) + return True + except Exception as e: + failures[name] = str(e) + logger.warning(f"启用 MCP 服务器失败: {name}, err={e!s}") + return False + + results = await asyncio.gather(*[_enable_one(n) for n in server_names]) + enabled_count = sum(1 for ok in results if ok) + return enabled_count, failures + + async def sync_mcp_servers_from_provider( + self, + provider_name: str, + payload: dict[str, Any], + *, + max_concurrency: int = 5, + ) -> dict[str, Any]: + provider = get_mcp_sync_provider(provider_name) + servers: list[SyncedMcpServer] = await provider.fetch(payload) + if not servers: + return { + "provider": provider_name, + "synced": 0, + "enabled": 0, + "failed": 0, + "failed_servers": [], + } + + local_mcp_config = self.load_mcp_config() + local_mcp_config.setdefault("mcpServers", {}) + + for item in servers: + local_mcp_config["mcpServers"][item.name] = item.config + + if not self.save_mcp_config(local_mcp_config): + raise RuntimeError("保存 MCP 配置失败,已取消同步启用") + + enabled_count, failures = await self._enable_mcp_servers_with_concurrency_limit( + [item.name for item in servers], + local_mcp_config, + max_concurrency=max_concurrency, + ) + + return { + "provider": provider_name, + "synced": len(servers), + "enabled": enabled_count, + "failed": len(failures), + "failed_servers": sorted(failures.keys()), } - try: - async with aiohttp.ClientSession() as session: - async with session.get(url, headers=headers) as response: - if response.status == 200: - data = await response.json() - mcp_server_list = data.get("data", {}).get( - "mcp_server_list", - [], - ) - local_mcp_config = self.load_mcp_config() - - synced_count = 0 - for server in mcp_server_list: - server_name = server["name"] - operational_urls = server.get("operational_urls", []) - if not operational_urls: - continue - url_info = operational_urls[0] - server_url = url_info.get("url") - if not server_url: - continue - # 添加到配置中(同名会覆盖) - local_mcp_config["mcpServers"][server_name] = { - "url": server_url, - "transport": "sse", - "active": True, - "provider": "modelscope", - } - synced_count += 1 - - if synced_count > 0: - self.save_mcp_config(local_mcp_config) - tasks = [] - for server in mcp_server_list: - name = server["name"] - tasks.append( - self.enable_mcp_server( - name=name, - config=local_mcp_config["mcpServers"][name], - ), - ) - await asyncio.gather(*tasks) - logger.info( - f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器", - ) - else: - logger.warning("没有找到可用的 ModelScope MCP 服务器") - else: - raise Exception( - f"ModelScope API 请求失败: HTTP {response.status}", - ) - - except aiohttp.ClientError as e: - raise Exception(f"网络连接错误: {e!s}") - except Exception as e: - raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") + async def list_mcp_servers_from_provider( + self, + provider_name: str, + payload: dict[str, Any], + ) -> list[dict[str, Any]]: + provider = get_mcp_sync_provider(provider_name) + return await provider.list_servers(payload) + + async def sync_modelscope_mcp_servers(self, access_token: str) -> None: + await self.sync_mcp_servers_from_provider( + "modelscope", + {"access_token": access_token}, + ) + + async def sync_mcprouter_mcp_servers( + self, + api_key: str, + *, + app_url: str = "", + app_name: str = "AstrBot", + api_base: str = "https://api.mcprouter.to/v1", + limit: int = 100, + max_servers: int = 30, + query: str = "", + server_keys: str | list[str] | None = None, + ) -> dict[str, Any]: + return await self.sync_mcp_servers_from_provider( + "mcprouter", + { + "api_key": api_key, + "app_url": app_url, + "app_name": app_name, + "api_base": api_base, + "limit": limit, + "max_servers": max_servers, + "query": query, + "server_keys": server_keys, + }, + ) def __str__(self) -> str: return str(self.func_list) diff --git a/astrbot/core/provider/mcp_sync_providers.py b/astrbot/core/provider/mcp_sync_providers.py new file mode 100644 index 0000000000..57f76334ed --- /dev/null +++ b/astrbot/core/provider/mcp_sync_providers.py @@ -0,0 +1,560 @@ +from __future__ import annotations + +import abc +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from typing import Any, ClassVar + +import aiohttp + + +@dataclass(frozen=True, slots=True) +class SyncedMcpServer: + name: str + config: dict[str, Any] + + +class McpServerSyncProvider(abc.ABC): + provider: ClassVar[str] + + @abc.abstractmethod + async def fetch(self, payload: dict[str, Any]) -> list[SyncedMcpServer]: + raise NotImplementedError + + async def list_servers(self, payload: dict[str, Any]) -> list[dict[str, Any]]: + return [] + + +_provider_registry: dict[str, type[McpServerSyncProvider]] = {} + + +def register_mcp_sync_provider(provider: str): + def decorator(cls: type[McpServerSyncProvider]) -> type[McpServerSyncProvider]: + if provider in _provider_registry: + raise ValueError(f"MCP sync provider already registered: {provider}") + cls.provider = provider # type: ignore[attr-defined] + _provider_registry[provider] = cls + return cls + + return decorator + + +def get_mcp_sync_provider(provider: str) -> McpServerSyncProvider: + cls = _provider_registry.get(provider) + if not cls: + raise ValueError(f"Unknown MCP sync provider: {provider}") + return cls() + + +@register_mcp_sync_provider("modelscope") +class ModelscopeMcpServerSyncProvider(McpServerSyncProvider): + async def fetch(self, payload: dict[str, Any]) -> list[SyncedMcpServer]: + access_token = str(payload.get("access_token", "")).strip() + if not access_token: + raise ValueError("Missing required field: access_token") + + base_url = "https://www.modelscope.cn/openapi/v1" + url = f"{base_url}/mcp/servers/operational" + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status != 200: + raise RuntimeError( + f"ModelScope API request failed: HTTP {response.status}" + ) + data = await response.json() + + mcp_server_list = data.get("data", {}).get("mcp_server_list", []) or [] + items: list[SyncedMcpServer] = [] + for server in mcp_server_list: + server_name = server.get("name") + operational_urls = server.get("operational_urls") or [] + if not server_name or not operational_urls: + continue + server_url = (operational_urls[0] or {}).get("url") + if not server_url: + continue + items.append( + SyncedMcpServer( + name=server_name, + config={ + "url": server_url, + "transport": "sse", + "active": True, + "provider": "modelscope", + }, + ) + ) + return items + + +@register_mcp_sync_provider("mcprouter") +class McpRouterMcpServerSyncProvider(McpServerSyncProvider): + @staticmethod + def _build_error_detail(data: dict[str, Any]) -> str: + message = data.get("message") or data.get("error") or data.get("detail") + code = data.get("code") + if message is None: + return "" + message_text = str(message).strip() + if not message_text: + return "" + if code is None: + return message_text + return f"{code}: {message_text}" + + async def _post_json( + self, + *, + session: aiohttp.ClientSession, + url: str, + payload: dict[str, Any], + headers: dict[str, str], + action: str, + ) -> dict[str, Any]: + async with session.post(url, json=payload, headers=headers) as response: + body_text = "" + data: dict[str, Any] = {} + try: + parsed = await response.json(content_type=None) + if isinstance(parsed, dict): + data = parsed + except Exception: + body_text = (await response.text()).strip() + + if response.status != 200: + reason = response.reason or "" + detail = self._build_error_detail(data) or body_text[:300] + if detail: + raise RuntimeError( + f"{action} failed: HTTP {response.status} {reason} ({detail})" + ) + raise RuntimeError(f"{action} failed: HTTP {response.status} {reason}") + + if not data: + detail = body_text[:300] if body_text else "empty or non-json response" + raise RuntimeError(f"{action} failed: invalid response ({detail})") + + return data + + def _ensure_api_success(self, data: dict[str, Any], *, action: str) -> None: + if data.get("code") == 0: + return + detail = self._build_error_detail(data) or "unknown error" + raise RuntimeError(f"{action} failed: {detail}") + + def _normalize_api_key(self, value: str) -> str: + raw = value.strip() + if not raw: + return raw + lower = raw.lower() + if lower.startswith("bearer "): + return raw[7:].strip() + if lower.startswith("authorization:"): + after = raw.split(":", 1)[1].strip() + if after.lower().startswith("bearer "): + return after[7:].strip() + return after + return raw + + def _build_api_headers( + self, + *, + api_key: str, + app_url: str, + app_name: str, + ) -> dict[str, str]: + headers: dict[str, str] = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + if app_url: + headers["HTTP-Referer"] = app_url + if app_name: + headers["X-Title"] = app_name + return headers + + def _build_mcp_headers( + self, + *, + api_key: str, + app_url: str, + app_name: str, + ) -> dict[str, str]: + return self._build_api_headers( + api_key=api_key, + app_url=app_url, + app_name=app_name, + ) + + @staticmethod + def _parse_server_keys(value: Any) -> list[str]: + if isinstance(value, list): + parts = [str(item).strip() for item in value] + elif isinstance(value, str): + raw = value.replace(",", "\n").replace(";", "\n") + parts = [line.strip() for line in raw.splitlines()] + else: + return [] + + keys = [item for item in parts if item] + seen: set[str] = set() + result: list[str] = [] + for item in keys: + if item in seen: + continue + seen.add(item) + result.append(item) + return result + + @staticmethod + def _matches(server: dict[str, Any], q: str) -> bool: + if not q: + return True + haystacks = [ + server.get("config_name"), + server.get("server_key"), + server.get("name"), + server.get("title"), + server.get("description"), + server.get("author_name"), + ] + combined = " ".join(str(v) for v in haystacks if v) + return q in combined.lower() + + @staticmethod + def _resolve_server_name( + server: dict[str, Any], + *, + fallback: str | None = None, + ) -> str | None: + return ( + server.get("config_name") + or server.get("server_key") + or server.get("name") + or server.get("title") + or fallback + ) + + def _make_item( + self, + *, + name: str, + url: str, + used_names: set[str], + headers: dict[str, str], + server_key: str | None = None, + ) -> SyncedMcpServer: + final_name = name + if final_name in used_names: + suffix = server_key or "dup" + final_name = f"{final_name}-{suffix}" + i = 2 + while final_name in used_names: + final_name = f"{name}-{i}" + i += 1 + used_names.add(final_name) + return SyncedMcpServer( + name=final_name, + config={ + "url": url, + "transport": "streamable_http", + "headers": headers, + "active": True, + "provider": "mcprouter", + }, + ) + + async def _iter_list_servers_batches( + self, + *, + list_url: str, + api_headers: dict[str, str], + limit: int, + max_pages: int, + ) -> AsyncGenerator[list[dict[str, Any]], None]: + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + for page in range(1, max_pages + 1): + data = await self._post_json( + session=session, + url=list_url, + payload={"page": page, "limit": limit}, + headers=api_headers, + action="MCPRouter list-servers", + ) + self._ensure_api_success(data, action="MCPRouter list-servers") + raw_batch = data.get("data", {}).get("servers", []) or [] + if not raw_batch: + break + batch = [item for item in raw_batch if isinstance(item, dict)] + if batch: + yield batch + if len(raw_batch) < limit: + break + + async def _fetch_from_provided_servers( + self, + *, + provided_servers: list[Any], + raw_max_servers: Any, + max_servers: int, + mcp_headers: dict[str, str], + ) -> list[SyncedMcpServer]: + used_names: set[str] = set() + items: list[SyncedMcpServer] = [] + selected_servers = ( + provided_servers[:max_servers] + if raw_max_servers is not None + else provided_servers + ) + for server in selected_servers: + if not isinstance(server, dict): + continue + server_name = self._resolve_server_name(server) + server_url = server.get("server_url") + if not server_name or not server_url: + continue + server_key = server.get("server_key") + items.append( + self._make_item( + name=str(server_name), + url=str(server_url), + used_names=used_names, + headers=mcp_headers, + server_key=str(server_key) if server_key else None, + ) + ) + return items + + async def _fetch_from_server_keys( + self, + *, + server_keys: list[str], + max_servers: int, + get_url: str, + api_headers: dict[str, str], + mcp_headers: dict[str, str], + ) -> list[SyncedMcpServer]: + timeout = aiohttp.ClientTimeout(total=30) + used_names: set[str] = set() + items: list[SyncedMcpServer] = [] + async with aiohttp.ClientSession(timeout=timeout) as session: + for server_key in server_keys[:max_servers]: + data = await self._post_json( + session=session, + url=get_url, + payload={"server": server_key}, + headers=api_headers, + action="MCPRouter get-server", + ) + self._ensure_api_success(data, action="MCPRouter get-server") + server = data.get("data") or {} + if not isinstance(server, dict): + continue + server_url = server.get("server_url") + server_name = self._resolve_server_name(server, fallback=server_key) + if not server_url or not server_name: + continue + items.append( + self._make_item( + name=str(server_name), + url=str(server_url), + used_names=used_names, + headers=mcp_headers, + server_key=server_key, + ) + ) + return items + + async def _fetch_from_listing( + self, + *, + list_url: str, + api_headers: dict[str, str], + mcp_headers: dict[str, str], + query: str, + max_servers: int, + limit: int, + max_pages: int, + ) -> list[SyncedMcpServer]: + used_names: set[str] = set() + items: list[SyncedMcpServer] = [] + async for batch in self._iter_list_servers_batches( + list_url=list_url, + api_headers=api_headers, + limit=limit, + max_pages=max_pages, + ): + for server in batch: + if not self._matches(server, query): + continue + server_url = server.get("server_url") + server_name = self._resolve_server_name(server) + if not server_url or not server_name: + continue + server_key = server.get("server_key") + items.append( + self._make_item( + name=str(server_name), + url=str(server_url), + used_names=used_names, + headers=mcp_headers, + server_key=str(server_key) if server_key else None, + ) + ) + if len(items) >= max_servers: + return items + return items + + async def _validate_api_key( + self, + *, + api_key: str, + app_url: str, + app_name: str, + base_url: str, + ) -> None: + url = f"{base_url}/list-servers" + headers = self._build_api_headers( + api_key=api_key, + app_url=app_url, + app_name=app_name, + ) + timeout = aiohttp.ClientTimeout(total=20) + async with aiohttp.ClientSession(timeout=timeout) as session: + data = await self._post_json( + session=session, + url=url, + payload={"page": 1, "limit": 1}, + headers=headers, + action="MCPRouter API key validation", + ) + + if data.get("code") != 0: + detail = self._build_error_detail(data) or "unknown" + raise ValueError(f"MCPRouter API key validation failed: {detail}") + + async def fetch(self, payload: dict[str, Any]) -> list[SyncedMcpServer]: + api_key = self._normalize_api_key(str(payload.get("api_key", ""))) + if not api_key: + raise ValueError("Missing required field: api_key") + + app_url = str(payload.get("app_url", "")).strip() + app_name = str(payload.get("app_name", "")).strip() + + base_url = str(payload.get("api_base", "https://api.mcprouter.to/v1")).rstrip( + "/" + ) + list_url = f"{base_url}/list-servers" + get_url = f"{base_url}/get-server" + + await self._validate_api_key( + api_key=api_key, + app_url=app_url, + app_name=app_name, + base_url=base_url, + ) + + api_headers = self._build_api_headers( + api_key=api_key, + app_url=app_url, + app_name=app_name, + ) + mcp_headers = self._build_mcp_headers( + api_key=api_key, + app_url=app_url, + app_name=app_name, + ) + + query = str(payload.get("query", "")).strip().lower() + raw_max_servers = payload.get("max_servers") + max_servers = int(raw_max_servers or 30) + max_servers = max(1, min(max_servers, 500)) + + limit = int(payload.get("limit", 30) or 30) + limit = max(1, min(limit, 100)) + + max_pages = int(payload.get("max_pages", 10) or 10) + max_pages = max(1, min(max_pages, 50)) + + provided_servers = payload.get("servers") + if isinstance(provided_servers, list) and provided_servers: + return await self._fetch_from_provided_servers( + provided_servers=provided_servers, + raw_max_servers=raw_max_servers, + max_servers=max_servers, + mcp_headers=mcp_headers, + ) + + server_keys = self._parse_server_keys(payload.get("server_keys")) + if server_keys: + return await self._fetch_from_server_keys( + server_keys=server_keys, + max_servers=max_servers, + get_url=get_url, + api_headers=api_headers, + mcp_headers=mcp_headers, + ) + + return await self._fetch_from_listing( + list_url=list_url, + api_headers=api_headers, + mcp_headers=mcp_headers, + query=query, + max_servers=max_servers, + limit=limit, + max_pages=max_pages, + ) + + async def list_servers(self, payload: dict[str, Any]) -> list[dict[str, Any]]: + api_key = self._normalize_api_key(str(payload.get("api_key", ""))) + if not api_key: + raise ValueError("Missing required field: api_key") + + app_url = str(payload.get("app_url", "")).strip() + app_name = str(payload.get("app_name", "")).strip() + base_url = str(payload.get("api_base", "https://api.mcprouter.to/v1")).rstrip( + "/" + ) + list_url = f"{base_url}/list-servers" + + await self._validate_api_key( + api_key=api_key, + app_url=app_url, + app_name=app_name, + base_url=base_url, + ) + + api_headers = self._build_api_headers( + api_key=api_key, + app_url=app_url, + app_name=app_name, + ) + + limit = 100 + max_pages = 20 + max_items = 2000 + + servers: list[dict[str, Any]] = [] + async for batch in self._iter_list_servers_batches( + list_url=list_url, + api_headers=api_headers, + limit=limit, + max_pages=max_pages, + ): + for item in batch: + server_url = item.get("server_url") + server_key = item.get("server_key") + config_name = item.get("config_name") + if not server_url or not (server_key or config_name): + continue + servers.append(item) + if len(servers) >= max_items: + return servers + + return servers diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index b19385c285..d22c6be325 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -26,6 +26,10 @@ def __init__( "/tools/mcp/update": ("POST", self.update_mcp_server), "/tools/mcp/delete": ("POST", self.delete_mcp_server), "/tools/mcp/test": ("POST", self.test_mcp_connection), + "/tools/mcp/providers/mcprouter/list-servers": ( + "POST", + self.list_mcprouter_servers, + ), "/tools/list": ("GET", self.get_tool_list), "/tools/toggle-tool": ("POST", self.toggle_tool), "/tools/mcp/sync-provider": ("POST", self.sync_provider), @@ -377,19 +381,77 @@ async def toggle_tool(self): logger.error(traceback.format_exc()) return Response().error(f"操作工具失败: {e!s}").__dict__ + async def list_mcprouter_servers(self): + """List MCP servers from MCPRouter.""" + try: + data = await request.json + api_key = str((data or {}).get("api_key", "")).strip() + if not api_key: + return Response().error("缺少必要参数: api_key").__dict__ + + app_url = str((data or {}).get("app_url", "")).strip() + if not app_url: + app_url = ( + request.headers.get("Origin") + or request.headers.get("Referer") + or "" + ) + app_name = str((data or {}).get("app_name", "")).strip() or "AstrBot" + api_base = ( + str((data or {}).get("api_base", "https://api.mcprouter.to/v1")).strip() + or "https://api.mcprouter.to/v1" + ) + + servers = await self.tool_mgr.list_mcp_servers_from_provider( + "mcprouter", + { + "api_key": api_key, + "app_url": app_url, + "app_name": app_name, + "api_base": api_base, + }, + ) + return ( + Response() + .ok(data=servers, message=f"已获取 {len(servers)} 个服务器") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取 MCPRouter 服务器列表失败: {e!s}").__dict__ + async def sync_provider(self): """同步 MCP 提供者配置""" try: - data = await request.json + data = (await request.json) or {} provider_name = data.get("name") # modelscope, or others - match provider_name: - case "modelscope": - access_token = data.get("access_token", "") - await self.tool_mgr.sync_modelscope_mcp_servers(access_token) - case _: - return Response().error(f"未知: {provider_name}").__dict__ - - return Response().ok(message="同步成功").__dict__ + if not provider_name: + return Response().error("缺少必要参数: name").__dict__ + + if provider_name == "mcprouter": + data.setdefault( + "app_url", + request.headers.get("Origin") + or request.headers.get("Referer") + or "", + ) + data.setdefault("app_name", "AstrBot") + + result = await self.tool_mgr.sync_mcp_servers_from_provider( + provider_name, + data, + ) + synced = int(result.get("synced", 0) or 0) + enabled = int(result.get("enabled", 0) or 0) + failed = int(result.get("failed", 0) or 0) + + if synced == 0: + return Response().ok(message="未找到可同步的 MCP 服务器").__dict__ + + msg = f"同步完成:同步 {synced} 个,启用 {enabled} 个" + if failed: + msg += f",失败 {failed} 个" + return Response().ok(message=msg).__dict__ except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"同步失败: {e!s}").__dict__ diff --git a/dashboard/src/components/extension/McpServersSection.vue b/dashboard/src/components/extension/McpServersSection.vue index 95b6795809..69aab79617 100644 --- a/dashboard/src/components/extension/McpServersSection.vue +++ b/dashboard/src/components/extension/McpServersSection.vue @@ -152,14 +152,14 @@ - - - 同步外部平台 MCP 服务器 - + + + {{ tm('syncProvider.title') }} + + :label="tm('syncProvider.fields.provider')" variant="outlined" required>
@@ -192,6 +192,70 @@
+
+ + +
+
{{ tm('syncProvider.timeline.mcprouter.createApiKeyTitle') }}
+

+ {{ tm('syncProvider.timeline.mcprouter.createApiKeyDescPrefix') }} + MCPRouter + {{ tm('syncProvider.timeline.mcprouter.createApiKeyDescSuffix') }} +

+
+
+ + +
+
{{ tm('syncProvider.timeline.mcprouter.inputApiKeyTitle') }}
+

+ {{ tm('syncProvider.timeline.mcprouter.inputApiKeyDesc') }} +

+ + + {{ tm('syncProvider.buttons.fetchServers') }} + + +
+
+ {{ tm('syncProvider.status.fetchedServers', { count: mcprouterServers.length }) }} +
+ + + + + {{ server.title || server.config_name || server.server_key || server.name }} + ({{ server.server_key }}) + + + {{ server.description || server.author_name || server.config_name || '' }} + + + + + +
+
+
+ + +
+
{{ tm('syncProvider.timeline.mcprouter.optionalAppInfoTitle') }}
+

+ {{ tm('syncProvider.timeline.mcprouter.optionalAppInfoDesc') }} +

+ + +
+
+
+
@@ -241,8 +305,13 @@ export default { mcpServers: [], showMcpServerDialog: false, selectedMcpServerProvider: 'modelscope', - mcpServerProviderList: ['modelscope'], + mcpServerProviderList: ['modelscope', 'mcprouter'], mcpProviderToken: '', + mcprouterApiKey: '', + mcprouterAppUrl: '', + mcprouterAppName: 'AstrBot', + mcprouterServersLoading: false, + mcprouterServers: [], showSyncMcpServerDialog: false, addServerDialogMessage: '', loading: false, @@ -283,6 +352,7 @@ export default { }, mounted() { this.getServers(); + this.mcprouterAppUrl = window.location.origin; this.refreshInterval = setInterval(() => { this.getServers(); }, 5000); @@ -486,6 +556,43 @@ export default { this.save_message_success = 'error'; this.save_message_snack = true; }, + async fetchMcpRouterServers() { + if (!this.mcprouterApiKey.trim()) { + this.showError(this.tm('syncProvider.status.enterApiKey')); + return; + } + this.mcprouterServersLoading = true; + try { + const requestData = { + api_key: this.mcprouterApiKey.trim() + }; + if (this.mcprouterAppUrl.trim()) { + requestData.app_url = this.mcprouterAppUrl.trim(); + } + if (this.mcprouterAppName.trim()) { + requestData.app_name = this.mcprouterAppName.trim(); + } + const response = await axios.post('/api/tools/mcp/providers/mcprouter/list-servers', requestData); + if (response.data.status === 'ok') { + this.mcprouterServers = response.data.data || []; + this.showSuccess(response.data.message || this.tm('syncProvider.messages.fetchServersSuccess', { count: this.mcprouterServers.length })); + } else { + this.showError(response.data.message || this.tm('syncProvider.messages.fetchServersError', { error: 'Unknown error' })); + } + } catch (error) { + this.showError(this.tm('syncProvider.messages.fetchServersError', { + error: + error.response?.data?.message || + error.message || + this.tm('syncProvider.messages.networkOrApiKeyIssue') + })); + } finally { + this.mcprouterServersLoading = false; + } + }, + removeMcpRouterServer(index) { + this.mcprouterServers.splice(index, 1); + }, async syncMcpServers() { if (!this.selectedMcpServerProvider) { this.showError(this.tm('syncProvider.status.selectProvider')); @@ -503,19 +610,43 @@ export default { return; } requestData.access_token = this.mcpProviderToken.trim(); + } else if (this.selectedMcpServerProvider === 'mcprouter') { + if (!this.mcprouterApiKey.trim()) { + this.showError(this.tm('syncProvider.status.enterApiKey')); + this.loading = false; + return; + } + if (!this.mcprouterServers.length) { + this.showError(this.tm('syncProvider.status.fetchServersFirst')); + this.loading = false; + return; + } + requestData.api_key = this.mcprouterApiKey.trim(); + if (this.mcprouterAppUrl.trim()) { + requestData.app_url = this.mcprouterAppUrl.trim(); + } + if (this.mcprouterAppName.trim()) { + requestData.app_name = this.mcprouterAppName.trim(); + } + requestData.servers = this.mcprouterServers; } const response = await axios.post('/api/tools/mcp/sync-provider', requestData); if (response.data.status === 'ok') { this.showSuccess(response.data.message || this.tm('syncProvider.messages.syncSuccess')); this.showSyncMcpServerDialog = false; this.mcpProviderToken = ''; + this.mcprouterApiKey = ''; + this.mcprouterServers = []; this.getServers(); } else { this.showError(response.data.message || this.tm('syncProvider.messages.syncError', { error: 'Unknown error' })); } } catch (error) { this.showError(this.tm('syncProvider.messages.syncError', { - error: error.response?.data?.message || error.message || '网络连接或访问令牌问题' + error: + error.response?.data?.message || + error.message || + this.tm('syncProvider.messages.networkOrTokenIssue') })); } finally { this.loading = false; @@ -538,4 +669,9 @@ export default { margin-top: 4px; overflow: hidden; } + +.mcprouter-server-list { + max-height: 260px; + overflow-y: auto; +} diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json index 2c68b82435..6d548ca7ac 100644 --- a/dashboard/src/i18n/locales/en-US/features/tool-use.json +++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json @@ -109,29 +109,52 @@ }, "providers": { "modelscope": "ModelScope", + "mcprouter": "MCPRouter", "description": "ModelScope is an open model community providing MCP servers for various machine learning and AI services" }, "fields": { "provider": "Select Provider", "accessToken": "Access Token", + "apiKey": "API Key", + "appUrl": "HTTP-Referer", + "appName": "X-Title", "tokenRequired": "Access token is required", "tokenHint": "Please enter your ModelScope access token" }, + "timeline": { + "mcprouter": { + "createApiKeyTitle": "Create API Key", + "createApiKeyDescPrefix": "Visit ", + "createApiKeyDescSuffix": " to create and copy your API key.", + "inputApiKeyTitle": "Enter API Key", + "inputApiKeyDesc": "Enter your API key to sync MCP servers.", + "optionalAppInfoTitle": "Optional: App Identity", + "optionalAppInfoDesc": "Some MCPRouter services may validate the following headers (optional)." + } + }, "buttons": { "cancel": "Cancel", "previous": "Previous", "next": "Next", "sync": "Start Sync", - "getToken": "Get Token" + "getToken": "Get Token", + "fetchServers": "Fetch Servers" }, "status": { "selectProvider": "Please select an MCP server provider", "enterToken": "Please enter the access token to continue", + "enterApiKey": "Please enter the API key to continue", + "fetchServersFirst": "Please fetch the server list first", + "fetchedServers": "Fetched servers ({count})", "readyToSync": "Ready to sync server configurations" }, "messages": { "syncSuccess": "MCP servers synced successfully!", "syncError": "Sync failed: {error}", + "fetchServersSuccess": "Fetched {count} servers", + "fetchServersError": "Failed to fetch server list: {error}", + "networkOrApiKeyIssue": "Network connection issue or invalid API key", + "networkOrTokenIssue": "Network connection issue or invalid access token", "tokenHelp": "How to get a ModelScope access token? Click the button on the right for instructions" } }, diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json index f6e6c4407a..db2c152d38 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json +++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json @@ -97,42 +97,65 @@ "importConfig": "导入配置" } }, - "confirmDelete": "确定要删除服务器 {name} 吗?", - "syncProvider": { - "title": "同步 MCP 服务器", - "subtitle": "从提供商同步 MCP 服务器配置到本地", - "steps": { - "selectProvider": "步骤 1: 选择提供商", - "configureAuth": "步骤 2: 配置认证", - "syncServers": "步骤 3: 同步服务器" - }, - "providers": { - "modelscope": "ModelScope", - "description": "ModelScope 是一个开源的模型社区,提供各种机器学习和AI服务的MCP服务器" - }, - "fields": { - "provider": "选择提供商", - "accessToken": "访问令牌", - "tokenRequired": "访问令牌是必填项", - "tokenHint": "请输入您的 ModelScope 访问令牌" - }, - "buttons": { - "cancel": "取消", - "previous": "上一步", - "next": "下一步", - "sync": "开始同步", - "getToken": "获取令牌" - }, - "status": { - "selectProvider": "请选择一个 MCP 服务器提供商", - "enterToken": "请输入访问令牌以继续", - "readyToSync": "准备同步服务器配置" - }, - "messages": { - "syncSuccess": "MCP 服务器同步成功!", - "syncError": "同步失败: {error}", - "tokenHelp": "如何获取 ModelScope 访问令牌?点击右侧按钮查看说明" + "confirmDelete": "确定要删除服务器 {name} 吗?" + }, + "syncProvider": { + "title": "同步外部平台 MCP 服务器", + "subtitle": "从提供商同步 MCP 服务器配置到本地", + "steps": { + "selectProvider": "步骤 1: 选择提供商", + "configureAuth": "步骤 2: 配置认证", + "syncServers": "步骤 3: 同步服务器" + }, + "providers": { + "modelscope": "ModelScope", + "mcprouter": "MCPRouter", + "description": "ModelScope 是一个开源的模型社区,提供各种机器学习和AI服务的MCP服务器" + }, + "fields": { + "provider": "选择提供商", + "accessToken": "访问令牌", + "apiKey": "API Key", + "appUrl": "HTTP-Referer", + "appName": "X-Title", + "tokenRequired": "访问令牌是必填项", + "tokenHint": "请输入您的 ModelScope 访问令牌" + }, + "timeline": { + "mcprouter": { + "createApiKeyTitle": "创建 API Key", + "createApiKeyDescPrefix": "访问 ", + "createApiKeyDescSuffix": " 创建并复制您的 API Key。", + "inputApiKeyTitle": "输入 API Key", + "inputApiKeyDesc": "输入 API Key 以同步 MCP 服务器。", + "optionalAppInfoTitle": "可选:应用标识", + "optionalAppInfoDesc": "部分 MCPRouter 服务可能会校验以下标识(可留空)。" } + }, + "buttons": { + "cancel": "取消", + "previous": "上一步", + "next": "下一步", + "sync": "开始同步", + "getToken": "获取令牌", + "fetchServers": "获取服务器列表" + }, + "status": { + "selectProvider": "请选择一个 MCP 服务器提供商", + "enterToken": "请输入访问令牌以继续", + "enterApiKey": "请输入 API Key 以继续", + "fetchServersFirst": "请先获取服务器列表", + "fetchedServers": "已获取服务器 ({count})", + "readyToSync": "准备同步服务器配置" + }, + "messages": { + "syncSuccess": "MCP 服务器同步成功!", + "syncError": "同步失败: {error}", + "fetchServersSuccess": "已获取 {count} 个服务器", + "fetchServersError": "获取服务器列表失败: {error}", + "networkOrApiKeyIssue": "网络连接异常或 API Key 无效", + "networkOrTokenIssue": "网络连接异常或访问令牌无效", + "tokenHelp": "如何获取 ModelScope 访问令牌?点击右侧按钮查看说明" } }, "messages": { @@ -156,4 +179,4 @@ "toggleToolError": "工具状态切换失败: {error}", "testError": "测试连接失败: {error}" } -} \ No newline at end of file +}