Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 95 additions & 24 deletions astrbot/dashboard/routes/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
DEFAULT_MCP_CONFIG = {"mcpServers": {}}


class EmptyMcpServersError(ValueError):
"""mcpServers 为空时抛出"""

pass


def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
"""从用户提交的 mcpServers 字段中提取服务器配置。

Raises:
ValueError: 配置不合法
"""
if not isinstance(mcp_servers_value, dict):
raise ValueError("mcpServers 必须是一个 JSON 对象")
if not mcp_servers_value:
raise EmptyMcpServersError("mcpServers 配置不能为空")
key_0 = next(iter(mcp_servers_value))
extracted = mcp_servers_value[key_0]
if not isinstance(extracted, dict):
raise ValueError(
"mcpServers 配置格式不正确。请确保 mcpServers 内部的 key 是服务器名称,"
"其值为包含 command/url 等字段的对象。"
)
return extracted


class ToolsRoute(Route):
def __init__(
self,
Expand All @@ -33,13 +59,37 @@ def __init__(
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools

def _rollback_mcp_server(self, name: str) -> bool:
try:
rollback_config = self.tool_mgr.load_mcp_config()
if name in rollback_config["mcpServers"]:
rollback_config["mcpServers"].pop(name)
return self.tool_mgr.save_mcp_config(rollback_config)
return True
except Exception:
logger.error(traceback.format_exc())
return False

async def get_mcp_servers(self):
try:
config = self.tool_mgr.load_mcp_config()
servers = []
mcp_servers = config.get("mcpServers", {})

if not isinstance(mcp_servers, dict):
logger.warning(
f"MCP 服务器配置无效(类型为 {type(mcp_servers).__name__}),应为对象/字典类型,已跳过所有 MCP 服务器"
)
mcp_servers = {}

# 获取所有服务器并添加它们的工具列表
for name, server_config in config["mcpServers"].items():
for name, server_config in mcp_servers.items():
if not isinstance(server_config, dict):
logger.warning(
f"MCP 服务器 '{name}' 的配置无效(类型为 {type(server_config).__name__}),已跳过"
)
continue

server_info = {
"name": name,
"active": server_config.get("active", True),
Expand Down Expand Up @@ -87,10 +137,12 @@ async def add_mcp_server(self):
for key, value in server_data.items():
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
if key == "mcpServers":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): 建议把重复的 mcpServers 解析逻辑和 old_config 归一化逻辑抽取到辅助函数中,在保留所有校验的同时,减少重复代码和分散的类型检查。

你可以通过集中处理重复的 mcpServers 逻辑并把 old_config 归一化到辅助函数中,来降低新增的复杂度。这样既能保留所有现有的校验和行为,又能去除重复代码和分散的类型检查。

1. 将 mcpServers 解析提取为一个辅助函数

add_mcp_serverupdate_mcp_serverkey == "mcpServers" 的分支是相同的。可以将其提取到一个小的辅助函数中,该函数返回已验证的 dict 或错误的 Response

def _extract_mcp_server_config(self, server_data: dict):
    mcp_servers = server_data.get("mcpServers") or {}
    mcp_keys = list(mcp_servers.keys())
    if not mcp_keys:
        return None, Response().error("mcpServers 配置不能为空").__dict__

    key_0 = mcp_keys[0]
    extracted = mcp_servers[key_0]
    if not isinstance(extracted, dict):
        return None, Response().error(
            "mcpServers 配置格式不正确。请确保 mcpServers 内部的 key 是服务器名称,"
            "其值为包含 command/url 等字段的对象。"
        ).__dict__

    return extracted, None

然后在两个位置中复用它,简化分支逻辑:

# in add_mcp_server
for key, value in server_data.items():
    if key not in ["name", "active", "tools", "errlogs"]:
        if key == "mcpServers":
            extracted, error = self._extract_mcp_server_config(server_data)
            if error:
                return error
            server_config = extracted
        else:
            server_config[key] = value
        has_valid_config = True
# in update_mcp_server
for key, value in server_data.items():
    if key not in ["name", "active", "tools", "errlogs", "oldName"]:
        if key == "mcpServers":
            extracted, error = self._extract_mcp_server_config(server_data)
            if error:
                return error
            server_config = extracted
        else:
            server_config[key] = value
        only_update_active = False

这样可以保留当前所有的错误信息和行为,同时去除重复逻辑和分支。

2. 将 old_config 归一化为一个安全的 dict

可以把多处重复的 isinstance(old_config, dict) 检查封装到一个归一化函数中,让路由逻辑可以直接假定拿到的是一个 dict:

def _normalize_old_config(self, config: dict, old_name: str) -> dict:
    old_config = config["mcpServers"][old_name]
    if isinstance(old_config, dict):
        return old_config
    # Preserve current “fallback to defaults” behavior
    logger.warning(
        f"MCP 服务器 '{old_name}' 的配置无效(类型为 {type(old_config).__name__}),"
        "将使用默认 active 并忽略其他字段"
    )
    return {}

然后可以简化 update_mcp_server

old_config = self._normalize_old_config(config, old_name)
old_active = old_config.get("active", True)
active = server_data.get("active", old_active)

server_config = {"active": active}
only_update_active = True

# ... loop as above ...

# 如果只更新活动状态,保留原始配置
if only_update_active:
    for key, value in old_config.items():
        if key != "active":
            server_config[key] = value

这样可以在保留当前回退逻辑(非 dict -> 使用默认 active,忽略其他字段)的同时,移除分散的 isinstance(old_config, dict) 检查。

Original comment in English

issue (complexity): Consider extracting the repeated mcpServers parsing and old_config normalization into helper functions to keep validations while reducing duplication and scattered type checks.

You can reduce the new complexity by centralizing the repeated mcpServers handling and normalizing old_config into helpers. This keeps all new validations and behavior but removes duplication and scattered type checks.

1. Extract mcpServers parsing into a helper

The key == "mcpServers" blocks in add_mcp_server and update_mcp_server are identical. Move that into a small helper that returns either a validated dict or an error Response:

def _extract_mcp_server_config(self, server_data: dict):
    mcp_servers = server_data.get("mcpServers") or {}
    mcp_keys = list(mcp_servers.keys())
    if not mcp_keys:
        return None, Response().error("mcpServers 配置不能为空").__dict__

    key_0 = mcp_keys[0]
    extracted = mcp_servers[key_0]
    if not isinstance(extracted, dict):
        return None, Response().error(
            "mcpServers 配置格式不正确。请确保 mcpServers 内部的 key 是服务器名称,"
            "其值为包含 command/url 等字段的对象。"
        ).__dict__

    return extracted, None

Then use it in both places to simplify branching:

# in add_mcp_server
for key, value in server_data.items():
    if key not in ["name", "active", "tools", "errlogs"]:
        if key == "mcpServers":
            extracted, error = self._extract_mcp_server_config(server_data)
            if error:
                return error
            server_config = extracted
        else:
            server_config[key] = value
        has_valid_config = True
# in update_mcp_server
for key, value in server_data.items():
    if key not in ["name", "active", "tools", "errlogs", "oldName"]:
        if key == "mcpServers":
            extracted, error = self._extract_mcp_server_config(server_data)
            if error:
                return error
            server_config = extracted
        else:
            server_config[key] = value
        only_update_active = False

This keeps all current error messages and behavior, but removes duplicated logic and branching.

2. Normalize old_config into a safe dict

The repeated isinstance(old_config, dict) checks can be hidden behind a normalizer, so the route code can assume a dict:

def _normalize_old_config(self, config: dict, old_name: str) -> dict:
    old_config = config["mcpServers"][old_name]
    if isinstance(old_config, dict):
        return old_config
    # Preserve current “fallback to defaults” behavior
    logger.warning(
        f"MCP 服务器 '{old_name}' 的配置无效(类型为 {type(old_config).__name__}),"
        "将使用默认 active 并忽略其他字段"
    )
    return {}

Then simplify update_mcp_server:

old_config = self._normalize_old_config(config, old_name)
old_active = old_config.get("active", True)
active = server_data.get("active", old_active)

server_config = {"active": active}
only_update_active = True

# ... loop as above ...

# 如果只更新活动状态,保留原始配置
if only_update_active:
    for key, value in old_config.items():
        if key != "active":
            server_config[key] = value

This removes the scattered isinstance(old_config, dict) checks while preserving the current fallback logic (non-dict -> default active, ignore other fields).

key_0 = list(server_data["mcpServers"].keys())[
0
] # 不考虑为空的情况
server_config = server_data["mcpServers"][key_0]
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else:
server_config[key] = value
has_valid_config = True
Expand All @@ -103,6 +155,12 @@ async def add_mcp_server(self):
if name in config["mcpServers"]:
return Response().error(f"服务器 {name} 已存在").__dict__

try:
await self.tool_mgr.test_mcp_server_connection(server_config)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__

config["mcpServers"][name] = server_config

if self.tool_mgr.save_mcp_config(config):
Expand All @@ -113,12 +171,18 @@ async def add_mcp_server(self):
timeout=30,
)
except TimeoutError:
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"启用 MCP 服务器 {name} 超时。"
if not rollback_ok:
err_msg += " 配置回滚失败,请手动检查配置。"
return Response().error(err_msg).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return (
Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__
)
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"启用 MCP 服务器 {name} 失败: {e!s}"
if not rollback_ok:
err_msg += " 配置回滚失败,请手动检查配置。"
return Response().error(err_msg).__dict__
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
except Exception as e:
Expand Down Expand Up @@ -146,10 +210,12 @@ async def update_mcp_server(self):
return Response().error(f"服务器 {name} 已存在").__dict__

# 获取活动状态
active = server_data.get(
"active",
config["mcpServers"][old_name].get("active", True),
)
old_config = config["mcpServers"][old_name]
if isinstance(old_config, dict):
old_active = old_config.get("active", True)
else:
old_active = True
active = server_data.get("active", old_active)

# 创建新的配置对象
server_config = {"active": active}
Expand All @@ -167,17 +233,19 @@ async def update_mcp_server(self):
"oldName",
]: # 排除特殊字段
if key == "mcpServers":
key_0 = list(server_data["mcpServers"].keys())[
0
] # 不考虑为空的情况
server_config = server_data["mcpServers"][key_0]
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else:
server_config[key] = value
only_update_active = False

