Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 106 additions & 61 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
import copy
import sys
import time
import traceback
import typing as T
from dataclasses import dataclass
from dataclasses import dataclass, field

from mcp.types import (
BlobResourceContents,
Expand Down Expand Up @@ -68,6 +69,14 @@ def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult":
return cls(kind="cached_image", cached_image=image)


@dataclass(slots=True)
class FollowUpTicket:
seq: int
text: str
consumed: bool = False
resolved: asyncio.Event = field(default_factory=asyncio.Event)


class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
@override
async def reset(
Expand Down Expand Up @@ -139,6 +148,8 @@ async def reset(
self.run_context = run_context
self._stop_requested = False
self._aborted = False
self._pending_follow_ups: list[FollowUpTicket] = []
self._follow_up_seq = 0
Comment on lines +151 to +152
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这里存在一个竞态条件。_pending_follow_ups 列表会在没有同步机制的情况下被多个协程访问。follow_up 方法(从一个协程调用)会向列表追加内容,而 _consume_follow_up_notice_resolve_unconsumed_follow_ups 等方法(从代理的 step 协程调用)会读取并清空它。这可能导致后续消息丢失。

为了修复这个问题,建议使用 asyncio.Lock 来保护对 self._pending_follow_ups 的所有访问。

  1. 在这里初始化一个锁: self._follow_up_lock = asyncio.Lock()
  2. follow_up, _resolve_unconsumed_follow_ups, 和 _consume_follow_up_notice 方法改为 async
  3. 在这些方法中,将修改 _pending_follow_ups 的临界区代码包裹在 async with self._follow_up_lock: 中。
  4. 相应地更新这些方法的调用方,使用 await。例如, _merge_follow_up_notice 也需要变成 async
Suggested change
self._pending_follow_ups: list[FollowUpTicket] = []
self._follow_up_seq = 0
self._pending_follow_ups: list[FollowUpTicket] = []
self._follow_up_seq = 0
self._follow_up_lock = asyncio.Lock()


# These two are used for tool schema mode handling
# We now have two modes:
Expand Down Expand Up @@ -277,6 +288,55 @@ def _simple_print_message_role(self, tag: str = ""):
roles.append(message.role)
logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")

def follow_up(
self,
*,
message_text: str,
) -> FollowUpTicket | None:
"""Queue a follow-up message for the next tool result."""
if self.done():
return None
text = (message_text or "").strip()
if not text:
return None
ticket = FollowUpTicket(seq=self._follow_up_seq, text=text)
self._follow_up_seq += 1
self._pending_follow_ups.append(ticket)
return ticket

def _resolve_unconsumed_follow_ups(self) -> None:
if not self._pending_follow_ups:
return
follow_ups = self._pending_follow_ups
self._pending_follow_ups = []
for ticket in follow_ups:
ticket.resolved.set()

def _consume_follow_up_notice(self) -> str:
if not self._pending_follow_ups:
return ""
follow_ups = self._pending_follow_ups
self._pending_follow_ups = []
for ticket in follow_ups:
ticket.consumed = True
ticket.resolved.set()
follow_up_lines = "\n".join(
f"{idx}. {ticket.text}" for idx, ticket in enumerate(follow_ups, start=1)
)
return (
"\n\n[SYSTEM NOTICE] User sent follow-up messages while tool execution "
"was in progress. Prioritize these follow-up instructions in your next "
"actions. In your very next action, briefly acknowledge to the user "
"that their follow-up message(s) were received before continuing.\n"
f"{follow_up_lines}"
)

def _merge_follow_up_notice(self, content: str) -> str:
notice = self._consume_follow_up_notice()
if not notice:
return content
return f"{content}{notice}"

@override
async def step(self):
"""Process a single step of the agent.
Expand Down Expand Up @@ -391,6 +451,7 @@ async def step(self):
type="aborted",
data=AgentResponseData(chain=MessageChain(type="aborted")),
)
self._resolve_unconsumed_follow_ups()
return

# 处理 LLM 响应
Expand All @@ -401,6 +462,7 @@ async def step(self):
self.final_llm_resp = llm_resp
self.stats.end_time = time.time()
self._transition_state(AgentState.ERROR)
self._resolve_unconsumed_follow_ups()
yield AgentResponse(
type="err",
data=AgentResponseData(
Expand Down Expand Up @@ -439,6 +501,7 @@ async def step(self):
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
self._resolve_unconsumed_follow_ups()

# 返回 LLM 结果
if llm_resp.result_chain:
Expand Down Expand Up @@ -583,6 +646,15 @@ async def _handle_function_tools(
tool_call_result_blocks: list[ToolCallMessageSegment] = []
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")

def _append_tool_call_result(tool_call_id: str, content: str) -> None:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=tool_call_id,
content=self._merge_follow_up_notice(content),
),
)

# 执行函数调用
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
Expand Down Expand Up @@ -622,12 +694,9 @@ async def _handle_function_tools(

if not func_tool:
logger.warning(f"未找到指定的工具: {func_tool_name},将跳过。")
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: Tool {func_tool_name} not found.",
),
_append_tool_call_result(
func_tool_id,
f"error: Tool {func_tool_name} not found.",
)
continue

Expand Down Expand Up @@ -680,12 +749,9 @@ async def _handle_function_tools(
res = resp
_final_resp = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
),
_append_tool_call_result(
func_tool_id,
res.content[0].text,
)
elif isinstance(res.content[0], ImageContent):
# Cache the image instead of sending directly
Expand All @@ -696,15 +762,12 @@ async def _handle_function_tools(
index=0,
mime_type=res.content[0].mimeType or "image/png",
)
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
_append_tool_call_result(
func_tool_id,
(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
)
# Yield image info for LLM visibility (will be handled in step())
Expand All @@ -714,12 +777,9 @@ async def _handle_function_tools(
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
),
_append_tool_call_result(
func_tool_id,
resource.text,
)
elif (
isinstance(resource, BlobResourceContents)
Expand All @@ -734,28 +794,22 @@ async def _handle_function_tools(
index=0,
mime_type=resource.mimeType,
)
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
_append_tool_call_result(
func_tool_id,
(
f"Image returned and cached at path='{cached_img.file_path}'. "
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
f"with type='image' and path='{cached_img.file_path}'."
),
)
# Yield image info for LLM visibility
yield _HandleFunctionToolsResult.from_cached_image(
cached_img
)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="The tool has returned a data type that is not supported.",
),
_append_tool_call_result(
func_tool_id,
"The tool has returned a data type that is not supported.",
)

elif resp is None:
Expand All @@ -767,24 +821,18 @@ async def _handle_function_tools(
)
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="The tool has no return value, or has sent the result directly to the user.",
),
_append_tool_call_result(
func_tool_id,
"The tool has no return value, or has sent the result directly to the user.",
)
else:
# 不应该出现其他类型
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)}。",
)
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
),
_append_tool_call_result(
func_tool_id,
"*The tool has returned an unsupported type. Please tell the user to check the definition and implementation of this tool.*",
)

try:
Expand All @@ -798,12 +846,9 @@ async def _handle_function_tools(
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
except Exception as e:
logger.warning(traceback.format_exc())
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {e!s}",
),
_append_tool_call_result(
func_tool_id,
f"error: {e!s}",
)

# yield the last tool call result
Expand Down
Loading