diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 46ec4346b3..fd6db5a98e 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -4,6 +4,8 @@ import traceback import typing as T import uuid +from collections.abc import Sequence +from collections.abc import Set as AbstractSet import mcp @@ -26,6 +28,7 @@ SEND_MESSAGE_TO_USER_TOOL, ) from astrbot.core.cron.events import CronMessageEvent +from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( CommandResult, MessageChain, @@ -34,10 +37,86 @@ from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history +from astrbot.core.utils.image_ref_utils import is_supported_image_ref +from astrbot.core.utils.string_utils import normalize_and_dedupe_strings class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + @classmethod + def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: + if image_urls_raw is None: + return [] + + if isinstance(image_urls_raw, str): + return [image_urls_raw] + + if isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance( + image_urls_raw, (str, bytes, bytearray) + ): + return [item for item in image_urls_raw if isinstance(item, str)] + + logger.debug( + "Unsupported image_urls type in handoff tool args: %s", + type(image_urls_raw).__name__, + ) + return [] + + @classmethod + async def _collect_image_urls_from_message( + cls, run_context: ContextWrapper[AstrAgentContext] + ) -> list[str]: + urls: list[str] = [] + event = getattr(run_context.context, "event", None) + message_obj = getattr(event, "message_obj", None) + message = getattr(message_obj, "message", None) + if message: + for idx, component in enumerate(message): + if not isinstance(component, Image): + continue + try: + path = await component.convert_to_file_path() + if path: + urls.append(path) + except Exception as e: + logger.error( + "Failed to convert handoff image component at index %d: %s", + idx, + e, + exc_info=True, + ) + return urls + + @classmethod + async def _collect_handoff_image_urls( + cls, + run_context: ContextWrapper[AstrAgentContext], + image_urls_raw: T.Any, + ) -> list[str]: + candidates: list[str] = [] + candidates.extend(cls._collect_image_urls_from_args(image_urls_raw)) + candidates.extend(await cls._collect_image_urls_from_message(run_context)) + + normalized = normalize_and_dedupe_strings(candidates) + extensionless_local_roots = (get_astrbot_temp_path(),) + sanitized = [ + item + for item in normalized + if is_supported_image_ref( + item, + allow_extensionless_existing_local_file=True, + extensionless_local_roots=extensionless_local_roots, + ) + ] + dropped_count = len(normalized) - len(sanitized) + if dropped_count > 0: + logger.debug( + "Dropped %d invalid image_urls entries in handoff image inputs.", + dropped_count, + ) + return sanitized + @classmethod async def execute(cls, tool, run_context, **tool_args): """执行函数调用。 @@ -161,10 +240,28 @@ async def _execute_handoff( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, + *, + image_urls_prepared: bool = False, + **tool_args: T.Any, ): + tool_args = dict(tool_args) input_ = tool_args.get("input") - image_urls = tool_args.get("image_urls") + if image_urls_prepared: + prepared_image_urls = tool_args.get("image_urls") + if isinstance(prepared_image_urls, list): + image_urls = prepared_image_urls + else: + logger.debug( + "Expected prepared handoff image_urls as list[str], got %s.", + type(prepared_image_urls).__name__, + ) + image_urls = [] + else: + image_urls = await cls._collect_handoff_image_urls( + run_context, + tool_args.get("image_urls"), + ) + tool_args["image_urls"] = image_urls # Build handoff toolset from registered tools plus runtime computer tools. toolset = cls._build_handoff_toolset(run_context, tool.agent.tools) @@ -263,8 +360,18 @@ async def _do_handoff_background( ) -> None: """Run the subagent handoff and, on completion, wake the main agent.""" result_text = "" + tool_args = dict(tool_args) + tool_args["image_urls"] = await cls._collect_handoff_image_urls( + run_context, + tool_args.get("image_urls"), + ) try: - async for r in cls._execute_handoff(tool, run_context, **tool_args): + async for r in cls._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + **tool_args, + ): if isinstance(r, mcp.types.CallToolResult): for content in r.content: if isinstance(content, mcp.types.TextContent): diff --git a/astrbot/core/utils/image_ref_utils.py b/astrbot/core/utils/image_ref_utils.py new file mode 100644 index 0000000000..204e576312 --- /dev/null +++ b/astrbot/core/utils/image_ref_utils.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import os +from collections.abc import Sequence +from pathlib import Path +from urllib.parse import unquote, urlparse + +ALLOWED_IMAGE_EXTENSIONS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tif", + ".tiff", + ".svg", + ".heic", +} + + +def resolve_file_url_path(image_ref: str) -> str: + parsed = urlparse(image_ref) + if parsed.scheme != "file": + return image_ref + + path = unquote(parsed.path or "") + netloc = unquote(parsed.netloc or "") + + # Keep support for file:///path and file:// forms. + if netloc and netloc.lower() != "localhost": + path = f"//{netloc}{path}" if path else netloc + elif not path and netloc: + path = netloc + + if os.name == "nt" and len(path) > 2 and path[0] == "/" and path[2] == ":": + path = path[1:] + + return path or image_ref + + +def _is_path_within_roots(path: str, roots: Sequence[str]) -> bool: + try: + candidate = Path(path).resolve(strict=False) + except Exception: + return False + + for root in roots: + try: + root_path = Path(root).resolve(strict=False) + candidate.relative_to(root_path) + return True + except Exception: + continue + return False + + +def is_supported_image_ref( + image_ref: str, + *, + allow_extensionless_existing_local_file: bool = False, + extensionless_local_roots: Sequence[str] | None = None, +) -> bool: + if not image_ref: + return False + + lowered = image_ref.lower() + if lowered.startswith(("http://", "https://", "base64://")): + return True + + file_path = ( + resolve_file_url_path(image_ref) if lowered.startswith("file://") else image_ref + ) + ext = os.path.splitext(file_path)[1].lower() + if ext in ALLOWED_IMAGE_EXTENSIONS: + return True + if not allow_extensionless_existing_local_file: + return False + if not extensionless_local_roots: + return False + # Keep support for extension-less temp files returned by image converters. + return ( + ext == "" + and os.path.exists(file_path) + and _is_path_within_roots(file_path, extensionless_local_roots) + ) diff --git a/astrbot/core/utils/quoted_message/image_refs.py b/astrbot/core/utils/quoted_message/image_refs.py index 009d6844a2..a1ea815516 100644 --- a/astrbot/core/utils/quoted_message/image_refs.py +++ b/astrbot/core/utils/quoted_message/image_refs.py @@ -3,16 +3,9 @@ import os from urllib.parse import urlsplit -IMAGE_EXTENSIONS = { - ".jpg", - ".jpeg", - ".png", - ".webp", - ".bmp", - ".tif", - ".tiff", - ".gif", -} +from astrbot.core.utils.image_ref_utils import ALLOWED_IMAGE_EXTENSIONS + +IMAGE_EXTENSIONS = ALLOWED_IMAGE_EXTENSIONS def normalize_file_like_url(path: str | None) -> str | None: diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py new file mode 100644 index 0000000000..9d405f1ab5 --- /dev/null +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -0,0 +1,296 @@ +from types import SimpleNamespace + +import mcp +import pytest + +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor +from astrbot.core.message.components import Image + + +class _DummyEvent: + def __init__(self, message_components: list[object] | None = None) -> None: + self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" + self.message_obj = SimpleNamespace(message=message_components or []) + + def get_extra(self, _key: str): + return None + + +class _DummyTool: + def __init__(self) -> None: + self.name = "transfer_to_subagent" + self.agent = SimpleNamespace(name="subagent") + + +def _build_run_context(message_components: list[object] | None = None): + event = _DummyEvent(message_components=message_components) + ctx = SimpleNamespace(event=event, context=SimpleNamespace()) + return ContextWrapper(context=ctx) + + +@pytest.mark.asyncio +async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image( + monkeypatch: pytest.MonkeyPatch, +): + async def _fake_convert_to_file_path(self): + return "/tmp/event_image.png" + + monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) + + run_context = _build_run_context([Image(file="file:///tmp/original.png")]) + image_urls_input = ( + " https://example.com/a.png ", + "/tmp/not_an_image.txt", + "/tmp/local.webp", + 123, + ) + + image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, + image_urls_input, + ) + + assert image_urls == [ + "https://example.com/a.png", + "/tmp/local.webp", + "/tmp/event_image.png", + ] + + +@pytest.mark.asyncio +async def test_collect_handoff_image_urls_skips_failed_event_image_conversion( + monkeypatch: pytest.MonkeyPatch, +): + async def _fake_convert_to_file_path(self): + raise RuntimeError("boom") + + monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) + + run_context = _build_run_context([Image(file="file:///tmp/original.png")]) + image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, + ["https://example.com/a.png"], + ) + + assert image_urls == ["https://example.com/a.png"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("image_refs", "expected_supported_refs"), + [ + pytest.param( + ( + "https://example.com/valid.png", + "base64://iVBORw0KGgoAAAANSUhEUgAAAAUA", + "file:///tmp/photo.heic", + "file://localhost/tmp/vector.svg", + "file://fileserver/share/image.webp", + "file:///tmp/not-image.txt", + "mailto:user@example.com", + "random-string-without-scheme-or-extension", + ), + { + "https://example.com/valid.png", + "base64://iVBORw0KGgoAAAANSUhEUgAAAAUA", + "file:///tmp/photo.heic", + "file://localhost/tmp/vector.svg", + "file://fileserver/share/image.webp", + }, + id="mixed_supported_and_unsupported_refs", + ), + ], +) +async def test_collect_handoff_image_urls_filters_supported_schemes_and_extensions( + image_refs: tuple[str, ...], + expected_supported_refs: set[str], +): + run_context = _build_run_context([]) + result = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, image_refs + ) + assert set(result) == expected_supported_refs + + +@pytest.mark.asyncio +async def test_collect_handoff_image_urls_collects_event_image_when_args_is_none( + monkeypatch: pytest.MonkeyPatch, +): + async def _fake_convert_to_file_path(self): + return "/tmp/event_only.png" + + monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) + + run_context = _build_run_context([Image(file="file:///tmp/original.png")]) + image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, + None, + ) + + assert image_urls == ["/tmp/event_only.png"] + + +@pytest.mark.asyncio +async def test_do_handoff_background_reports_prepared_image_urls( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + async def _fake_execute_handoff( + cls, tool, run_context, image_urls_prepared=False, **tool_args + ): + assert image_urls_prepared is True + yield mcp.types.CallToolResult( + content=[mcp.types.TextContent(type="text", text="ok")] + ) + + async def _fake_wake(cls, run_context, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr( + FunctionToolExecutor, + "_execute_handoff", + classmethod(_fake_execute_handoff), + ) + monkeypatch.setattr( + FunctionToolExecutor, + "_wake_main_agent_for_background_result", + classmethod(_fake_wake), + ) + + run_context = _build_run_context() + await FunctionToolExecutor._do_handoff_background( + tool=_DummyTool(), + run_context=run_context, + task_id="task-id", + input="hello", + image_urls="https://example.com/raw.png", + ) + + assert captured["tool_args"]["image_urls"] == ["https://example.com/raw.png"] + + +@pytest.mark.asyncio +async def test_execute_handoff_skips_renormalize_when_image_urls_prepared( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + def _boom(_items): + raise RuntimeError("normalize should not be called") + + async def _fake_get_current_chat_provider_id(_umo): + return "provider-id" + + async def _fake_tool_loop_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(completion_text="ok") + + context = SimpleNamespace( + get_current_chat_provider_id=_fake_get_current_chat_provider_id, + tool_loop_agent=_fake_tool_loop_agent, + get_config=lambda **_kwargs: {"provider_settings": {}}, + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + tool = SimpleNamespace( + name="transfer_to_subagent", + provider_id=None, + agent=SimpleNamespace( + name="subagent", + tools=[], + instructions="subagent-instructions", + begin_dialogs=[], + run_hooks=None, + ), + ) + + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.normalize_and_dedupe_strings", _boom + ) + + results = [] + async for result in FunctionToolExecutor._execute_handoff( + tool, + run_context, + image_urls_prepared=True, + input="hello", + image_urls=["https://example.com/raw.png"], + ): + results.append(result) + + assert len(results) == 1 + assert captured["image_urls"] == ["https://example.com/raw.png"] + + +@pytest.mark.asyncio +async def test_collect_handoff_image_urls_keeps_extensionless_existing_event_file( + monkeypatch: pytest.MonkeyPatch, +): + async def _fake_convert_to_file_path(self): + return "/tmp/astrbot-handoff-image" + + monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.get_astrbot_temp_path", lambda: "/tmp" + ) + monkeypatch.setattr( + "astrbot.core.utils.image_ref_utils.os.path.exists", lambda _: True + ) + + run_context = _build_run_context([Image(file="file:///tmp/original.png")]) + image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, + [], + ) + + assert image_urls == ["/tmp/astrbot-handoff-image"] + + +@pytest.mark.asyncio +async def test_collect_handoff_image_urls_filters_extensionless_missing_event_file( + monkeypatch: pytest.MonkeyPatch, +): + async def _fake_convert_to_file_path(self): + return "/tmp/astrbot-handoff-missing-image" + + monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.get_astrbot_temp_path", lambda: "/tmp" + ) + monkeypatch.setattr( + "astrbot.core.utils.image_ref_utils.os.path.exists", lambda _: False + ) + + run_context = _build_run_context([Image(file="file:///tmp/original.png")]) + image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, + [], + ) + + assert image_urls == [] + + +@pytest.mark.asyncio +async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root( + monkeypatch: pytest.MonkeyPatch, +): + async def _fake_convert_to_file_path(self): + return "/var/tmp/astrbot-handoff-image" + + monkeypatch.setattr(Image, "convert_to_file_path", _fake_convert_to_file_path) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.get_astrbot_temp_path", lambda: "/tmp" + ) + monkeypatch.setattr( + "astrbot.core.utils.image_ref_utils.os.path.exists", lambda _: True + ) + + run_context = _build_run_context([Image(file="file:///tmp/original.png")]) + image_urls = await FunctionToolExecutor._collect_handoff_image_urls( + run_context, + [], + ) + + assert image_urls == []