# 如果只更新活动状态,保留原始配置
if only_update_active:
for key, value in config["mcpServers"][old_name].items():
if only_update_active and isinstance(old_config, dict):
for key, value in old_config.items():
if key != "active": # 除了active之外的所有字段都保留
server_config[key] = value

Expand Down Expand Up @@ -302,12 +370,15 @@ async def test_mcp_connection(self):
return Response().error("无效的 MCP 服务器配置").__dict__

if "mcpServers" in config:
keys = list(config["mcpServers"].keys())
if not keys:
return Response().error("MCP 服务器配置不能为空").__dict__
if len(keys) > 1:
mcp_servers = config["mcpServers"]
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
config = config["mcpServers"][keys[0]]
try:
config = _extract_mcp_server_config(mcp_servers)
except EmptyMcpServersError:
return Response().error("MCP 服务器配置不能为空").__dict__
except ValueError as e:
return Response().error(f"{e!s}").__dict__
elif not config:
return Response().error("MCP 服务器配置不能为空").__dict__

Expand Down
8 changes: 8 additions & 0 deletions dashboard/src/components/extension/McpServersSection.vue
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ export default {
this.loadingGettingServers = true;
axios.get('/api/tools/mcp/servers')
.then(response => {
if (response.data.status === 'error') {
this.showError(response.data.message || this.tm('messages.getServersError', { error: 'Unknown error' }));
return;
}
this.mcpServers = response.data.data || [];
this.mcpServers.forEach(server => {
if (!this.mcpServerUpdateLoaders[server.name]) {
Expand Down Expand Up @@ -372,6 +376,10 @@ export default {
axios.post(endpoint, serverData)
.then(response => {
this.loading = false;
if (response.data.status === 'error') {
this.showError(response.data.message || this.tm('messages.saveError', { error: 'Unknown error' }));
return;
}
this.showMcpServerDialog = false;
this.addServerDialogMessage = '';
this.getServers();
Expand Down