diff --git a/AGENTS.md b/AGENTS.md index 9f3617ce9..9e89ab222 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -19,6 +19,16 @@ pnpm dev Runs on `http://localhost:3000` by default. +## Test + +```bash +uv run pytest +``` + +Please run the tests after modifying the code to ensure everything works as expected and to prevent regressions. + + + ## Dev environment tips 1. When modifying the WebUI, be sure to maintain componentization and clean code. Avoid duplicate code. diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 8475009d3..e9cbeca69 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -33,6 +33,10 @@ def __init__( # Optional provider override for this subagent. When set, the handoff # execution will use this chat provider id instead of the global/default. self.provider_id: str | None = None + # Human-readable display name shown in orchestration UI/logs. + self.agent_display_name: str | None = None + # Optional per-subagent max steps override. + self.max_steps: int | None = None # Note: Must assign after super().__init__() to prevent parent class from overriding this attribute self.agent = agent @@ -62,4 +66,4 @@ def default_parameters(self) -> dict: def default_description(self, agent_name: str | None) -> str: agent_name = agent_name or "another" - return f"Delegate tasks to {self.name} agent to handle the request." + return f"Delegate tasks to {agent_name} agent to handle the request." diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 743b28007..74b607301 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -750,6 +750,7 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None: executor = self.tool_executor.execute( tool=func_tool, run_context=self.run_context, + tool_call_id=func_tool_id, **valid_params, # 只传递有效的参数 ) diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 0dc8b9eeb..428f0f8c6 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -1,122 +1,30 @@ import asyncio import inspect -import json import traceback import typing as T import uuid -from collections.abc import Sequence -from collections.abc import Set as AbstractSet import mcp from astrbot import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool -from astrbot.core.agent.message import Message from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.agent.tool import FunctionTool from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.astr_main_agent_resources import ( - BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT, - EXECUTE_SHELL_TOOL, - FILE_DOWNLOAD_TOOL, - FILE_UPLOAD_TOOL, - LOCAL_EXECUTE_SHELL_TOOL, - LOCAL_PYTHON_TOOL, - PYTHON_TOOL, - SEND_MESSAGE_TO_USER_TOOL, -) -from astrbot.core.cron.events import CronMessageEvent -from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( CommandResult, MessageChain, MessageEventResult, ) -from astrbot.core.platform.message_session import MessageSession -from astrbot.core.provider.entites import ProviderRequest -from astrbot.core.provider.register import llm_tools -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path -from astrbot.core.utils.history_saver import persist_agent_history -from astrbot.core.utils.image_ref_utils import is_supported_image_ref -from astrbot.core.utils.string_utils import normalize_and_dedupe_strings +from astrbot.core.subagent.background_notifier import ( + wake_main_agent_for_background_result, +) +from astrbot.core.subagent.handoff_executor import HandoffExecutor class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): - @classmethod - def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: - if image_urls_raw is None: - return [] - - if isinstance(image_urls_raw, str): - return [image_urls_raw] - - if isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance( - image_urls_raw, (str, bytes, bytearray) - ): - return [item for item in image_urls_raw if isinstance(item, str)] - - logger.debug( - "Unsupported image_urls type in handoff tool args: %s", - type(image_urls_raw).__name__, - ) - return [] - - @classmethod - async def _collect_image_urls_from_message( - cls, run_context: ContextWrapper[AstrAgentContext] - ) -> list[str]: - urls: list[str] = [] - event = getattr(run_context.context, "event", None) - message_obj = getattr(event, "message_obj", None) - message = getattr(message_obj, "message", None) - if message: - for idx, component in enumerate(message): - if not isinstance(component, Image): - continue - try: - path = await component.convert_to_file_path() - if path: - urls.append(path) - except Exception as e: - logger.error( - "Failed to convert handoff image component at index %d: %s", - idx, - e, - exc_info=True, - ) - return urls - - @classmethod - async def _collect_handoff_image_urls( - cls, - run_context: ContextWrapper[AstrAgentContext], - image_urls_raw: T.Any, - ) -> list[str]: - candidates: list[str] = [] - candidates.extend(cls._collect_image_urls_from_args(image_urls_raw)) - candidates.extend(await cls._collect_image_urls_from_message(run_context)) - - normalized = normalize_and_dedupe_strings(candidates) - extensionless_local_roots = (get_astrbot_temp_path(),) - sanitized = [ - item - for item in normalized - if is_supported_image_ref( - item, - allow_extensionless_existing_local_file=True, - extensionless_local_roots=extensionless_local_roots, - ) - ] - dropped_count = len(normalized) - len(sanitized) - if dropped_count > 0: - logger.debug( - "Dropped %d invalid image_urls entries in handoff image inputs.", - dropped_count, - ) - return sanitized - @classmethod async def execute(cls, tool, run_context, **tool_args): """执行函数调用。 @@ -129,15 +37,21 @@ async def execute(cls, tool, run_context, **tool_args): AsyncGenerator[None | mcp.types.CallToolResult, None] """ + tool_call_id = tool_args.pop("tool_call_id", None) if isinstance(tool, HandoffTool): is_bg = tool_args.pop("background_task", False) if is_bg: - async for r in cls._execute_handoff_background( - tool, run_context, **tool_args + async for r in HandoffExecutor.submit_background( + tool, + run_context, + tool_call_id=tool_call_id, + **tool_args, ): yield r return - async for r in cls._execute_handoff(tool, run_context, **tool_args): + async for r in HandoffExecutor.execute_foreground( + tool, run_context, **tool_args + ): yield r return @@ -176,229 +90,6 @@ async def _run_in_background() -> None: yield r return - @classmethod - def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]: - if runtime == "sandbox": - return { - EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL, - PYTHON_TOOL.name: PYTHON_TOOL, - FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL, - FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL, - } - if runtime == "local": - return { - LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL, - LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL, - } - return {} - - @classmethod - def _build_handoff_toolset( - cls, - run_context: ContextWrapper[AstrAgentContext], - tools: list[str | FunctionTool] | None, - ) -> ToolSet | None: - ctx = run_context.context.context - event = run_context.context.event - cfg = ctx.get_config(umo=event.unified_msg_origin) - provider_settings = cfg.get("provider_settings", {}) - runtime = str(provider_settings.get("computer_use_runtime", "local")) - runtime_computer_tools = cls._get_runtime_computer_tools(runtime) - - # Keep persona semantics aligned with the main agent: tools=None means - # "all tools", including runtime computer-use tools. - if tools is None: - toolset = ToolSet() - for registered_tool in llm_tools.func_list: - if isinstance(registered_tool, HandoffTool): - continue - if registered_tool.active: - toolset.add_tool(registered_tool) - for runtime_tool in runtime_computer_tools.values(): - toolset.add_tool(runtime_tool) - return None if toolset.empty() else toolset - - if not tools: - return None - - toolset = ToolSet() - for tool_name_or_obj in tools: - if isinstance(tool_name_or_obj, str): - registered_tool = llm_tools.get_func(tool_name_or_obj) - if registered_tool and registered_tool.active: - toolset.add_tool(registered_tool) - continue - runtime_tool = runtime_computer_tools.get(tool_name_or_obj) - if runtime_tool: - toolset.add_tool(runtime_tool) - elif isinstance(tool_name_or_obj, FunctionTool): - toolset.add_tool(tool_name_or_obj) - return None if toolset.empty() else toolset - - @classmethod - async def _execute_handoff( - cls, - tool: HandoffTool, - run_context: ContextWrapper[AstrAgentContext], - *, - image_urls_prepared: bool = False, - **tool_args: T.Any, - ): - tool_args = dict(tool_args) - input_ = tool_args.get("input") - if image_urls_prepared: - prepared_image_urls = tool_args.get("image_urls") - if isinstance(prepared_image_urls, list): - image_urls = prepared_image_urls - else: - logger.debug( - "Expected prepared handoff image_urls as list[str], got %s.", - type(prepared_image_urls).__name__, - ) - image_urls = [] - else: - image_urls = await cls._collect_handoff_image_urls( - run_context, - tool_args.get("image_urls"), - ) - tool_args["image_urls"] = image_urls - - # Build handoff toolset from registered tools plus runtime computer tools. - toolset = cls._build_handoff_toolset(run_context, tool.agent.tools) - - ctx = run_context.context.context - event = run_context.context.event - umo = event.unified_msg_origin - - # Use per-subagent provider override if configured; otherwise fall back - # to the current/default provider resolution. - prov_id = getattr( - tool, "provider_id", None - ) or await ctx.get_current_chat_provider_id(umo) - - # prepare begin dialogs - contexts = None - dialogs = tool.agent.begin_dialogs - if dialogs: - contexts = [] - for dialog in dialogs: - try: - contexts.append( - dialog - if isinstance(dialog, Message) - else Message.model_validate(dialog) - ) - except Exception: - continue - - prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {}) - agent_max_step = int(prov_settings.get("max_agent_step", 30)) - stream = prov_settings.get("streaming_response", False) - llm_resp = await ctx.tool_loop_agent( - event=event, - chat_provider_id=prov_id, - prompt=input_, - image_urls=image_urls, - system_prompt=tool.agent.instructions, - tools=toolset, - contexts=contexts, - max_steps=agent_max_step, - stream=stream, - ) - yield mcp.types.CallToolResult( - content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] - ) - - @classmethod - async def _execute_handoff_background( - cls, - tool: HandoffTool, - run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): - """Execute a handoff as a background task. - - Immediately yields a success response with a task_id, then runs - the subagent asynchronously. When the subagent finishes, a - ``CronMessageEvent`` is created so the main LLM can inform the - user of the result – the same pattern used by - ``_execute_background`` for regular background tasks. - """ - task_id = uuid.uuid4().hex - - async def _run_handoff_in_background() -> None: - try: - await cls._do_handoff_background( - tool=tool, - run_context=run_context, - task_id=task_id, - **tool_args, - ) - except Exception as e: # noqa: BLE001 - logger.error( - f"Background handoff {task_id} ({tool.name}) failed: {e!s}", - exc_info=True, - ) - - asyncio.create_task(_run_handoff_in_background()) - - text_content = mcp.types.TextContent( - type="text", - text=( - f"Background task dedicated to subagent '{tool.agent.name}' submitted. task_id={task_id}. " - f"The subagent '{tool.agent.name}' is working on the task on hehalf you. " - f"You will be notified when it finishes." - ), - ) - yield mcp.types.CallToolResult(content=[text_content]) - - @classmethod - async def _do_handoff_background( - cls, - tool: HandoffTool, - run_context: ContextWrapper[AstrAgentContext], - task_id: str, - **tool_args, - ) -> None: - """Run the subagent handoff and, on completion, wake the main agent.""" - result_text = "" - tool_args = dict(tool_args) - tool_args["image_urls"] = await cls._collect_handoff_image_urls( - run_context, - tool_args.get("image_urls"), - ) - try: - async for r in cls._execute_handoff( - tool, - run_context, - image_urls_prepared=True, - **tool_args, - ): - if isinstance(r, mcp.types.CallToolResult): - for content in r.content: - if isinstance(content, mcp.types.TextContent): - result_text += content.text + "\n" - except Exception as e: - result_text = ( - f"error: Background task execution failed, internal error: {e!s}" - ) - - event = run_context.context.event - - await cls._wake_main_agent_for_background_result( - run_context=run_context, - task_id=task_id, - tool_name=tool.name, - result_text=result_text, - tool_args=tool_args, - note=( - event.get_extra("background_note") - or f"Background task for subagent '{tool.agent.name}' finished." - ), - summary_name=f"Dedicated to subagent `{tool.agent.name}`", - extra_result_fields={"subagent_name": tool.agent.name}, - ) - @classmethod async def _execute_background( cls, @@ -426,7 +117,7 @@ async def _execute_background( event = run_context.context.event - await cls._wake_main_agent_for_background_result( + await wake_main_agent_for_background_result( run_context=run_context, task_id=task_id, tool_name=tool.name, @@ -439,115 +130,6 @@ async def _execute_background( summary_name=tool.name, ) - @classmethod - async def _wake_main_agent_for_background_result( - cls, - run_context: ContextWrapper[AstrAgentContext], - *, - task_id: str, - tool_name: str, - result_text: str, - tool_args: dict[str, T.Any], - note: str, - summary_name: str, - extra_result_fields: dict[str, T.Any] | None = None, - ) -> None: - from astrbot.core.astr_main_agent import ( - MainAgentBuildConfig, - _get_session_conv, - build_main_agent, - ) - - event = run_context.context.event - ctx = run_context.context.context - - task_result = { - "task_id": task_id, - "tool_name": tool_name, - "result": result_text or "", - "tool_args": tool_args, - } - if extra_result_fields: - task_result.update(extra_result_fields) - extras = {"background_task_result": task_result} - - session = MessageSession.from_str(event.unified_msg_origin) - cron_event = CronMessageEvent( - context=ctx, - session=session, - message=note, - extras=extras, - message_type=session.message_type, - ) - cron_event.role = event.role - config = MainAgentBuildConfig( - tool_call_timeout=3600, - streaming_response=ctx.get_config() - .get("provider_settings", {}) - .get("stream", False), - ) - - req = ProviderRequest() - conv = await _get_session_conv(event=cron_event, plugin_context=ctx) - req.conversation = conv - context = json.loads(conv.history) - if context: - req.contexts = context - context_dump = req._print_friendly_context() - req.contexts = [] - req.system_prompt += ( - "\n\nBellow is you and user previous conversation history:\n" - f"{context_dump}" - ) - - bg = json.dumps(extras["background_task_result"], ensure_ascii=False) - req.system_prompt += BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT.format( - background_task_result=bg - ) - req.prompt = ( - "Proceed according to your system instructions. " - "Output using same language as previous conversation. " - "If you need to deliver the result to the user immediately, " - "you MUST use `send_message_to_user` tool to send the message directly to the user, " - "otherwise the user will not see the result. " - "After completing your task, summarize and output your actions and results. " - ) - if not req.func_tool: - req.func_tool = ToolSet() - req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) - - result = await build_main_agent( - event=cron_event, plugin_context=ctx, config=config, req=req - ) - if not result: - logger.error(f"Failed to build main agent for background task {tool_name}.") - return - - runner = result.agent_runner - async for _ in runner.step_until_done(30): - # agent will send message to user via using tools - pass - llm_resp = runner.get_final_llm_resp() - task_meta = extras.get("background_task_result", {}) - summary_note = ( - f"[BackgroundTask] {summary_name} " - f"(task_id={task_meta.get('task_id', task_id)}) finished. " - f"Result: {task_meta.get('result') or result_text or 'no content'}" - ) - if llm_resp and llm_resp.completion_text: - summary_note += ( - f"I finished the task, here is the result: {llm_resp.completion_text}" - ) - await persist_agent_history( - ctx.conversation_manager, - event=cron_event, - req=req, - summary_note=summary_note, - ) - if not llm_resp: - logger.warning("background task agent got no response") - return - @classmethod async def _execute_local( cls, diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index f18b49a43..4e7ae1c39 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -11,7 +11,6 @@ from dataclasses import dataclass, field from astrbot.core import logger -from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.message import TextPart from astrbot.core.agent.tool import ToolSet @@ -139,6 +138,63 @@ class MainAgentBuildConfig: max_quoted_fallback_images: int = 20 """Maximum number of images injected from quoted-message fallback extraction.""" + @classmethod + def from_provider_settings( + cls, + provider_settings: dict, + cfg: dict | None = None, + **overrides, + ) -> MainAgentBuildConfig: + """Build config from a ``provider_settings`` dict and an optional + top-level config dict *cfg*. + + Any extra keyword argument overrides the corresponding field so that + call-sites can still customise individual values (e.g. a different + ``tool_call_timeout`` for background tasks). + """ + cfg = cfg or {} + ps = provider_settings + proactive_cfg = ps.get("proactive_capability", {}) + file_extract_cfg = ps.get("file_extract", {}) + + defaults: dict = { + "tool_call_timeout": int(ps.get("tool_call_timeout", 60)), + "tool_schema_mode": str(ps.get("tool_schema_mode", "full")), + "streaming_response": bool( + ps.get("streaming_response", ps.get("stream", False)) + ), + "sanitize_context_by_modalities": bool( + ps.get("sanitize_context_by_modalities", False) + ), + "kb_agentic_mode": bool(cfg.get("kb_agentic_mode", False)), + "file_extract_enabled": bool(file_extract_cfg.get("enable", False)), + "file_extract_prov": str(file_extract_cfg.get("provider", "moonshotai")), + "file_extract_msh_api_key": str( + file_extract_cfg.get("moonshotai_api_key", "") + ), + "context_limit_reached_strategy": str( + ps.get("context_limit_reached_strategy", "truncate_by_turns") + ), + "llm_compress_instruction": str(ps.get("llm_compress_instruction", "")), + "llm_compress_keep_recent": int(ps.get("llm_compress_keep_recent", 6)), + "llm_compress_provider_id": str(ps.get("llm_compress_provider_id", "")), + "max_context_length": int(ps.get("max_context_length", -1)), + "dequeue_context_length": int(ps.get("dequeue_context_length", 1)), + "llm_safety_mode": bool(ps.get("llm_safety_mode", True)), + "safety_mode_strategy": str( + ps.get("safety_mode_strategy", "system_prompt") + ), + "computer_use_runtime": str(ps.get("computer_use_runtime", "local")), + "sandbox_cfg": ps.get("sandbox", {}) or {}, + "add_cron_tools": bool(proactive_cfg.get("add_cron_tools", True)), + "provider_settings": ps, + "subagent_orchestrator": cfg.get("subagent_orchestrator", {}) or {}, + "timezone": cfg.get("timezone"), + "max_quoted_fallback_images": int(ps.get("max_quoted_fallback_images", 20)), + } + defaults.update(overrides) + return cls(**defaults) + @dataclass(slots=True) class MainAgentBuildResult: @@ -307,6 +363,7 @@ async def _ensure_persona_and_skills( """Ensure persona and skills are applied to the request's system prompt or user prompt.""" if not req.conversation: return + req.system_prompt = req.system_prompt or "" ( persona_id, @@ -373,72 +430,30 @@ async def _ensure_persona_and_skills( else: req.func_tool.merge(persona_toolset) - # sub agents integration - orch_cfg = plugin_context.get_config().get("subagent_orchestrator", {}) + # sub agents integration (deterministic mount plan from orchestrator) so = plugin_context.subagent_orchestrator - if orch_cfg.get("main_enable", False) and so: - remove_dup = bool(orch_cfg.get("remove_main_duplicate_tools", False)) - - assigned_tools: set[str] = set() - agents = orch_cfg.get("agents", []) - if isinstance(agents, list): - for a in agents: - if not isinstance(a, dict): - continue - if a.get("enabled", True) is False: - continue - persona_tools = None - pid = a.get("persona_id") - if pid: - persona_tools = next( - ( - p.get("tools") - for p in plugin_context.persona_manager.personas_v3 - if p["name"] == pid - ), - None, - ) - tools = a.get("tools", []) - if persona_tools is not None: - tools = persona_tools - if tools is None: - assigned_tools.update( - [ - tool.name - for tool in tmgr.func_list - if not isinstance(tool, HandoffTool) - ] - ) - continue - if not isinstance(tools, list): - continue - for t in tools: - name = str(t).strip() - if name: - assigned_tools.add(name) - - if req.func_tool is None: - req.func_tool = ToolSet() - - # add subagent handoff tools - for tool in so.handoffs: - req.func_tool.add_tool(tool) - - # check duplicates - if remove_dup: - handoff_names = {tool.name for tool in so.handoffs} - for tool_name in assigned_tools: + if so: + plan = so.get_mount_plan() + so_cfg = so.get_config() + cfg_main_enable = getattr(so_cfg, "main_enable", False) is True + if plan and cfg_main_enable: + if req.func_tool is None: + req.func_tool = ToolSet() + + for tool in plan.handoffs: + req.func_tool.add_tool(tool) + + handoff_names = {tool.name for tool in plan.handoffs} + for tool_name in plan.main_tool_exclude_set: if tool_name in handoff_names: continue req.func_tool.remove_tool(tool_name) - router_prompt = ( - plugin_context.get_config() - .get("subagent_orchestrator", {}) - .get("router_system_prompt", "") - ).strip() - if router_prompt: - req.system_prompt += f"\n{router_prompt}\n" + if plan.router_prompt: + req.system_prompt += f"\n{plan.router_prompt}\n" + + for diagnostic in plan.diagnostics: + logger.warning("Subagent plan diagnostic: %s", diagnostic) try: event.trace.record( "sel_persona", @@ -846,8 +861,7 @@ def _apply_sandbox_tools( ) -> None: if req.func_tool is None: req.func_tool = ToolSet() - if req.system_prompt is None: - req.system_prompt = "" + system_prompt = req.system_prompt or "" booter = config.sandbox_cfg.get("booter", "shipyard_neo") if booter == "shipyard": ep = config.sandbox_cfg.get("shipyard_endpoint", "") @@ -865,14 +879,14 @@ def _apply_sandbox_tools( if booter == "shipyard_neo": # Neo-specific path rule: filesystem tools operate relative to sandbox # workspace root. Do not prepend "/workspace". - req.system_prompt += ( + system_prompt += ( "\n[Shipyard Neo File Path Rule]\n" "When using sandbox filesystem tools (upload/download/read/write/list/delete), " "always pass paths relative to the sandbox workspace root. " "Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n" ) - req.system_prompt += ( + system_prompt += ( "\n[Neo Skill Lifecycle Workflow]\n" "When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n" "Preferred sequence:\n" @@ -914,7 +928,7 @@ def _apply_sandbox_tools( req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL) req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL) - req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n" + req.system_prompt = f"{system_prompt}\n{SANDBOX_MODE_PROMPT}\n" def _proactive_cron_job_tools(req: ProviderRequest) -> None: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index cbadb5c18..ed852927d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -151,6 +151,19 @@ "subagent_orchestrator": { "main_enable": False, "remove_main_duplicate_tools": False, + "max_concurrent_subagent_runs": 8, + "max_nested_depth": 2, + "error_classifier": { + "type": "default", + "fatal_exceptions": ["ValueError", "PermissionError", "KeyError"], + "transient_exceptions": [ + "asyncio.TimeoutError", + "TimeoutError", + "ConnectionError", + "ConnectionResetError", + ], + "default_class": "transient", + }, "router_system_prompt": ( "You are a task router. Your job is to chat naturally, recognize user intent, " "and delegate work to the most suitable subagent using transfer_to_* tools. " diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe6b1c351..a5e5f1bef 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -57,6 +57,7 @@ def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: self.db = db # 初始化数据库 self.subagent_orchestrator: SubAgentOrchestrator | None = None + self._subagent_worker_started = False self.cron_manager: CronJobManager | None = None self.temp_dir_cleaner: TempDirCleaner | None = None @@ -91,9 +92,13 @@ async def _init_or_reload_subagent_orchestrator(self) -> None: self.provider_manager.llm_tools, self.persona_mgr, ) - await self.subagent_orchestrator.reload_from_config( + diagnostics = await self.subagent_orchestrator.reload_from_config( self.astrbot_config.get("subagent_orchestrator", {}), ) + if not isinstance(diagnostics, list): + diagnostics = [] + for diagnostic in diagnostics: + logger.warning("Subagent diagnostic: %s", diagnostic) except Exception as e: logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True) @@ -193,6 +198,8 @@ async def initialize(self) -> None: self.cron_manager, self.subagent_orchestrator, ) + if self.subagent_orchestrator: + self.subagent_orchestrator.bind_context(self.star_context) # 初始化插件管理器 self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) @@ -263,6 +270,11 @@ def _load(self) -> None: tasks_.append(cron_task) if temp_dir_cleaner_task: tasks_.append(temp_dir_cleaner_task) + if self.subagent_orchestrator and not self._subagent_worker_started: + worker_task = self.subagent_orchestrator.start_worker() + if isinstance(worker_task, asyncio.Task): + tasks_.append(worker_task) + self._subagent_worker_started = True for task in tasks_: self.curr_tasks.append( asyncio.create_task(self._task_wrapper(task), name=task.get_name()), @@ -314,6 +326,12 @@ async def start(self) -> None: async def stop(self) -> None: """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器.""" + if self.subagent_orchestrator and self._subagent_worker_started: + stop_result = self.subagent_orchestrator.stop_worker() + if asyncio.iscoroutine(stop_result): + await stop_result + self._subagent_worker_started = False + if self.temp_dir_cleaner: await self.temp_dir_cleaner.stop() diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 166f770a5..c6753acfa 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -23,6 +23,7 @@ Preference, SessionProjectRelation, Stats, + SubagentTask, ) @@ -624,6 +625,87 @@ async def list_cron_jobs(self, job_type: str | None = None) -> list[CronJob]: """List cron jobs, optionally filtered by job_type.""" ... + # ==== + # Subagent Background Task Management + # ==== + + @abc.abstractmethod + async def create_subagent_task( + self, + *, + task_id: str, + idempotency_key: str, + umo: str, + subagent_name: str, + handoff_tool_name: str, + payload_json: str, + max_attempts: int = 3, + ) -> SubagentTask: ... + + @abc.abstractmethod + async def get_subagent_task_by_idempotency( + self, idempotency_key: str + ) -> SubagentTask | None: ... + + @abc.abstractmethod + async def claim_due_subagent_tasks( + self, + *, + now: datetime.datetime, + limit: int = 20, + ) -> list[SubagentTask]: ... + + @abc.abstractmethod + async def mark_subagent_task_running(self, task_id: str) -> SubagentTask | None: ... + + @abc.abstractmethod + async def mark_subagent_task_retrying( + self, + *, + task_id: str, + next_run_at: datetime.datetime, + error_class: str, + last_error: str, + ) -> bool: ... + + @abc.abstractmethod + async def reschedule_subagent_task( + self, + *, + task_id: str, + next_run_at: datetime.datetime, + error_class: str, + last_error: str, + ) -> bool: ... + + @abc.abstractmethod + async def mark_subagent_task_succeeded( + self, + task_id: str, + *, + result_text: str, + ) -> bool: ... + + @abc.abstractmethod + async def mark_subagent_task_failed( + self, + *, + task_id: str, + error_class: str, + last_error: str, + ) -> bool: ... + + @abc.abstractmethod + async def cancel_subagent_task(self, task_id: str) -> bool: ... + + @abc.abstractmethod + async def list_subagent_tasks( + self, + *, + status: str | None = None, + limit: int = 100, + ) -> list[SubagentTask]: ... + # ==== # Platform Session Management # ==== diff --git a/astrbot/core/db/migration/migra_subagent_tasks.py b/astrbot/core/db/migration/migra_subagent_tasks.py new file mode 100644 index 000000000..5e9ef0a17 --- /dev/null +++ b/astrbot/core/db/migration/migra_subagent_tasks.py @@ -0,0 +1,50 @@ +"""Migration script for subagent background task table.""" + +from sqlalchemy import text + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import SQLModel + + +async def migrate_subagent_tasks(db_helper: BaseDatabase) -> None: + table_marker = "migration_done_subagent_tasks_1" + index_marker = "migration_done_subagent_tasks_index_2" + table_migration_done = await db_helper.get_preference( + "global", "global", table_marker + ) + index_migration_done = await db_helper.get_preference( + "global", "global", index_marker + ) + if table_migration_done and index_migration_done: + return + + logger.info("Start migration for subagent_tasks table...") + try: + async with db_helper.engine.begin() as conn: + if not table_migration_done: + await conn.run_sync(SQLModel.metadata.create_all) + result = await conn.execute( + text( + "SELECT name FROM sqlite_master WHERE type='table' AND name='subagent_tasks'" + ) + ) + if not result.fetchone(): + raise RuntimeError("subagent_tasks table was not created") + + if not index_migration_done: + await conn.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_subagent_tasks_status_next_run_created " + "ON subagent_tasks(status, next_run_at, created_at)" + ) + ) + + if not table_migration_done: + await sp.put_async("global", "global", table_marker, True) + if not index_migration_done: + await sp.put_async("global", "global", index_marker, True) + logger.info("subagent_tasks migration completed.") + except Exception as exc: + logger.error("Migration for subagent_tasks failed: %s", exc, exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 451f054f6..17ce79685 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import TypedDict +from sqlalchemy import Index from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint @@ -172,6 +173,37 @@ class CronJob(TimestampMixin, SQLModel, table=True): last_error: str | None = Field(default=None, sa_type=Text) +class SubagentTask(TimestampMixin, SQLModel, table=True): + """Persistent subagent background task.""" + + __tablename__: str = "subagent_tasks" + __table_args__ = ( + Index( + "idx_subagent_tasks_status_next_run_created", + "status", + "next_run_at", + "created_at", + ), + ) + + task_id: str = Field(primary_key=True, max_length=64) + idempotency_key: str = Field( + nullable=False, max_length=128, unique=True, index=True + ) + umo: str = Field(nullable=False, max_length=255, index=True) + subagent_name: str = Field(nullable=False, max_length=255) + handoff_tool_name: str = Field(nullable=False, max_length=255) + status: str = Field(default="pending", nullable=False, max_length=32) + attempt: int = Field(default=0, nullable=False) + max_attempts: int = Field(default=3, nullable=False) + next_run_at: datetime | None = Field(default=None, nullable=True) + error_class: str | None = Field(default=None, max_length=32) + last_error: str | None = Field(default=None, sa_type=Text) + payload_json: str = Field(default="{}", nullable=False, sa_type=Text) + result_text: str | None = Field(default=None, sa_type=Text) + finished_at: datetime | None = Field(default=None, nullable=True) + + class Preference(TimestampMixin, SQLModel, table=True): """This class represents preferences for bots.""" diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index f496e19d5..ce088ea5e 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -25,6 +25,7 @@ Preference, SessionProjectRelation, SQLModel, + SubagentTask, ) from astrbot.core.db.po import ( Platform as DeprecatedPlatformStat, @@ -1851,3 +1852,256 @@ async def list_cron_jobs(self, job_type: str | None = None) -> list[CronJob]: query = query.order_by(desc(CronJob.created_at)) result = await session.execute(query) return list(result.scalars().all()) + + # ==== + # Subagent Background Task Management + # ==== + + async def create_subagent_task( + self, + *, + task_id: str, + idempotency_key: str, + umo: str, + subagent_name: str, + handoff_tool_name: str, + payload_json: str, + max_attempts: int = 3, + ) -> SubagentTask: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + task = SubagentTask( + task_id=task_id, + idempotency_key=idempotency_key, + umo=umo, + subagent_name=subagent_name, + handoff_tool_name=handoff_tool_name, + payload_json=payload_json, + max_attempts=max(1, int(max_attempts)), + status="pending", + next_run_at=now, + created_at=now, + updated_at=now, + ) + session.add(task) + await session.flush() + await session.refresh(task) + return task + + async def get_subagent_task_by_idempotency( + self, idempotency_key: str + ) -> SubagentTask | None: + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(SubagentTask).where( + col(SubagentTask.idempotency_key) == idempotency_key + ) + ) + return result.scalar_one_or_none() + + async def claim_due_subagent_tasks( + self, + *, + now: datetime, + limit: int = 20, + ) -> list[SubagentTask]: + async with self.get_db() as session: + session: AsyncSession + query = ( + select(SubagentTask) + .where( + col(SubagentTask.status).in_(["pending", "retrying"]), + or_( + col(SubagentTask.next_run_at).is_(None), + col(SubagentTask.next_run_at) <= now, + ), + ) + .order_by(col(SubagentTask.created_at)) + .limit(max(1, int(limit))) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def mark_subagent_task_running(self, task_id: str) -> SubagentTask | None: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + stmt = ( + update(SubagentTask) + .where( + col(SubagentTask.task_id) == task_id, + col(SubagentTask.status).in_(["pending", "retrying"]), + ) + .values( + status="running", + attempt=col(SubagentTask.attempt) + 1, + updated_at=now, + ) + ) + res = await session.execute(stmt) + if res.rowcount == 0: + return None + result = await session.execute( + select(SubagentTask).where(col(SubagentTask.task_id) == task_id) + ) + return result.scalar_one_or_none() + + async def mark_subagent_task_retrying( + self, + *, + task_id: str, + next_run_at: datetime, + error_class: str, + last_error: str, + ) -> bool: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + stmt = ( + update(SubagentTask) + .where( + col(SubagentTask.task_id) == task_id, + col(SubagentTask.status) == "running", + ) + .values( + status="retrying", + next_run_at=next_run_at, + error_class=error_class, + last_error=last_error, + updated_at=now, + ) + ) + res = await session.execute(stmt) + return bool(res.rowcount) + + async def reschedule_subagent_task( + self, + *, + task_id: str, + next_run_at: datetime, + error_class: str, + last_error: str, + ) -> bool: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + stmt = ( + update(SubagentTask) + .where( + col(SubagentTask.task_id) == task_id, + col(SubagentTask.status).in_( + ["failed", "canceled", "succeeded", "pending", "retrying"] + ), + ) + .values( + status="retrying", + attempt=0, + next_run_at=next_run_at, + error_class=error_class, + last_error=last_error, + result_text=None, + finished_at=None, + updated_at=now, + ) + ) + res = await session.execute(stmt) + return bool(res.rowcount) + + async def mark_subagent_task_succeeded( + self, + task_id: str, + *, + result_text: str, + ) -> bool: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + stmt = ( + update(SubagentTask) + .where( + col(SubagentTask.task_id) == task_id, + col(SubagentTask.status) == "running", + ) + .values( + status="succeeded", + result_text=result_text, + finished_at=now, + updated_at=now, + ) + ) + res = await session.execute(stmt) + return bool(res.rowcount) + + async def mark_subagent_task_failed( + self, + *, + task_id: str, + error_class: str, + last_error: str, + ) -> bool: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + stmt = ( + update(SubagentTask) + .where( + col(SubagentTask.task_id) == task_id, + col(SubagentTask.status) == "running", + ) + .values( + status="failed", + error_class=error_class, + last_error=last_error, + finished_at=now, + updated_at=now, + ) + ) + res = await session.execute(stmt) + return bool(res.rowcount) + + async def cancel_subagent_task(self, task_id: str) -> bool: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now(timezone.utc) + stmt = ( + update(SubagentTask) + .where( + col(SubagentTask.task_id) == task_id, + col(SubagentTask.status).in_( + ["pending", "retrying", "running"] + ), + ) + .values( + status="canceled", + finished_at=now, + updated_at=now, + ) + ) + res = await session.execute(stmt) + return bool(res.rowcount) + + async def list_subagent_tasks( + self, + *, + status: str | None = None, + limit: int = 100, + ) -> list[SubagentTask]: + async with self.get_db() as session: + session: AsyncSession + query = select(SubagentTask) + if status: + query = query.where(col(SubagentTask.status) == status) + query = query.order_by(desc(SubagentTask.created_at)).limit( + max(1, int(limit)) + ) + result = await session.execute(query) + return list(result.scalars().all()) diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index d15876526..a16cffbb3 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -85,12 +85,22 @@ def build_skills_prompt(skills: list[SkillInfo]) -> str: skills_lines: list[str] = [] example_path = "" for skill in skills: - description = skill.description or "No description" - skills_lines.append( - f"- **{skill.name}**: {description}\n File: `{skill.path}`" + name = str(getattr(skill, "name", "") or "").strip() or "unknown-skill" + # 验证 name 格式(防御性编程,防止注入) + if not _SKILL_NAME_RE.match(name): + name = "unknown-skill" + description = str(getattr(skill, "description", "") or "").strip() + # 清理换行符,防止 Indirect Prompt Injection + description = ( + (description or "No description").replace("\n", " ").replace("\r", " ") ) + path = str(getattr(skill, "path", "") or "").strip() + path = path or "//SKILL.md" + # 清理路径中的危险字符 + path = _SAFE_PATH_RE.sub("", path) + skills_lines.append(f"- **{name}**: {description}\n File: `{path}`") if not example_path: - example_path = skill.path + example_path = path skills_block = "\n".join(skills_lines) # Sanitize example_path — it may originate from sandbox cache (untrusted) example_path = _SAFE_PATH_RE.sub("", example_path) if example_path else "" diff --git a/astrbot/core/subagent/__init__.py b/astrbot/core/subagent/__init__.py new file mode 100644 index 000000000..290ca48f5 --- /dev/null +++ b/astrbot/core/subagent/__init__.py @@ -0,0 +1,30 @@ +from .codec import decode_subagent_config, encode_subagent_config +from .error_classifier import DefaultErrorClassifier, ErrorClassifier +from .hooks import NoopSubagentHooks, SubagentHooks +from .models import ( + SubagentAgentSpec, + SubagentConfig, + SubagentErrorClassifierConfig, + SubagentMountPlan, + SubagentTaskData, + SubagentTaskStatus, + ToolsScope, + build_safe_handoff_agent_name, +) + +__all__ = [ + "ToolsScope", + "SubagentAgentSpec", + "SubagentConfig", + "SubagentErrorClassifierConfig", + "SubagentMountPlan", + "SubagentTaskStatus", + "SubagentTaskData", + "SubagentHooks", + "NoopSubagentHooks", + "ErrorClassifier", + "DefaultErrorClassifier", + "decode_subagent_config", + "encode_subagent_config", + "build_safe_handoff_agent_name", +] diff --git a/astrbot/core/subagent/background_notifier.py b/astrbot/core/subagent/background_notifier.py new file mode 100644 index 000000000..48c488bb6 --- /dev/null +++ b/astrbot/core/subagent/background_notifier.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import json +import typing as T + +from astrbot import logger +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolSet +from astrbot.core.cron.events import CronMessageEvent +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.utils.history_saver import persist_agent_history + +if T.TYPE_CHECKING: + from astrbot.core.astr_agent_context import AstrAgentContext + + +async def wake_main_agent_for_background_result( + run_context: ContextWrapper[AstrAgentContext], + *, + task_id: str, + tool_name: str, + result_text: str, + tool_args: dict[str, T.Any], + note: str, + summary_name: str, + extra_result_fields: dict[str, T.Any] | None = None, +) -> None: + from astrbot.core.astr_main_agent import ( + MainAgentBuildConfig, + _get_session_conv, + build_main_agent, + ) + from astrbot.core.astr_main_agent_resources import ( + BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT, + SEND_MESSAGE_TO_USER_TOOL, + ) + + event = run_context.context.event + ctx = run_context.context.context + + task_result = { + "task_id": task_id, + "tool_name": tool_name, + "result": result_text or "", + "tool_args": tool_args, + } + if extra_result_fields: + task_result.update(extra_result_fields) + extras = {"background_task_result": task_result} + + session = MessageSession.from_str(event.unified_msg_origin) + cron_event = CronMessageEvent( + context=ctx, + session=session, + message=note, + extras=extras, + message_type=session.message_type, + ) + cron_event.role = event.role + cfg = ctx.get_config(umo=event.unified_msg_origin) + provider_settings = cfg.get("provider_settings", {}) + config = MainAgentBuildConfig.from_provider_settings( + provider_settings, + cfg=cfg, + # Background tasks use a longer timeout and disable local computer use + # by default – these overrides preserve the original behaviour. + tool_call_timeout=int(provider_settings.get("tool_call_timeout", 3600)), + computer_use_runtime=str(provider_settings.get("computer_use_runtime", "none")), + ) + + req = ProviderRequest() + conv = await _get_session_conv(event=cron_event, plugin_context=ctx) + req.conversation = conv + context = json.loads(conv.history) + if context: + req.contexts = context + context_dump = req._print_friendly_context() + req.contexts = [] + req.system_prompt += ( + f"\n\nBellow is you and user previous conversation history:\n{context_dump}" + ) + + bg = json.dumps(extras["background_task_result"], ensure_ascii=False) + req.system_prompt += BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT.format( + background_task_result=bg + ) + req.prompt = ( + "Proceed according to your system instructions. " + "Output using same language as previous conversation. " + "If you need to deliver the result to the user immediately, " + "you MUST use `send_message_to_user` tool to send the message directly to the user, " + "otherwise the user will not see the result. " + "After completing your task, summarize and output your actions and results. " + ) + if not req.func_tool: + req.func_tool = ToolSet() + req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) + + result = await build_main_agent( + event=cron_event, plugin_context=ctx, config=config, req=req + ) + if not result: + logger.error("Failed to build main agent for background task %s.", tool_name) + return + + runner = result.agent_runner + async for _ in runner.step_until_done(30): + pass + llm_resp = runner.get_final_llm_resp() + task_meta = extras.get("background_task_result", {}) + summary_note = ( + f"[BackgroundTask] {summary_name} " + f"(task_id={task_meta.get('task_id', task_id)}) finished. " + f"Result: {task_meta.get('result') or result_text or 'no content'}" + ) + if llm_resp and llm_resp.completion_text: + summary_note += ( + f"I finished the task, here is the result: {llm_resp.completion_text}" + ) + await persist_agent_history( + ctx.conversation_manager, + event=cron_event, + req=req, + summary_note=summary_note, + ) + if not llm_resp: + logger.warning("background task agent got no response") diff --git a/astrbot/core/subagent/codec.py b/astrbot/core/subagent/codec.py new file mode 100644 index 000000000..4fe241ec1 --- /dev/null +++ b/astrbot/core/subagent/codec.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from typing import Any, Literal, cast + +from .constants import ( + DEFAULT_FATAL_EXCEPTION_NAMES, + DEFAULT_TRANSIENT_EXCEPTION_NAMES, +) +from .models import ( + SubagentAgentSpec, + SubagentConfig, + SubagentErrorClassifierConfig, + ToolsScope, +) + +_DEFAULT_CLASS_ALLOWED = {"fatal", "transient", "retryable"} +_DefaultClassLiteral = Literal["fatal", "transient", "retryable"] + +_CONFIG_KEYS = { + "main_enable", + "enable", + "remove_main_duplicate_tools", + "router_system_prompt", + "agents", + "max_concurrent_subagent_runs", + "max_nested_depth", + "error_classifier", + "diagnostics", + "compat_warnings", +} +_AGENT_KEYS = { + "name", + "enabled", + "enable", + "persona_id", + "provider_id", + "public_description", + "tools_scope", + "tools", + "instructions", + "system_prompt", + "max_steps", +} + + +def _parse_bool(value: Any, *, field_name: str) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, int) and value in {0, 1}: + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False + raise ValueError(f"`{field_name}` must be a boolean value") + + +def _validate_no_unknown_keys(data: dict[str, Any], allowed: set[str]) -> None: + unknown = [k for k in data.keys() if k not in allowed and not k.startswith("x-")] + if unknown: + keys = ", ".join(sorted(unknown)) + raise ValueError(f"Unknown subagent config fields: {keys}") + + +def _infer_tools_scope(item: dict[str, Any]) -> ToolsScope: + scope_raw = item.get("tools_scope") + if scope_raw: + return ToolsScope(str(scope_raw)) + + tools = item.get("tools") + persona_id = str(item.get("persona_id") or "").strip() + if isinstance(tools, list): + return ToolsScope.NONE if len(tools) == 0 else ToolsScope.LIST + if tools is None and persona_id: + return ToolsScope.PERSONA + if tools is None: + return ToolsScope.ALL + raise ValueError("Invalid `tools` type, expected list or null.") + + +def decode_subagent_config(raw: dict[str, Any]) -> tuple[SubagentConfig, list[str]]: + if not isinstance(raw, dict): + raise ValueError("subagent config must be a JSON object") + + _validate_no_unknown_keys(raw, _CONFIG_KEYS) + diagnostics: list[str] = [] + compat_warnings: list[str] = [] + + if "main_enable" in raw: + main_enable = _parse_bool(raw["main_enable"], field_name="main_enable") + elif "enable" in raw: + main_enable = _parse_bool(raw["enable"], field_name="enable") + compat_warnings.append( + "legacy field `enable` is accepted and mapped to `main_enable`." + ) + else: + main_enable = False + + agents_raw = raw.get("agents", []) + if agents_raw is None: + agents_raw = [] + if not isinstance(agents_raw, list): + raise ValueError("`agents` must be a list") + + agents: list[SubagentAgentSpec] = [] + for idx, item in enumerate(agents_raw): + if not isinstance(item, dict): + raise ValueError(f"`agents[{idx}]` must be an object") + _validate_no_unknown_keys(item, _AGENT_KEYS) + + scope = _infer_tools_scope(item) + if "system_prompt" in item and "instructions" not in item: + compat_warnings.append( + f"legacy field `agents[{idx}].system_prompt` is accepted and mapped to `instructions`." + ) + + extensions = {k: v for k, v in item.items() if k.startswith("x-")} + + tools_raw = item.get("tools") + tools: list[str] | None + if scope == ToolsScope.LIST: + if tools_raw is None: + tools = [] + elif isinstance(tools_raw, list): + tools = [str(t).strip() for t in tools_raw if str(t).strip()] + else: + raise ValueError( + f"`agents[{idx}].tools` must be a list when tools_scope=list" + ) + else: + tools = None + + try: + if "enabled" in item: + enabled = _parse_bool( + item["enabled"], field_name=f"agents[{idx}].enabled" + ) + elif "enable" in item: + enabled = _parse_bool( + item["enable"], field_name=f"agents[{idx}].enable" + ) + else: + enabled = True + spec = SubagentAgentSpec( + name=str(item.get("name", "")).strip(), + enabled=enabled, + persona_id=( + str(item.get("persona_id")).strip() + if item.get("persona_id") is not None + else None + ) + or None, + provider_id=( + str(item.get("provider_id")).strip() + if item.get("provider_id") is not None + else None + ) + or None, + public_description=str(item.get("public_description", "")).strip(), + tools_scope=scope, + tools=tools, + instructions=str( + item.get("instructions", item.get("system_prompt", "")) + ).strip(), + max_steps=( + int(item["max_steps"]) + if item.get("max_steps") is not None + else None + ), + extensions=extensions, + ) + except Exception as exc: + raise ValueError(f"invalid subagent at agents[{idx}]: {exc}") from exc + agents.append(spec) + + error_classifier_raw = raw.get("error_classifier", {}) + if error_classifier_raw is None: + error_classifier_raw = {} + if not isinstance(error_classifier_raw, dict): + raise ValueError("`error_classifier` must be an object") + + fatal_exceptions_raw = error_classifier_raw.get("fatal_exceptions") + transient_exceptions_raw = error_classifier_raw.get("transient_exceptions") + if fatal_exceptions_raw is None: + fatal_exceptions_raw = DEFAULT_FATAL_EXCEPTION_NAMES + if transient_exceptions_raw is None: + transient_exceptions_raw = DEFAULT_TRANSIENT_EXCEPTION_NAMES + if not isinstance(fatal_exceptions_raw, list): + raise ValueError("`error_classifier.fatal_exceptions` must be a list") + if not isinstance(transient_exceptions_raw, list): + raise ValueError("`error_classifier.transient_exceptions` must be a list") + + error_classifier = SubagentErrorClassifierConfig( + type=str(error_classifier_raw.get("type", "default")).strip() or "default", + fatal_exceptions=[ + str(item).strip() for item in fatal_exceptions_raw if str(item).strip() + ], + transient_exceptions=[ + str(item).strip() for item in transient_exceptions_raw if str(item).strip() + ], + default_class=cast( + _DefaultClassLiteral, + ( + default_class + if ( + default_class := str( + error_classifier_raw.get("default_class", "transient") + ).strip() + ) + in _DEFAULT_CLASS_ALLOWED + else "transient" + ), + ), + ) + + extensions = {k: v for k, v in raw.items() if k.startswith("x-")} + config = SubagentConfig( + main_enable=main_enable, + remove_main_duplicate_tools=( + _parse_bool( + raw["remove_main_duplicate_tools"], + field_name="remove_main_duplicate_tools", + ) + if "remove_main_duplicate_tools" in raw + else False + ), + router_system_prompt=str(raw.get("router_system_prompt", "")).strip(), + agents=agents, + max_concurrent_subagent_runs=int(raw.get("max_concurrent_subagent_runs", 8)), + max_nested_depth=int(raw.get("max_nested_depth", 2)), + error_classifier=error_classifier, + extensions=extensions, + ) + diagnostics.extend(compat_warnings) + return config, diagnostics + + +def encode_subagent_config( + config: SubagentConfig, + *, + diagnostics: list[str] | None = None, + compat_warnings: list[str] | None = None, +) -> dict[str, Any]: + payload: dict[str, Any] = { + "main_enable": bool(config.main_enable), + "remove_main_duplicate_tools": bool(config.remove_main_duplicate_tools), + "router_system_prompt": config.router_system_prompt or "", + "max_concurrent_subagent_runs": int(config.max_concurrent_subagent_runs), + "max_nested_depth": int(config.max_nested_depth), + "error_classifier": { + "type": str(config.error_classifier.type or "default"), + "fatal_exceptions": list(config.error_classifier.fatal_exceptions), + "transient_exceptions": list(config.error_classifier.transient_exceptions), + "default_class": str(config.error_classifier.default_class), + }, + "agents": [], + } + if config.extensions: + payload.update(config.extensions) + + for spec in config.agents: + if spec.tools_scope == ToolsScope.LIST: + tools = list(spec.tools or []) + elif spec.tools_scope == ToolsScope.NONE: + tools = [] + else: + tools = None + + item: dict[str, Any] = { + "name": spec.name, + "enabled": bool(spec.enabled), + "persona_id": spec.persona_id, + "provider_id": spec.provider_id, + "public_description": spec.public_description, + "tools_scope": spec.tools_scope.value, + "tools": tools, + "instructions": spec.instructions, + # TRANSITIONAL: `system_prompt` is a deprecated mirror of + # `instructions`. Both are emitted during the transition period + # so older dashboard versions and plugins continue to work. + # Remove `system_prompt` once all consumers migrate to + # `instructions`. + "system_prompt": spec.instructions, + "max_steps": spec.max_steps, + } + if spec.extensions: + item.update(spec.extensions) + payload["agents"].append(item) + + if diagnostics: + payload["diagnostics"] = diagnostics + if compat_warnings: + payload["compat_warnings"] = compat_warnings + return payload diff --git a/astrbot/core/subagent/constants.py b/astrbot/core/subagent/constants.py new file mode 100644 index 000000000..d5b841488 --- /dev/null +++ b/astrbot/core/subagent/constants.py @@ -0,0 +1,93 @@ +"""Constants for subagent configuration and error classification. + +This module centralizes default values and configuration constants +for the subagent orchestration system. +""" + +import asyncio +from typing import Literal + +# ============================================================================ +# Error Classifier Defaults +# ============================================================================ + +DEFAULT_FATAL_EXCEPTIONS: tuple[type[Exception], ...] = ( + ValueError, + PermissionError, + KeyError, +) + +DEFAULT_TRANSIENT_EXCEPTIONS: tuple[type[Exception], ...] = ( + TimeoutError, + ConnectionError, + ConnectionResetError, +) + +# Exception name strings for configuration serialization +DEFAULT_FATAL_EXCEPTION_NAMES: list[str] = [ + "ValueError", + "PermissionError", + "KeyError", +] + +DEFAULT_TRANSIENT_EXCEPTION_NAMES: list[str] = [ + "asyncio.TimeoutError", + "TimeoutError", + "ConnectionError", + "ConnectionResetError", +] + +# Default error classification for unclassified exceptions +ErrorClass = Literal["fatal", "transient", "retryable"] +DEFAULT_ERROR_CLASS: ErrorClass = "transient" + +# ============================================================================ +# Subagent Runtime Defaults +# ============================================================================ + +DEFAULT_MAX_CONCURRENT_TASKS: int = 8 +DEFAULT_MAX_ATTEMPTS: int = 3 +DEFAULT_BASE_DELAY_MS: int = 500 +DEFAULT_MAX_DELAY_MS: int = 30000 +DEFAULT_JITTER_RATIO: float = 0.1 + +# Limits for runtime parameters +MIN_CONCURRENT_TASKS: int = 1 +MAX_CONCURRENT_TASKS: int = 64 +MIN_ATTEMPTS: int = 1 +MIN_BASE_DELAY_MS: int = 100 + +# ============================================================================ +# Subagent Worker Defaults +# ============================================================================ + +DEFAULT_POLL_INTERVAL: float = 1.0 +DEFAULT_BATCH_SIZE: int = 8 +MIN_POLL_INTERVAL: float = 0.1 +MIN_BATCH_SIZE: int = 1 + +# ============================================================================ +# Handoff Execution Limits +# ============================================================================ + +# Maximum nested depth for subagent handoffs +MAX_NESTED_DEPTH_LIMIT: int = 8 +MIN_NESTED_DEPTH_LIMIT: int = 1 +DEFAULT_MAX_NESTED_HANDOFF_DEPTH: int = 2 + +# Default max steps for subagent execution +DEFAULT_MAX_STEPS: int = 30 + +# ============================================================================ +# Allowed Exception Types for Configuration +# ============================================================================ + +EXCEPTION_ALLOWLIST: dict[str, type[Exception]] = { + "ValueError": ValueError, + "PermissionError": PermissionError, + "KeyError": KeyError, + "TimeoutError": TimeoutError, + "ConnectionError": ConnectionError, + "ConnectionResetError": ConnectionResetError, + "asyncio.TimeoutError": asyncio.TimeoutError, +} diff --git a/astrbot/core/subagent/error_classifier.py b/astrbot/core/subagent/error_classifier.py new file mode 100644 index 000000000..b5b762917 --- /dev/null +++ b/astrbot/core/subagent/error_classifier.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from typing import Literal, Protocol + +from .constants import ( + DEFAULT_ERROR_CLASS, + DEFAULT_FATAL_EXCEPTION_NAMES, + DEFAULT_FATAL_EXCEPTIONS, + DEFAULT_TRANSIENT_EXCEPTION_NAMES, + DEFAULT_TRANSIENT_EXCEPTIONS, + EXCEPTION_ALLOWLIST, +) +from .models import SubagentErrorClassifierConfig + +ErrorClass = Literal["fatal", "transient", "retryable"] + + +class ErrorClassifier(Protocol): + def classify(self, exc: Exception) -> ErrorClass: ... + + +class DefaultErrorClassifier: + def __init__( + self, + *, + fatal_types: tuple[type[Exception], ...] | None = None, + transient_types: tuple[type[Exception], ...] | None = None, + default_class: ErrorClass = DEFAULT_ERROR_CLASS, + ) -> None: + self.fatal_types = fatal_types or DEFAULT_FATAL_EXCEPTIONS + self.transient_types = transient_types or DEFAULT_TRANSIENT_EXCEPTIONS + self.default_class: ErrorClass = ( + default_class + if default_class in {"fatal", "transient", "retryable"} + else DEFAULT_ERROR_CLASS + ) + + def classify(self, exc: Exception) -> ErrorClass: + if isinstance(exc, self.fatal_types): + return "fatal" + if isinstance(exc, self.transient_types): + return "transient" + return self.default_class + + +def get_error_classifier_defaults() -> dict[str, str | list[str]]: + """Return default configuration for error classifier. + + This function provides serializable default values for configuration. + """ + return { + "fatal_exceptions": DEFAULT_FATAL_EXCEPTION_NAMES, + "transient_exceptions": DEFAULT_TRANSIENT_EXCEPTION_NAMES, + "default_class": DEFAULT_ERROR_CLASS, + } + + +def build_error_classifier_from_config( + cfg: SubagentErrorClassifierConfig | None, +) -> tuple[ErrorClassifier, list[str]]: + config = cfg or SubagentErrorClassifierConfig() + diagnostics: list[str] = [] + + if config.type != "default": + diagnostics.append( + f"Unsupported error_classifier.type '{config.type}', fallback to 'default'." + ) + + fatal_types = _resolve_exception_types( + config.fatal_exceptions, "error_classifier.fatal_exceptions", diagnostics + ) + transient_types = _resolve_exception_types( + config.transient_exceptions, + "error_classifier.transient_exceptions", + diagnostics, + ) + + classifier = DefaultErrorClassifier( + fatal_types=fatal_types, + transient_types=transient_types, + default_class=config.default_class, + ) + return classifier, diagnostics + + +def _resolve_exception_types( + names: list[str], + field_name: str, + diagnostics: list[str], +) -> tuple[type[Exception], ...]: + resolved: list[type[Exception]] = [] + for raw_name in names: + name = str(raw_name or "").strip() + if not name: + continue + exc_type = EXCEPTION_ALLOWLIST.get(name) + if exc_type is None: + diagnostics.append(f"{field_name}: unsupported exception '{name}' ignored.") + continue + if exc_type not in resolved: + resolved.append(exc_type) + return tuple(resolved) diff --git a/astrbot/core/subagent/handoff_executor.py b/astrbot/core/subagent/handoff_executor.py new file mode 100644 index 000000000..dbbc61e8a --- /dev/null +++ b/astrbot/core/subagent/handoff_executor.py @@ -0,0 +1,576 @@ +from __future__ import annotations + +import json +import typing as T +from collections.abc import Sequence +from collections.abc import Set as AbstractSet +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import mcp + +from astrbot import logger +from astrbot.core.agent.agent import Agent +from astrbot.core.agent.handoff import HandoffTool +from astrbot.core.agent.message import Message +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.cron.events import CronMessageEvent +from astrbot.core.message.components import Image +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.provider.register import llm_tools +from astrbot.core.subagent.background_notifier import ( + wake_main_agent_for_background_result, +) +from astrbot.core.subagent.constants import ( + DEFAULT_MAX_STEPS, + MAX_NESTED_DEPTH_LIMIT, + MIN_NESTED_DEPTH_LIMIT, +) +from astrbot.core.subagent.models import SubagentTaskData +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.image_ref_utils import is_supported_image_ref +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings + +if TYPE_CHECKING: + from astrbot.core.astr_agent_context import AstrAgentContext + + +@dataclass(slots=True) +class _HandoffExecutionSettings: + runtime: str + max_nested_depth: int + default_max_steps: int + streaming_response: bool + + +class HandoffExecutor: + @staticmethod + def _safe_int(value: T.Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + @classmethod + def _get_provider_settings( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> dict[str, T.Any]: + ctx = run_context.context.context + event = run_context.context.event + cfg = ctx.get_config(umo=event.unified_msg_origin) + provider_settings = cfg.get("provider_settings", {}) + if not isinstance(provider_settings, dict): + return {} + return provider_settings + + @classmethod + def _get_orchestrator( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> T.Any | None: + return getattr(run_context.context.context, "subagent_orchestrator", None) + + @classmethod + def _resolve_max_nested_depth( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> int: + """Resolve max nested handoff depth from the orchestrator. + + Falls back to the default constant if the orchestrator is unavailable. + """ + orchestrator = cls._get_orchestrator(run_context) + orchestrator_depth_getter = getattr(orchestrator, "get_max_nested_depth", None) + if callable(orchestrator_depth_getter): + depth = cls._safe_int(orchestrator_depth_getter(), 2) + return min(MAX_NESTED_DEPTH_LIMIT, max(MIN_NESTED_DEPTH_LIMIT, depth)) + + from astrbot.core.subagent.constants import DEFAULT_MAX_NESTED_HANDOFF_DEPTH + + return DEFAULT_MAX_NESTED_HANDOFF_DEPTH + + @classmethod + def _resolve_execution_settings( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> _HandoffExecutionSettings: + provider_settings = cls._get_provider_settings(run_context) + return _HandoffExecutionSettings( + runtime=str(provider_settings.get("computer_use_runtime", "none")), + max_nested_depth=cls._resolve_max_nested_depth(run_context), + default_max_steps=cls._safe_int( + provider_settings.get("max_agent_step", DEFAULT_MAX_STEPS), + DEFAULT_MAX_STEPS, + ), + streaming_response=bool(provider_settings.get("streaming_response", False)), + ) + + @classmethod + def _resolve_agent_max_steps(cls, tool: HandoffTool, default_max_steps: int) -> int: + configured_max_step = getattr(tool, "max_steps", None) + if isinstance(configured_max_step, int) and configured_max_step > 0: + return configured_max_step + return default_max_steps + + @classmethod + def _build_handoff_from_snapshot(cls, snapshot: T.Any) -> HandoffTool | None: + if not isinstance(snapshot, dict): + return None + agent_name = str(snapshot.get("agent_name", "")).strip() + if not agent_name: + return None + instructions = str(snapshot.get("instructions", "") or "") + tools_raw = snapshot.get("tools") + tools: list[str] | None + if tools_raw is None: + tools = None + elif isinstance(tools_raw, list): + tools = [str(item).strip() for item in tools_raw if str(item).strip()] + else: + tools = [] + + agent = Agent[T.Any]( + name=agent_name, + instructions=instructions, + tools=tools, # type: ignore[arg-type] + ) + dialogs_raw = snapshot.get("begin_dialogs") + if isinstance(dialogs_raw, list): + agent.begin_dialogs = dialogs_raw + + description_raw = snapshot.get("tool_description") + description = ( + str(description_raw).strip() + if isinstance(description_raw, str) and description_raw.strip() + else None + ) + handoff = HandoffTool(agent=agent, tool_description=description) + expected_name = str(snapshot.get("name", "")).strip() + if expected_name and handoff.name != expected_name: + logger.warning( + "Subagent snapshot handoff name mismatch: expected=%s actual=%s", + expected_name, + handoff.name, + ) + provider_id_raw = snapshot.get("provider_id") + handoff.provider_id = ( + str(provider_id_raw).strip() + if isinstance(provider_id_raw, str) and provider_id_raw.strip() + else None + ) + display_name_raw = snapshot.get("agent_display_name") + handoff.agent_display_name = ( + str(display_name_raw).strip() + if isinstance(display_name_raw, str) and display_name_raw.strip() + else agent_name + ) + max_steps_raw = snapshot.get("max_steps") + handoff.max_steps = ( + int(max_steps_raw) + if isinstance(max_steps_raw, int) and max_steps_raw > 0 + else None + ) + return handoff + + @classmethod + def collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: + if image_urls_raw is None: + return [] + + if isinstance(image_urls_raw, str): + return [image_urls_raw] + + if isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance( + image_urls_raw, (str, bytes, bytearray) + ): + return [item for item in image_urls_raw if isinstance(item, str)] + + logger.debug( + "Unsupported image_urls type in handoff tool args: %s", + type(image_urls_raw).__name__, + ) + return [] + + @classmethod + async def collect_image_urls_from_message( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> list[str]: + urls: list[str] = [] + event = getattr(run_context.context, "event", None) + message_obj = getattr(event, "message_obj", None) + message = getattr(message_obj, "message", None) + if message: + for idx, component in enumerate(message): + if not isinstance(component, Image): + continue + try: + path = await component.convert_to_file_path() + if path: + urls.append(path) + except Exception as exc: # noqa: BLE001 + logger.error( + "Failed to convert handoff image component at index %d: %s", + idx, + exc, + exc_info=True, + ) + return urls + + @classmethod + async def collect_handoff_image_urls( + cls, + run_context: ContextWrapper[AstrAgentContext], + image_urls_raw: T.Any, + ) -> list[str]: + candidates: list[str] = [] + candidates.extend(cls.collect_image_urls_from_args(image_urls_raw)) + candidates.extend(await cls.collect_image_urls_from_message(run_context)) + + normalized = normalize_and_dedupe_strings(candidates) + extensionless_local_roots = (get_astrbot_temp_path(),) + sanitized = [ + item + for item in normalized + if is_supported_image_ref( + item, + allow_extensionless_existing_local_file=True, + extensionless_local_roots=extensionless_local_roots, + ) + ] + dropped_count = len(normalized) - len(sanitized) + if dropped_count > 0: + logger.debug( + "Dropped %d invalid image_urls entries in handoff image inputs.", + dropped_count, + ) + return sanitized + + @classmethod + def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]: + from astrbot.core.astr_main_agent_resources import ( + EXECUTE_SHELL_TOOL, + FILE_DOWNLOAD_TOOL, + FILE_UPLOAD_TOOL, + LOCAL_EXECUTE_SHELL_TOOL, + LOCAL_PYTHON_TOOL, + PYTHON_TOOL, + ) + + if runtime == "sandbox": + return { + EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL, + PYTHON_TOOL.name: PYTHON_TOOL, + FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL, + FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL, + } + if runtime == "local": + return { + LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL, + LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL, + } + return {} + + @classmethod + def build_handoff_toolset( + cls, + run_context: ContextWrapper[AstrAgentContext], + tools: list[str | FunctionTool] | None, + ) -> ToolSet | None: + provider_settings = cls._get_provider_settings(run_context) + runtime = str(provider_settings.get("computer_use_runtime", "none")) + runtime_computer_tools = cls._get_runtime_computer_tools(runtime) + + if tools is None: + toolset = ToolSet() + for registered_tool in llm_tools.func_list: + if isinstance(registered_tool, HandoffTool): + continue + if registered_tool.active: + toolset.add_tool(registered_tool) + for runtime_tool in runtime_computer_tools.values(): + toolset.add_tool(runtime_tool) + return None if toolset.empty() else toolset + + if not tools: + return None + + toolset = ToolSet() + for tool_name_or_obj in tools: + if isinstance(tool_name_or_obj, str): + registered_tool = llm_tools.get_func(tool_name_or_obj) + if ( + registered_tool + and registered_tool.active + and not isinstance(registered_tool, HandoffTool) + ): + toolset.add_tool(registered_tool) + continue + runtime_tool = runtime_computer_tools.get(tool_name_or_obj) + if runtime_tool: + toolset.add_tool(runtime_tool) + elif isinstance(tool_name_or_obj, FunctionTool) and not isinstance( + tool_name_or_obj, HandoffTool + ): + toolset.add_tool(tool_name_or_obj) + return None if toolset.empty() else toolset + + @classmethod + async def execute_foreground( + cls, + tool: HandoffTool, + run_context: ContextWrapper[AstrAgentContext], + *, + image_urls_prepared: bool = False, + **tool_args: T.Any, + ): + args = dict(tool_args) + input_ = args.get("input") + if image_urls_prepared: + prepared_image_urls = args.get("image_urls") + if isinstance(prepared_image_urls, list): + image_urls = prepared_image_urls + else: + logger.debug( + "Expected prepared handoff image_urls as list[str], got %s.", + type(prepared_image_urls).__name__, + ) + image_urls = [] + else: + image_urls = await cls.collect_handoff_image_urls( + run_context, + args.get("image_urls"), + ) + args["image_urls"] = image_urls + + toolset = cls.build_handoff_toolset(run_context, tool.agent.tools) + execution_settings = cls._resolve_execution_settings(run_context) + + ctx = run_context.context.context + event = run_context.context.event + umo = event.unified_msg_origin + event_get_extra = getattr(event, "get_extra", None) + current_depth = ( + cls._safe_int(event_get_extra("subagent_handoff_depth"), 0) + if callable(event_get_extra) + else 0 + ) + max_depth = execution_settings.max_nested_depth + if current_depth >= max_depth: + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=( + f"error: nested subagent handoff depth limit reached ({max_depth}). " + "Please continue in current agent." + ), + ) + ] + ) + return + + prov_id = getattr( + tool, "provider_id", None + ) or await ctx.get_current_chat_provider_id(umo) + + contexts = None + dialogs = tool.agent.begin_dialogs + if dialogs: + contexts = [] + for dialog in dialogs: + try: + contexts.append( + dialog + if isinstance(dialog, Message) + else Message.model_validate(dialog) + ) + except Exception: # noqa: BLE001 + continue + + agent_max_step = cls._resolve_agent_max_steps( + tool, execution_settings.default_max_steps + ) + stream = execution_settings.streaming_response + event_set_extra = getattr(event, "set_extra", None) + if callable(event_set_extra): + event_set_extra("subagent_handoff_depth", current_depth + 1) + try: + llm_resp = await ctx.tool_loop_agent( + event=event, + chat_provider_id=prov_id, + prompt=input_, + image_urls=image_urls, + system_prompt=tool.agent.instructions, + tools=toolset, + contexts=contexts, + max_steps=agent_max_step, + stream=stream, + ) + finally: + if callable(event_set_extra): + event_set_extra("subagent_handoff_depth", current_depth) + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] + ) + + @classmethod + async def submit_background( + cls, + tool: HandoffTool, + run_context: ContextWrapper[AstrAgentContext], + *, + tool_call_id: str | None = None, + **tool_args: T.Any, + ): + prepared_tool_args = dict(tool_args) + prepared_tool_args["image_urls"] = await cls.collect_handoff_image_urls( + run_context, + prepared_tool_args.get("image_urls"), + ) + orchestrator = getattr( + run_context.context.context, "subagent_orchestrator", None + ) + if orchestrator is None: + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=( + "error: subagent orchestrator is not available, " + "background handoff cannot be submitted." + ), + ) + ] + ) + return + + try: + task_id = await orchestrator.submit_handoff( + handoff=tool, + run_context=run_context, + payload=prepared_tool_args, + background=True, + tool_call_id=tool_call_id, + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Failed to submit handoff to subagent orchestrator runtime: %s", + exc, + exc_info=True, + ) + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text="error: failed to submit subagent background task to orchestrator.", + ) + ] + ) + return + + if not task_id: + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=( + "error: failed to submit subagent background task " + "because orchestrator returned no task id." + ), + ) + ] + ) + return + + yield mcp.types.CallToolResult( + content=[ + mcp.types.TextContent( + type="text", + text=( + f"Background task dedicated to subagent '{tool.agent.name}' submitted. " + f"task_id={task_id}. You will be notified when it finishes." + ), + ) + ] + ) + + @classmethod + async def execute_queued_task( + cls, + *, + task: SubagentTaskData, + plugin_context: T.Any, + handoff: HandoffTool | None, + ) -> str: + payload = json.loads(task.payload_json) + if not isinstance(payload, dict): + raise ValueError("Invalid task payload.") + + snapshot_handoff = cls._build_handoff_from_snapshot( + payload.get("_handoff_snapshot") + ) + queued_handoff = snapshot_handoff or handoff + if queued_handoff is None: + raise ValueError(f"Handoff tool `{task.handoff_tool_name}` not found.") + + tool_args = payload.get("tool_args", {}) + if not isinstance(tool_args, dict): + raise ValueError("Invalid task tool_args payload.") + meta = payload.get("_meta", {}) + if not isinstance(meta, dict): + meta = {} + + session = MessageSession.from_str(task.umo) + cron_event = CronMessageEvent( + context=plugin_context, + session=session, + message=str( + tool_args.get("input") or f"[SubagentTask] {task.subagent_name}" + ), + extras={ + "background_note": meta.get("background_note") + or f"Background task for subagent '{task.subagent_name}' finished." + }, + message_type=session.message_type, + ) + if role := meta.get("role"): + cron_event.role = role + + from astrbot.core.astr_agent_context import ( + AgentContextWrapper, + AstrAgentContext, + ) + + agent_ctx = AstrAgentContext(context=plugin_context, event=cron_event) + wrapper = AgentContextWrapper( + context=agent_ctx, + tool_call_timeout=int(meta.get("tool_call_timeout", 3600)), + ) + + handoff_args = dict(tool_args) + handoff_args["image_urls"] = await cls.collect_handoff_image_urls( + wrapper, + handoff_args.get("image_urls"), + ) + result_text = "" + async for result in cls.execute_foreground( + queued_handoff, + wrapper, + image_urls_prepared=True, + **handoff_args, + ): + if isinstance(result, mcp.types.CallToolResult): + for content in result.content: + if isinstance(content, mcp.types.TextContent): + result_text += content.text + "\n" + + await wake_main_agent_for_background_result( + run_context=wrapper, + task_id=task.task_id, + tool_name=queued_handoff.name, + result_text=result_text, + tool_args=handoff_args, + note=meta.get("background_note") + or f"Background task for subagent '{task.subagent_name}' finished.", + summary_name=f"Dedicated to subagent `{task.subagent_name}`", + extra_result_fields={"subagent_name": task.subagent_name}, + ) + return result_text or "ok" diff --git a/astrbot/core/subagent/hooks.py b/astrbot/core/subagent/hooks.py new file mode 100644 index 000000000..6aebe12f0 --- /dev/null +++ b/astrbot/core/subagent/hooks.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Protocol + +from .models import SubagentTaskData + + +class SubagentHooks(Protocol): + async def on_task_enqueued(self, task: SubagentTaskData) -> None: ... + + async def on_task_started(self, task: SubagentTaskData) -> None: ... + + async def on_task_retrying( + self, + task: SubagentTaskData, + *, + delay_seconds: float, + error_class: str, + error: Exception, + ) -> None: ... + + async def on_task_succeeded(self, task: SubagentTaskData, result: str) -> None: ... + + async def on_task_failed( + self, task: SubagentTaskData, *, error_class: str, error: Exception + ) -> None: ... + + async def on_task_canceled(self, task_id: str) -> None: ... + + async def on_task_result_ignored( + self, task: SubagentTaskData, *, reason: str + ) -> None: ... + + +class NoopSubagentHooks: + async def on_task_enqueued(self, task: SubagentTaskData) -> None: + _ = task + + async def on_task_started(self, task: SubagentTaskData) -> None: + _ = task + + async def on_task_retrying( + self, + task: SubagentTaskData, + *, + delay_seconds: float, + error_class: str, + error: Exception, + ) -> None: + _ = task + _ = delay_seconds + _ = error_class + _ = error + + async def on_task_succeeded(self, task: SubagentTaskData, result: str) -> None: + _ = task + _ = result + + async def on_task_failed( + self, task: SubagentTaskData, *, error_class: str, error: Exception + ) -> None: + _ = task + _ = error_class + _ = error + + async def on_task_canceled(self, task_id: str) -> None: + _ = task_id + + async def on_task_result_ignored( + self, task: SubagentTaskData, *, reason: str + ) -> None: + _ = task + _ = reason diff --git a/astrbot/core/subagent/models.py b/astrbot/core/subagent/models.py new file mode 100644 index 000000000..b47fdc6b9 --- /dev/null +++ b/astrbot/core/subagent/models.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import hashlib +import re +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from astrbot.core.agent.handoff import HandoffTool + +_TOOL_SAFE_CHARS = re.compile(r"[^a-z0-9_-]+") +_TOOL_ALLOWED = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + + +class ToolsScope(str, Enum): + ALL = "all" + NONE = "none" + LIST = "list" + PERSONA = "persona" + + +def build_safe_handoff_agent_name(display_name: str) -> str: + raw = str(display_name or "").strip() + if not raw: + raise ValueError("Subagent name cannot be empty.") + lowered = raw.lower() + slug = _TOOL_SAFE_CHARS.sub("_", lowered).strip("_") + if not slug: + slug = "subagent" + candidate = slug + tool_name = f"transfer_to_{candidate}" + if _TOOL_ALLOWED.match(tool_name): + return candidate + + digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:8] + max_base_len = max(1, 64 - len("transfer_to__") - len(digest)) + trimmed = slug[:max_base_len].strip("_") or "subagent" + candidate = f"{trimmed}_{digest}" + tool_name = f"transfer_to_{candidate}" + if not _TOOL_ALLOWED.match(tool_name): + raise ValueError( + f"Invalid subagent name '{display_name}', cannot derive a safe handoff tool name." + ) + return candidate + + +class SubagentAgentSpec(BaseModel): + name: str = Field(min_length=1, max_length=256) + enabled: bool = True + persona_id: str | None = None + provider_id: str | None = None + public_description: str = "" + tools_scope: ToolsScope = ToolsScope.ALL + tools: list[str] | None = None + instructions: str = "" + max_steps: int | None = None + extensions: dict[str, Any] = Field(default_factory=dict) + + model_config = ConfigDict(extra="allow") + + @field_validator("name") + @classmethod + def _validate_name(cls, value: str) -> str: + _ = build_safe_handoff_agent_name(value) + return value.strip() + + @property + def handoff_agent_name(self) -> str: + return build_safe_handoff_agent_name(self.name) + + +class SubagentErrorClassifierConfig(BaseModel): + type: str = "default" + fatal_exceptions: list[str] = Field( + default_factory=lambda: ["ValueError", "PermissionError", "KeyError"] + ) + transient_exceptions: list[str] = Field( + default_factory=lambda: [ + "asyncio.TimeoutError", + "TimeoutError", + "ConnectionError", + "ConnectionResetError", + ] + ) + default_class: Literal["fatal", "transient", "retryable"] = "transient" + + model_config = ConfigDict(extra="allow") + + +class SubagentConfig(BaseModel): + main_enable: bool = False + remove_main_duplicate_tools: bool = False + router_system_prompt: str = "" + agents: list[SubagentAgentSpec] = Field(default_factory=list) + max_concurrent_subagent_runs: int = 8 + max_nested_depth: int = 2 + error_classifier: SubagentErrorClassifierConfig = Field( + default_factory=SubagentErrorClassifierConfig + ) + extensions: dict[str, Any] = Field(default_factory=dict) + + model_config = ConfigDict(extra="allow") + + @field_validator("max_concurrent_subagent_runs") + @classmethod + def _validate_max_concurrent(cls, value: int) -> int: + if value < 1: + return 1 + if value > 64: + return 64 + return value + + @field_validator("max_nested_depth") + @classmethod + def _validate_max_nested_depth(cls, value: int) -> int: + if value < 1: + return 1 + if value > 8: + return 8 + return value + + +@dataclass(slots=True) +class SubagentMountPlan: + handoffs: list[HandoffTool] = field(default_factory=list) + handoff_by_tool_name: dict[str, HandoffTool] = field(default_factory=dict) + main_tool_exclude_set: set[str] = field(default_factory=set) + router_prompt: str | None = None + diagnostics: list[str] = field(default_factory=list) + + +class SubagentTaskStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + RETRYING = "retrying" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELED = "canceled" + + +@dataclass(slots=True) +class SubagentTaskData: + task_id: str + idempotency_key: str + umo: str + subagent_name: str + handoff_tool_name: str + status: SubagentTaskStatus | str + attempt: int + max_attempts: int + next_run_at: datetime | None + payload_json: str + error_class: str | None + last_error: str | None + result_text: str | None + created_at: datetime + updated_at: datetime + finished_at: datetime | None diff --git a/astrbot/core/subagent/planner.py b/astrbot/core/subagent/planner.py new file mode 100644 index 000000000..bd86794f5 --- /dev/null +++ b/astrbot/core/subagent/planner.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from typing import Any + +from astrbot import logger +from astrbot.core.agent.agent import Agent +from astrbot.core.agent.handoff import HandoffTool + +from .models import ( + SubagentAgentSpec, + SubagentConfig, + SubagentMountPlan, + ToolsScope, + build_safe_handoff_agent_name, +) + + +class SubagentPlanner: + """Build a deterministic mount plan from canonical subagent configuration.""" + + def __init__( + self, + tool_mgr: Any, + persona_mgr: Any, + ) -> None: + self._tool_mgr = tool_mgr + self._persona_mgr = persona_mgr + + async def build_mount_plan(self, config: SubagentConfig) -> SubagentMountPlan: + diagnostics: list[str] = [] + handoffs: list[HandoffTool] = [] + handoff_map: dict[str, HandoffTool] = {} + exclude_set: set[str] = set() + + all_active_tools = { + tool.name + for tool in self._tool_mgr.func_list + if getattr(tool, "active", True) and not isinstance(tool, HandoffTool) + } + seen_agent_names: set[str] = set() + + for spec in config.agents: + if not spec.enabled: + continue + + safe_agent_name = spec.handoff_agent_name + resolved_agent_name = safe_agent_name + if resolved_agent_name in seen_agent_names: + suffix = 2 + while suffix <= 99 and resolved_agent_name in seen_agent_names: + resolved_agent_name = build_safe_handoff_agent_name( + f"{spec.name}-{suffix}" + ) + suffix += 1 + if resolved_agent_name in seen_agent_names: + diagnostics.append( + f"ERROR: duplicate subagent tool name generated from `{spec.name}`." + ) + continue + diagnostics.append( + "WARN: duplicate subagent tool name generated from " + f"`{spec.name}`, renamed to `{resolved_agent_name}`." + ) + seen_agent_names.add(resolved_agent_name) + + persona = await self._resolve_persona(spec, diagnostics) + instructions = self._resolve_instructions(spec, persona) + public_description = self._resolve_public_description(spec, persona) + tools = self._resolve_tools(spec, persona, all_active_tools, diagnostics) + begin_dialogs = getattr(persona, "begin_dialogs", None) if persona else None + + agent = Agent[Any]( + name=resolved_agent_name, + instructions=instructions, + tools=tools, # type: ignore[arg-type] + ) + agent.begin_dialogs = begin_dialogs + + handoff = HandoffTool( + agent=agent, + tool_description=public_description or None, + ) + handoff.provider_id = spec.provider_id + handoff.agent_display_name = spec.name # type: ignore[attr-defined] + handoff.max_steps = spec.max_steps # type: ignore[attr-defined] + handoffs.append(handoff) + handoff_map[handoff.name] = handoff + + if config.remove_main_duplicate_tools: + if tools is None: + exclude_set.update(all_active_tools) + else: + exclude_set.update( + {name for name in tools if name in all_active_tools} + ) + + for handoff in handoffs: + logger.info("Registered subagent handoff tool: %s", handoff.name) + + return SubagentMountPlan( + handoffs=handoffs, + handoff_by_tool_name=handoff_map, + main_tool_exclude_set=exclude_set, + router_prompt=(config.router_system_prompt or "").strip() or None, + diagnostics=diagnostics, + ) + + async def _resolve_persona( + self, + spec: SubagentAgentSpec, + diagnostics: list[str], + ) -> Any | None: + if not spec.persona_id: + return None + try: + return await self._persona_mgr.get_persona(spec.persona_id) + except (ValueError, StopIteration): + diagnostics.append( + f"WARN: subagent `{spec.name}` persona `{spec.persona_id}` not found, fallback to inline settings." + ) + return None + + @staticmethod + def _resolve_instructions(spec: SubagentAgentSpec, persona: Any | None) -> str: + if persona and getattr(persona, "system_prompt", None): + return str(persona.system_prompt) + return spec.instructions + + @staticmethod + def _resolve_public_description( + spec: SubagentAgentSpec, persona: Any | None + ) -> str: + if spec.public_description: + return spec.public_description + if persona and getattr(persona, "system_prompt", None): + return str(persona.system_prompt)[:120] + return "" + + @staticmethod + def _resolve_tools( + spec: SubagentAgentSpec, + persona: Any | None, + all_active_tools: set[str], + diagnostics: list[str], + ) -> list[str] | None: + if spec.tools_scope == ToolsScope.ALL: + return None + if spec.tools_scope == ToolsScope.NONE: + return [] + if spec.tools_scope == ToolsScope.PERSONA: + if persona is None: + diagnostics.append( + f"WARN: subagent `{spec.name}` uses tools_scope=persona but persona is missing." + ) + return [] + tools = getattr(persona, "tools", None) + if tools is None: + return None + if not isinstance(tools, list): + return [] + return [str(t).strip() for t in tools if str(t).strip() in all_active_tools] + + tools = spec.tools or [] + filtered: list[str] = [] + for name in tools: + tool_name = str(name).strip() + if not tool_name: + continue + if tool_name.startswith("transfer_to_"): + diagnostics.append( + f"WARN: subagent `{spec.name}` tool `{tool_name}` ignored to prevent recursive handoff." + ) + continue + if tool_name in all_active_tools: + filtered.append(tool_name) + return filtered diff --git a/astrbot/core/subagent/runtime.py b/astrbot/core/subagent/runtime.py new file mode 100644 index 000000000..d91e40cf3 --- /dev/null +++ b/astrbot/core/subagent/runtime.py @@ -0,0 +1,462 @@ +from __future__ import annotations + +import asyncio +import hashlib +import json +import random +import typing as T +import uuid +from collections.abc import Awaitable, Callable +from datetime import datetime, timedelta, timezone +from typing import Any + +from sqlalchemy.exc import IntegrityError + +from astrbot import logger +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import SubagentTask + +from .constants import ( + DEFAULT_BASE_DELAY_MS, + DEFAULT_JITTER_RATIO, + DEFAULT_MAX_ATTEMPTS, + DEFAULT_MAX_CONCURRENT_TASKS, + DEFAULT_MAX_DELAY_MS, + MAX_CONCURRENT_TASKS, + MIN_ATTEMPTS, + MIN_BASE_DELAY_MS, + MIN_CONCURRENT_TASKS, +) +from .error_classifier import DefaultErrorClassifier, ErrorClassifier +from .hooks import NoopSubagentHooks, SubagentHooks +from .models import SubagentTaskData + + +class SubagentRuntime: + """Background runtime for subagent handoff tasks. + + It provides queueing, lane-based concurrency control, retry classification, + and lifecycle logging. Task execution is delegated via a callback. + """ + + def __init__( + self, + db: BaseDatabase | None, + *, + max_concurrent: int = DEFAULT_MAX_CONCURRENT_TASKS, + max_attempts: int = DEFAULT_MAX_ATTEMPTS, + base_delay_ms: int = DEFAULT_BASE_DELAY_MS, + max_delay_ms: int = DEFAULT_MAX_DELAY_MS, + jitter_ratio: float = DEFAULT_JITTER_RATIO, + hooks: SubagentHooks | None = None, + error_classifier: ErrorClassifier | None = None, + ) -> None: + self._db = db + self._max_concurrent = max( + MIN_CONCURRENT_TASKS, min(int(max_concurrent), MAX_CONCURRENT_TASKS) + ) + self._max_attempts = max(MIN_ATTEMPTS, int(max_attempts)) + self._base_delay_ms = max(MIN_BASE_DELAY_MS, int(base_delay_ms)) + self._max_delay_ms = max(self._base_delay_ms, int(max_delay_ms)) + self._jitter_ratio = max(0.0, min(float(jitter_ratio), 1.0)) + self._active_lanes: dict[str, str] = {} + self._task_executor: Callable[[SubagentTaskData], Awaitable[str]] | None = None + self._running_recovery_done = False + self._hooks: SubagentHooks = hooks or NoopSubagentHooks() + self._error_classifier: ErrorClassifier = ( + error_classifier or DefaultErrorClassifier() + ) + + def set_max_concurrent(self, value: int) -> None: + self._max_concurrent = max( + MIN_CONCURRENT_TASKS, min(int(value), MAX_CONCURRENT_TASKS) + ) + + def set_task_executor( + self, + executor: Callable[[SubagentTaskData], Awaitable[str]], + ) -> None: + self._task_executor = executor + + def set_hooks(self, hooks: SubagentHooks | None) -> None: + self._hooks = hooks or NoopSubagentHooks() + + def set_error_classifier(self, error_classifier: ErrorClassifier | None) -> None: + self._error_classifier = error_classifier or DefaultErrorClassifier() + + async def enqueue( + self, + *, + umo: str, + subagent_name: str, + handoff_tool_name: str, + payload: dict[str, Any], + tool_call_id: str | None, + ) -> str: + if not self._db: + raise RuntimeError("Subagent runtime database is not available.") + payload_json = json.dumps( + payload, ensure_ascii=False, sort_keys=True, default=str + ) + idem = self._build_idempotency_key( + umo=umo, + handoff_tool_name=handoff_tool_name, + tool_call_id=tool_call_id, + payload_json=payload_json, + ) + existing = await self._db.get_subagent_task_by_idempotency(idem) + if existing: + return existing.task_id + + task_id = uuid.uuid4().hex + try: + created = await self._db.create_subagent_task( + task_id=task_id, + idempotency_key=idem, + umo=umo, + subagent_name=subagent_name, + handoff_tool_name=handoff_tool_name, + payload_json=payload_json, + max_attempts=self._max_attempts, + ) + except IntegrityError: + # Handle concurrent enqueue with the same idempotency key. + existing_after_race = await self._db.get_subagent_task_by_idempotency(idem) + if existing_after_race: + return existing_after_race.task_id + raise + self._emit_event("task_enqueued", task_id, subagent_name, 0, umo) + await self._call_hook("on_task_enqueued", _to_task_data(created)) + return task_id + + async def process_once(self, *, batch_size: int = 8) -> int: + if not self._db: + return 0 + if not self._task_executor: + return 0 + if not self._running_recovery_done: + try: + recovered = await self._recover_interrupted_running_tasks() + if recovered > 0: + logger.info( + "[SubagentRuntime] recovered %d interrupted running task(s).", + recovered, + ) + finally: + self._running_recovery_done = True + now = datetime.now(timezone.utc) + candidates = await self._db.claim_due_subagent_tasks( + now=now, limit=batch_size * 2 + ) + selected: list[SubagentTaskData] = [] + for task in candidates: + lane = self._lane_key(task.umo, task.subagent_name) + if lane in self._active_lanes: + continue + if len(self._active_lanes) >= self._max_concurrent: + break + running = await self._db.mark_subagent_task_running(task.task_id) + if not running: + continue + self._active_lanes[lane] = running.task_id + selected.append(_to_task_data(running)) + if len(selected) >= batch_size: + break + + if not selected: + return 0 + await asyncio.gather(*(self._run_one(task) for task in selected)) + return len(selected) + + async def list_tasks( + self, *, status: str | None = None, limit: int = 100 + ) -> list[dict]: + if not self._db: + return [] + rows = await self._db.list_subagent_tasks(status=status, limit=limit) + return [_serialize_task(row) for row in rows] + + async def retry_task(self, task_id: str) -> bool: + if not self._db: + return False + return await self._db.reschedule_subagent_task( + task_id=task_id, + next_run_at=datetime.now(timezone.utc), + error_class="manual", + last_error="manual retry requested", + ) + + async def cancel_task(self, task_id: str) -> bool: + if not self._db: + return False + canceled = await self._db.cancel_subagent_task(task_id) + if canceled: + await self._call_hook("on_task_canceled", task_id) + return canceled + + async def _run_one(self, task: SubagentTaskData) -> None: + db = self._db + task_executor = self._task_executor + if db is None or task_executor is None: + raise RuntimeError("Subagent runtime is not fully initialized.") + + lane = self._lane_key(task.umo, task.subagent_name) + self._emit_event( + "task_started", task.task_id, task.subagent_name, task.attempt, task.umo + ) + await self._call_hook("on_task_started", task) + try: + result = await task_executor(task) + updated = await db.mark_subagent_task_succeeded( + task.task_id, result_text=result + ) + if not updated: + self._emit_event( + "task_result_ignored", + task.task_id, + task.subagent_name, + task.attempt, + task.umo, + error_class="state_changed", + error_message="task status changed before success commit", + ) + await self._call_hook( + "on_task_result_ignored", + task, + reason="task status changed before success commit", + ) + return + self._emit_event( + "task_succeeded", + task.task_id, + task.subagent_name, + task.attempt, + task.umo, + ) + await self._call_hook("on_task_succeeded", task, result) + return + except Exception as exc: # noqa: BLE001 + error_class = self._classify_error(exc) + if ( + error_class in {"transient", "retryable"} + and task.attempt < task.max_attempts + ): + delay = self._compute_delay_seconds(task.attempt) + next_run = datetime.now(timezone.utc) + timedelta(seconds=delay) + updated = await db.mark_subagent_task_retrying( + task_id=task.task_id, + next_run_at=next_run, + error_class=error_class, + last_error=str(exc), + ) + if not updated: + self._emit_event( + "task_result_ignored", + task.task_id, + task.subagent_name, + task.attempt, + task.umo, + error_class="state_changed", + error_message="task status changed before retry commit", + ) + await self._call_hook( + "on_task_result_ignored", + task, + reason="task status changed before retry commit", + ) + return + self._emit_event( + "task_retrying", + task.task_id, + task.subagent_name, + task.attempt, + task.umo, + delay_seconds=delay, + error_class=error_class, + error_message=str(exc), + ) + await self._call_hook( + "on_task_retrying", + task, + delay_seconds=delay, + error_class=error_class, + error=exc, + ) + return + + updated = await db.mark_subagent_task_failed( + task_id=task.task_id, + error_class=error_class, + last_error=str(exc), + ) + if not updated: + self._emit_event( + "task_result_ignored", + task.task_id, + task.subagent_name, + task.attempt, + task.umo, + error_class="state_changed", + error_message="task status changed before failure commit", + ) + await self._call_hook( + "on_task_result_ignored", + task, + reason="task status changed before failure commit", + ) + return + self._emit_event( + "task_failed", + task.task_id, + task.subagent_name, + task.attempt, + task.umo, + error_class=error_class, + error_message=str(exc), + ) + await self._call_hook( + "on_task_failed", + task, + error_class=error_class, + error=exc, + ) + return + finally: + self._active_lanes.pop(lane, None) + + @staticmethod + def _lane_key(umo: str, subagent_name: str) -> str: + return f"session:{umo}:{subagent_name}" + + @staticmethod + def _build_idempotency_key( + *, + umo: str, + handoff_tool_name: str, + tool_call_id: str | None, + payload_json: str, + ) -> str: + raw = f"{umo}:{handoff_tool_name}:{tool_call_id or ''}:{payload_json}" + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + def _compute_delay_seconds(self, attempt: int) -> float: + delay_ms = min(self._max_delay_ms, self._base_delay_ms * (2 ** max(0, attempt))) + jitter_ms = delay_ms * self._jitter_ratio * random.random() # noqa: S311 + return (delay_ms + jitter_ms) / 1000.0 + + def _classify_error(self, exc: Exception) -> str: + classified = self._error_classifier.classify(exc) + if classified in {"fatal", "transient", "retryable"}: + return classified + return "transient" + + async def _recover_interrupted_running_tasks(self) -> int: + db = self._db + if not db: + return 0 + rows = await db.list_subagent_tasks(status="running", limit=200) + if not rows: + return 0 + now = datetime.now(timezone.utc) + # Only recover tasks that have been stale for at least 5 minutes; + # recently-updated tasks may still be executing on another worker. + stale_threshold = timedelta(minutes=5) + recovered = 0 + for row in rows: + if row.updated_at and (now - row.updated_at) < stale_threshold: + continue + ok = await db.mark_subagent_task_retrying( + task_id=row.task_id, + next_run_at=now, + error_class="transient", + last_error="Recovered interrupted running task after worker restart.", + ) + if ok: + recovered += 1 + return recovered + + async def _call_hook(self, hook_name: str, *args, **kwargs) -> None: + hook = getattr(self._hooks, hook_name, None) + if not callable(hook): + return + typed_hook = T.cast(Callable[..., Awaitable[None]], hook) + try: + await typed_hook(*args, **kwargs) + except Exception as exc: + exc_type = type(exc).__name__ + exc_module = type(exc).__module__ + full_exc_name = ( + exc_type if exc_module == "builtins" else f"{exc_module}.{exc_type}" + ) + logger.error( + "[SubagentRuntime] hook=%s failed (type=%s): %s", + hook_name, + full_exc_name, + exc, + exc_info=True, + ) + + @staticmethod + def _emit_event( + event_type: str, + task_id: str, + subagent_name: str, + attempt: int, + umo: str, + *, + delay_seconds: float | None = None, + error_class: str | None = None, + error_message: str | None = None, + ) -> None: + logger.debug( + "[SubagentRuntime] event=%s task_id=%s subagent=%s attempt=%s umo=%s delay=%s error_class=%s error=%s", + event_type, + task_id, + subagent_name, + attempt, + umo, + delay_seconds, + error_class, + error_message, + ) + + +def _to_task_data(task: SubagentTask) -> SubagentTaskData: + return SubagentTaskData( + task_id=task.task_id, + idempotency_key=task.idempotency_key, + umo=task.umo, + subagent_name=task.subagent_name, + handoff_tool_name=task.handoff_tool_name, + status=task.status, + attempt=task.attempt, + max_attempts=task.max_attempts, + next_run_at=task.next_run_at, + payload_json=task.payload_json, + error_class=task.error_class, + last_error=task.last_error, + result_text=task.result_text, + created_at=task.created_at, + updated_at=task.updated_at, + finished_at=task.finished_at, + ) + + +def _serialize_task(task: SubagentTask) -> dict[str, Any]: + return { + "task_id": task.task_id, + "idempotency_key": task.idempotency_key, + "umo": task.umo, + "subagent_name": task.subagent_name, + "handoff_tool_name": task.handoff_tool_name, + "status": task.status, + "attempt": task.attempt, + "max_attempts": task.max_attempts, + "next_run_at": task.next_run_at.isoformat() if task.next_run_at else None, + "error_class": task.error_class, + "last_error": task.last_error, + "result_text": task.result_text, + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat(), + "finished_at": task.finished_at.isoformat() if task.finished_at else None, + } diff --git a/astrbot/core/subagent/worker.py b/astrbot/core/subagent/worker.py new file mode 100644 index 000000000..650a67c06 --- /dev/null +++ b/astrbot/core/subagent/worker.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from astrbot import logger + +from .constants import DEFAULT_BATCH_SIZE, DEFAULT_POLL_INTERVAL +from .runtime import SubagentRuntime + +if TYPE_CHECKING: + pass + + +class SubagentWorker: + def __init__( + self, + runtime: SubagentRuntime, + *, + poll_interval: float = DEFAULT_POLL_INTERVAL, + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> None: + self._runtime = runtime + self._poll_interval = max(0.1, float(poll_interval)) + self._batch_size = max(1, int(batch_size)) + self._task: asyncio.Task[None] | None = None + self._stop_event = asyncio.Event() + + def start(self) -> asyncio.Task[None]: + if self._task and not self._task.done(): + return self._task + self._stop_event.clear() + self._task = asyncio.create_task(self._run_loop(), name="subagent_worker") + return self._task + + async def stop(self) -> None: + self._stop_event.set() + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + async def _run_loop(self) -> None: + logger.info("Subagent worker started.") + while not self._stop_event.is_set(): + try: + processed = await self._runtime.process_once( + batch_size=self._batch_size + ) + if processed <= 0: + await asyncio.sleep(self._poll_interval) + except asyncio.CancelledError: + raise + except Exception as exc: + exc_type = type(exc).__name__ + exc_module = type(exc).__module__ + full_exc_name = ( + exc_type if exc_module == "builtins" else f"{exc_module}.{exc_type}" + ) + logger.error( + "Subagent worker loop error (type=%s): %s", + full_exc_name, + exc, + exc_info=True, + ) + await asyncio.sleep(self._poll_interval) + logger.info("Subagent worker stopped.") diff --git a/astrbot/core/subagent_orchestrator.py b/astrbot/core/subagent_orchestrator.py index 205c554cb..be86ec125 100644 --- a/astrbot/core/subagent_orchestrator.py +++ b/astrbot/core/subagent_orchestrator.py @@ -3,17 +3,28 @@ from typing import Any from astrbot import logger -from astrbot.core.agent.agent import Agent from astrbot.core.agent.handoff import HandoffTool from astrbot.core.persona_mgr import PersonaManager from astrbot.core.provider.func_tool_manager import FunctionToolManager +from astrbot.core.subagent.codec import decode_subagent_config +from astrbot.core.subagent.error_classifier import build_error_classifier_from_config +from astrbot.core.subagent.handoff_executor import HandoffExecutor +from astrbot.core.subagent.models import ( + SubagentConfig, + SubagentMountPlan, + SubagentTaskData, +) +from astrbot.core.subagent.planner import SubagentPlanner +from astrbot.core.subagent.runtime import SubagentRuntime +from astrbot.core.subagent.worker import SubagentWorker class SubAgentOrchestrator: - """Loads subagent definitions from config and registers handoff tools. + """Subagent orchestration facade. - This is intentionally lightweight: it does not execute agents itself. - Execution happens via HandoffTool in FunctionToolExecutor. + This class holds canonical config + mount plan, and delegates heavy lifting to: + - planner: deterministic handoff plan generation + - runtime/worker: background task queue and retries """ def __init__( @@ -21,78 +32,188 @@ def __init__( ) -> None: self._tool_mgr = tool_mgr self._persona_mgr = persona_mgr + self._planner = SubagentPlanner(tool_mgr, persona_mgr) + self._config = SubagentConfig() + self._mount_plan = SubagentMountPlan() + self._context = None + db = getattr(persona_mgr, "db", None) + self._runtime = SubagentRuntime(db=db) + self._runtime.set_task_executor(self._execute_background_task) + self._worker = SubagentWorker(self._runtime) self.handoffs: list[HandoffTool] = [] - async def reload_from_config(self, cfg: dict[str, Any]) -> None: - from astrbot.core.astr_agent_context import AstrAgentContext - - agents = cfg.get("agents", []) - if not isinstance(agents, list): - logger.warning("subagent_orchestrator.agents must be a list") - return - - handoffs: list[HandoffTool] = [] - for item in agents: - if not isinstance(item, dict): - continue - if not item.get("enabled", True): - continue - - name = str(item.get("name", "")).strip() - if not name: - continue - - persona_id = item.get("persona_id") - persona_data = None - if persona_id: - try: - persona_data = await self._persona_mgr.get_persona(persona_id) - except StopIteration: - logger.warning( - "SubAgent persona %s not found, fallback to inline prompt.", - persona_id, - ) - - instructions = str(item.get("system_prompt", "")).strip() - public_description = str(item.get("public_description", "")).strip() - provider_id = item.get("provider_id") - if provider_id is not None: - provider_id = str(provider_id).strip() or None - tools = item.get("tools", []) - begin_dialogs = None - - if persona_data: - instructions = persona_data.system_prompt or instructions - begin_dialogs = persona_data.begin_dialogs - tools = persona_data.tools - if public_description == "" and persona_data.system_prompt: - public_description = persona_data.system_prompt[:120] - if tools is None: - tools = None - elif not isinstance(tools, list): - tools = [] - else: - tools = [str(t).strip() for t in tools if str(t).strip()] - - agent = Agent[AstrAgentContext]( - name=name, - instructions=instructions, - tools=tools, # type: ignore + def bind_context(self, context) -> None: + self._context = context + + @staticmethod + def _build_handoff_snapshot(handoff: HandoffTool) -> dict[str, Any]: + tools_raw = getattr(handoff.agent, "tools", None) + serialized_tools: list[str] | None + if tools_raw is None: + serialized_tools = None + elif isinstance(tools_raw, list): + serialized_tools = [] + for item in tools_raw: + if isinstance(item, str): + tool_name = item.strip() + else: + tool_name = str(getattr(item, "name", "")).strip() + if tool_name: + serialized_tools.append(tool_name) + else: + serialized_tools = [] + + dialogs_raw = getattr(handoff.agent, "begin_dialogs", None) + serialized_dialogs: list[dict[str, Any]] | None = None + if isinstance(dialogs_raw, list): + serialized_dialogs = [] + for item in dialogs_raw: + if isinstance(item, dict): + serialized_dialogs.append(item) + continue + model_dump = getattr(item, "model_dump", None) + if callable(model_dump): + dumped = model_dump(mode="python") + if isinstance(dumped, dict): + serialized_dialogs.append(dumped) + + max_steps_raw = getattr(handoff, "max_steps", None) + max_steps = ( + int(max_steps_raw) + if isinstance(max_steps_raw, int) and max_steps_raw > 0 + else None + ) + provider_id_raw = getattr(handoff, "provider_id", None) + provider_id = ( + str(provider_id_raw).strip() + if isinstance(provider_id_raw, str) and provider_id_raw.strip() + else None + ) + tool_description_raw = getattr(handoff, "description", None) + tool_description = ( + str(tool_description_raw).strip() + if isinstance(tool_description_raw, str) and tool_description_raw.strip() + else None + ) + display_name_raw = getattr(handoff, "agent_display_name", None) + display_name = ( + str(display_name_raw).strip() + if isinstance(display_name_raw, str) and display_name_raw.strip() + else handoff.agent.name + ) + + return { + "name": handoff.name, + "agent_name": handoff.agent.name, + "agent_display_name": display_name, + "instructions": str(getattr(handoff.agent, "instructions", "") or ""), + "tools": serialized_tools, + "begin_dialogs": serialized_dialogs, + "provider_id": provider_id, + "max_steps": max_steps, + "tool_description": tool_description, + } + + def start_worker(self): + return self._worker.start() + + async def stop_worker(self) -> None: + await self._worker.stop() + + async def reload_from_config(self, cfg: dict[str, Any]) -> list[str]: + try: + canonical, diagnostics = decode_subagent_config(cfg) + except Exception as exc: + logger.error("Invalid subagent config: %s", exc) + self._config = SubagentConfig() + self._mount_plan = SubagentMountPlan( + diagnostics=[f"ERROR: invalid config: {exc}"] ) - agent.begin_dialogs = begin_dialogs - # The tool description should be a short description for the main LLM, - # while the subagent system prompt can be longer/more specific. - handoff = HandoffTool( - agent=agent, - tool_description=public_description or None, + self.handoffs = [] + return self._mount_plan.diagnostics + + self._config = canonical + self._runtime.set_max_concurrent(canonical.max_concurrent_subagent_runs) + classifier, classifier_diagnostics = build_error_classifier_from_config( + canonical.error_classifier + ) + self._runtime.set_error_classifier(classifier) + diagnostics.extend(classifier_diagnostics) + mount_plan = await self._planner.build_mount_plan(canonical) + mount_plan.diagnostics = diagnostics + mount_plan.diagnostics + self._mount_plan = mount_plan + self.handoffs = mount_plan.handoffs + return mount_plan.diagnostics + + def get_mount_plan(self) -> SubagentMountPlan: + return self._mount_plan + + def get_config(self) -> SubagentConfig: + return self._config + + def get_max_nested_depth(self) -> int: + return int(self._config.max_nested_depth) + + async def submit_handoff( + self, + *, + handoff: HandoffTool, + run_context, + payload: dict[str, Any], + background: bool, + tool_call_id: str | None = None, + ) -> str | None: + if not background: + return None + + event = run_context.context.event + event_get_extra = getattr(event, "get_extra", None) + background_note = ( + event_get_extra("background_note") if callable(event_get_extra) else None + ) + umo = getattr(event, "unified_msg_origin", None) + if not isinstance(umo, str) or not umo: + raise ValueError( + "Cannot submit subagent handoff without unified_msg_origin" ) - - # Optional per-subagent chat provider override. - handoff.provider_id = provider_id - - handoffs.append(handoff) - - for handoff in handoffs: - logger.info(f"Registered subagent handoff tool: {handoff.name}") - - self.handoffs = handoffs + task_payload = { + "tool_args": payload, + "_handoff_snapshot": self._build_handoff_snapshot(handoff), + "_meta": { + "role": getattr(event, "role", None), + "background_note": background_note, + "tool_call_timeout": int( + getattr(run_context, "tool_call_timeout", 3600) + ), + }, + } + return await self._runtime.enqueue( + umo=umo, + subagent_name=getattr(handoff, "agent_display_name", handoff.agent.name), + handoff_tool_name=handoff.name, + payload=task_payload, + tool_call_id=tool_call_id, + ) + + async def list_tasks( + self, status: str | None = None, limit: int = 100 + ) -> list[dict]: + return await self._runtime.list_tasks(status=status, limit=limit) + + async def retry_task(self, task_id: str) -> bool: + return await self._runtime.retry_task(task_id) + + async def cancel_task(self, task_id: str) -> bool: + return await self._runtime.cancel_task(task_id) + + def find_handoff(self, handoff_tool_name: str) -> HandoffTool | None: + return self._mount_plan.handoff_by_tool_name.get(handoff_tool_name) + + async def _execute_background_task(self, task: SubagentTaskData) -> str: + if not self._context: + raise RuntimeError("Subagent orchestrator context is not bound.") + return await HandoffExecutor.execute_queued_task( + task=task, + plugin_context=self._context, + handoff=self.find_handoff(task.handoff_tool_name), + ) diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 40b899620..f83f50dc8 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -7,6 +7,7 @@ ) from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_subagent_tasks import migrate_subagent_tasks from astrbot.core.db.migration.migra_token_usage import migrate_token_usage from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session @@ -156,6 +157,13 @@ async def migra( logger.error(f"Migration for token_usage column failed: {e!s}") logger.error(traceback.format_exc()) + # migration for subagent_tasks table + try: + await migrate_subagent_tasks(db) + except Exception as e: + logger.error(f"Migration for subagent_tasks table failed: {e!s}") + logger.error(traceback.format_exc()) + # migra third party agent runner configs _c = False providers = astrbot_config["provider"] diff --git a/astrbot/dashboard/routes/subagent.py b/astrbot/dashboard/routes/subagent.py index e3d77f73a..5d1d4a5f7 100644 --- a/astrbot/dashboard/routes/subagent.py +++ b/astrbot/dashboard/routes/subagent.py @@ -1,3 +1,4 @@ +import re import traceback from quart import jsonify, request @@ -5,9 +6,17 @@ from astrbot.core import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.subagent.codec import decode_subagent_config, encode_subagent_config from .route import Response, Route, RouteContext +_TASK_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") + + +def _validate_task_id(task_id: str) -> bool: + """Validate task_id format to prevent injection attacks.""" + return bool(_TASK_ID_PATTERN.match(task_id)) + class SubAgentRoute(Route): def __init__( @@ -23,42 +32,42 @@ def __init__( ("/subagent/config", ("GET", self.get_config)), ("/subagent/config", ("POST", self.update_config)), ("/subagent/available-tools", ("GET", self.get_available_tools)), + ("/subagent/tasks", ("GET", self.get_tasks)), + ("/subagent/tasks//retry", ("POST", self.retry_task)), + ("/subagent/tasks//cancel", ("POST", self.cancel_task)), ] self.register_routes() + @staticmethod + def _split_compat_warnings( + diagnostics: list[str] | None, + ) -> tuple[list[str], list[str]]: + if not diagnostics: + return [], [] + compat_warnings: list[str] = [] + normal_diagnostics: list[str] = [] + for item in diagnostics: + if "legacy field" in item: + compat_warnings.append(item) + else: + normal_diagnostics.append(item) + return normal_diagnostics, compat_warnings + async def get_config(self): try: cfg = self.core_lifecycle.astrbot_config - data = cfg.get("subagent_orchestrator") - - # First-time access: return a sane default instead of erroring. - if not isinstance(data, dict): - data = { - "main_enable": False, - "remove_main_duplicate_tools": False, - "agents": [], - } - - # Backward compatibility: older config used `enable`. - if ( - isinstance(data, dict) - and "main_enable" not in data - and "enable" in data - ): - data["main_enable"] = bool(data.get("enable", False)) - - # Ensure required keys exist. - data.setdefault("main_enable", False) - data.setdefault("remove_main_duplicate_tools", False) - data.setdefault("agents", []) - - # Backward/forward compatibility: ensure each agent contains provider_id. - # None means follow global/default provider settings. - if isinstance(data.get("agents"), list): - for a in data["agents"]: - if isinstance(a, dict): - a.setdefault("provider_id", None) - a.setdefault("persona_id", None) + raw = cfg.get("subagent_orchestrator") + if not isinstance(raw, dict): + raw = {} + canonical, diagnostics = decode_subagent_config(raw) + normal_diagnostics, compat_warnings = self._split_compat_warnings( + diagnostics + ) + data = encode_subagent_config( + canonical, + diagnostics=normal_diagnostics, + compat_warnings=compat_warnings, + ) return jsonify(Response().ok(data=data).__dict__) except Exception as e: logger.error(traceback.format_exc()) @@ -69,9 +78,13 @@ async def update_config(self): data = await request.json if not isinstance(data, dict): return jsonify(Response().error("配置必须为 JSON 对象").__dict__) + # Canonical field is `instructions`; `system_prompt` is accepted for + # backward compatibility and serialized as a deprecated mirror field. + canonical, diagnostics = decode_subagent_config(data) + normalized = encode_subagent_config(canonical) cfg = self.core_lifecycle.astrbot_config - cfg["subagent_orchestrator"] = data + cfg["subagent_orchestrator"] = normalized # Persist to cmd_config.json # AstrBotConfigManager does not expose a `save()` method; persist via AstrBotConfig. @@ -79,10 +92,27 @@ async def update_config(self): # Reload dynamic handoff tools if orchestrator exists orch = getattr(self.core_lifecycle, "subagent_orchestrator", None) + reload_diagnostics: list[str] = [] if orch is not None: - await orch.reload_from_config(data) - - return jsonify(Response().ok(message="保存成功").__dict__) + res = await orch.reload_from_config(normalized) + if isinstance(res, list): + reload_diagnostics = res + merged_diagnostics = diagnostics + reload_diagnostics + normal_diagnostics, compat_warnings = self._split_compat_warnings( + merged_diagnostics + ) + + return jsonify( + Response() + .ok( + message="保存成功", + data={ + "diagnostics": normal_diagnostics, + "compat_warnings": compat_warnings, + }, + ) + .__dict__ + ) except Exception as e: logger.error(traceback.format_exc()) return jsonify(Response().error(f"保存 subagent 配置失败: {e!s}").__dict__) @@ -115,3 +145,54 @@ async def get_available_tools(self): except Exception as e: logger.error(traceback.format_exc()) return jsonify(Response().error(f"获取可用工具失败: {e!s}").__dict__) + + async def get_tasks(self): + try: + status = request.args.get("status", default=None, type=str) + limit = request.args.get("limit", default=100, type=int) + if limit < 1: + limit = 1 + if limit > 1000: + limit = 1000 + orch = getattr(self.core_lifecycle, "subagent_orchestrator", None) + if orch is None: + return jsonify(Response().ok(data=[]).__dict__) + tasks = await orch.list_tasks(status=status, limit=limit) + return jsonify(Response().ok(data=tasks).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify(Response().error(f"获取任务列表失败: {e!s}").__dict__) + + async def retry_task(self, task_id: str): + try: + if not _validate_task_id(task_id): + return jsonify(Response().error("无效的任务ID格式").__dict__) + orch = getattr(self.core_lifecycle, "subagent_orchestrator", None) + if orch is None: + return jsonify( + Response().error("subagent orchestrator 不存在").__dict__ + ) + ok = await orch.retry_task(task_id) + if not ok: + return jsonify(Response().error("任务不存在或无法重试").__dict__) + return jsonify(Response().ok(message="重试已提交").__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify(Response().error(f"重试任务失败: {e!s}").__dict__) + + async def cancel_task(self, task_id: str): + try: + if not _validate_task_id(task_id): + return jsonify(Response().error("无效的任务ID格式").__dict__) + orch = getattr(self.core_lifecycle, "subagent_orchestrator", None) + if orch is None: + return jsonify( + Response().error("subagent orchestrator 不存在").__dict__ + ) + ok = await orch.cancel_task(task_id) + if not ok: + return jsonify(Response().error("任务不存在或无法取消").__dict__) + return jsonify(Response().ok(message="任务已取消").__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify(Response().error(f"取消任务失败: {e!s}").__dict__) diff --git a/dashboard/src/views/SubAgentPage.vue b/dashboard/src/views/SubAgentPage.vue index 029cc5a82..663cb75d2 100644 --- a/dashboard/src/views/SubAgentPage.vue +++ b/dashboard/src/views/SubAgentPage.vue @@ -3,13 +3,13 @@
-

