diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 333700410..456385b48 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -12,6 +12,32 @@ DEFAULT_MCP_CONFIG = {"mcpServers": {}} +class EmptyMcpServersError(ValueError): + """mcpServers 为空时抛出""" + + pass + + +def _extract_mcp_server_config(mcp_servers_value: object) -> dict: + """从用户提交的 mcpServers 字段中提取服务器配置。 + + Raises: + ValueError: 配置不合法 + """ + if not isinstance(mcp_servers_value, dict): + raise ValueError("mcpServers 必须是一个 JSON 对象") + if not mcp_servers_value: + raise EmptyMcpServersError("mcpServers 配置不能为空") + key_0 = next(iter(mcp_servers_value)) + extracted = mcp_servers_value[key_0] + if not isinstance(extracted, dict): + raise ValueError( + "mcpServers 配置格式不正确。请确保 mcpServers 内部的 key 是服务器名称," + "其值为包含 command/url 等字段的对象。" + ) + return extracted + + class ToolsRoute(Route): def __init__( self, @@ -33,13 +59,37 @@ def __init__( self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools + def _rollback_mcp_server(self, name: str) -> bool: + try: + rollback_config = self.tool_mgr.load_mcp_config() + if name in rollback_config["mcpServers"]: + rollback_config["mcpServers"].pop(name) + return self.tool_mgr.save_mcp_config(rollback_config) + return True + except Exception: + logger.error(traceback.format_exc()) + return False + async def get_mcp_servers(self): try: config = self.tool_mgr.load_mcp_config() servers = [] + mcp_servers = config.get("mcpServers", {}) + + if not isinstance(mcp_servers, dict): + logger.warning( + f"MCP 服务器配置无效(类型为 {type(mcp_servers).__name__}),应为对象/字典类型,已跳过所有 MCP 服务器" + ) + mcp_servers = {} # 获取所有服务器并添加它们的工具列表 - for name, server_config in config["mcpServers"].items(): + for name, server_config in mcp_servers.items(): + if not isinstance(server_config, dict): + logger.warning( + f"MCP 服务器 '{name}' 的配置无效(类型为 {type(server_config).__name__}),已跳过" + ) + continue + server_info = { "name": name, "active": server_config.get("active", True), @@ -87,10 +137,12 @@ async def add_mcp_server(self): for key, value in server_data.items(): if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段 if key == "mcpServers": - key_0 = list(server_data["mcpServers"].keys())[ - 0 - ] # 不考虑为空的情况 - server_config = server_data["mcpServers"][key_0] + try: + server_config = _extract_mcp_server_config( + server_data["mcpServers"] + ) + except ValueError as e: + return Response().error(f"{e!s}").__dict__ else: server_config[key] = value has_valid_config = True @@ -103,6 +155,12 @@ async def add_mcp_server(self): if name in config["mcpServers"]: return Response().error(f"服务器 {name} 已存在").__dict__ + try: + await self.tool_mgr.test_mcp_server_connection(server_config) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__ + config["mcpServers"][name] = server_config if self.tool_mgr.save_mcp_config(config): @@ -113,12 +171,18 @@ async def add_mcp_server(self): timeout=30, ) except TimeoutError: - return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__ + rollback_ok = self._rollback_mcp_server(name) + err_msg = f"启用 MCP 服务器 {name} 超时。" + if not rollback_ok: + err_msg += " 配置回滚失败,请手动检查配置。" + return Response().error(err_msg).__dict__ except Exception as e: logger.error(traceback.format_exc()) - return ( - Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__ - ) + rollback_ok = self._rollback_mcp_server(name) + err_msg = f"启用 MCP 服务器 {name} 失败: {e!s}" + if not rollback_ok: + err_msg += " 配置回滚失败,请手动检查配置。" + return Response().error(err_msg).__dict__ return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ return Response().error("保存配置失败").__dict__ except Exception as e: @@ -146,10 +210,12 @@ async def update_mcp_server(self): return Response().error(f"服务器 {name} 已存在").__dict__ # 获取活动状态 - active = server_data.get( - "active", - config["mcpServers"][old_name].get("active", True), - ) + old_config = config["mcpServers"][old_name] + if isinstance(old_config, dict): + old_active = old_config.get("active", True) + else: + old_active = True + active = server_data.get("active", old_active) # 创建新的配置对象 server_config = {"active": active} @@ -167,17 +233,19 @@ async def update_mcp_server(self): "oldName", ]: # 排除特殊字段 if key == "mcpServers": - key_0 = list(server_data["mcpServers"].keys())[ - 0 - ] # 不考虑为空的情况 - server_config = server_data["mcpServers"][key_0] + try: + server_config = _extract_mcp_server_config( + server_data["mcpServers"] + ) + except ValueError as e: + return Response().error(f"{e!s}").__dict__ else: server_config[key] = value only_update_active = False # 如果只更新活动状态,保留原始配置 - if only_update_active: - for key, value in config["mcpServers"][old_name].items(): + if only_update_active and isinstance(old_config, dict): + for key, value in old_config.items(): if key != "active": # 除了active之外的所有字段都保留 server_config[key] = value @@ -302,12 +370,15 @@ async def test_mcp_connection(self): return Response().error("无效的 MCP 服务器配置").__dict__ if "mcpServers" in config: - keys = list(config["mcpServers"].keys()) - if not keys: - return Response().error("MCP 服务器配置不能为空").__dict__ - if len(keys) > 1: + mcp_servers = config["mcpServers"] + if isinstance(mcp_servers, dict) and len(mcp_servers) > 1: return Response().error("一次只能配置一个 MCP 服务器配置").__dict__ - config = config["mcpServers"][keys[0]] + try: + config = _extract_mcp_server_config(mcp_servers) + except EmptyMcpServersError: + return Response().error("MCP 服务器配置不能为空").__dict__ + except ValueError as e: + return Response().error(f"{e!s}").__dict__ elif not config: return Response().error("MCP 服务器配置不能为空").__dict__ diff --git a/dashboard/src/components/extension/McpServersSection.vue b/dashboard/src/components/extension/McpServersSection.vue index 95b679580..d24bcec58 100644 --- a/dashboard/src/components/extension/McpServersSection.vue +++ b/dashboard/src/components/extension/McpServersSection.vue @@ -300,6 +300,10 @@ export default { this.loadingGettingServers = true; axios.get('/api/tools/mcp/servers') .then(response => { + if (response.data.status === 'error') { + this.showError(response.data.message || this.tm('messages.getServersError', { error: 'Unknown error' })); + return; + } this.mcpServers = response.data.data || []; this.mcpServers.forEach(server => { if (!this.mcpServerUpdateLoaders[server.name]) { @@ -372,6 +376,10 @@ export default { axios.post(endpoint, serverData) .then(response => { this.loading = false; + if (response.data.status === 'error') { + this.showError(response.data.message || this.tm('messages.saveError', { error: 'Unknown error' })); + return; + } this.showMcpServerDialog = false; this.addServerDialogMessage = ''; this.getServers();