diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 8475009d3f..0363e2d55d 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -62,4 +62,4 @@ def default_parameters(self) -> dict: def default_description(self, agent_name: str | None) -> str: agent_name = agent_name or "another" - return f"Delegate tasks to {self.name} agent to handle the request." + return f"Delegate tasks to {agent_name} agent to handle the request." diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 0f51a29c05..8ff66ced5d 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -390,14 +390,9 @@ async def _ensure_persona_and_skills( persona_tools = None pid = a.get("persona_id") if pid: - persona_tools = next( - ( - p.get("tools") - for p in plugin_context.persona_manager.personas_v3 - if p["name"] == pid - ), - None, - ) + persona = plugin_context.persona_manager.get_persona_v3_by_id(pid) + if persona is not None: + persona_tools = persona.get("tools") tools = a.get("tools", []) if persona_tools is not None: tools = persona_tools diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index d141f40e43..6320ac3bbc 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -44,6 +44,22 @@ async def get_persona(self, persona_id: str): raise ValueError(f"Persona with ID {persona_id} does not exist.") return persona + def get_persona_v3_by_id(self, persona_id: str | None) -> Personality | None: + """Resolve a v3 persona object by id. + + - None/empty id returns None. + - "default" maps to in-memory DEFAULT_PERSONALITY. + - Otherwise search in personas_v3 by persona name. + """ + if not persona_id: + return None + if persona_id == "default": + return DEFAULT_PERSONALITY + return next( + (persona for persona in self.personas_v3 if persona["name"] == persona_id), + None, + ) + async def get_default_persona_v3( self, umo: str | MessageSession | None = None, @@ -54,12 +70,7 @@ async def get_default_persona_v3( "default_personality", "default", ) - if not default_persona_id or default_persona_id == "default": - return DEFAULT_PERSONALITY - try: - return next(p for p in self.personas_v3 if p["name"] == default_persona_id) - except Exception: - return DEFAULT_PERSONALITY + return self.get_persona_v3_by_id(default_persona_id) or DEFAULT_PERSONALITY async def resolve_selected_persona( self, diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index 205c554cb8..c6c595dfc9 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -1,13 +1,16 @@ from __future__ import annotations -from typing import Any +import copy +from typing import TYPE_CHECKING, Any from astrbot import logger from astrbot.core.agent.agent import Agent from astrbot.core.agent.handoff import HandoffTool -from astrbot.core.persona_mgr import PersonaManager from astrbot.core.provider.func_tool_manager import FunctionToolManager +if TYPE_CHECKING: + from astrbot.core.persona_mgr import PersonaManager + class SubAgentOrchestrator: """Loads subagent definitions from config and registers handoff tools. @@ -43,15 +46,14 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: continue persona_id = item.get("persona_id") - persona_data = None - if persona_id: - try: - persona_data = await self._persona_mgr.get_persona(persona_id) - except StopIteration: - logger.warning( - "SubAgent persona %s not found, fallback to inline prompt.", - persona_id, - ) + if persona_id is not None: + persona_id = str(persona_id).strip() or None + persona_data = self._persona_mgr.get_persona_v3_by_id(persona_id) + if persona_id and persona_data is None: + logger.warning( + "SubAgent persona %s not found, fallback to inline prompt.", + persona_id, + ) instructions = str(item.get("system_prompt", "")).strip() public_description = str(item.get("public_description", "")).strip() @@ -62,11 +64,15 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None: begin_dialogs = None if persona_data: - instructions = persona_data.system_prompt or instructions - begin_dialogs = persona_data.begin_dialogs - tools = persona_data.tools - if public_description == "" and persona_data.system_prompt: - public_description = persona_data.system_prompt[:120] + prompt = str(persona_data.get("prompt", "")).strip() + if prompt: + instructions = prompt + begin_dialogs = copy.deepcopy( + persona_data.get("_begin_dialogs_processed") + ) + tools = persona_data.get("tools") + if public_description == "" and prompt: + public_description = prompt[:120] if tools is None: tools = None elif not isinstance(tools, list): diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index cbaec66c7c..9ff4636d25 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,4 +1,5 @@ import asyncio +import copy import os import sys from types import SimpleNamespace @@ -101,6 +102,55 @@ async def test_get_stat(app: Quart, authenticated_header: dict): assert data["status"] == "ok" and "platform" in data["data"] +@pytest.mark.asyncio +async def test_subagent_config_accepts_default_persona( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, +): + test_client = app.test_client() + old_cfg = copy.deepcopy( + core_lifecycle_td.astrbot_config.get("subagent_orchestrator", {}) + ) + payload = { + "main_enable": True, + "remove_main_duplicate_tools": True, + "agents": [ + { + "name": "planner", + "persona_id": "default", + "public_description": "planner", + "system_prompt": "", + "enabled": True, + } + ], + } + + try: + response = await test_client.post( + "/api/subagent/config", + json=payload, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + get_response = await test_client.get( + "/api/subagent/config", headers=authenticated_header + ) + assert get_response.status_code == 200 + get_data = await get_response.get_json() + assert get_data["status"] == "ok" + assert get_data["data"]["agents"][0]["persona_id"] == "default" + finally: + await test_client.post( + "/api/subagent/config", + json=old_cfg, + headers=authenticated_header, + ) + + @pytest.mark.asyncio async def test_plugins( app: Quart, diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 3ce974f419..bd4d476d2b 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -39,6 +39,7 @@ def mock_context(): ctx.persona_manager.resolve_selected_persona = AsyncMock( return_value=(None, None, None, False) ) + ctx.persona_manager.get_persona_v3_by_id = MagicMock(return_value=None) ctx.get_llm_tool_manager.return_value = MagicMock() ctx.subagent_orchestrator = None return ctx @@ -562,6 +563,63 @@ async def test_ensure_tools_from_persona(self, mock_event, mock_context): assert req.func_tool is not None + @pytest.mark.asyncio + async def test_subagent_dedupe_uses_default_persona_tools( + self, mock_event, mock_context + ): + """Test dedupe uses resolved default persona tools in subagent mode.""" + module = ama + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=(None, None, None, False) + ) + mock_context.persona_manager.get_persona_v3_by_id = MagicMock( + return_value={"name": "default", "tools": ["tool_a"]} + ) + + tool_a = FunctionTool( + name="tool_a", + parameters={"type": "object", "properties": {}}, + description="tool a", + ) + tool_b = FunctionTool( + name="tool_b", + parameters={"type": "object", "properties": {}}, + description="tool b", + ) + tmgr = mock_context.get_llm_tool_manager.return_value + tmgr.func_list = [tool_a, tool_b] + tmgr.get_full_tool_set.return_value = ToolSet([tool_a, tool_b]) + tmgr.get_func.side_effect = lambda name: {"tool_a": tool_a, "tool_b": tool_b}.get( + name + ) + + handoff = MagicMock() + handoff.name = "transfer_to_planner" + mock_context.subagent_orchestrator = MagicMock(handoffs=[handoff]) + mock_context.get_config.return_value = { + "subagent_orchestrator": { + "main_enable": True, + "remove_main_duplicate_tools": True, + "agents": [ + { + "name": "planner", + "enabled": True, + "persona_id": "default", + } + ], + } + } + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + + assert req.func_tool is not None + assert "transfer_to_planner" in req.func_tool.names() + assert "tool_a" not in req.func_tool.names() + assert "tool_b" in req.func_tool.names() + class TestDecorateLlmRequest: """Tests for _decorate_llm_request function.""" diff --git a/tests/unit/test_subagent_orchestrator.py b/tests/unit/test_subagent_orchestrator.py new file mode 100644 index 0000000000..9befac8872 --- /dev/null +++ b/tests/unit/test_subagent_orchestrator.py @@ -0,0 +1,110 @@ +from copy import deepcopy +from unittest.mock import MagicMock, patch + +import pytest + +from astrbot.core.subagent_orchestrator import SubAgentOrchestrator + + +def _build_cfg(agent_overrides: dict) -> dict: + agent = { + "name": "planner", + "enabled": True, + "persona_id": None, + "system_prompt": "inline prompt", + "public_description": "", + "tools": ["tool_a", " ", "tool_b"], + } + agent.update(agent_overrides) + return {"agents": [agent]} + + +@pytest.mark.asyncio +async def test_reload_from_config_default_persona_is_resolved(): + tool_mgr = MagicMock() + persona_mgr = MagicMock() + default_persona = { + "name": "default", + "prompt": "You are a helpful and friendly assistant.", + "tools": None, + "_begin_dialogs_processed": [], + } + persona_mgr.get_persona_v3_by_id.return_value = deepcopy(default_persona) + orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) + + await orchestrator.reload_from_config(_build_cfg({"persona_id": "default"})) + + assert len(orchestrator.handoffs) == 1 + handoff = orchestrator.handoffs[0] + assert handoff.agent.instructions == default_persona["prompt"] + assert handoff.agent.tools is None + assert handoff.agent.begin_dialogs == default_persona["_begin_dialogs_processed"] + + +@pytest.mark.asyncio +async def test_reload_from_config_missing_persona_falls_back_to_inline_and_warns(): + tool_mgr = MagicMock() + persona_mgr = MagicMock() + persona_mgr.get_persona_v3_by_id.return_value = None + orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) + + with patch("astrbot.core.subagent_orchestrator.logger") as mock_logger: + await orchestrator.reload_from_config(_build_cfg({"persona_id": "not_exists"})) + + assert len(orchestrator.handoffs) == 1 + handoff = orchestrator.handoffs[0] + assert handoff.agent.instructions == "inline prompt" + assert handoff.agent.tools == ["tool_a", "tool_b"] + assert handoff.agent.begin_dialogs is None + mock_logger.warning.assert_called_once_with( + "SubAgent persona %s not found, fallback to inline prompt.", + "not_exists", + ) + + +@pytest.mark.asyncio +async def test_reload_from_config_uses_processed_begin_dialogs_and_deepcopy(): + tool_mgr = MagicMock() + persona_mgr = MagicMock() + processed_dialogs = [{"role": "user", "content": "hello", "_no_save": True}] + persona_mgr.get_persona_v3_by_id.return_value = { + "name": "custom", + "prompt": "persona prompt", + "tools": ["tool_from_persona"], + "_begin_dialogs_processed": processed_dialogs, + } + orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) + + await orchestrator.reload_from_config(_build_cfg({"persona_id": "custom"})) + processed_dialogs[0]["content"] = "mutated" + + handoff = orchestrator.handoffs[0] + assert handoff.agent.instructions == "persona prompt" + assert handoff.agent.tools == ["tool_from_persona"] + assert handoff.agent.begin_dialogs[0]["content"] == "hello" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("raw_tools", "expected_tools"), + [ + (None, None), + ([], []), + ("not-a-list", []), + ], +) +async def test_reload_from_config_tool_normalization(raw_tools, expected_tools): + tool_mgr = MagicMock() + persona_mgr = MagicMock() + persona_mgr.get_persona_v3_by_id.return_value = { + "name": "custom", + "prompt": "persona prompt", + "tools": raw_tools, + "_begin_dialogs_processed": [], + } + orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr) + + await orchestrator.reload_from_config(_build_cfg({"persona_id": "custom"})) + + handoff = orchestrator.handoffs[0] + assert handoff.agent.tools == expected_tools