Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import traceback
import typing as T
import uuid
from collections.abc import Sequence
from collections.abc import Set as AbstractSet

import mcp

Expand All @@ -26,6 +28,7 @@
SEND_MESSAGE_TO_USER_TOOL,
)
from astrbot.core.cron.events import CronMessageEvent
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
CommandResult,
MessageChain,
Expand All @@ -34,10 +37,86 @@
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider.register import llm_tools
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.history_saver import persist_agent_history
from astrbot.core.utils.image_ref_utils import is_supported_image_ref
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings


class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@classmethod
def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]:
if image_urls_raw is None:
return []

if isinstance(image_urls_raw, str):
return [image_urls_raw]

if isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance(
image_urls_raw, (str, bytes, bytearray)
):
return [item for item in image_urls_raw if isinstance(item, str)]

logger.debug(
"Unsupported image_urls type in handoff tool args: %s",
type(image_urls_raw).__name__,
)
return []

@classmethod
async def _collect_image_urls_from_message(
cls, run_context: ContextWrapper[AstrAgentContext]
) -> list[str]:
urls: list[str] = []
event = getattr(run_context.context, "event", None)
message_obj = getattr(event, "message_obj", None)
message = getattr(message_obj, "message", None)
if message:
for idx, component in enumerate(message):
if not isinstance(component, Image):
continue
try:
path = await component.convert_to_file_path()
if path:
urls.append(path)
except Exception as e:
logger.error(
"Failed to convert handoff image component at index %d: %s",
idx,
e,
exc_info=True,
)
return urls

@classmethod
async def _collect_handoff_image_urls(
cls,
run_context: ContextWrapper[AstrAgentContext],
image_urls_raw: T.Any,
) -> list[str]:
candidates: list[str] = []
candidates.extend(cls._collect_image_urls_from_args(image_urls_raw))
candidates.extend(await cls._collect_image_urls_from_message(run_context))

normalized = normalize_and_dedupe_strings(candidates)
extensionless_local_roots = (get_astrbot_temp_path(),)
sanitized = [
item
for item in normalized
if is_supported_image_ref(
item,
allow_extensionless_existing_local_file=True,
extensionless_local_roots=extensionless_local_roots,
)
]
dropped_count = len(normalized) - len(sanitized)
if dropped_count > 0:
logger.debug(
"Dropped %d invalid image_urls entries in handoff image inputs.",
dropped_count,
)
return sanitized

@classmethod
async def execute(cls, tool, run_context, **tool_args):
"""执行函数调用。
Expand Down Expand Up @@ -161,10 +240,28 @@ async def _execute_handoff(
cls,
tool: HandoffTool,
run_context: ContextWrapper[AstrAgentContext],
**tool_args,
*,
image_urls_prepared: bool = False,
**tool_args: T.Any,
):
tool_args = dict(tool_args)
input_ = tool_args.get("input")
image_urls = tool_args.get("image_urls")
if image_urls_prepared:
prepared_image_urls = tool_args.get("image_urls")
if isinstance(prepared_image_urls, list):
image_urls = prepared_image_urls
else:
logger.debug(
"Expected prepared handoff image_urls as list[str], got %s.",
type(prepared_image_urls).__name__,
)
image_urls = []
else:
image_urls = await cls._collect_handoff_image_urls(
run_context,
tool_args.get("image_urls"),
)
tool_args["image_urls"] = image_urls

# Build handoff toolset from registered tools plus runtime computer tools.
toolset = cls._build_handoff_toolset(run_context, tool.agent.tools)
Expand Down Expand Up @@ -263,8 +360,18 @@ async def _do_handoff_background(
) -> None:
"""Run the subagent handoff and, on completion, wake the main agent."""
result_text = ""
tool_args = dict(tool_args)
tool_args["image_urls"] = await cls._collect_handoff_image_urls(
run_context,
tool_args.get("image_urls"),
)
try:
async for r in cls._execute_handoff(tool, run_context, **tool_args):
async for r in cls._execute_handoff(
tool,
run_context,
image_urls_prepared=True,
**tool_args,
):
if isinstance(r, mcp.types.CallToolResult):
for content in r.content:
if isinstance(content, mcp.types.TextContent):
Expand Down
86 changes: 86 additions & 0 deletions astrbot/core/utils/image_ref_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import os
from collections.abc import Sequence
from pathlib import Path
from urllib.parse import unquote, urlparse

ALLOWED_IMAGE_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".gif",
".webp",
".bmp",
".tif",
".tiff",
".svg",
".heic",
}


def resolve_file_url_path(image_ref: str) -> str:
parsed = urlparse(image_ref)
if parsed.scheme != "file":
return image_ref

path = unquote(parsed.path or "")
netloc = unquote(parsed.netloc or "")

# Keep support for file://<host>/path and file://<path> forms.
if netloc and netloc.lower() != "localhost":
path = f"//{netloc}{path}" if path else netloc
elif not path and netloc:
path = netloc

if os.name == "nt" and len(path) > 2 and path[0] == "/" and path[2] == ":":
path = path[1:]

return path or image_ref


def _is_path_within_roots(path: str, roots: Sequence[str]) -> bool:
try:
candidate = Path(path).resolve(strict=False)
except Exception:
return False

for root in roots:
try:
root_path = Path(root).resolve(strict=False)
candidate.relative_to(root_path)
return True
except Exception:
continue
return False


def is_supported_image_ref(
image_ref: str,
*,
allow_extensionless_existing_local_file: bool = False,
extensionless_local_roots: Sequence[str] | None = None,
) -> bool:
if not image_ref:
return False

lowered = image_ref.lower()
if lowered.startswith(("http://", "https://", "base64://")):
return True

file_path = (
resolve_file_url_path(image_ref) if lowered.startswith("file://") else image_ref
)
ext = os.path.splitext(file_path)[1].lower()
if ext in ALLOWED_IMAGE_EXTENSIONS:
return True
if not allow_extensionless_existing_local_file:
return False
if not extensionless_local_roots:
return False
# Keep support for extension-less temp files returned by image converters.
return (
ext == ""
and os.path.exists(file_path)
and _is_path_within_roots(file_path, extensionless_local_roots)
)
13 changes: 3 additions & 10 deletions astrbot/core/utils/quoted_message/image_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@
import os
from urllib.parse import urlsplit

IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".webp",
".bmp",
".tif",
".tiff",
".gif",
}
from astrbot.core.utils.image_ref_utils import ALLOWED_IMAGE_EXTENSIONS

IMAGE_EXTENSIONS = ALLOWED_IMAGE_EXTENSIONS


def normalize_file_like_url(path: str | None) -> str | None:
Expand Down
Loading