{{ tm('page.title') }}

+

{{ tf('page.title', 'SubAgent Orchestration') }}

- {{ tm('page.beta') }} + {{ tf('page.beta', 'Experimental') }}
- {{ tm('page.subtitle') }} + {{ tf('page.subtitle', 'The main LLM can use its own tools directly and delegate tasks to SubAgents via handoff.') }}
@@ -21,7 +21,7 @@ :loading="loading" @click="reload" > - {{ tm('actions.refresh') }} + {{ tf('actions.refresh', 'Refresh') }} - {{ tm('actions.save') }} + {{ tf('actions.save', 'Save') }}
@@ -40,7 +40,7 @@
-
{{ tm('section.globalSettings') || 'Global Settings' }}
+
{{ tf('section.globalSettings', 'Global Settings') }}
{{ mainStateDescription }}
@@ -53,7 +53,7 @@ @@ -71,7 +71,7 @@ @@ -93,7 +93,7 @@
-
{{ tm('section.title') }}
+
{{ tf('section.title', 'SubAgents') }}
{{ cfg.agents.length }} @@ -103,7 +103,7 @@ color="primary" @click="addAgent" > - {{ tm('actions.add') }} + {{ tf('actions.add', 'Add SubAgent') }}
@@ -129,11 +129,11 @@
- {{ agent.name || tm('cards.unnamed') }} + {{ agent.name || tf('cards.unnamed', 'Untitled SubAgent') }}
- {{ agent.public_description || tm('cards.noDescription') }} + {{ agent.public_description || tf('cards.noDescription', 'No description') }}
@@ -165,8 +165,11 @@
-
{{ tm('form.providerLabel') }}
+
{{ tf('form.providerLabel', 'Chat Provider (optional)') }}
-
{{ tm('form.personaLabel') }}
+
{{ tf('form.personaLabel', 'Choose Persona') }}
- {{ tm('cards.personaPreview') }} + {{ tf('cards.personaPreview', 'Persona Preview') }}
-
{{ tm('empty.title') }}
-
{{ tm('empty.subtitle') }}
+
{{ tf('empty.title', 'No Agents Configured') }}
+
{{ tf('empty.subtitle', 'Add a new sub-agent to get started') }}
- {{ tm('empty.action') }} + {{ tf('empty.action', 'Create First Agent') }}
{{ snackbar.message }}
@@ -255,6 +258,8 @@ import PersonaSelector from '@/components/shared/PersonaSelector.vue' import PersonaQuickPreview from '@/components/shared/PersonaQuickPreview.vue' import { useModuleI18n } from '@/i18n/composables' +type ToolsScope = 'all' | 'none' | 'list' | 'persona' + type SubAgentItem = { __key: string @@ -263,11 +268,21 @@ type SubAgentItem = { public_description: string enabled: boolean provider_id?: string + tools_scope?: ToolsScope + tools?: string[] + max_steps?: number + instructions?: string } type SubAgentConfig = { main_enable: boolean remove_main_duplicate_tools: boolean + error_classifier?: { + type?: string + fatal_exceptions?: string[] + transient_exceptions?: string[] + default_class?: string + } agents: SubAgentItem[] } @@ -286,27 +301,102 @@ function toast(message: string, color: 'success' | 'error' | 'warning' = 'succes snackbar.value = { show: true, message, color } } +function tf( + key: string, + fallback: string, + params?: Record +): string { + const translated = tm(key, params) + if ( + !translated || + translated.startsWith('[MISSING:') || + translated.startsWith('[INVALID:') + ) { + return fallback + } + return translated +} + const cfg = ref({ main_enable: false, remove_main_duplicate_tools: false, + error_classifier: { + type: 'default', + fatal_exceptions: ['ValueError', 'PermissionError', 'KeyError'], + transient_exceptions: [ + 'asyncio.TimeoutError', + 'TimeoutError', + 'ConnectionError', + 'ConnectionResetError' + ], + default_class: 'transient' + }, agents: [] }) const mainStateDescription = computed(() => - cfg.value.main_enable ? tm('description.enabled') : tm('description.disabled') + cfg.value.main_enable + ? tf( + 'description.enabled', + 'When on: the main LLM keeps its own tools and mounts transfer_to_* delegate tools.' + ) + : tf( + 'description.disabled', + 'When off: SubAgent is disabled and the main LLM calls tools directly.' + ) ) +function inferToolsScope(a: any): ToolsScope { + const explicitScope = (a?.tools_scope ?? '').toString().toLowerCase() + if (explicitScope === 'all' || explicitScope === 'none' || explicitScope === 'list' || explicitScope === 'persona') { + return explicitScope as ToolsScope + } + if (Array.isArray(a?.tools)) { + return a.tools.length === 0 ? 'none' : 'list' + } + if ((a?.persona_id ?? '').toString().trim()) { + return 'persona' + } + return 'all' +} + function normalizeConfig(raw: any): SubAgentConfig { const main_enable = !!raw?.main_enable const remove_main_duplicate_tools = !!raw?.remove_main_duplicate_tools + const error_classifier = raw?.error_classifier && typeof raw.error_classifier === 'object' + ? { + type: (raw.error_classifier.type ?? 'default').toString(), + fatal_exceptions: Array.isArray(raw.error_classifier.fatal_exceptions) + ? raw.error_classifier.fatal_exceptions.map((x: any) => (x ?? '').toString()).filter((x: string) => !!x) + : ['ValueError', 'PermissionError', 'KeyError'], + transient_exceptions: Array.isArray(raw.error_classifier.transient_exceptions) + ? raw.error_classifier.transient_exceptions.map((x: any) => (x ?? '').toString()).filter((x: string) => !!x) + : ['asyncio.TimeoutError', 'TimeoutError', 'ConnectionError', 'ConnectionResetError'], + default_class: (raw.error_classifier.default_class ?? 'transient').toString() + } + : { + type: 'default', + fatal_exceptions: ['ValueError', 'PermissionError', 'KeyError'], + transient_exceptions: ['asyncio.TimeoutError', 'TimeoutError', 'ConnectionError', 'ConnectionResetError'], + default_class: 'transient' + } const agentsRaw = Array.isArray(raw?.agents) ? raw.agents : [] const agents: SubAgentItem[] = agentsRaw.map((a: any, i: number) => { const name = (a?.name ?? '').toString() const persona_id = (a?.persona_id ?? '').toString() const public_description = (a?.public_description ?? '').toString() + const instructions = (a?.instructions ?? a?.system_prompt ?? '').toString() const enabled = a?.enabled !== false const provider_id = (a?.provider_id ?? undefined) as string | undefined + const tools_scope = inferToolsScope(a) + const tools = Array.isArray(a?.tools) + ? a.tools.map((t: any) => (t ?? '').toString().trim()).filter((t: string) => !!t) + : undefined + const max_steps = + a?.max_steps === null || a?.max_steps === undefined || a?.max_steps === '' + ? undefined + : Number(a.max_steps) return { __key: `${Date.now()}_${i}_${Math.random().toString(16).slice(2)}`, @@ -314,11 +404,15 @@ function normalizeConfig(raw: any): SubAgentConfig { persona_id, public_description, enabled, - provider_id + provider_id, + tools_scope, + tools, + max_steps: Number.isFinite(max_steps) ? max_steps : undefined, + instructions } }) - return { main_enable, remove_main_duplicate_tools, agents } + return { main_enable, remove_main_duplicate_tools, error_classifier, agents } } async function loadConfig() { @@ -328,10 +422,10 @@ async function loadConfig() { if (res.data.status === 'ok') { cfg.value = normalizeConfig(res.data.data) } else { - toast(res.data.message || tm('messages.loadConfigFailed'), 'error') + toast(res.data.message || tf('messages.loadConfigFailed', 'Failed to load config'), 'error') } } catch (e: any) { - toast(e?.response?.data?.message || tm('messages.loadConfigFailed'), 'error') + toast(e?.response?.data?.message || tf('messages.loadConfigFailed', 'Failed to load config'), 'error') } finally { loading.value = false } @@ -344,7 +438,11 @@ function addAgent() { persona_id: '', public_description: '', enabled: true, - provider_id: undefined + provider_id: undefined, + tools_scope: 'persona', + tools: [], + max_steps: undefined, + instructions: '' }) } @@ -353,27 +451,25 @@ function removeAgent(idx: number) { } function validateBeforeSave(): boolean { - const nameRe = /^[a-z][a-z0-9_]{0,63}$/ const seen = new Set() for (const a of cfg.value.agents) { const name = (a.name || '').trim() if (!name) { - toast(tm('messages.nameMissing'), 'warning') + toast(tf('messages.nameMissing', 'A SubAgent is missing a name'), 'warning') return false } - if (!nameRe.test(name)) { - toast(tm('messages.nameInvalid'), 'warning') + if (name.length > 256) { + toast(tf('messages.nameInvalid', 'Invalid SubAgent name'), 'warning') return false } if (seen.has(name)) { - toast(tm('messages.nameDuplicate', { name }), 'warning') + toast( + tf('messages.nameDuplicate', `Duplicate SubAgent name: ${name}`, { name }), + 'warning' + ) return false } seen.add(name) - if (!a.persona_id) { - toast(tm('messages.personaMissing', { name }), 'warning') - return false - } } return true } @@ -385,23 +481,39 @@ async function save() { const payload = { main_enable: cfg.value.main_enable, remove_main_duplicate_tools: cfg.value.remove_main_duplicate_tools, + error_classifier: cfg.value.error_classifier ?? { + type: 'default', + fatal_exceptions: ['ValueError', 'PermissionError', 'KeyError'], + transient_exceptions: ['asyncio.TimeoutError', 'TimeoutError', 'ConnectionError', 'ConnectionResetError'], + default_class: 'transient' + }, agents: cfg.value.agents.map((a) => ({ name: a.name, persona_id: a.persona_id, public_description: a.public_description, enabled: a.enabled, - provider_id: a.provider_id + provider_id: a.provider_id, + tools_scope: a.tools_scope || inferToolsScope(a), + tools: + (a.tools_scope || inferToolsScope(a)) === 'list' + ? Array.isArray(a.tools) + ? a.tools + : [] + : null, + max_steps: a.max_steps ?? null, + instructions: a.instructions ?? '', + system_prompt: a.instructions ?? '' })) } const res = await axios.post('/api/subagent/config', payload) if (res.data.status === 'ok') { - toast(res.data.message || tm('messages.saveSuccess'), 'success') + toast(res.data.message || tf('messages.saveSuccess', 'Saved successfully'), 'success') } else { - toast(res.data.message || tm('messages.saveFailed'), 'error') + toast(res.data.message || tf('messages.saveFailed', 'Failed to save'), 'error') } } catch (e: any) { - toast(e?.response?.data?.message || tm('messages.saveFailed'), 'error') + toast(e?.response?.data?.message || tf('messages.saveFailed', 'Failed to save'), 'error') } finally { saving.value = false } diff --git a/tests/unit/_fake_subagent_db.py b/tests/unit/_fake_subagent_db.py new file mode 100644 index 000000000..8d603adef --- /dev/null +++ b/tests/unit/_fake_subagent_db.py @@ -0,0 +1,137 @@ +"""Shared in-memory fake database for subagent tests. + +Provides ``FakeSubagentDb`` — a minimal, dictionary-backed double that +implements the DB methods used by ``SubagentRuntime``. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +from astrbot.core.db.po import SubagentTask + + +class FakeSubagentDb: + """In-memory fake of the subset of ``BaseDatabase`` used by the subagent + runtime, planner and hooks tests.""" + + def __init__(self) -> None: + self.tasks: dict[str, SubagentTask] = {} + + async def create_subagent_task(self, **kwargs) -> SubagentTask: + now = datetime.now(timezone.utc) + task = SubagentTask( + task_id=kwargs["task_id"], + idempotency_key=kwargs["idempotency_key"], + umo=kwargs["umo"], + subagent_name=kwargs["subagent_name"], + handoff_tool_name=kwargs["handoff_tool_name"], + payload_json=kwargs["payload_json"], + max_attempts=kwargs.get("max_attempts", 3), + status="pending", + attempt=0, + next_run_at=now, + created_at=now, + updated_at=now, + ) + self.tasks[task.task_id] = task + return task + + async def get_subagent_task_by_idempotency(self, idempotency_key: str): + for task in self.tasks.values(): + if task.idempotency_key == idempotency_key: + return task + return None + + async def claim_due_subagent_tasks(self, *, now: datetime, limit: int = 20): + rows = [ + t + for t in self.tasks.values() + if t.status in {"pending", "retrying"} + and (t.next_run_at is None or t.next_run_at <= now) + ] + rows.sort(key=lambda item: item.created_at) + return rows[:limit] + + async def mark_subagent_task_running(self, task_id: str): + task = self.tasks.get(task_id) + if task is None or task.status not in {"pending", "retrying"}: + return None + task.status = "running" + task.attempt += 1 + task.updated_at = datetime.now(timezone.utc) + return task + + async def mark_subagent_task_retrying( + self, *, task_id: str, next_run_at: datetime, error_class: str, last_error: str + ): + task = self.tasks.get(task_id) + if task is None or task.status not in {"running", "retrying"}: + return False + task.status = "retrying" + task.next_run_at = next_run_at + task.error_class = error_class + task.last_error = last_error + task.updated_at = datetime.now(timezone.utc) + return True + + async def reschedule_subagent_task( + self, *, task_id: str, next_run_at: datetime, error_class: str, last_error: str + ): + task = self.tasks.get(task_id) + if task is None or task.status not in { + "failed", + "canceled", + "succeeded", + "pending", + "retrying", + }: + return False + task.status = "retrying" + task.attempt = 0 + task.next_run_at = next_run_at + task.error_class = error_class + task.last_error = last_error + task.result_text = None + task.finished_at = None + task.updated_at = datetime.now(timezone.utc) + return True + + async def mark_subagent_task_succeeded(self, task_id: str, *, result_text: str): + task = self.tasks.get(task_id) + if task is None: + return False + task.status = "succeeded" + task.result_text = result_text + task.finished_at = datetime.now(timezone.utc) + task.updated_at = task.finished_at + return True + + async def mark_subagent_task_failed( + self, *, task_id: str, error_class: str, last_error: str + ): + task = self.tasks.get(task_id) + if task is None: + return False + task.status = "failed" + task.error_class = error_class + task.last_error = last_error + task.finished_at = datetime.now(timezone.utc) + task.updated_at = task.finished_at + return True + + async def cancel_subagent_task(self, task_id: str): + task = self.tasks.get(task_id) + if task is None: + return False + task.status = "canceled" + task.finished_at = datetime.now(timezone.utc) + task.updated_at = task.finished_at + return True + + async def list_subagent_tasks(self, *, status: str | None = None, limit: int = 100): + rows = list(self.tasks.values()) + if status: + rows = [r for r in rows if r.status == status] + rows.sort(key=lambda item: item.created_at, reverse=True) + return rows[:limit] diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 9d405f1ab..4d1af1da6 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -1,17 +1,25 @@ +from datetime import datetime, timezone +import json from types import SimpleNamespace +from unittest.mock import AsyncMock, patch import mcp import pytest from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.message.components import Image +from astrbot.core.subagent.background_notifier import ( + wake_main_agent_for_background_result, +) +from astrbot.core.subagent.handoff_executor import HandoffExecutor +from astrbot.core.subagent.models import SubagentTaskData class _DummyEvent: def __init__(self, message_components: list[object] | None = None) -> None: self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" self.message_obj = SimpleNamespace(message=message_components or []) + self.role = "assistant" def get_extra(self, _key: str): return None @@ -46,7 +54,7 @@ async def _fake_convert_to_file_path(self): 123, ) - image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + image_urls = await HandoffExecutor.collect_handoff_image_urls( run_context, image_urls_input, ) @@ -68,7 +76,7 @@ async def _fake_convert_to_file_path(self): monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) run_context = _build_run_context([Image(file="file:///tmp/original.png")]) - image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + image_urls = await HandoffExecutor.collect_handoff_image_urls( run_context, ["https://example.com/a.png"], ) @@ -107,9 +115,7 @@ async def test_collect_handoff_image_urls_filters_supported_schemes_and_extensio expected_supported_refs: set[str], ): run_context = _build_run_context([]) - result = await FunctionToolExecutor._collect_handoff_image_urls( - run_context, image_refs - ) + result = await HandoffExecutor.collect_handoff_image_urls(run_context, image_refs) assert set(result) == expected_supported_refs @@ -123,7 +129,7 @@ async def _fake_convert_to_file_path(self): monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) run_context = _build_run_context([Image(file="file:///tmp/original.png")]) - image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + image_urls = await HandoffExecutor.collect_handoff_image_urls( run_context, None, ) @@ -132,54 +138,64 @@ async def _fake_convert_to_file_path(self): @pytest.mark.asyncio -async def test_do_handoff_background_reports_prepared_image_urls( +async def test_execute_handoff_skips_renormalize_when_image_urls_prepared( monkeypatch: pytest.MonkeyPatch, ): captured: dict = {} - async def _fake_execute_handoff( - cls, tool, run_context, image_urls_prepared=False, **tool_args - ): - assert image_urls_prepared is True - yield mcp.types.CallToolResult( - content=[mcp.types.TextContent(type="text", text="ok")] - ) + def _boom(_items): + raise RuntimeError("normalize should not be called") + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" - async def _fake_wake(cls, run_context, **kwargs): + async def _fake_tool_loop_agent(**kwargs): captured.update(kwargs) + return SimpleNamespace(completion_text="ok") - monkeypatch.setattr( - FunctionToolExecutor, - "_execute_handoff", - classmethod(_fake_execute_handoff), + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = SimpleNamespace( + name="transfer_to_subagent", + provider_id=None, + agent=SimpleNamespace( + name="subagent", + tools=[], + instructions="subagent-instructions", + begin_dialogs=[], + run_hooks=None, + ), ) + monkeypatch.setattr( - FunctionToolExecutor, - "_wake_main_agent_for_background_result", - classmethod(_fake_wake), + "astrbot.core.subagent.handoff_executor.normalize_and_dedupe_strings", _boom ) - run_context = _build_run_context() - await FunctionToolExecutor._do_handoff_background( - tool=_DummyTool(), - run_context=run_context, - task_id="task-id", + results = [] + async for result in HandoffExecutor.execute_foreground( + tool, + run_context, + image_urls_prepared=True, input="hello", - image_urls="https://example.com/raw.png", - ) + image_urls=["https://example.com/raw.png"], + ): + results.append(result) - assert captured["tool_args"]["image_urls"] == ["https://example.com/raw.png"] + assert len(results) == 1 + assert captured["image_urls"] == ["https://example.com/raw.png"] @pytest.mark.asyncio -async def test_execute_handoff_skips_renormalize_when_image_urls_prepared( +async def test_execute_handoff_uses_subagent_max_steps_override( monkeypatch: pytest.MonkeyPatch, ): captured: dict = {} - def _boom(_items): - raise RuntimeError("normalize should not be called") - async def _fake_get_current_chat_provider_id(_umo): return "provider-id" @@ -190,13 +206,14 @@ async def _fake_tool_loop_agent(**kwargs): context = SimpleNamespace( get_current_chat_provider_id=_fake_get_current_chat_provider_id, tool_loop_agent=_fake_tool_loop_agent, - get_config=lambda **_kwargs: {"provider_settings": {}}, + get_config=lambda **_kwargs: {"provider_settings": {"max_agent_step": 30}}, ) event = _DummyEvent([]) run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) tool = SimpleNamespace( name="transfer_to_subagent", provider_id=None, + max_steps=5, agent=SimpleNamespace( name="subagent", tools=[], @@ -206,22 +223,294 @@ async def _fake_tool_loop_agent(**kwargs): ), ) - monkeypatch.setattr( - "astrbot.core.astr_agent_tool_exec.normalize_and_dedupe_strings", _boom - ) - results = [] - async for result in FunctionToolExecutor._execute_handoff( + async for result in HandoffExecutor.execute_foreground( tool, run_context, - image_urls_prepared=True, input="hello", - image_urls=["https://example.com/raw.png"], + image_urls=[], ): results.append(result) assert len(results) == 1 - assert captured["image_urls"] == ["https://example.com/raw.png"] + assert captured["max_steps"] == 5 + + +@pytest.mark.asyncio +async def test_execute_queued_task_uses_prepared_image_urls_and_notifies( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + class _FakeAstrAgentContext: + def __init__(self, *, context, event): + self.context = context + self.event = event + + class _FakeAgentContextWrapper: + def __init__(self, *, context, tool_call_timeout): + self.context = context + self.tool_call_timeout = tool_call_timeout + + async def _fake_execute_foreground(*_args, **kwargs): + assert kwargs["image_urls_prepared"] is True + assert kwargs["image_urls"] == ["https://example.com/a.png"] + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="done from queued task")] + ) + + async def _fake_notify(**kwargs): + captured.update(kwargs) + + monkeypatch.setattr(HandoffExecutor, "execute_foreground", _fake_execute_foreground) + monkeypatch.setattr( + "astrbot.core.subagent.handoff_executor.wake_main_agent_for_background_result", + _fake_notify, + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_context.AstrAgentContext", + _FakeAstrAgentContext, + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_context.AgentContextWrapper", + _FakeAgentContextWrapper, + ) + + task = SubagentTaskData( + task_id="task_queued_1", + idempotency_key="idem", + umo="webchat:FriendMessage:webchat!user!session", + subagent_name="subagent", + handoff_tool_name="transfer_to_subagent", + status="running", + attempt=1, + max_attempts=3, + next_run_at=None, + payload_json=json.dumps( + { + "_meta": {"background_note": "finished", "tool_call_timeout": 90}, + "tool_args": { + "image_urls": ["https://example.com/a.png"], + "input": "hello", + }, + } + ), + error_class=None, + last_error=None, + result_text=None, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + finished_at=None, + ) + plugin_context = SimpleNamespace() + handoff = SimpleNamespace( + name="transfer_to_subagent", + agent=SimpleNamespace( + name="subagent", + tools=[], + instructions="subagent-instructions", + begin_dialogs=[], + run_hooks=None, + ), + provider_id=None, + max_steps=None, + ) + + result = await HandoffExecutor.execute_queued_task( + task=task, + plugin_context=plugin_context, + handoff=handoff, + ) + + assert "done from queued task" in result + assert captured["task_id"] == "task_queued_1" + assert captured["note"] == "finished" + assert captured["tool_args"]["image_urls"] == ["https://example.com/a.png"] + + +@pytest.mark.asyncio +async def test_execute_queued_task_restores_handoff_from_snapshot_when_missing( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + class _FakeAstrAgentContext: + def __init__(self, *, context, event): + self.context = context + self.event = event + + class _FakeAgentContextWrapper: + def __init__(self, *, context, tool_call_timeout): + self.context = context + self.tool_call_timeout = tool_call_timeout + + async def _fake_execute_foreground(tool, *_args, **kwargs): + captured["tool"] = tool + captured["kwargs"] = kwargs + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="done from snapshot")] + ) + + async def _fake_notify(**kwargs): + captured["notify"] = kwargs + + monkeypatch.setattr(HandoffExecutor, "execute_foreground", _fake_execute_foreground) + monkeypatch.setattr( + "astrbot.core.subagent.handoff_executor.wake_main_agent_for_background_result", + _fake_notify, + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_context.AstrAgentContext", + _FakeAstrAgentContext, + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_context.AgentContextWrapper", + _FakeAgentContextWrapper, + ) + + task = SubagentTaskData( + task_id="task_queued_snapshot", + idempotency_key="idem_snapshot", + umo="webchat:FriendMessage:webchat!user!session", + subagent_name="subagent", + handoff_tool_name="transfer_to_subagent", + status="running", + attempt=1, + max_attempts=3, + next_run_at=None, + payload_json=json.dumps( + { + "_handoff_snapshot": { + "name": "transfer_to_subagent", + "agent_name": "subagent", + "agent_display_name": "Sub Agent", + "instructions": "snapshot prompt", + "tools": ["tool_a"], + "begin_dialogs": [{"role": "assistant", "content": "hello"}], + "provider_id": "provider-x", + "max_steps": 9, + "tool_description": "snapshot desc", + }, + "_meta": {"background_note": "done", "tool_call_timeout": 60}, + "tool_args": {"image_urls": ["https://example.com/a.png"], "input": "hi"}, + } + ), + error_class=None, + last_error=None, + result_text=None, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + finished_at=None, + ) + + result = await HandoffExecutor.execute_queued_task( + task=task, + plugin_context=SimpleNamespace(), + handoff=None, + ) + + restored_tool = captured["tool"] + assert restored_tool.name == "transfer_to_subagent" + assert restored_tool.provider_id == "provider-x" + assert restored_tool.max_steps == 9 + assert restored_tool.agent_display_name == "Sub Agent" + assert restored_tool.agent.instructions == "snapshot prompt" + assert restored_tool.agent.tools == ["tool_a"] + assert restored_tool.agent.begin_dialogs == [ + {"role": "assistant", "content": "hello"} + ] + assert "done from snapshot" in result + + +@pytest.mark.asyncio +async def test_build_handoff_toolset_defaults_runtime_to_none(): + event = _DummyEvent([]) + context = SimpleNamespace( + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + toolset = HandoffExecutor.build_handoff_toolset(run_context, []) + assert toolset is None + + +@pytest.mark.asyncio +async def test_wake_main_agent_for_background_result_uses_provider_settings( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + class _Runner: + async def step_until_done(self, _max_steps): + if False: + yield None + + def get_final_llm_resp(self): + return SimpleNamespace(completion_text="done") + + async def _fake_get_session_conv(*, event, plugin_context): + _ = event + _ = plugin_context + return SimpleNamespace(history="[]") + + async def _fake_build_main_agent(*, event, plugin_context, config, req): + captured["event"] = event + captured["plugin_context"] = plugin_context + captured["config"] = config + captured["req"] = req + return SimpleNamespace(agent_runner=_Runner()) + + async def _fake_persist_agent_history(*args, **kwargs): + _ = args + _ = kwargs + return None + + monkeypatch.setattr( + "astrbot.core.astr_main_agent._get_session_conv", + _fake_get_session_conv, + ) + monkeypatch.setattr( + "astrbot.core.astr_main_agent.build_main_agent", + _fake_build_main_agent, + ) + monkeypatch.setattr( + "astrbot.core.subagent.background_notifier.persist_agent_history", + _fake_persist_agent_history, + ) + + provider_settings = { + "tool_call_timeout": 123, + "streaming_response": False, + "computer_use_runtime": "none", + "proactive_capability": {"add_cron_tools": False}, + "llm_safety_mode": True, + } + context = SimpleNamespace( + get_config=lambda **_kwargs: { + "provider_settings": provider_settings, + "subagent_orchestrator": {"main_enable": True}, + "kb_agentic_mode": False, + "timezone": "UTC", + }, + conversation_manager=SimpleNamespace(), + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + + await wake_main_agent_for_background_result( + run_context=run_context, + task_id="task_1", + tool_name="transfer_to_subagent", + result_text="ok", + tool_args={"input": "hello"}, + note="background finished", + summary_name="subagent-summary", + ) + + cfg = captured["config"] + assert cfg.tool_call_timeout == 123 + assert cfg.computer_use_runtime == "none" + assert cfg.add_cron_tools is False + assert cfg.provider_settings["computer_use_runtime"] == "none" @pytest.mark.asyncio @@ -233,14 +522,14 @@ async def _fake_convert_to_file_path(self): monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) monkeypatch.setattr( - "astrbot.core.astr_agent_tool_exec.get_astrbot_temp_path", lambda: "/tmp" + "astrbot.core.subagent.handoff_executor.get_astrbot_temp_path", lambda: "/tmp" ) monkeypatch.setattr( "astrbot.core.utils.image_ref_utils.os.path.exists", lambda _: True ) run_context = _build_run_context([Image(file="file:///tmp/original.png")]) - image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + image_urls = await HandoffExecutor.collect_handoff_image_urls( run_context, [], ) @@ -257,14 +546,14 @@ async def _fake_convert_to_file_path(self): monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) monkeypatch.setattr( - "astrbot.core.astr_agent_tool_exec.get_astrbot_temp_path", lambda: "/tmp" + "astrbot.core.subagent.handoff_executor.get_astrbot_temp_path", lambda: "/tmp" ) monkeypatch.setattr( "astrbot.core.utils.image_ref_utils.os.path.exists", lambda _: False ) run_context = _build_run_context([Image(file="file:///tmp/original.png")]) - image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + image_urls = await HandoffExecutor.collect_handoff_image_urls( run_context, [], ) @@ -281,16 +570,65 @@ async def _fake_convert_to_file_path(self): monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) monkeypatch.setattr( - "astrbot.core.astr_agent_tool_exec.get_astrbot_temp_path", lambda: "/tmp" + "astrbot.core.subagent.handoff_executor.get_astrbot_temp_path", lambda: "/tmp" ) monkeypatch.setattr( "astrbot.core.utils.image_ref_utils.os.path.exists", lambda _: True ) run_context = _build_run_context([Image(file="file:///tmp/original.png")]) - image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + image_urls = await HandoffExecutor.collect_handoff_image_urls( run_context, [], ) assert image_urls == [] + + +@pytest.mark.asyncio +async def test_execute_handoff_background_strict_failover_without_orchestrator(): + run_context = _build_run_context() + tool = _DummyTool() + + with patch( + "astrbot.core.astr_agent_tool_exec.asyncio.create_task" + ) as create_task_mock: + results = [] + async for result in HandoffExecutor.submit_background( + tool, + run_context, + tool_call_id="call_1", + input="hello", + image_urls=["https://example.com/raw.png"], + ): + results.append(result) + + assert len(results) == 1 + assert "error:" in results[0].content[0].text + create_task_mock.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_handoff_background_strict_failover_submit_error(): + orchestrator = SimpleNamespace( + submit_handoff=AsyncMock(side_effect=RuntimeError("boom")) + ) + event = _DummyEvent([]) + context = SimpleNamespace(subagent_orchestrator=orchestrator) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + + with patch( + "astrbot.core.astr_agent_tool_exec.asyncio.create_task" + ) as create_task_mock: + results = [] + async for result in HandoffExecutor.submit_background( + _DummyTool(), + run_context, + tool_call_id="call_1", + input="hello", + ): + results.append(result) + + assert len(results) == 1 + assert "error:" in results[0].content[0].text + create_task_mock.assert_not_called() diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index e0682ae06..3c46914c3 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1511,3 +1511,50 @@ def test_apply_sandbox_tools_with_none_system_prompt(self): assert isinstance(req.system_prompt, str) assert "sandboxed environment" in req.system_prompt + + +class TestSubagentMainEnableSource: + @pytest.mark.asyncio + async def test_subagent_mount_uses_orchestrator_canonical_main_enable( + self, mock_event, mock_context + ): + module = ama + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=(None, None, None, False) + ) + mock_context.get_config.return_value = { + "subagent_orchestrator": {"main_enable": False} + } + mock_context.get_llm_tool_manager.return_value.get_full_tool_set.return_value = ( + ToolSet() + ) + mock_context.get_llm_tool_manager.return_value.get_func.return_value = None + + handoff_tool = FunctionTool( + name="transfer_to_writer", + description="handoff", + parameters={"type": "object", "properties": {}}, + handler=None, + ) + plan = MagicMock() + plan.handoffs = [handoff_tool] + plan.main_tool_exclude_set = set() + plan.router_prompt = None + plan.diagnostics = [] + + orchestrator = MagicMock() + orchestrator.get_mount_plan.return_value = plan + orchestrator.get_config.return_value = MagicMock(main_enable=True) + mock_context.subagent_orchestrator = orchestrator + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + with patch("astrbot.core.astr_main_agent.SkillManager") as mock_skill_mgr_cls: + mock_skill_mgr = MagicMock() + mock_skill_mgr.list_skills.return_value = [] + mock_skill_mgr_cls.return_value = mock_skill_mgr + await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) + + assert req.func_tool is not None + assert req.func_tool.get_tool("transfer_to_writer") is not None diff --git a/tests/unit/test_core_lifecycle.py b/tests/unit/test_core_lifecycle.py index fc8300bf9..e2d48bfec 100644 --- a/tests/unit/test_core_lifecycle.py +++ b/tests/unit/test_core_lifecycle.py @@ -744,6 +744,32 @@ async def test_stop_handles_plugin_termination_error( # Verify warning was logged about plugin termination failure mock_logger.warning.assert_called() + @pytest.mark.asyncio + async def test_stop_worker_is_idempotent(self, mock_log_broker, mock_db): + lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db) + lifecycle.subagent_orchestrator = MagicMock() + lifecycle.subagent_orchestrator.stop_worker = AsyncMock() + lifecycle._subagent_worker_started = True + + lifecycle.temp_dir_cleaner = None + lifecycle.cron_manager = None + lifecycle.plugin_manager = MagicMock() + lifecycle.plugin_manager.context = MagicMock() + lifecycle.plugin_manager.context.get_all_stars = MagicMock(return_value=[]) + lifecycle.provider_manager = MagicMock() + lifecycle.provider_manager.terminate = AsyncMock() + lifecycle.platform_manager = MagicMock() + lifecycle.platform_manager.terminate = AsyncMock() + lifecycle.kb_manager = MagicMock() + lifecycle.kb_manager.terminate = AsyncMock() + lifecycle.dashboard_shutdown_event = asyncio.Event() + lifecycle.curr_tasks = [] + + await lifecycle.stop() + await lifecycle.stop() + + lifecycle.subagent_orchestrator.stop_worker.assert_awaited_once() + class TestAstrBotCoreLifecycleRestart: """Tests for AstrBotCoreLifecycle.restart method.""" diff --git a/tests/unit/test_subagent_codec.py b/tests/unit/test_subagent_codec.py new file mode 100644 index 000000000..2c5d29cfe --- /dev/null +++ b/tests/unit/test_subagent_codec.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import pytest + +from astrbot.core.subagent.codec import decode_subagent_config, encode_subagent_config +from astrbot.core.subagent.models import ToolsScope + + +def test_decode_subagent_config_accepts_legacy_fields_and_infers_scope(): + config, diagnostics = decode_subagent_config( + { + "enable": True, + "agents": [ + { + "name": "writer", + "enable": True, + "persona_id": "p1", + "system_prompt": "legacy prompt", + "x-note": "abc", + } + ], + "x-ext": {"k": "v"}, + } + ) + assert config.main_enable is True + assert config.agents[0].tools_scope == ToolsScope.PERSONA + assert config.agents[0].instructions == "legacy prompt" + assert config.extensions["x-ext"] == {"k": "v"} + assert any("legacy field `enable`" in d for d in diagnostics) + + +def test_decode_subagent_config_rejects_unknown_non_extension_fields(): + with pytest.raises(ValueError): + decode_subagent_config( + { + "main_enable": True, + "unknown_field": 1, + } + ) + + +def test_encode_subagent_config_to_transitional_dual_fields(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + { + "name": "planner", + "enabled": True, + "tools_scope": "list", + "tools": ["tool_a"], + "instructions": "hello", + } + ], + } + ) + payload = encode_subagent_config(config) + assert payload["main_enable"] is True + assert payload["agents"][0]["name"] == "planner" + assert payload["agents"][0]["tools"] == ["tool_a"] + assert payload["agents"][0]["instructions"] == "hello" + assert payload["agents"][0]["system_prompt"] == "hello" + + +def test_decode_subagent_config_agent_extension_passthrough(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + { + "name": "writer", + "enabled": True, + "tools_scope": "none", + "x-tag": "tag-1", + } + ], + } + ) + assert config.agents[0].extensions["x-tag"] == "tag-1" + + +def test_decode_subagent_config_explicit_tools_scope_overrides_tools_inference(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + { + "name": "writer", + "enabled": True, + "tools_scope": "none", + "tools": ["tool_a", "tool_b"], + } + ], + } + ) + assert config.agents[0].tools_scope == ToolsScope.NONE + assert config.agents[0].tools is None + + +def test_encode_decode_roundtrip_with_instructions_has_no_legacy_warning(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + { + "name": "writer", + "enabled": True, + "instructions": "do work", + } + ], + } + ) + encoded = encode_subagent_config(config) + _, diagnostics = decode_subagent_config(encoded) + assert not any("legacy field `agents[0].system_prompt`" in d for d in diagnostics) + + +def test_decode_subagent_config_parses_boolean_strings(): + config, _ = decode_subagent_config( + { + "main_enable": "false", + "remove_main_duplicate_tools": "1", + "agents": [ + { + "name": "writer", + "enabled": "0", + "instructions": "do work", + } + ], + } + ) + assert config.main_enable is False + assert config.remove_main_duplicate_tools is True + assert config.agents[0].enabled is False + + +def test_decode_subagent_config_rejects_invalid_boolean_value(): + with pytest.raises(ValueError): + decode_subagent_config( + { + "main_enable": "not-a-bool", + "agents": [{"name": "writer"}], + } + ) diff --git a/tests/unit/test_subagent_error_classifier.py b/tests/unit/test_subagent_error_classifier.py new file mode 100644 index 000000000..04a81136a --- /dev/null +++ b/tests/unit/test_subagent_error_classifier.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from astrbot.core.subagent.error_classifier import ( + DefaultErrorClassifier, + build_error_classifier_from_config, +) +from astrbot.core.subagent.models import SubagentErrorClassifierConfig + + +def test_build_error_classifier_from_config_maps_allowlisted_types(): + classifier, diagnostics = build_error_classifier_from_config( + SubagentErrorClassifierConfig( + type="default", + fatal_exceptions=["ValueError"], + transient_exceptions=["TimeoutError"], + default_class="retryable", + ) + ) + assert diagnostics == [] + assert classifier.classify(ValueError("x")) == "fatal" + assert classifier.classify(TimeoutError("x")) == "transient" + assert classifier.classify(RuntimeError("x")) == "retryable" + + +def test_build_error_classifier_ignores_unknown_exception_name(): + _, diagnostics = build_error_classifier_from_config( + SubagentErrorClassifierConfig( + fatal_exceptions=["ValueError", "NotExistError"], + transient_exceptions=["TimeoutError"], + default_class="transient", + ) + ) + assert any("NotExistError" in item for item in diagnostics) + + +def test_default_error_classifier_defaults_to_transient_for_unknown(): + classifier = DefaultErrorClassifier() + assert classifier.classify(RuntimeError("unknown")) == "transient" diff --git a/tests/unit/test_subagent_hooks.py b/tests/unit/test_subagent_hooks.py new file mode 100644 index 000000000..fb489d9a2 --- /dev/null +++ b/tests/unit/test_subagent_hooks.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from _fake_subagent_db import FakeSubagentDb as _FakeDb + +from astrbot.core.subagent.runtime import SubagentRuntime + + +class _RecorderHooks: + def __init__(self) -> None: + self.events: list[str] = [] + + async def on_task_enqueued(self, task) -> None: + self.events.append(f"enqueued:{task.task_id}") + + async def on_task_started(self, task) -> None: + self.events.append(f"started:{task.task_id}") + + async def on_task_retrying( + self, task, *, delay_seconds: float, error_class: str, error: Exception + ) -> None: + _ = delay_seconds + _ = error + self.events.append(f"retrying:{task.task_id}:{error_class}") + + async def on_task_succeeded(self, task, result: str) -> None: + _ = result + self.events.append(f"succeeded:{task.task_id}") + + async def on_task_failed(self, task, *, error_class: str, error: Exception) -> None: + _ = error + self.events.append(f"failed:{task.task_id}:{error_class}") + + async def on_task_canceled(self, task_id: str) -> None: + self.events.append(f"canceled:{task_id}") + + async def on_task_result_ignored(self, task, *, reason: str) -> None: + _ = reason + self.events.append(f"ignored:{task.task_id}") + + +@pytest.mark.asyncio +async def test_runtime_hooks_called_in_success_order(): + db = _FakeDb() + hooks = _RecorderHooks() + runtime = SubagentRuntime(db=db, hooks=hooks) + + async def _executor(_task): + return "done" + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="umo:1", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "x"}}, + tool_call_id="call_1", + ) + await runtime.process_once(batch_size=8) + + assert hooks.events == [ + f"enqueued:{task_id}", + f"started:{task_id}", + f"succeeded:{task_id}", + ] + + +@pytest.mark.asyncio +async def test_runtime_hook_failure_does_not_block_task_processing(): + db = _FakeDb() + hooks = _RecorderHooks() + + async def _broken_started(task) -> None: + hooks.events.append(f"started:{task.task_id}") + raise RuntimeError("hook boom") + + hooks.on_task_started = _broken_started + runtime = SubagentRuntime(db=db, hooks=hooks) + + async def _executor(_task): + return "done" + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="umo:2", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "x"}}, + tool_call_id="call_2", + ) + await runtime.process_once(batch_size=8) + + assert db.tasks[task_id].status == "succeeded" + assert f"succeeded:{task_id}" in hooks.events + + +@pytest.mark.asyncio +async def test_runtime_hooks_called_on_retry_and_failure(): + db = _FakeDb() + hooks = _RecorderHooks() + runtime = SubagentRuntime(db=db, hooks=hooks, max_attempts=2) + calls = {"n": 0} + + async def _executor(_task): + calls["n"] += 1 + if calls["n"] == 1: + raise TimeoutError("retry") + raise ValueError("fatal") + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="umo:3", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "x"}}, + tool_call_id="call_3", + ) + await runtime.process_once(batch_size=8) + db.tasks[task_id].next_run_at = datetime.now(timezone.utc) + await runtime.process_once(batch_size=8) + + assert f"retrying:{task_id}:transient" in hooks.events + assert f"failed:{task_id}:fatal" in hooks.events diff --git a/tests/unit/test_subagent_persistence.py b/tests/unit/test_subagent_persistence.py new file mode 100644 index 000000000..741969b95 --- /dev/null +++ b/tests/unit/test_subagent_persistence.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.db.migration.migra_subagent_tasks import migrate_subagent_tasks +from astrbot.core.db.sqlite import SQLiteDatabase + + +class _FakeMigrationConn: + async def run_sync(self, _fn): + return None + + async def execute(self, _stmt): + return SimpleNamespace(fetchone=lambda: ("subagent_tasks",)) + + +class _FakeMigrationEngine: + class _BeginCtx: + async def __aenter__(self): + return _FakeMigrationConn() + + async def __aexit__(self, exc_type, exc, tb): + return False + + def begin(self): + return _FakeMigrationEngine._BeginCtx() + + +class _FakeMigrationDb: + def __init__(self): + self.markers: dict[str, bool] = {} + self.engine = _FakeMigrationEngine() + + async def get_preference(self, _scope, _scope_id, key): + return self.markers.get(key, False) + + +@pytest.mark.asyncio +async def test_subagent_migration_is_idempotent(monkeypatch: pytest.MonkeyPatch): + db = _FakeMigrationDb() + + async def _fake_put_async(_scope, _scope_id, key, value): + db.markers[key] = bool(value) + + put_async = AsyncMock(side_effect=_fake_put_async) + monkeypatch.setattr( + "astrbot.core.db.migration.migra_subagent_tasks.sp.put_async", put_async + ) + + await migrate_subagent_tasks(db) + await migrate_subagent_tasks(db) + + assert put_async.await_count == 2 + + +@pytest.mark.asyncio +async def test_sqlite_subagent_task_status_transitions(tmp_path: Path): + db_path = tmp_path / "subagent_tasks.db" + db = SQLiteDatabase(str(db_path)) + await db.initialize() + + now = datetime.now(timezone.utc) + task = await db.create_subagent_task( + task_id="task_1", + idempotency_key="idem_1", + umo="umo:1", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload_json='{"input":"hello"}', + max_attempts=3, + ) + claimed = await db.claim_due_subagent_tasks( + now=now + timedelta(seconds=1), limit=10 + ) + assert any(row.task_id == task.task_id for row in claimed) + + running = await db.mark_subagent_task_running(task.task_id) + assert running is not None + assert running.status == "running" + assert running.attempt == 1 + + retried = await db.mark_subagent_task_retrying( + task_id=task.task_id, + next_run_at=now - timedelta(seconds=1), + error_class="transient", + last_error="timeout", + ) + assert retried is True + + running_again = await db.mark_subagent_task_running(task.task_id) + assert running_again is not None + assert running_again.status == "running" + assert running_again.attempt == 2 + + succeeded = await db.mark_subagent_task_succeeded(task.task_id, result_text="done") + assert succeeded is True + + failed_task = await db.create_subagent_task( + task_id="task_2", + idempotency_key="idem_2", + umo="umo:2", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload_json='{"input":"boom"}', + max_attempts=3, + ) + running_failed = await db.mark_subagent_task_running(failed_task.task_id) + assert running_failed is not None + failed = await db.mark_subagent_task_failed( + task_id=failed_task.task_id, + error_class="fatal", + last_error="bad input", + ) + assert failed is True + retried_failed = await db.reschedule_subagent_task( + task_id=failed_task.task_id, + next_run_at=now - timedelta(seconds=1), + error_class="manual", + last_error="manual retry requested", + ) + assert retried_failed is True + retried_rows = await db.list_subagent_tasks(status="retrying", limit=10) + retried_row = next(row for row in retried_rows if row.task_id == failed_task.task_id) + assert retried_row.attempt == 0 + running_failed_retry = await db.mark_subagent_task_running(failed_task.task_id) + assert running_failed_retry is not None + assert running_failed_retry.attempt == 1 + succeeded_after_retry = await db.mark_subagent_task_succeeded( + failed_task.task_id, result_text="done_after_manual_retry" + ) + assert succeeded_after_retry is True + + canceled_task = await db.create_subagent_task( + task_id="task_3", + idempotency_key="idem_3", + umo="umo:3", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload_json='{"input":"cancel"}', + max_attempts=3, + ) + canceled = await db.cancel_subagent_task(canceled_task.task_id) + assert canceled is True + + canceled_running_task = await db.create_subagent_task( + task_id="task_4", + idempotency_key="idem_4", + umo="umo:4", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload_json='{"input":"cancel_running"}', + max_attempts=3, + ) + running_canceled = await db.mark_subagent_task_running(canceled_running_task.task_id) + assert running_canceled is not None + canceled_running = await db.cancel_subagent_task(canceled_running_task.task_id) + assert canceled_running is True + failed_after_cancel = await db.mark_subagent_task_failed( + task_id=canceled_running_task.task_id, + error_class="fatal", + last_error="should_not_override_cancel", + ) + assert failed_after_cancel is False + + succeeded_rows = await db.list_subagent_tasks(status="succeeded", limit=10) + failed_rows = await db.list_subagent_tasks(status="failed", limit=10) + canceled_rows = await db.list_subagent_tasks(status="canceled", limit=10) + assert any(row.task_id == "task_1" for row in succeeded_rows) + assert not any(row.task_id == "task_2" for row in failed_rows) + assert any(row.task_id == "task_2" for row in succeeded_rows) + assert any(row.task_id == "task_3" for row in canceled_rows) + assert any(row.task_id == "task_4" for row in canceled_rows) diff --git a/tests/unit/test_subagent_planner.py b/tests/unit/test_subagent_planner.py new file mode 100644 index 000000000..b7a10c91e --- /dev/null +++ b/tests/unit/test_subagent_planner.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.subagent.codec import decode_subagent_config +from astrbot.core.subagent.planner import SubagentPlanner + + +class _FakeToolMgr: + def __init__(self): + self.func_list = [ + FunctionTool( + name="tool_a", + description="A", + parameters={"type": "object", "properties": {}}, + handler=None, + ), + FunctionTool( + name="tool_b", + description="B", + parameters={"type": "object", "properties": {}}, + handler=None, + ), + ] + + +class _FakePersonaMgr: + async def get_persona(self, persona_id: str): + if persona_id == "missing": + raise ValueError("missing") + return SimpleNamespace( + system_prompt="persona prompt", + tools=["tool_b"], + begin_dialogs=[{"role": "assistant", "content": "hi"}], + ) + + +@pytest.mark.asyncio +async def test_planner_builds_handoff_and_dedupe_set(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "remove_main_duplicate_tools": True, + "agents": [ + { + "name": "writer", + "enabled": True, + "tools_scope": "list", + "tools": ["tool_a", "transfer_to_x", "not_exist"], + "instructions": "x", + "max_steps": 7, + } + ], + } + ) + planner = SubagentPlanner(_FakeToolMgr(), _FakePersonaMgr()) + plan = await planner.build_mount_plan(config) + assert len(plan.handoffs) == 1 + assert "transfer_to_writer" in plan.handoff_by_tool_name + assert "tool_a" in plan.main_tool_exclude_set + assert all("transfer_to_x" not in msg for msg in plan.main_tool_exclude_set) + assert any("recursive handoff" in d for d in plan.diagnostics) + assert getattr(plan.handoffs[0], "max_steps", None) == 7 + + +@pytest.mark.asyncio +async def test_planner_uses_persona_tools_and_prompt(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + { + "name": "persona_agent", + "enabled": True, + "persona_id": "p1", + "tools_scope": "persona", + } + ], + } + ) + planner = SubagentPlanner(_FakeToolMgr(), _FakePersonaMgr()) + plan = await planner.build_mount_plan(config) + assert len(plan.handoffs) == 1 + handoff = plan.handoffs[0] + assert handoff.agent.tools == ["tool_b"] + assert handoff.agent.instructions == "persona prompt" + + +@pytest.mark.asyncio +async def test_planner_detects_safe_tool_name_conflict(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + {"name": "A A", "enabled": True, "tools_scope": "none"}, + {"name": "a_a", "enabled": True, "tools_scope": "none"}, + ], + } + ) + planner = SubagentPlanner(_FakeToolMgr(), _FakePersonaMgr()) + plan = await planner.build_mount_plan(config) + assert len(plan.handoffs) == 2 + assert plan.handoffs[0].name == "transfer_to_a_a" + assert plan.handoffs[1].name == "transfer_to_a_a-2" + assert any( + "duplicate subagent tool name" in d and "renamed" in d + for d in plan.diagnostics + ) + + +@pytest.mark.asyncio +async def test_planner_respects_tools_scope_all_and_none(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + {"name": "all_agent", "enabled": True, "tools_scope": "all"}, + {"name": "none_agent", "enabled": True, "tools_scope": "none"}, + ], + } + ) + planner = SubagentPlanner(_FakeToolMgr(), _FakePersonaMgr()) + plan = await planner.build_mount_plan(config) + by_name = {handoff.agent.name: handoff for handoff in plan.handoffs} + assert by_name["all_agent"].agent.tools is None + assert by_name["none_agent"].agent.tools == [] + + +@pytest.mark.asyncio +async def test_planner_safe_tool_name_is_stable(): + config, _ = decode_subagent_config( + { + "main_enable": True, + "agents": [ + { + "name": "Writer Agent !!!", + "enabled": True, + "tools_scope": "none", + } + ], + } + ) + planner = SubagentPlanner(_FakeToolMgr(), _FakePersonaMgr()) + plan1 = await planner.build_mount_plan(config) + plan2 = await planner.build_mount_plan(config) + assert len(plan1.handoffs) == 1 + assert len(plan2.handoffs) == 1 + assert plan1.handoffs[0].name == plan2.handoffs[0].name diff --git a/tests/unit/test_subagent_route.py b/tests/unit/test_subagent_route.py new file mode 100644 index 000000000..1cadf8b1e --- /dev/null +++ b/tests/unit/test_subagent_route.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from quart import Quart + +from astrbot.core.agent.tool import FunctionTool +from astrbot.dashboard.routes.route import RouteContext +from astrbot.dashboard.routes.subagent import SubAgentRoute + + +class _FakeConfig(dict): + def save_config(self) -> None: + self["_saved"] = True + + +@pytest.fixture +def subagent_app(): + app = Quart(__name__) + astrbot_config = _FakeConfig( + { + "subagent_orchestrator": { + "enable": True, + "agents": [ + { + "name": "writer", + "system_prompt": "legacy prompt", + } + ], + } + } + ) + lifecycle = MagicMock() + lifecycle.astrbot_config = astrbot_config + lifecycle.subagent_orchestrator = MagicMock() + lifecycle.subagent_orchestrator.reload_from_config = AsyncMock( + return_value=["WARN: reload"] + ) + lifecycle.subagent_orchestrator.list_tasks = AsyncMock(return_value=[]) + lifecycle.subagent_orchestrator.retry_task = AsyncMock(return_value=True) + lifecycle.subagent_orchestrator.cancel_task = AsyncMock(return_value=True) + + normal_tool = FunctionTool( + name="tool_a", + description="A", + parameters={"type": "object", "properties": {}}, + handler=None, + ) + hidden_tool = FunctionTool( + name="tool_b", + description="B", + parameters={"type": "object", "properties": {}}, + handler=None, + ) + hidden_tool.handler_module_path = "core.subagent_orchestrator" + lifecycle.provider_manager.llm_tools.func_list = [normal_tool, hidden_tool] + + route_ctx = RouteContext(config=MagicMock(), app=app) + SubAgentRoute(route_ctx, lifecycle) + return app, lifecycle, astrbot_config + + +@pytest.mark.asyncio +async def test_get_subagent_config_returns_compatible_shape(subagent_app): + app, _, _ = subagent_app + async with app.test_app(): + client = app.test_client() + resp = await client.get("/api/subagent/config") + body = await resp.get_json() + assert resp.status_code == 200 + assert body["status"] == "ok" + assert "agents" in body["data"] + assert body["data"]["agents"][0]["system_prompt"] == "legacy prompt" + assert "compat_warnings" in body["data"] + + +@pytest.mark.asyncio +async def test_post_subagent_config_returns_diagnostics_and_compat_warnings(subagent_app): + app, lifecycle, astrbot_config = subagent_app + payload = { + "enable": True, + "agents": [ + { + "name": "writer", + "system_prompt": "legacy prompt", + } + ], + } + async with app.test_app(): + client = app.test_client() + resp = await client.post("/api/subagent/config", json=payload) + body = await resp.get_json() + + assert resp.status_code == 200 + assert body["status"] == "ok" + assert "diagnostics" in body["data"] + assert "compat_warnings" in body["data"] + assert any("legacy field" in item for item in body["data"]["compat_warnings"]) + assert astrbot_config.get("_saved") is True + lifecycle.subagent_orchestrator.reload_from_config.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_subagent_tasks_clamps_limit(subagent_app): + app, lifecycle, _ = subagent_app + async with app.test_app(): + client = app.test_client() + resp = await client.get("/api/subagent/tasks?limit=0") + body = await resp.get_json() + assert resp.status_code == 200 + assert body["status"] == "ok" + lifecycle.subagent_orchestrator.list_tasks.assert_awaited_once_with( + status=None, limit=1 + ) + + +@pytest.mark.asyncio +async def test_subagent_task_actions(subagent_app): + app, lifecycle, _ = subagent_app + async with app.test_app(): + client = app.test_client() + retry_resp = await client.post("/api/subagent/tasks/task-1/retry") + cancel_resp = await client.post("/api/subagent/tasks/task-1/cancel") + retry_body = await retry_resp.get_json() + cancel_body = await cancel_resp.get_json() + assert retry_resp.status_code == 200 + assert cancel_resp.status_code == 200 + assert retry_body["status"] == "ok" + assert cancel_body["status"] == "ok" + lifecycle.subagent_orchestrator.retry_task.assert_awaited_once_with("task-1") + lifecycle.subagent_orchestrator.cancel_task.assert_awaited_once_with("task-1") diff --git a/tests/unit/test_subagent_runtime.py b/tests/unit/test_subagent_runtime.py new file mode 100644 index 000000000..46a35a7aa --- /dev/null +++ b/tests/unit/test_subagent_runtime.py @@ -0,0 +1,344 @@ +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta, timezone + +import pytest +from _fake_subagent_db import FakeSubagentDb as _FakeDb +from sqlalchemy.exc import IntegrityError + +from astrbot.core.db.po import SubagentTask +from astrbot.core.subagent.error_classifier import ErrorClassifier +from astrbot.core.subagent.runtime import SubagentRuntime + + +@pytest.mark.asyncio +async def test_runtime_enqueue_is_idempotent(): + db = _FakeDb() + runtime = SubagentRuntime(db) + payload = {"tool_args": {"input": "hello"}} + task1 = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload=payload, + tool_call_id="call_1", + ) + task2 = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload=payload, + tool_call_id="call_1", + ) + assert task1 == task2 + + +@pytest.mark.asyncio +async def test_runtime_enqueue_diff_tool_call_id_not_deduped(): + db = _FakeDb() + runtime = SubagentRuntime(db) + payload = {"tool_args": {"input": "hello"}} + task1 = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload=payload, + tool_call_id="call_1", + ) + task2 = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload=payload, + tool_call_id="call_2", + ) + assert task1 != task2 + + +@pytest.mark.asyncio +async def test_runtime_retries_transient_then_succeeds(): + db = _FakeDb() + runtime = SubagentRuntime(db, base_delay_ms=100, max_delay_ms=100, max_attempts=3) + calls = {"n": 0} + + async def _executor(_task): + calls["n"] += 1 + if calls["n"] == 1: + raise TimeoutError("transient") + return "done" + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "x"}}, + tool_call_id="call_2", + ) + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "retrying" + + db.tasks[task_id].next_run_at = datetime.now(timezone.utc) - timedelta(seconds=1) + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "succeeded" + assert calls["n"] == 2 + + +@pytest.mark.asyncio +async def test_runtime_respects_session_lane_serialization(): + db = _FakeDb() + runtime = SubagentRuntime(db) + + async def _executor(_task): + await asyncio.sleep(0.05) + return "ok" + + runtime.set_task_executor(_executor) + + await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "a"}}, + tool_call_id="call_a", + ) + await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "b"}}, + tool_call_id="call_b", + ) + + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + pending = [task for task in db.tasks.values() if task.status == "pending"] + assert len(pending) == 1 + + +@pytest.mark.asyncio +async def test_runtime_fatal_error_fails_without_retry(): + db = _FakeDb() + runtime = SubagentRuntime(db, max_attempts=3) + + async def _executor(_task): + raise ValueError("fatal") + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "fatal"}}, + tool_call_id="call_fatal", + ) + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "failed" + assert db.tasks[task_id].error_class == "fatal" + + +@pytest.mark.asyncio +async def test_runtime_respects_global_concurrency_limit(): + db = _FakeDb() + runtime = SubagentRuntime(db, max_concurrent=1) + + async def _executor(_task): + await asyncio.sleep(0.05) + return "ok" + + runtime.set_task_executor(_executor) + await runtime.enqueue( + umo="umo:1", + subagent_name="writer_a", + handoff_tool_name="transfer_to_writer_a", + payload={"tool_args": {"input": "a"}}, + tool_call_id="call_a", + ) + await runtime.enqueue( + umo="umo:2", + subagent_name="writer_b", + handoff_tool_name="transfer_to_writer_b", + payload={"tool_args": {"input": "b"}}, + tool_call_id="call_b", + ) + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + pending = [task for task in db.tasks.values() if task.status == "pending"] + assert len(pending) == 1 + + +@pytest.mark.asyncio +async def test_runtime_canceled_task_is_not_claimed(): + db = _FakeDb() + runtime = SubagentRuntime(db) + + async def _executor(_task): + return "ok" + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "x"}}, + tool_call_id="call_cancel", + ) + canceled = await runtime.cancel_task(task_id) + assert canceled is True + processed = await runtime.process_once(batch_size=8) + assert processed == 0 + assert db.tasks[task_id].status == "canceled" + + +@pytest.mark.asyncio +async def test_runtime_recovers_running_tasks_after_restart(): + db = _FakeDb() + runtime = SubagentRuntime(db) + + async def _executor(_task): + return "ok" + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "recover"}}, + tool_call_id="call_recover", + ) + running = await db.mark_subagent_task_running(task_id) + assert running is not None + assert db.tasks[task_id].status == "running" + # Simulate a stale task by setting updated_at beyond the recovery threshold. + db.tasks[task_id].updated_at = datetime.now(timezone.utc) - timedelta(minutes=10) + + restarted_runtime = SubagentRuntime(db) + restarted_runtime.set_task_executor(_executor) + processed = await restarted_runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "succeeded" + assert db.tasks[task_id].attempt == 2 + + +class _RetryableClassifier(ErrorClassifier): + def classify(self, exc: Exception) -> str: + _ = exc + return "retryable" + + +@pytest.mark.asyncio +async def test_runtime_retryable_classification_follows_retry_branch(): + db = _FakeDb() + runtime = SubagentRuntime( + db, + base_delay_ms=100, + max_delay_ms=100, + max_attempts=2, + error_classifier=_RetryableClassifier(), + ) + + async def _executor(_task): + raise RuntimeError("retryable") + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "x"}}, + tool_call_id="call_retryable", + ) + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "retrying" + + +@pytest.mark.asyncio +async def test_runtime_manual_retry_reschedules_failed_task(): + db = _FakeDb() + runtime = SubagentRuntime(db, max_attempts=3) + calls = {"n": 0} + + async def _executor(_task): + calls["n"] += 1 + if calls["n"] == 1: + raise ValueError("fatal") + return "done" + + runtime.set_task_executor(_executor) + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "manual-retry"}}, + tool_call_id="call_manual_retry", + ) + + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "failed" + + retried = await runtime.retry_task(task_id) + assert retried is True + assert db.tasks[task_id].status == "retrying" + assert db.tasks[task_id].attempt == 0 + + db.tasks[task_id].next_run_at = datetime.now(timezone.utc) - timedelta(seconds=1) + processed = await runtime.process_once(batch_size=8) + assert processed == 1 + assert db.tasks[task_id].status == "succeeded" + assert calls["n"] == 2 + + +class _RaceDb(_FakeDb): + def __init__(self): + super().__init__() + self._race_injected = False + + async def create_subagent_task(self, **kwargs) -> SubagentTask: + if not self._race_injected: + self._race_injected = True + now = datetime.now(timezone.utc) + winner = SubagentTask( + task_id="winner_task", + idempotency_key=kwargs["idempotency_key"], + umo=kwargs["umo"], + subagent_name=kwargs["subagent_name"], + handoff_tool_name=kwargs["handoff_tool_name"], + payload_json=kwargs["payload_json"], + max_attempts=kwargs.get("max_attempts", 3), + status="pending", + attempt=0, + next_run_at=now, + created_at=now, + updated_at=now, + ) + self.tasks[winner.task_id] = winner + raise IntegrityError( + statement="INSERT INTO subagent_tasks ...", + params={}, + orig=Exception( + "UNIQUE constraint failed: subagent_tasks.idempotency_key" + ), + ) + return await super().create_subagent_task(**kwargs) + + +@pytest.mark.asyncio +async def test_runtime_enqueue_handles_idempotency_race_and_returns_existing_task(): + db = _RaceDb() + runtime = SubagentRuntime(db) + + task_id = await runtime.enqueue( + umo="webchat:FriendMessage:webchat!u!s", + subagent_name="writer", + handoff_tool_name="transfer_to_writer", + payload={"tool_args": {"input": "race"}}, + tool_call_id="call_race", + ) + + assert task_id == "winner_task"