From fe25c3493d712c2aa78ad441c08b0c56350dfb46 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 26 Feb 2026 21:19:43 +0800 Subject: [PATCH 1/2] feat: implement follow-up message handling in ToolLoopAgentRunner --- .../agent/runners/tool_loop_agent_runner.py | 167 +++++---- .../core/pipeline/process_stage/follow_up.py | 227 ++++++++++++ .../method/agent_sub_stages/internal.py | 340 ++++++++++-------- tests/test_tool_loop_agent_runner.py | 84 +++++ 4 files changed, 606 insertions(+), 212 deletions(-) create mode 100644 astrbot/core/pipeline/process_stage/follow_up.py diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 10cf2e96c6..94069089d9 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,9 +1,10 @@ +import asyncio import copy import sys import time import traceback import typing as T -from dataclasses import dataclass +from dataclasses import dataclass, field from mcp.types import ( BlobResourceContents, @@ -68,6 +69,14 @@ def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult": return cls(kind="cached_image", cached_image=image) +@dataclass(slots=True) +class FollowUpTicket: + seq: int + text: str + consumed: bool = False + resolved: asyncio.Event = field(default_factory=asyncio.Event) + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): @override async def reset( @@ -139,6 +148,8 @@ async def reset( self.run_context = run_context self._stop_requested = False self._aborted = False + self._pending_follow_ups: list[FollowUpTicket] = [] + self._follow_up_seq = 0 # These two are used for tool schema mode handling # We now have two modes: @@ -277,6 +288,55 @@ def _simple_print_message_role(self, tag: str = ""): roles.append(message.role) logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}") + def follow_up( + self, + *, + message_text: str, + ) -> FollowUpTicket | None: + """Queue a follow-up message for the next tool result.""" + if self.done(): + return None + text = (message_text or "").strip() + if not text: + return None + ticket = FollowUpTicket(seq=self._follow_up_seq, text=text) + self._follow_up_seq += 1 + self._pending_follow_ups.append(ticket) + return ticket + + def _resolve_unconsumed_follow_ups(self) -> None: + if not self._pending_follow_ups: + return + follow_ups = self._pending_follow_ups + self._pending_follow_ups = [] + for ticket in follow_ups: + ticket.resolved.set() + + def _consume_follow_up_notice(self) -> str: + if not self._pending_follow_ups: + return "" + follow_ups = self._pending_follow_ups + self._pending_follow_ups = [] + for ticket in follow_ups: + ticket.consumed = True + ticket.resolved.set() + follow_up_lines = "\n".join( + f"{idx}. {ticket.text}" for idx, ticket in enumerate(follow_ups, start=1) + ) + return ( + "\n\n[SYSTEM NOTICE] User sent follow-up messages while tool execution " + "was in progress. Prioritize these follow-up instructions in your next " + "actions. In your very next action, briefly acknowledge to the user " + "that their follow-up message(s) were received before continuing.\n" + f"{follow_up_lines}" + ) + + def _merge_follow_up_notice(self, content: str) -> str: + notice = self._consume_follow_up_notice() + if not notice: + return content + return f"{content}{notice}" + @override async def step(self): """Process a single step of the agent. @@ -391,6 +451,7 @@ async def step(self): type="aborted", data=AgentResponseData(chain=MessageChain(type="aborted")), ) + self._resolve_unconsumed_follow_ups() return # 处理 LLM 响应 @@ -401,6 +462,7 @@ async def step(self): self.final_llm_resp = llm_resp self.stats.end_time = time.time() self._transition_state(AgentState.ERROR) + self._resolve_unconsumed_follow_ups() yield AgentResponse( type="err", data=AgentResponseData( @@ -439,6 +501,7 @@ async def step(self): await self.agent_hooks.on_agent_done(self.run_context, llm_resp) except Exception as e: logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + self._resolve_unconsumed_follow_ups() # 返回 LLM 结果 if llm_resp.result_chain: @@ -583,6 +646,15 @@ async def _handle_function_tools( tool_call_result_blocks: list[ToolCallMessageSegment] = [] logger.info(f"Agent 使用工具: {llm_response.tools_call_name}") + def _append_tool_call_result(tool_call_id: str, content: str) -> None: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=tool_call_id, + content=self._merge_follow_up_notice(content), + ), + ) + # 执行函数调用 for func_tool_name, func_tool_args, func_tool_id in zip( llm_response.tools_call_name, @@ -622,12 +694,9 @@ async def _handle_function_tools( if not func_tool: logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。") - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: Tool {func_tool_name} not found.", - ), + _append_tool_call_result( + func_tool_id, + f"error: Tool {func_tool_name} not found.", ) continue @@ -680,12 +749,9 @@ async def _handle_function_tools( res = resp _final_resp = resp if isinstance(res.content[0], TextContent): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, - ), + _append_tool_call_result( + func_tool_id, + res.content[0].text, ) elif isinstance(res.content[0], ImageContent): # Cache the image instead of sending directly @@ -696,15 +762,12 @@ async def _handle_function_tools( index=0, mime_type=res.content[0].mimeType or "image/png", ) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=( - f"Image returned and cached at path='{cached_img.file_path}'. " - f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " - f"with type='image' and path='{cached_img.file_path}'." - ), + _append_tool_call_result( + func_tool_id, + ( + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." ), ) # Yield image info for LLM visibility (will be handled in step()) @@ -714,12 +777,9 @@ async def _handle_function_tools( elif isinstance(res.content[0], EmbeddedResource): resource = res.content[0].resource if isinstance(resource, TextResourceContents): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resource.text, - ), + _append_tool_call_result( + func_tool_id, + resource.text, ) elif ( isinstance(resource, BlobResourceContents) @@ -734,15 +794,12 @@ async def _handle_function_tools( index=0, mime_type=resource.mimeType, ) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=( - f"Image returned and cached at path='{cached_img.file_path}'. " - f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " - f"with type='image' and path='{cached_img.file_path}'." - ), + _append_tool_call_result( + func_tool_id, + ( + f"Image returned and cached at path='{cached_img.file_path}'. " + f"Review the image below. Use send_message_to_user to send it to the user if satisfied, " + f"with type='image' and path='{cached_img.file_path}'." ), ) # Yield image info for LLM visibility @@ -750,12 +807,9 @@ async def _handle_function_tools( cached_img ) else: - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="The tool has returned a data type that is not supported.", - ), + _append_tool_call_result( + func_tool_id, + "The tool has returned a data type that is not supported.", ) elif resp is None: @@ -767,24 +821,18 @@ async def _handle_function_tools( ) self._transition_state(AgentState.DONE) self.stats.end_time = time.time() - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="The tool has no return value, or has sent the result directly to the user.", - ), + _append_tool_call_result( + func_tool_id, + "The tool has no return value, or has sent the result directly to the user.", ) else: # 不应该出现其他类型 logger.warning( f"Tool 返回了不支持的类型: {type(resp)}。", ) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*", - ), + _append_tool_call_result( + func_tool_id, + "*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*", ) try: @@ -798,12 +846,9 @@ async def _handle_function_tools( logger.error(f"Error in on_tool_end hook: {e}", exc_info=True) except Exception as e: logger.warning(traceback.format_exc()) - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: {e!s}", - ), + _append_tool_call_result( + func_tool_id, + f"error: {e!s}", ) # yield the last tool call result diff --git a/astrbot/core/pipeline/process_stage/follow_up.py b/astrbot/core/pipeline/process_stage/follow_up.py new file mode 100644 index 0000000000..6c1a4fa06b --- /dev/null +++ b/astrbot/core/pipeline/process_stage/follow_up.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + +from astrbot import logger +from astrbot.core.agent.runners.tool_loop_agent_runner import FollowUpTicket +from astrbot.core.astr_agent_run_util import AgentRunner +from astrbot.core.platform.astr_message_event import AstrMessageEvent + +_ACTIVE_AGENT_RUNNERS: dict[str, AgentRunner] = {} +_FOLLOW_UP_ORDER_STATE: dict[str, dict[str, object]] = {} +"""UMO-level follow-up order state. + +State fields: +- `statuses`: seq -> {"pending"|"active"|"consumed"|"finished"} +- `next_order`: monotonically increasing sequence allocator +- `next_turn`: next sequence allowed to proceed when not consumed +""" + + +@dataclass(slots=True) +class FollowUpCapture: + umo: str + ticket: FollowUpTicket + order_seq: int + monitor_task: asyncio.Task[None] + + +def _event_follow_up_text(event: AstrMessageEvent) -> str: + text = (event.get_message_str() or "").strip() + if text: + return text + return event.get_message_outline().strip() + + +def register_active_runner(umo: str, runner: AgentRunner) -> None: + _ACTIVE_AGENT_RUNNERS[umo] = runner + + +def unregister_active_runner(umo: str, runner: AgentRunner) -> None: + if _ACTIVE_AGENT_RUNNERS.get(umo) is runner: + _ACTIVE_AGENT_RUNNERS.pop(umo, None) + + +def _get_follow_up_order_state(umo: str) -> dict[str, object]: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if state is None: + state = { + "condition": asyncio.Condition(), + # Sequence status map for strict in-order resume after unresolved follow-ups. + "statuses": {}, + # Stable allocator for arrival order; never decreases for the same UMO state. + "next_order": 0, + # The sequence currently allowed to continue main internal flow. + "next_turn": 0, + } + _FOLLOW_UP_ORDER_STATE[umo] = state + return state + + +def _advance_follow_up_turn_locked(state: dict[str, object]) -> None: + # Skip slots that are already handled, and stop at the first unfinished slot. + statuses = state["statuses"] + assert isinstance(statuses, dict) + next_turn = state["next_turn"] + assert isinstance(next_turn, int) + + while True: + curr = statuses.get(next_turn) + if curr in ("consumed", "finished"): + statuses.pop(next_turn, None) + next_turn += 1 + continue + break + + state["next_turn"] = next_turn + + +def _allocate_follow_up_order(umo: str) -> int: + state = _get_follow_up_order_state(umo) + next_order = state["next_order"] + assert isinstance(next_order, int) + seq = next_order + state["next_order"] = seq + 1 + statuses = state["statuses"] + assert isinstance(statuses, dict) + statuses[seq] = "pending" + return seq + + +async def _mark_follow_up_consumed(umo: str, seq: int) -> None: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if not state: + return + condition = state["condition"] + assert isinstance(condition, asyncio.Condition) + async with condition: + statuses = state["statuses"] + assert isinstance(statuses, dict) + if seq in statuses and statuses[seq] != "finished": + statuses[seq] = "consumed" + _advance_follow_up_turn_locked(state) + condition.notify_all() + + # Release state only when this UMO has no pending statuses and no active runner. + if not statuses and _ACTIVE_AGENT_RUNNERS.get(umo) is None: + _FOLLOW_UP_ORDER_STATE.pop(umo, None) + + +async def _activate_and_wait_follow_up_turn(umo: str, seq: int) -> None: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if not state: + return + condition = state["condition"] + assert isinstance(condition, asyncio.Condition) + async with condition: + statuses = state["statuses"] + assert isinstance(statuses, dict) + if seq in statuses: + statuses[seq] = "active" + + # Strict ordering: only the head (`next_turn`) can continue. + while True: + next_turn = state["next_turn"] + assert isinstance(next_turn, int) + if next_turn == seq: + break + await condition.wait() + + +async def _finish_follow_up_turn(umo: str, seq: int) -> None: + state = _FOLLOW_UP_ORDER_STATE.get(umo) + if not state: + return + condition = state["condition"] + assert isinstance(condition, asyncio.Condition) + async with condition: + statuses = state["statuses"] + assert isinstance(statuses, dict) + if seq in statuses: + statuses[seq] = "finished" + _advance_follow_up_turn_locked(state) + condition.notify_all() + + if not statuses and _ACTIVE_AGENT_RUNNERS.get(umo) is None: + _FOLLOW_UP_ORDER_STATE.pop(umo, None) + + +async def _monitor_follow_up_ticket( + umo: str, + ticket: FollowUpTicket, + order_seq: int, +) -> None: + """Advance consumed slots immediately on resolution to avoid wake-order drift.""" + await ticket.resolved.wait() + if ticket.consumed: + await _mark_follow_up_consumed(umo, order_seq) + + +def try_capture_follow_up(event: AstrMessageEvent) -> FollowUpCapture | None: + sender_id = event.get_sender_id() + if not sender_id: + return None + runner = _ACTIVE_AGENT_RUNNERS.get(event.unified_msg_origin) + if not runner: + return None + runner_event = getattr(getattr(runner.run_context, "context", None), "event", None) + if runner_event is None: + return None + active_sender_id = runner_event.get_sender_id() + if not active_sender_id or active_sender_id != sender_id: + return None + + ticket = runner.follow_up(message_text=_event_follow_up_text(event)) + if not ticket: + return None + # Allocate strict order at capture time (arrival order), not at wake time. + order_seq = _allocate_follow_up_order(event.unified_msg_origin) + monitor_task = asyncio.create_task( + _monitor_follow_up_ticket( + event.unified_msg_origin, + ticket, + order_seq, + ) + ) + logger.info( + "Captured follow-up message for active agent run, umo=%s, order_seq=%s", + event.unified_msg_origin, + order_seq, + ) + return FollowUpCapture( + umo=event.unified_msg_origin, + ticket=ticket, + order_seq=order_seq, + monitor_task=monitor_task, + ) + + +async def prepare_follow_up_capture(capture: FollowUpCapture) -> tuple[bool, bool]: + """Return `(consumed_marked, activated)` for internal stage branch handling.""" + await capture.ticket.resolved.wait() + if capture.ticket.consumed: + await _mark_follow_up_consumed(capture.umo, capture.order_seq) + return True, False + await _activate_and_wait_follow_up_turn(capture.umo, capture.order_seq) + return False, True + + +async def finalize_follow_up_capture( + capture: FollowUpCapture, + *, + activated: bool, + consumed_marked: bool, +) -> None: + # Best-effort cancellation: monitor task is auxiliary and should not leak. + if not capture.monitor_task.done(): + capture.monitor_task.cancel() + try: + await capture.monitor_task + except asyncio.CancelledError: + pass + + if activated: + await _finish_follow_up_turn(capture.umo, capture.order_seq) + elif not consumed_marked: + await _mark_follow_up_consumed(capture.umo, capture.order_seq) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 98cf77fcc9..5f86a05967 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -29,8 +29,16 @@ from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from .....astr_agent_run_util import run_agent, run_live_agent +from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent from ....context import PipelineContext, call_event_hook +from ....follow_up import ( + FollowUpCapture, + finalize_follow_up_capture, + prepare_follow_up_capture, + register_active_runner, + try_capture_follow_up, + unregister_active_runner, +) class InternalAgentSubStage(Stage): @@ -130,6 +138,9 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent, provider_wake_prefix: str ) -> AsyncGenerator[None, None]: + follow_up_capture: FollowUpCapture | None = None + follow_up_consumed_marked = False + follow_up_activated = False try: streaming_response = self.streaming_response if (enable_streaming := event.get_extra("enable_streaming")) is not None: @@ -150,188 +161,208 @@ async def process( return logger.debug("ready to request llm provider") + follow_up_capture = try_capture_follow_up(event) + if follow_up_capture: + ( + follow_up_consumed_marked, + follow_up_activated, + ) = await prepare_follow_up_capture(follow_up_capture) + if follow_up_consumed_marked: + logger.info( + "Follow-up ticket already consumed, stopping processing. umo=%s, seq=%s", + event.unified_msg_origin, + follow_up_capture.ticket.seq, + ) + return await event.send_typing() await call_event_hook(event, EventType.OnWaitingLLMRequestEvent) async with session_lock_manager.acquire_lock(event.unified_msg_origin): logger.debug("acquired session lock for llm request") + agent_runner: AgentRunner | None = None + runner_registered = False + try: + build_cfg = replace( + self.main_agent_cfg, + provider_wake_prefix=provider_wake_prefix, + streaming_response=streaming_response, + ) - build_cfg = replace( - self.main_agent_cfg, - provider_wake_prefix=provider_wake_prefix, - streaming_response=streaming_response, - ) + build_result: MainAgentBuildResult | None = await build_main_agent( + event=event, + plugin_context=self.ctx.plugin_manager.context, + config=build_cfg, + apply_reset=False, + ) - build_result: MainAgentBuildResult | None = await build_main_agent( - event=event, - plugin_context=self.ctx.plugin_manager.context, - config=build_cfg, - apply_reset=False, - ) + if build_result is None: + return - if build_result is None: - return + agent_runner = build_result.agent_runner + req = build_result.provider_request + provider = build_result.provider + reset_coro = build_result.reset_coro + + api_base = provider.provider_config.get("api_base", "") + for host in decoded_blocked: + if host in api_base: + logger.error( + "Provider API base %s is blocked due to security reasons. Please use another ai provider.", + api_base, + ) + return - agent_runner = build_result.agent_runner - req = build_result.provider_request - provider = build_result.provider - reset_coro = build_result.reset_coro - - api_base = provider.provider_config.get("api_base", "") - for host in decoded_blocked: - if host in api_base: - logger.error( - "Provider API base %s is blocked due to security reasons. Please use another ai provider.", - api_base, - ) - return + stream_to_general = ( + self.unsupported_streaming_strategy == "turn_off" + and not event.platform_meta.support_streaming_message + ) - stream_to_general = ( - self.unsupported_streaming_strategy == "turn_off" - and not event.platform_meta.support_streaming_message - ) + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + if reset_coro: + reset_coro.close() + return - if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + # apply reset if reset_coro: - reset_coro.close() - return - - # apply reset - if reset_coro: - await reset_coro - - action_type = event.get_extra("action_type") - - event.trace.record( - "astr_agent_prepare", - system_prompt=req.system_prompt, - tools=req.func_tool.names() if req.func_tool else [], - stream=streaming_response, - chat_provider={ - "id": provider.provider_config.get("id", ""), - "model": provider.get_model(), - }, - ) + await reset_coro + + register_active_runner(event.unified_msg_origin, agent_runner) + runner_registered = True + action_type = event.get_extra("action_type") + + event.trace.record( + "astr_agent_prepare", + system_prompt=req.system_prompt, + tools=req.func_tool.names() if req.func_tool else [], + stream=streaming_response, + chat_provider={ + "id": provider.provider_config.get("id", ""), + "model": provider.get_model(), + }, + ) - # 检测 Live Mode - if action_type == "live": - # Live Mode: 使用 run_live_agent - logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理") + # 检测 Live Mode + if action_type == "live": + # Live Mode: 使用 run_live_agent + logger.info("[Internal Agent] 检测到 Live Mode,启用 TTS 处理") - # 获取 TTS Provider - tts_provider = ( - self.ctx.plugin_manager.context.get_using_tts_provider( - event.unified_msg_origin + # 获取 TTS Provider + tts_provider = ( + self.ctx.plugin_manager.context.get_using_tts_provider( + event.unified_msg_origin + ) ) - ) - if not tts_provider: - logger.warning( - "[Live Mode] TTS Provider 未配置,将使用普通流式模式" + if not tts_provider: + logger.warning( + "[Live Mode] TTS Provider 未配置,将使用普通流式模式" + ) + + # 使用 run_live_agent,总是使用流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_live_agent( + agent_runner, + tts_provider, + self.max_step, + self.show_tool_use, + self.show_tool_call_result, + show_reasoning=self.show_reasoning, + ), + ), ) + yield - # 使用 run_live_agent,总是使用流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_live_agent( - agent_runner, - tts_provider, - self.max_step, - self.show_tool_use, - self.show_tool_call_result, - show_reasoning=self.show_reasoning, + # 保存历史记录 + if agent_runner.done() and ( + not event.is_stopped() or agent_runner.was_aborted() + ): + await self._save_to_history( + event, + req, + agent_runner.get_final_llm_resp(), + agent_runner.run_context.messages, + agent_runner.stats, + user_aborted=agent_runner.was_aborted(), + ) + + elif streaming_response and not stream_to_general: + # 流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + self.show_tool_call_result, + show_reasoning=self.show_reasoning, + ), ), - ), + ) + yield + if agent_runner.done(): + if final_llm_resp := agent_runner.get_final_llm_resp(): + if final_llm_resp.completion_text: + chain = ( + MessageChain() + .message(final_llm_resp.completion_text) + .chain + ) + elif final_llm_resp.result_chain: + chain = final_llm_resp.result_chain.chain + else: + chain = MessageChain().chain + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + async for _ in run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + self.show_tool_call_result, + stream_to_general, + show_reasoning=self.show_reasoning, + ): + yield + + final_resp = agent_runner.get_final_llm_resp() + + event.trace.record( + "astr_agent_complete", + stats=agent_runner.stats.to_dict(), + resp=final_resp.completion_text if final_resp else None, ) - yield - # 保存历史记录 - if agent_runner.done() and ( - not event.is_stopped() or agent_runner.was_aborted() - ): + # 检查事件是否被停止,如果被停止则不保存历史记录 + if not event.is_stopped() or agent_runner.was_aborted(): await self._save_to_history( event, req, - agent_runner.get_final_llm_resp(), + final_resp, agent_runner.run_context.messages, agent_runner.stats, user_aborted=agent_runner.was_aborted(), ) - elif streaming_response and not stream_to_general: - # 流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - self.show_tool_call_result, - show_reasoning=self.show_reasoning, - ), + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, ), ) - yield - if agent_runner.done(): - if final_llm_resp := agent_runner.get_final_llm_resp(): - if final_llm_resp.completion_text: - chain = ( - MessageChain() - .message(final_llm_resp.completion_text) - .chain - ) - elif final_llm_resp.result_chain: - chain = final_llm_resp.result_chain.chain - else: - chain = MessageChain().chain - event.set_result( - MessageEventResult( - chain=chain, - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - async for _ in run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - self.show_tool_call_result, - stream_to_general, - show_reasoning=self.show_reasoning, - ): - yield - - final_resp = agent_runner.get_final_llm_resp() - - event.trace.record( - "astr_agent_complete", - stats=agent_runner.stats.to_dict(), - resp=final_resp.completion_text if final_resp else None, - ) - - # 检查事件是否被停止,如果被停止则不保存历史记录 - if not event.is_stopped() or agent_runner.was_aborted(): - await self._save_to_history( - event, - req, - final_resp, - agent_runner.run_context.messages, - agent_runner.stats, - user_aborted=agent_runner.was_aborted(), - ) - - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=agent_runner.provider.get_model(), - provider_type=agent_runner.provider.meta().type, - ), - ) + finally: + if runner_registered and agent_runner is not None: + unregister_active_runner(event.unified_msg_origin, agent_runner) except Exception as e: logger.error(f"Error occurred while processing agent: {e}") @@ -340,6 +371,13 @@ async def process( f"Error occurred while processing agent request: {e}" ) ) + finally: + if follow_up_capture: + await finalize_follow_up_capture( + follow_up_capture, + activated=follow_up_activated, + consumed_marked=follow_up_consumed_marked, + ) async def _save_to_history( self, diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 0b5190407d..c8925416b6 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -149,6 +149,20 @@ async def on_agent_done(self, run_context, llm_response): self.agent_done_called = True +class MockEvent: + def __init__(self, umo: str, sender_id: str): + self.unified_msg_origin = umo + self._sender_id = sender_id + + def get_sender_id(self): + return self._sender_id + + +class MockAgentContext: + def __init__(self, event): + self.event = event + + @pytest.fixture def mock_provider(): return MockProvider() @@ -451,6 +465,76 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message( assert runner.run_context.messages[-1].role == "assistant" +@pytest.mark.asyncio +async def test_tool_result_injects_follow_up_notice( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_event = MockEvent("test:FriendMessage:follow_up", "u1") + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + ticket1 = runner.follow_up( + message_text="follow up 1", + ) + ticket2 = runner.follow_up( + message_text="follow up 2", + ) + assert ticket1 is not None + assert ticket2 is not None + + async for _ in runner.step(): + pass + + assert provider_request.tool_calls_result is not None + assert isinstance(provider_request.tool_calls_result, list) + assert provider_request.tool_calls_result + tool_result = str( + provider_request.tool_calls_result[0].tool_calls_result[0].content + ) + assert "SYSTEM NOTICE" in tool_result + assert "1. follow up 1" in tool_result + assert "2. follow up 2" in tool_result + assert ticket1.resolved.is_set() is True + assert ticket2.resolved.is_set() is True + assert ticket1.consumed is True + assert ticket2.consumed is True + + +@pytest.mark.asyncio +async def test_follow_up_ticket_not_consumed_when_no_next_tool_call( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = False + mock_event = MockEvent("test:FriendMessage:follow_up_no_tool", "u1") + run_context = ContextWrapper(context=MockAgentContext(mock_event)) + + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=run_context, + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + ticket = runner.follow_up(message_text="follow up without tool") + assert ticket is not None + + async for _ in runner.step(): + pass + + assert ticket.resolved.is_set() is True + assert ticket.consumed is False + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) From 495a54e9e6776cc5e9b850bb151b3e927f868d59 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 26 Feb 2026 21:36:38 +0800 Subject: [PATCH 2/2] fix: correct import path for follow-up module in InternalAgentSubStage --- .../pipeline/process_stage/method/agent_sub_stages/internal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 5f86a05967..d95f7f86cc 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -31,7 +31,7 @@ from .....astr_agent_run_util import AgentRunner, run_agent, run_live_agent from ....context import PipelineContext, call_event_hook -from ....follow_up import ( +from ...follow_up import ( FollowUpCapture, finalize_follow_up_capture, prepare_follow_up_capture,