Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion astrbot/core/agent/handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def default_parameters(self) -> dict:

def default_description(self, agent_name: str | None) -> str:
agent_name = agent_name or "another"
return f"Delegate tasks to {self.name} agent to handle the request."
return f"Delegate tasks to {agent_name} agent to handle the request."
11 changes: 3 additions & 8 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,9 @@ async def _ensure_persona_and_skills(
persona_tools = None
pid = a.get("persona_id")
if pid:
persona_tools = next(
(
p.get("tools")
for p in plugin_context.persona_manager.personas_v3
if p["name"] == pid
),
None,
)
persona = plugin_context.persona_manager.get_persona_v3_by_id(pid)
if persona is not None:
persona_tools = persona.get("tools")
tools = a.get("tools", [])
if persona_tools is not None:
tools = persona_tools
Expand Down
23 changes: 17 additions & 6 deletions astrbot/core/persona_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ async def get_persona(self, persona_id: str):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona

def get_persona_v3_by_id(self, persona_id: str | None) -> Personality | None:
"""Resolve a v3 persona object by id.

- None/empty id returns None.
- "default" maps to in-memory DEFAULT_PERSONALITY.
- Otherwise search in personas_v3 by persona name.
"""
if not persona_id:
return None
if persona_id == "default":
return DEFAULT_PERSONALITY
return next(
(persona for persona in self.personas_v3 if persona["name"] == persona_id),
None,
)

async def get_default_persona_v3(
self,
umo: str | MessageSession | None = None,
Expand All @@ -54,12 +70,7 @@ async def get_default_persona_v3(
"default_personality",
"default",
)
if not default_persona_id or default_persona_id == "default":
return DEFAULT_PERSONALITY
try:
return next(p for p in self.personas_v3 if p["name"] == default_persona_id)
except Exception:
return DEFAULT_PERSONALITY
return self.get_persona_v3_by_id(default_persona_id) or DEFAULT_PERSONALITY

async def resolve_selected_persona(
self,
Expand Down
38 changes: 22 additions & 16 deletions astrbot/core/subagent_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

from typing import Any
import copy
from typing import TYPE_CHECKING, Any

from astrbot import logger
from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.func_tool_manager import FunctionToolManager

if TYPE_CHECKING:
from astrbot.core.persona_mgr import PersonaManager


class SubAgentOrchestrator:
"""Loads subagent definitions from config and registers handoff tools.
Expand Down Expand Up @@ -43,15 +46,14 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None:
continue

persona_id = item.get("persona_id")
persona_data = None
if persona_id:
try:
persona_data = await self._persona_mgr.get_persona(persona_id)
except StopIteration:
logger.warning(
"SubAgent persona %s not found, fallback to inline prompt.",
persona_id,
)
if persona_id is not None:
persona_id = str(persona_id).strip() or None
persona_data = self._persona_mgr.get_persona_v3_by_id(persona_id)
if persona_id and persona_data is None:
logger.warning(
"SubAgent persona %s not found, fallback to inline prompt.",
persona_id,
)

instructions = str(item.get("system_prompt", "")).strip()
public_description = str(item.get("public_description", "")).strip()
Expand All @@ -62,11 +64,15 @@ async def reload_from_config(self, cfg: dict[str, Any]) -> None:
begin_dialogs = None

if persona_data:
instructions = persona_data.system_prompt or instructions
begin_dialogs = persona_data.begin_dialogs
tools = persona_data.tools
if public_description == "" and persona_data.system_prompt:
public_description = persona_data.system_prompt[:120]
prompt = str(persona_data.get("prompt", "")).strip()
if prompt:
instructions = prompt
begin_dialogs = copy.deepcopy(
persona_data.get("_begin_dialogs_processed")
)
tools = persona_data.get("tools")
if public_description == "" and prompt:
public_description = prompt[:120]
if tools is None:
tools = None
elif not isinstance(tools, list):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_dashboard.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import os
import sys
from types import SimpleNamespace
Expand Down Expand Up @@ -101,6 +102,55 @@ async def test_get_stat(app: Quart, authenticated_header: dict):
assert data["status"] == "ok" and "platform" in data["data"]


@pytest.mark.asyncio
async def test_subagent_config_accepts_default_persona(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
):
test_client = app.test_client()
old_cfg = copy.deepcopy(
core_lifecycle_td.astrbot_config.get("subagent_orchestrator", {})
)
payload = {
"main_enable": True,
"remove_main_duplicate_tools": True,
"agents": [
{
"name": "planner",
"persona_id": "default",
"public_description": "planner",
"system_prompt": "",
"enabled": True,
}
],
}

try:
response = await test_client.post(
"/api/subagent/config",
json=payload,
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"

get_response = await test_client.get(
"/api/subagent/config", headers=authenticated_header
)
assert get_response.status_code == 200
get_data = await get_response.get_json()
assert get_data["status"] == "ok"
assert get_data["data"]["agents"][0]["persona_id"] == "default"
finally:
await test_client.post(
"/api/subagent/config",
json=old_cfg,
headers=authenticated_header,
)


@pytest.mark.asyncio
async def test_plugins(
app: Quart,
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/test_astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def mock_context():
ctx.persona_manager.resolve_selected_persona = AsyncMock(
return_value=(None, None, None, False)
)
ctx.persona_manager.get_persona_v3_by_id = MagicMock(return_value=None)
ctx.get_llm_tool_manager.return_value = MagicMock()
ctx.subagent_orchestrator = None
return ctx
Expand Down Expand Up @@ -562,6 +563,63 @@ async def test_ensure_tools_from_persona(self, mock_event, mock_context):

assert req.func_tool is not None

@pytest.mark.asyncio
async def test_subagent_dedupe_uses_default_persona_tools(
self, mock_event, mock_context
):
"""Test dedupe uses resolved default persona tools in subagent mode."""
module = ama
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=(None, None, None, False)
)
mock_context.persona_manager.get_persona_v3_by_id = MagicMock(
return_value={"name": "default", "tools": ["tool_a"]}
)

tool_a = FunctionTool(
name="tool_a",
parameters={"type": "object", "properties": {}},
description="tool a",
)
tool_b = FunctionTool(
name="tool_b",
parameters={"type": "object", "properties": {}},
description="tool b",
)
tmgr = mock_context.get_llm_tool_manager.return_value
tmgr.func_list = [tool_a, tool_b]
tmgr.get_full_tool_set.return_value = ToolSet([tool_a, tool_b])
tmgr.get_func.side_effect = lambda name: {"tool_a": tool_a, "tool_b": tool_b}.get(
name
)

handoff = MagicMock()
handoff.name = "transfer_to_planner"
mock_context.subagent_orchestrator = MagicMock(handoffs=[handoff])
mock_context.get_config.return_value = {
"subagent_orchestrator": {
"main_enable": True,
"remove_main_duplicate_tools": True,
"agents": [
{
"name": "planner",
"enabled": True,
"persona_id": "default",
}
],
}
}

req = ProviderRequest()
req.conversation = MagicMock(persona_id=None)

await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)

assert req.func_tool is not None
assert "transfer_to_planner" in req.func_tool.names()
assert "tool_a" not in req.func_tool.names()
assert "tool_b" in req.func_tool.names()


class TestDecorateLlmRequest:
"""Tests for _decorate_llm_request function."""
Expand Down
110 changes: 110 additions & 0 deletions tests/unit/test_subagent_orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from copy import deepcopy
from unittest.mock import MagicMock, patch

import pytest

from astrbot.core.subagent_orchestrator import SubAgentOrchestrator


def _build_cfg(agent_overrides: dict) -> dict:
agent = {
"name": "planner",
"enabled": True,
"persona_id": None,
"system_prompt": "inline prompt",
"public_description": "",
"tools": ["tool_a", " ", "tool_b"],
}
agent.update(agent_overrides)
return {"agents": [agent]}


@pytest.mark.asyncio
async def test_reload_from_config_default_persona_is_resolved():
tool_mgr = MagicMock()
persona_mgr = MagicMock()
default_persona = {
"name": "default",
"prompt": "You are a helpful and friendly assistant.",
"tools": None,
"_begin_dialogs_processed": [],
}
persona_mgr.get_persona_v3_by_id.return_value = deepcopy(default_persona)
orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr)

await orchestrator.reload_from_config(_build_cfg({"persona_id": "default"}))

assert len(orchestrator.handoffs) == 1
handoff = orchestrator.handoffs[0]
assert handoff.agent.instructions == default_persona["prompt"]
assert handoff.agent.tools is None
assert handoff.agent.begin_dialogs == default_persona["_begin_dialogs_processed"]


@pytest.mark.asyncio
async def test_reload_from_config_missing_persona_falls_back_to_inline_and_warns():
tool_mgr = MagicMock()
persona_mgr = MagicMock()
persona_mgr.get_persona_v3_by_id.return_value = None
orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr)

with patch("astrbot.core.subagent_orchestrator.logger") as mock_logger:
await orchestrator.reload_from_config(_build_cfg({"persona_id": "not_exists"}))

assert len(orchestrator.handoffs) == 1
handoff = orchestrator.handoffs[0]
assert handoff.agent.instructions == "inline prompt"
assert handoff.agent.tools == ["tool_a", "tool_b"]
assert handoff.agent.begin_dialogs is None
mock_logger.warning.assert_called_once_with(
"SubAgent persona %s not found, fallback to inline prompt.",
"not_exists",
)


@pytest.mark.asyncio
async def test_reload_from_config_uses_processed_begin_dialogs_and_deepcopy():
tool_mgr = MagicMock()
persona_mgr = MagicMock()
processed_dialogs = [{"role": "user", "content": "hello", "_no_save": True}]
persona_mgr.get_persona_v3_by_id.return_value = {
"name": "custom",
"prompt": "persona prompt",
"tools": ["tool_from_persona"],
"_begin_dialogs_processed": processed_dialogs,
}
orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr)

await orchestrator.reload_from_config(_build_cfg({"persona_id": "custom"}))
processed_dialogs[0]["content"] = "mutated"

handoff = orchestrator.handoffs[0]
assert handoff.agent.instructions == "persona prompt"
assert handoff.agent.tools == ["tool_from_persona"]
assert handoff.agent.begin_dialogs[0]["content"] == "hello"


@pytest.mark.asyncio
@pytest.mark.parametrize(
("raw_tools", "expected_tools"),
[
(None, None),
([], []),
("not-a-list", []),
],
)
async def test_reload_from_config_tool_normalization(raw_tools, expected_tools):
tool_mgr = MagicMock()
persona_mgr = MagicMock()
persona_mgr.get_persona_v3_by_id.return_value = {
"name": "custom",
"prompt": "persona prompt",
"tools": raw_tools,
"_begin_dialogs_processed": [],
}
orchestrator = SubAgentOrchestrator(tool_mgr=tool_mgr, persona_mgr=persona_mgr)

await orchestrator.reload_from_config(_build_cfg({"persona_id": "custom"}))

handoff = orchestrator.handoffs[0]
assert handoff.agent.tools == expected_tools