From 6d7b62d9ebdb761e05106a9212fa7bd6f1f2d2b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A8=E3=82=A4=E3=82=AB=E3=82=AF?= <62183434+zouyonghe@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:51:00 +0900 Subject: [PATCH 1/4] =?UTF-8?q?fix:=20=E5=B7=A5=E7=A8=8B=E5=8C=96=E6=94=B6?= =?UTF-8?q?=E6=95=9B=E5=B9=B6=E7=A7=BB=E9=99=A4=20ASYNC230/ASYNC240=20?= =?UTF-8?q?=E5=BF=BD=E7=95=A5=20(#5729)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test(skills): align sandbox cache tests with readonly behavior * ci(release): enforce core quality gate before publish * ci: enforce locked dependency installs in workflows * security: remove curl-pipe-shell installs * chore: align project python baseline to 3.12 * ci(dashboard): add explicit typecheck gate * chore(pre-commit): align ruff hook version with project * ci(codeql): add javascript-typescript analysis * chore(ruff): defer py312 migration lint rules * fix: resolve ruff violations without new ignores * fix: resolve ASYNC230 and ASYNC240 without ignores * fix(auth): replace utcnow with timezone-aware UTC now * fix: avoid blocking file read in file_to_base64 --- .github/workflows/codeql.yml | 2 + .github/workflows/coverage_test.yml | 9 +- .github/workflows/dashboard_ci.yml | 17 +- .github/workflows/docker-image.yml | 44 ++- .github/workflows/release.yml | 29 +- .pre-commit-config.yaml | 4 +- Dockerfile | 5 +- README.md | 2 +- README_fr.md | 2 +- README_ja.md | 2 +- README_ru.md | 2 +- README_zh-TW.md | 2 +- README_zh.md | 2 +- astrbot/cli/utils/plugin.py | 4 +- astrbot/core/agent/agent.py | 5 +- astrbot/core/agent/handoff.py | 5 +- astrbot/core/agent/hooks.py | 6 +- astrbot/core/agent/mcp_client.py | 6 +- astrbot/core/agent/runners/base.py | 4 +- .../agent/runners/coze/coze_agent_runner.py | 9 +- .../agent/runners/coze/coze_api_client.py | 18 +- .../dashscope/dashscope_agent_runner.py | 7 +- .../runners/deerflow/deerflow_agent_runner.py | 15 +- .../runners/deerflow/deerflow_api_client.py | 12 +- .../agent/runners/dify/dify_agent_runner.py | 11 +- .../agent/runners/dify/dify_api_client.py | 29 +- .../agent/runners/tool_loop_agent_runner.py | 7 +- astrbot/core/agent/tool.py | 6 +- astrbot/core/agent/tool_executor.py | 6 +- astrbot/core/astr_agent_run_util.py | 4 +- astrbot/core/astr_agent_tool_exec.py | 2 +- astrbot/core/astr_main_agent_resources.py | 3 +- astrbot/core/backup/exporter.py | 20 +- astrbot/core/backup/importer.py | 42 +-- astrbot/core/computer/booters/bay_manager.py | 6 +- astrbot/core/computer/booters/boxlite.py | 6 +- astrbot/core/computer/booters/local.py | 8 +- astrbot/core/computer/booters/shipyard_neo.py | 28 +- astrbot/core/computer/computer_client.py | 5 +- astrbot/core/computer/olayer/browser.py | 6 +- astrbot/core/computer/olayer/python.py | 2 +- astrbot/core/computer/olayer/shell.py | 2 +- astrbot/core/computer/tools/browser.py | 24 +- astrbot/core/computer/tools/fs.py | 5 +- astrbot/core/cron/manager.py | 4 +- astrbot/core/db/migration/helper.py | 3 +- astrbot/core/db/migration/migra_3_to_4.py | 6 +- astrbot/core/db/po.py | 8 +- astrbot/core/db/sqlite.py | 12 +- astrbot/core/file_token_service.py | 18 +- astrbot/core/knowledge_base/models.py | 16 +- astrbot/core/message/components.py | 59 ++-- .../sources/dingtalk/dingtalk_adapter.py | 4 +- .../core/platform/sources/discord/client.py | 7 +- .../discord/discord_platform_adapter.py | 8 +- .../platform/sources/kook/kook_adapter.py | 2 +- .../core/platform/sources/kook/kook_client.py | 8 +- .../platform/sources/lark/lark_adapter.py | 2 +- .../core/platform/sources/lark/lark_event.py | 77 ++--- .../core/platform/sources/line/line_event.py | 6 +- .../sources/misskey/misskey_adapter.py | 3 +- .../platform/sources/misskey/misskey_api.py | 22 +- .../qqofficial/qqofficial_message_event.py | 8 +- .../platform/sources/telegram/tg_adapter.py | 8 +- .../sources/webchat/message_parts_helper.py | 13 +- .../platform/sources/webchat/webchat_event.py | 12 +- .../platform/sources/wecom/wecom_adapter.py | 18 +- .../platform/sources/wecom/wecom_event.py | 265 ++++++++++-------- .../sources/wecom_ai_bot/wecomai_utils.py | 3 +- .../sources/wecom_ai_bot/wecomai_webhook.py | 7 +- .../weixin_offacc_adapter.py | 14 +- .../weixin_offacc_event.py | 99 ++++--- astrbot/core/provider/entities.py | 9 +- astrbot/core/provider/func_tool_manager.py | 74 +++-- astrbot/core/provider/provider.py | 25 +- .../core/provider/sources/anthropic_source.py | 11 +- .../core/provider/sources/dashscope_tts.py | 6 +- .../core/provider/sources/edge_tts_source.py | 255 ++++++++--------- .../sources/fishaudio_tts_api_source.py | 9 +- .../core/provider/sources/gemini_source.py | 8 +- astrbot/core/provider/sources/genie_tts.py | 8 +- .../provider/sources/gsv_selfhosted_source.py | 4 +- .../core/provider/sources/gsvi_tts_source.py | 121 ++++---- .../sources/minimax_tts_api_source.py | 5 +- .../core/provider/sources/openai_source.py | 8 +- .../provider/sources/openai_tts_api_source.py | 9 +- .../sources/sensevoice_selfhosted_source.py | 8 +- .../core/provider/sources/volcengine_tts.py | 8 +- .../provider/sources/whisper_api_source.py | 12 +- .../sources/whisper_selfhosted_source.py | 7 +- .../sources/xinference_stt_provider.py | 15 +- astrbot/core/skills/neo_skill_sync.py | 4 +- astrbot/core/skills/skill_manager.py | 4 +- astrbot/core/star/star_handler.py | 6 +- astrbot/core/star/star_manager.py | 56 ++-- astrbot/core/utils/datetime_utils.py | 6 +- astrbot/core/utils/io.py | 59 ++-- astrbot/core/utils/media_utils.py | 12 +- astrbot/core/utils/session_waiter.py | 10 +- astrbot/core/utils/temp_dir_cleaner.py | 2 +- astrbot/core/utils/tencent_record_helper.py | 40 +-- astrbot/dashboard/routes/api_key.py | 8 +- astrbot/dashboard/routes/auth.py | 2 +- astrbot/dashboard/routes/backup.py | 68 +++-- astrbot/dashboard/routes/chat.py | 24 +- astrbot/dashboard/routes/config.py | 14 +- astrbot/dashboard/routes/knowledge_base.py | 2 +- astrbot/dashboard/routes/live_chat.py | 6 +- astrbot/dashboard/routes/open_api.py | 2 +- astrbot/dashboard/routes/plugin.py | 19 +- astrbot/dashboard/routes/skills.py | 3 +- astrbot/dashboard/routes/stat.py | 24 +- main.py | 4 +- pyproject.toml | 6 +- tests/test_skill_manager_sandbox_cache.py | 11 +- tests/unit/test_io_file_to_base64.py | 16 ++ 116 files changed, 1177 insertions(+), 980 deletions(-) create mode 100644 tests/unit/test_io_file_to_base64.py diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 5aeef1eff0..8bccae959a 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -46,6 +46,8 @@ jobs: include: - language: python build-mode: none + - language: javascript-typescript + build-mode: none # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' # Use `c-cpp` to analyze code written in C, C++ or both # Use 'java-kotlin' to analyze code written in Java, Kotlin or both diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index f0019ee7e6..bd7beceb1c 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -23,12 +23,13 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 + with: + python-version: "3.12" - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install pytest pytest-asyncio pytest-cov - pip install --editable . + python -m pip install --upgrade pip uv + uv sync --group dev - name: Run tests run: | @@ -37,7 +38,7 @@ jobs: mkdir -p data/temp export TESTING=true export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }} - pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG + uv run pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG - name: Upload results to Codecov uses: codecov/codecov-action@v5 diff --git a/.github/workflows/dashboard_ci.yml b/.github/workflows/dashboard_ci.yml index 46d2fea735..921fc49922 100644 --- a/.github/workflows/dashboard_ci.yml +++ b/.github/workflows/dashboard_ci.yml @@ -13,18 +13,23 @@ jobs: - name: Checkout repository uses: actions/checkout@v6 + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.28.2 + - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '24.13.0' + cache: "pnpm" + cache-dependency-path: dashboard/pnpm-lock.yaml - - name: npm install, build + - name: Install and build run: | - cd dashboard - npm install pnpm -g - pnpm install - pnpm i --save-dev @types/markdown-it - pnpm run build + pnpm --dir dashboard install --frozen-lockfile + pnpm --dir dashboard run typecheck + pnpm --dir dashboard run build - name: Inject Commit SHA id: get_sha diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 18c8d49269..6300a65a03 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -25,6 +25,18 @@ jobs: fetch-depth: 1 fetch-tag: true + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.28.2 + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '24.13.0' + cache: "pnpm" + cache-dependency-path: dashboard/pnpm-lock.yaml + - name: Check for new commits today if: github.event_name == 'schedule' id: check-commits @@ -46,12 +58,10 @@ jobs: - name: Build Dashboard run: | - cd dashboard - npm install - npm run build - mkdir -p dist/assets - echo $(git rev-parse HEAD) > dist/assets/version - cd .. + pnpm --dir dashboard install --frozen-lockfile + pnpm --dir dashboard run build + mkdir -p dashboard/dist/assets + echo $(git rev-parse HEAD) > dashboard/dist/assets/version mkdir -p data cp -r dashboard/dist data/ @@ -123,6 +133,18 @@ jobs: fetch-depth: 1 fetch-tag: true + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.28.2 + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '24.13.0' + cache: "pnpm" + cache-dependency-path: dashboard/pnpm-lock.yaml + - name: Get latest tag (only on manual trigger) id: get-latest-tag if: github.event_name == 'workflow_dispatch' @@ -153,12 +175,10 @@ jobs: - name: Build Dashboard run: | - cd dashboard - npm install - npm run build - mkdir -p dist/assets - echo $(git rev-parse HEAD) > dist/assets/version - cd .. + pnpm --dir dashboard install --frozen-lockfile + pnpm --dir dashboard run build + mkdir -p dashboard/dist/assets + echo $(git rev-parse HEAD) > dashboard/dist/assets/version mkdir -p data cp -r dashboard/dist data/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 41f59f0a61..4950b7a4bf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,6 +18,29 @@ permissions: contents: write jobs: + verify-core: + name: Verify Core Quality Gate + runs-on: ubuntu-24.04 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ inputs.ref || github.ref }} + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install uv + shell: bash + run: python -m pip install uv + + - name: Run local PR gate checks + shell: bash + run: make pr-test-neo + build-dashboard: name: Build Dashboard runs-on: ubuntu-24.04 @@ -85,7 +108,8 @@ jobs: VERSION_TAG: ${{ steps.tag.outputs.tag }} shell: bash run: | - curl https://rclone.org/install.sh | sudo bash + sudo apt-get update + sudo apt-get install -y rclone mkdir -p ~/.config/rclone cat < ~/.config/rclone/rclone.conf @@ -106,6 +130,7 @@ jobs: name: Publish GitHub Release runs-on: ubuntu-24.04 needs: + - verify-core - build-dashboard steps: - name: Checkout repository @@ -226,7 +251,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.12" - name: Install uv shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8611e26984..4c2a126e9d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.14.1 + rev: v0.15.1 hooks: # Run the linter. - id: ruff-check @@ -22,4 +22,4 @@ repos: rev: v3.21.0 hooks: - id: pyupgrade - args: [--py310-plus] + args: [--py312-plus] diff --git a/Dockerfile b/Dockerfile index 992060d6ea..544c6d6ced 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,10 +13,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ bash \ ffmpeg \ curl \ - gnupg \ git \ - && curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \ - && apt-get install -y --no-install-recommends nodejs \ + nodejs \ + npm \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* diff --git a/README.md b/README.md index e3b096a324..9ac80f2874 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_fr.md b/README_fr.md index 3a586adfcb..a6e778df92 100644 --- a/README_fr.md +++ b/README_fr.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_ja.md b/README_ja.md index 43b73884db..c34106143d 100644 --- a/README_ja.md +++ b/README_ja.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_ru.md b/README_ru.md index 8848dd92d7..1bc1f5554d 100644 --- a/README_ru.md +++ b/README_ru.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_zh-TW.md b/README_zh-TW.md index e3291d0b0f..3bd2455b2e 100644 --- a/README_zh-TW.md +++ b/README_zh-TW.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_zh.md b/README_zh.md index 7a85217b40..dc2c015f01 100644 --- a/README_zh.md +++ b/README_zh.md @@ -17,7 +17,7 @@
-python +python zread Docker pull diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index c06dda3500..2764935d02 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -1,6 +1,6 @@ import shutil import tempfile -from enum import Enum +from enum import StrEnum from io import BytesIO from pathlib import Path from zipfile import ZipFile @@ -12,7 +12,7 @@ from .version_comparator import VersionComparator -class PluginStatus(str, Enum): +class PluginStatus(StrEnum): INSTALLED = "installed" NEED_UPDATE = "needs-update" NOT_INSTALLED = "not-installed" diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py index d6e2e7cb41..f4606b6da5 100644 --- a/astrbot/core/agent/agent.py +++ b/astrbot/core/agent/agent.py @@ -1,13 +1,12 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import Any from .hooks import BaseAgentRunHooks -from .run_context import TContext from .tool import FunctionTool @dataclass -class Agent(Generic[TContext]): +class Agent[TContext]: name: str instructions: str | None = None tools: list[str | FunctionTool] | None = None diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 8475009d3f..01fc5159c6 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,11 +1,8 @@ -from typing import Generic - from .agent import Agent -from .run_context import TContext from .tool import FunctionTool -class HandoffTool(FunctionTool, Generic[TContext]): +class HandoffTool[TContext](FunctionTool): """Handoff tool for delegating tasks to another agent.""" def __init__( diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 74ca6335b3..451a957539 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -1,14 +1,12 @@ -from typing import Generic - import mcp from astrbot.core.agent.tool import FunctionTool from astrbot.core.provider.entities import LLMResponse -from .run_context import ContextWrapper, TContext +from .run_context import ContextWrapper -class BaseAgentRunHooks(Generic[TContext]): +class BaseAgentRunHooks[TContext]: async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 18f4d47e04..5c4c19fade 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,7 +2,6 @@ import logging from contextlib import AsyncExitStack from datetime import timedelta -from typing import Generic from tenacity import ( before_sleep_log, @@ -16,7 +15,6 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe -from .run_context import TContext from .tool import FunctionTool try: @@ -101,7 +99,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return True, "" return False, f"HTTP {response.status}: {response.reason}" - except asyncio.TimeoutError: + except TimeoutError: return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" @@ -360,7 +358,7 @@ async def cleanup(self) -> None: self.running_event.set() -class MCPTool(FunctionTool, Generic[TContext]): +class MCPTool[TContext](FunctionTool): """A function tool that calls an MCP service.""" def __init__( diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 21e7964335..e1e3ff8e39 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -7,7 +7,7 @@ from ..hooks import BaseAgentRunHooks from ..response import AgentResponse -from ..run_context import ContextWrapper, TContext +from ..run_context import ContextWrapper class AgentState(Enum): @@ -19,7 +19,7 @@ class AgentState(Enum): ERROR = auto() # Error state -class BaseAgentRunner(T.Generic[TContext]): +class BaseAgentRunner[TContext]: @abc.abstractmethod async def reset( self, diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py index a8300bb711..0d7fab2070 100644 --- a/astrbot/core/agent/runners/coze/coze_agent_runner.py +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -1,7 +1,7 @@ import base64 import json -import sys import typing as T +from typing import override import astrbot.core.message.components as Comp from astrbot import logger @@ -18,11 +18,6 @@ from ..base import AgentResponse, AgentState, BaseAgentRunner from .coze_api_client import CozeAPIClient -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class CozeAgentRunner(BaseAgentRunner[TContext]): """Coze Agent Runner""" @@ -251,7 +246,7 @@ async def _execute_coze_request(self): conversation_id=conversation_id, auto_save_history=self.auto_save_history, stream=True, - timeout=self.timeout, + timeout_seconds=self.timeout, ): event_type = chunk.get("event") data = chunk.get("data", {}) diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index f5799dfbb7..03dbe64cc3 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -2,6 +2,7 @@ import io import json from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any import aiohttp @@ -90,7 +91,7 @@ async def upload_file( logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}") return file_id - except asyncio.TimeoutError: + except TimeoutError: logger.error("文件上传超时") raise Exception("文件上传超时") except Exception as e: @@ -128,7 +129,7 @@ async def chat_messages( conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, - timeout: float = 120, + timeout_seconds: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: """发送聊天消息并返回流式响应 @@ -139,7 +140,7 @@ async def chat_messages( conversation_id: 会话ID auto_save_history: 是否自动保存历史 stream: 是否流式响应 - timeout: 超时时间 + timeout_seconds: 超时时间 """ session = await self._ensure_session() @@ -166,7 +167,7 @@ async def chat_messages( url, json=payload, params=params, - timeout=aiohttp.ClientTimeout(total=timeout), + timeout=aiohttp.ClientTimeout(total=timeout_seconds), ) as response: if response.status == 401: raise Exception("Coze API 认证失败,请检查 API Key 是否正确") @@ -203,8 +204,8 @@ async def chat_messages( except json.JSONDecodeError: event_data = {"content": data_str} - except asyncio.TimeoutError: - raise Exception(f"Coze API 流式请求超时 ({timeout}秒)") + except TimeoutError: + raise Exception(f"Coze API 流式请求超时 ({timeout_seconds}秒)") except Exception as e: raise Exception(f"Coze API 流式请求失败: {e!s}") @@ -236,7 +237,7 @@ async def clear_context(self, conversation_id: str): except json.JSONDecodeError: raise Exception("Coze API 返回非JSON格式") - except asyncio.TimeoutError: + except TimeoutError: raise Exception("Coze API 请求超时") except aiohttp.ClientError as e: raise Exception(f"Coze API 请求失败: {e!s}") @@ -294,8 +295,7 @@ async def test_coze_api_client() -> None: client = CozeAPIClient(api_key=api_key) try: - with open("README.md", "rb") as f: - file_data = f.read() + file_data = await asyncio.to_thread(Path("README.md").read_bytes) file_id = await client.upload_file(file_data) print(f"Uploaded file_id: {file_id}") async for event in client.chat_messages( diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 1aaf6e3b9c..080e627d52 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -2,9 +2,9 @@ import functools import queue import re -import sys import threading import typing as T +from typing import override from dashscope import Application from dashscope.app.application_response import ApplicationResponse @@ -22,11 +22,6 @@ from ...run_context import ContextWrapper, TContext from ..base import AgentResponse, AgentState, BaseAgentRunner -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class DashscopeAgentRunner(BaseAgentRunner[TContext]): """Dashscope Agent Runner""" diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 50ec7c8262..9e4a114719 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -1,10 +1,10 @@ import asyncio import hashlib import json -import sys import typing as T from collections import deque from dataclasses import dataclass, field +from typing import override from uuid import uuid4 import astrbot.core.message.components as Comp @@ -40,11 +40,6 @@ get_message_id, ) -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class DeerFlowAgentRunner(BaseAgentRunner[TContext]): """DeerFlow Agent Runner via LangGraph HTTP API.""" @@ -378,7 +373,9 @@ async def _ensure_thread_id(self, session_id: str) -> str: if thread_id: return thread_id - thread = await self.api_client.create_thread(timeout=min(30, self.timeout)) + thread = await self.api_client.create_thread( + timeout_seconds=min(30, self.timeout) + ) thread_id = thread.get("thread_id", "") if not thread_id: raise Exception( @@ -639,7 +636,7 @@ async def _execute_deerflow_request(self): async for event in self.api_client.stream_run( thread_id=thread_id, payload=payload, - timeout=self.timeout, + timeout_seconds=self.timeout, ): event_type = event.get("event") data = event.get("data") @@ -666,7 +663,7 @@ async def _execute_deerflow_request(self): if event_type == "end": break - except (asyncio.TimeoutError, TimeoutError): + except TimeoutError: logger.warning( "DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.", self.timeout, diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 37a23f2432..4ae9432e09 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -139,7 +139,7 @@ async def __aexit__( ) -> None: await self.close() - async def create_thread(self, timeout: float = 20) -> dict[str, Any]: + async def create_thread(self, timeout_seconds: float = 20) -> dict[str, Any]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads" payload = {"metadata": {}} @@ -147,7 +147,7 @@ async def create_thread(self, timeout: float = 20) -> dict[str, Any]: url, json=payload, headers=self.headers, - timeout=timeout, + timeout=timeout_seconds, proxy=self.proxy, ) as resp: if resp.status not in (200, 201): @@ -161,7 +161,7 @@ async def stream_run( self, thread_id: str, payload: dict[str, Any], - timeout: float = 120, + timeout_seconds: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" @@ -183,9 +183,9 @@ async def stream_run( # Use socket read timeout so active heartbeats/chunks can keep the stream alive. stream_timeout = ClientTimeout( total=None, - connect=min(timeout, 30), - sock_connect=min(timeout, 30), - sock_read=timeout, + connect=min(timeout_seconds, 30), + sock_connect=min(timeout_seconds, 30), + sock_read=timeout_seconds, ) async with session.post( url, diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index 93f8d3570d..1630ebf08e 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -1,7 +1,7 @@ import base64 import os -import sys import typing as T +from typing import override import astrbot.core.message.components as Comp from astrbot.core import logger, sp @@ -19,11 +19,6 @@ from ..base import AgentResponse, AgentState, BaseAgentRunner from .dify_api_client import DifyAPIClient -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class DifyAgentRunner(BaseAgentRunner[TContext]): """Dify Agent Runner""" @@ -176,7 +171,7 @@ async def _execute_dify_request(self): user=session_id, conversation_id=conversation_id, files=files_payload, - timeout=self.timeout, + timeout_seconds=self.timeout, ): logger.debug(f"dify resp chunk: {chunk}") if chunk["event"] == "message" or chunk["event"] == "agent_message": @@ -216,7 +211,7 @@ async def _execute_dify_request(self): }, user=session_id, files=files_payload, - timeout=self.timeout, + timeout_seconds=self.timeout, ): logger.debug(f"dify workflow resp chunk: {chunk}") match chunk["event"]: diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index 26da6dfe9a..db7b923fcd 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -1,6 +1,8 @@ +import asyncio import codecs import json from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any from aiohttp import ClientResponse, ClientSession, FormData @@ -47,20 +49,20 @@ async def chat_messages( response_mode: str = "streaming", conversation_id: str = "", files: list[dict[str, Any]] | None = None, - timeout: float = 60, + timeout_seconds: float = 60, ) -> AsyncGenerator[dict[str, Any], None]: if files is None: files = [] url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") - payload.pop("timeout") + payload.pop("timeout_seconds") logger.info(f"chat_messages payload: {payload}") async with self.session.post( url, json=payload, headers=self.headers, - timeout=timeout, + timeout=timeout_seconds, ) as resp: if resp.status != 200: text = await resp.text() @@ -76,20 +78,20 @@ async def workflow_run( user: str, response_mode: str = "streaming", files: list[dict[str, Any]] | None = None, - timeout: float = 60, + timeout_seconds: float = 60, ): if files is None: files = [] url = f"{self.api_base}/workflows/run" payload = locals() payload.pop("self") - payload.pop("timeout") + payload.pop("timeout_seconds") logger.info(f"workflow_run payload: {payload}") async with self.session.post( url, json=payload, headers=self.headers, - timeout=timeout, + timeout=timeout_seconds, ) as resp: if resp.status != 200: text = await resp.text() @@ -134,14 +136,13 @@ async def file_upload( # 使用文件路径 import os - with open(file_path, "rb") as f: - file_content = f.read() - form.add_field( - "file", - file_content, - filename=os.path.basename(file_path), - content_type=mime_type or "application/octet-stream", - ) + file_content = await asyncio.to_thread(Path(file_path).read_bytes) + form.add_field( + "file", + file_content, + filename=os.path.basename(file_path), + content_type=mime_type or "application/octet-stream", + ) else: raise ValueError("file_path 和 file_data 不能同时为 None") diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 743b280070..cc231be693 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,10 +1,10 @@ import asyncio import copy -import sys import time import traceback import typing as T from dataclasses import dataclass, field +from typing import override from mcp.types import ( BlobResourceContents, @@ -44,11 +44,6 @@ from ..tool_executor import BaseFunctionToolExecutor from .base import AgentResponse, AgentState, BaseAgentRunner -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - @dataclass(slots=True) class _HandleFunctionToolsResult: diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index c2536708e6..98f354ae41 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,6 @@ import copy from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Any, Generic +from typing import Any import jsonschema import mcp @@ -10,7 +10,7 @@ from astrbot.core.message.message_event_result import MessageEventResult -from .run_context import ContextWrapper, TContext +from .run_context import ContextWrapper ParametersType = dict[str, Any] ToolExecResult = str | mcp.types.CallToolResult @@ -38,7 +38,7 @@ def validate_parameters(self) -> "ToolSchema": @dataclass -class FunctionTool(ToolSchema, Generic[TContext]): +class FunctionTool[TContext](ToolSchema): """A callable tool, for function calling.""" handler: ( diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 2704119d4f..8708fd97d2 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,13 +1,13 @@ from collections.abc import AsyncGenerator -from typing import Any, Generic +from typing import Any import mcp -from .run_context import ContextWrapper, TContext +from .run_context import ContextWrapper from .tool import FunctionTool -class BaseFunctionToolExecutor(Generic[TContext]): +class BaseFunctionToolExecutor[TContext]: @classmethod async def execute( cls, diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index dd65f92e69..f7178fe3bf 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -3,6 +3,7 @@ import time import traceback from collections.abc import AsyncGenerator +from pathlib import Path from astrbot.core import logger from astrbot.core.agent.message import Message @@ -509,8 +510,7 @@ async def _simulated_stream_tts( audio_path = await tts_provider.get_audio(text) if audio_path: - with open(audio_path, "rb") as f: - audio_data = f.read() + audio_data = await asyncio.to_thread(Path(audio_path).read_bytes) await audio_queue.put((text, audio_data)) except Exception as e: logger.error( diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 0dc8b9eeb7..be51ced720 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -625,7 +625,7 @@ async def _execute_local( exc_info=True, ) yield None - except asyncio.TimeoutError: + except TimeoutError: raise Exception( f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.", ) diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 2e0d8b0aa7..933346d1fe 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -1,3 +1,4 @@ +import asyncio import base64 import json import os @@ -241,7 +242,7 @@ async def _resolve_path_from_sandbox( bool: indicates whether the file was downloaded from sandbox. """ - if os.path.exists(path): + if await asyncio.to_thread(os.path.exists, path): return path, False # Try to check if the file exists in the sandbox diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index a922375998..d65ac7a843 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -4,11 +4,12 @@ 导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。 """ +import asyncio import hashlib import json import os import zipfile -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -83,7 +84,7 @@ async def export_all( output_dir = get_astrbot_backups_path() # 确保输出目录存在 - Path(output_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(Path(output_dir).mkdir, parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"astrbot_backup_{timestamp}.zip" @@ -160,9 +161,10 @@ async def export_all( # 3. 导出配置文件 if progress_callback: await progress_callback("config", 0, 100, "正在导出配置文件...") - if os.path.exists(self.config_path): - with open(self.config_path, encoding="utf-8") as f: - config_content = f.read() + if await asyncio.to_thread(os.path.exists, self.config_path): + config_content = await asyncio.to_thread( + Path(self.config_path).read_text, encoding="utf-8" + ) zf.writestr("config/cmd_config.json", config_content) self._add_checksum("config/cmd_config.json", config_content) if progress_callback: @@ -199,7 +201,7 @@ async def export_all( except Exception as e: logger.error(f"备份导出失败: {e}") # 清理失败的文件 - if os.path.exists(zip_path): + if await asyncio.to_thread(os.path.exists, zip_path): os.remove(zip_path) raise @@ -317,7 +319,7 @@ async def _export_directories( for dir_name, dir_path in backup_directories.items(): full_path = Path(dir_path) - if not full_path.exists(): + if not await asyncio.to_thread(full_path.exists): logger.debug(f"目录不存在,跳过: {full_path}") continue @@ -362,7 +364,7 @@ async def _export_attachments( for attachment in attachments: try: file_path = attachment.get("path", "") - if file_path and os.path.exists(file_path): + if file_path and await asyncio.to_thread(os.path.exists, file_path): # 使用 attachment_id 作为文件名 attachment_id = attachment.get("attachment_id", "") ext = os.path.splitext(file_path)[1] @@ -446,7 +448,7 @@ def _generate_manifest( manifest = { "version": BACKUP_MANIFEST_VERSION, "astrbot_version": VERSION, - "exported_at": datetime.now(timezone.utc).isoformat(), + "exported_at": datetime.now(UTC).isoformat(), "origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传 "schema_version": { "main_db": "v4", diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index b51c7d9560..5362ab3cbf 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -7,12 +7,13 @@ - 版本匹配时也需要用户确认 """ +import asyncio import json import os import shutil import zipfile from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -364,7 +365,7 @@ async def import_all( """ result = ImportResult() - if not os.path.exists(zip_path): + if not await asyncio.to_thread(os.path.exists, zip_path): result.add_error(f"备份文件不存在: {zip_path}") return result @@ -446,12 +447,13 @@ async def import_all( try: config_content = zf.read("config/cmd_config.json") # 备份现有配置 - if os.path.exists(self.config_path): + if await asyncio.to_thread(os.path.exists, self.config_path): backup_path = f"{self.config_path}.bak" shutil.copy2(self.config_path, backup_path) - with open(self.config_path, "wb") as f: - f.write(config_content) + await asyncio.to_thread( + Path(self.config_path).write_bytes, config_content + ) result.imported_files["config"] = 1 except Exception as e: result.add_warning(f"导入配置文件失败: {e}") @@ -675,9 +677,9 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime): dt = value if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) + dt = dt.replace(tzinfo=UTC) else: - dt = dt.astimezone(timezone.utc) + dt = dt.astimezone(UTC) return dt.isoformat() if isinstance(value, str): timestamp = value.strip() @@ -688,9 +690,9 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: try: dt = datetime.fromisoformat(timestamp) if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) + dt = dt.replace(tzinfo=UTC) else: - dt = dt.astimezone(timezone.utc) + dt = dt.astimezone(UTC) return dt.isoformat() except ValueError: return None @@ -753,8 +755,8 @@ async def _import_knowledge_bases( if faiss_path in zf.namelist(): try: target_path = kb_dir / "index.faiss" - with zf.open(faiss_path) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(faiss_path) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) except Exception as e: result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}") @@ -766,8 +768,8 @@ async def _import_knowledge_bases( rel_path = name[len(media_prefix) :] target_path = kb_dir / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(name) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) except Exception as e: result.add_warning(f"导入媒体文件 {name} 失败: {e}") @@ -828,8 +830,8 @@ async def _import_attachments( target_path = attachments_dir / os.path.basename(name) target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(name) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) count += 1 except Exception as e: logger.warning(f"导入附件 {name} 失败: {e}") @@ -885,15 +887,15 @@ async def _import_directories( continue # 备份现有目录(如果存在) - if target_dir.exists(): + if await asyncio.to_thread(target_dir.exists): backup_path = Path(f"{target_dir}.bak") - if backup_path.exists(): + if await asyncio.to_thread(backup_path.exists): shutil.rmtree(backup_path) shutil.move(str(target_dir), str(backup_path)) logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}") # 创建目标目录 - target_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True) # 解压文件 for name in dir_files: @@ -906,8 +908,8 @@ async def _import_directories( target_path = target_dir / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(name) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) file_count += 1 except Exception as e: result.add_warning(f"导入文件 {name} 失败: {e}") diff --git a/astrbot/core/computer/booters/bay_manager.py b/astrbot/core/computer/booters/bay_manager.py index 24fa379e82..da429eb652 100644 --- a/astrbot/core/computer/booters/bay_manager.py +++ b/astrbot/core/computer/booters/bay_manager.py @@ -118,10 +118,10 @@ async def ensure_running(self) -> str: return f"http://127.0.0.1:{self._host_port}" - async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: + async def wait_healthy(self, timeout_seconds: int = HEALTH_TIMEOUT_S) -> None: """Block until Bay's ``/health`` endpoint returns 200.""" url = f"http://127.0.0.1:{self._host_port}/health" - deadline = asyncio.get_event_loop().time() + timeout + deadline = asyncio.get_event_loop().time() + timeout_seconds last_error: str = "" async with aiohttp.ClientSession() as session: @@ -140,7 +140,7 @@ async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: await asyncio.sleep(HEALTH_POLL_INTERVAL_S) raise TimeoutError( - f"Bay did not become healthy within {timeout}s (last error: {last_error})" + f"Bay did not become healthy within {timeout_seconds}s (last error: {last_error})" ) async def read_credentials(self) -> str: diff --git a/astrbot/core/computer/booters/boxlite.py b/astrbot/core/computer/booters/boxlite.py index 70064fdd48..337f5a68e2 100644 --- a/astrbot/core/computer/booters/boxlite.py +++ b/astrbot/core/computer/booters/boxlite.py @@ -1,5 +1,6 @@ import asyncio import random +from pathlib import Path from typing import Any import aiohttp @@ -46,8 +47,7 @@ async def upload_file(self, path: str, remote_path: str) -> dict: try: # Read file content - with open(path, "rb") as f: - file_content = f.read() + file_content = await asyncio.to_thread(Path(path).read_bytes) # Create multipart form data data = aiohttp.FormData() @@ -88,7 +88,7 @@ async def upload_file(self, path: str, remote_path: str) -> dict: "error": f"Connection error: {str(e)}", "message": "File upload failed", } - except asyncio.TimeoutError: + except TimeoutError: return { "success": False, "error": "File upload timeout", diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index a80ef0da28..011ac45f42 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -59,7 +59,7 @@ async def exec( command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout: int | None = 30, + timeout_seconds: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: @@ -87,7 +87,7 @@ def _run() -> dict[str, Any]: shell=shell, cwd=working_dir, env=run_env, - timeout=timeout, + timeout=timeout_seconds, capture_output=True, text=True, ) @@ -106,14 +106,14 @@ async def exec( self, code: str, kernel_id: str | None = None, - timeout: int = 30, + timeout_seconds: int = 30, silent: bool = False, ) -> dict[str, Any]: def _run() -> dict[str, Any]: try: result = subprocess.run( [os.environ.get("PYTHON", sys.executable), "-c", code], - timeout=timeout, + timeout=timeout_seconds, capture_output=True, text=True, ) diff --git a/astrbot/core/computer/booters/shipyard_neo.py b/astrbot/core/computer/booters/shipyard_neo.py index 6304696ad2..6c6f62bb5f 100644 --- a/astrbot/core/computer/booters/shipyard_neo.py +++ b/astrbot/core/computer/booters/shipyard_neo.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import os import shlex +from pathlib import Path from typing import Any, cast from astrbot.api import logger @@ -33,11 +35,11 @@ async def exec( self, code: str, kernel_id: str | None = None, - timeout: int = 30, + timeout_seconds: int = 30, silent: bool = False, ) -> dict[str, Any]: _ = kernel_id # Bay runtime does not expose kernel_id in current SDK. - result = await self._sandbox.python.exec(code, timeout=timeout) + result = await self._sandbox.python.exec(code, timeout=timeout_seconds) payload = _maybe_model_dump(result) output_text = payload.get("output", "") or "" @@ -75,7 +77,7 @@ async def exec( command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout: int | None = 30, + timeout_seconds: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: @@ -99,7 +101,7 @@ async def exec( result = await self._sandbox.shell.exec( run_command, - timeout=timeout or 30, + timeout=timeout_seconds or 30, cwd=cwd, ) payload = _maybe_model_dump(result) @@ -192,7 +194,7 @@ def __init__(self, sandbox: Any) -> None: async def exec( self, cmd: str, - timeout: int = 30, + timeout_seconds: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, @@ -200,7 +202,7 @@ async def exec( ) -> dict[str, Any]: result = await self._sandbox.browser.exec( cmd, - timeout=timeout, + timeout=timeout_seconds, description=description, tags=tags, learn=learn, @@ -211,7 +213,7 @@ async def exec( async def exec_batch( self, commands: list[str], - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, @@ -220,7 +222,7 @@ async def exec_batch( ) -> dict[str, Any]: result = await self._sandbox.browser.exec_batch( commands, - timeout=timeout, + timeout=timeout_seconds, stop_on_error=stop_on_error, description=description, tags=tags, @@ -232,7 +234,7 @@ async def exec_batch( async def run_skill( self, skill_key: str, - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, @@ -240,7 +242,7 @@ async def run_skill( ) -> dict[str, Any]: result = await self._sandbox.browser.run_skill( skill_key=skill_key, - timeout=timeout, + timeout=timeout_seconds, stop_on_error=stop_on_error, include_trace=include_trace, description=description, @@ -468,8 +470,7 @@ def browser(self) -> BrowserComponent: async def upload_file(self, path: str, file_name: str) -> dict: if self._sandbox is None: raise RuntimeError("ShipyardNeoBooter is not initialized.") - with open(path, "rb") as f: - content = f.read() + content = await asyncio.to_thread(Path(path).read_bytes) remote_path = file_name.lstrip("/") await self._sandbox.filesystem.upload(remote_path, content) logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path) @@ -486,8 +487,7 @@ async def download_file(self, remote_path: str, local_path: str) -> None: local_dir = os.path.dirname(local_path) if local_dir: os.makedirs(local_dir, exist_ok=True) - with open(local_path, "wb") as f: - f.write(cast(bytes, content)) + await asyncio.to_thread(Path(local_path).write_bytes, cast(bytes, content)) logger.info( "[Computer] File downloaded from Neo sandbox: %s -> %s", remote_path, diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index aa10d125e7..1adaeae08c 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -1,3 +1,4 @@ +import asyncio import json import os import shutil @@ -372,12 +373,12 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: splitting into `apply` and `scan` phases. """ skills_root = Path(get_astrbot_skills_path()) - if not skills_root.is_dir(): + if not await asyncio.to_thread(skills_root.is_dir): return local_skill_dirs = _list_local_skill_dirs(skills_root) temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) zip_base = temp_dir / "skills_bundle" zip_path = zip_base.with_suffix(".zip") diff --git a/astrbot/core/computer/olayer/browser.py b/astrbot/core/computer/olayer/browser.py index aa69f4501d..5bc40a4462 100644 --- a/astrbot/core/computer/olayer/browser.py +++ b/astrbot/core/computer/olayer/browser.py @@ -11,7 +11,7 @@ class BrowserComponent(Protocol): async def exec( self, cmd: str, - timeout: int = 30, + timeout_seconds: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, @@ -23,7 +23,7 @@ async def exec( async def exec_batch( self, commands: list[str], - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, @@ -36,7 +36,7 @@ async def exec_batch( async def run_skill( self, skill_key: str, - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, diff --git a/astrbot/core/computer/olayer/python.py b/astrbot/core/computer/olayer/python.py index 6255041463..09bf497db4 100644 --- a/astrbot/core/computer/olayer/python.py +++ b/astrbot/core/computer/olayer/python.py @@ -12,7 +12,7 @@ async def exec( self, code: str, kernel_id: str | None = None, - timeout: int = 30, + timeout_seconds: int = 30, silent: bool = False, ) -> dict[str, Any]: """Execute Python code""" diff --git a/astrbot/core/computer/olayer/shell.py b/astrbot/core/computer/olayer/shell.py index df2263b65a..67d9f95efd 100644 --- a/astrbot/core/computer/olayer/shell.py +++ b/astrbot/core/computer/olayer/shell.py @@ -13,7 +13,7 @@ async def exec( command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout: int | None = 30, + timeout_seconds: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: diff --git a/astrbot/core/computer/tools/browser.py b/astrbot/core/computer/tools/browser.py index 70061ac313..80a9be11a2 100644 --- a/astrbot/core/computer/tools/browser.py +++ b/astrbot/core/computer/tools/browser.py @@ -71,19 +71,23 @@ async def call( self, context: ContextWrapper[AstrAgentContext], cmd: str, - timeout: int = 30, + timeout_seconds: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, include_trace: bool = False, + **kwargs: Any, ) -> ToolExecResult: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.exec( cmd=cmd, - timeout=timeout, + timeout_seconds=timeout_seconds, description=description, tags=tags, learn=learn, @@ -133,20 +137,24 @@ async def call( self, context: ContextWrapper[AstrAgentContext], commands: list[str], - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, learn: bool = False, include_trace: bool = False, + **kwargs: Any, ) -> ToolExecResult: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.exec_batch( commands=commands, - timeout=timeout, + timeout_seconds=timeout_seconds, stop_on_error=stop_on_error, description=description, tags=tags, @@ -181,19 +189,23 @@ async def call( self, context: ContextWrapper[AstrAgentContext], skill_key: str, - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, tags: str | None = None, + **kwargs: Any, ) -> ToolExecResult: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.run_skill( skill_key=skill_key, - timeout=timeout, + timeout_seconds=timeout_seconds, stop_on_error=stop_on_error, include_trace=include_trace, description=description, diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py index 31b7f3f513..d50025f4d4 100644 --- a/astrbot/core/computer/tools/fs.py +++ b/astrbot/core/computer/tools/fs.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid from dataclasses import dataclass, field @@ -111,10 +112,10 @@ async def call( ) try: # Check if file exists - if not os.path.exists(local_path): + if not await asyncio.to_thread(os.path.exists, local_path): return f"Error: File does not exist: {local_path}" - if not os.path.isfile(local_path): + if not await asyncio.to_thread(os.path.isfile, local_path): return f"Error: Path is not a file: {local_path}" # Use basename if sandbox_filename is not provided diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index d12878be3e..211514f7a2 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -1,7 +1,7 @@ import asyncio import json from collections.abc import Awaitable, Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from zoneinfo import ZoneInfo @@ -192,7 +192,7 @@ async def _run_job(self, job_id: str) -> None: job = await self.db.get_cron_job(job_id) if not job or not job.enabled: return - start_time = datetime.now(timezone.utc) + start_time = datetime.now(UTC) await self.db.update_cron_job( job_id, status="running", last_run_at=start_time, last_error=None ) diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index d7bca30678..47ecadf040 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -1,3 +1,4 @@ +import asyncio import os from astrbot.api import logger, sp @@ -22,7 +23,7 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: data_dir = get_astrbot_data_path() data_v3_db = os.path.join(data_dir, "data_v3.db") - if not os.path.exists(data_v3_db): + if not await asyncio.to_thread(os.path.exists, data_v3_db): return False migration_done = await db_helper.get_preference( "global", diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 727d97b29b..d7a57a6d61 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -106,8 +106,8 @@ async def migration_platform_table( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) secs_from_2023_4_10_to_now = ( - datetime.datetime.now(datetime.timezone.utc) - - datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc) + datetime.datetime.now(datetime.UTC) + - datetime.datetime(2023, 4, 10, tzinfo=datetime.UTC) ).total_seconds() offset_sec = int(secs_from_2023_4_10_to_now) logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。") @@ -162,7 +162,7 @@ async def migration_platform_table( { "timestamp": datetime.datetime.fromtimestamp( bucket_end, - tz=datetime.timezone.utc, + tz=datetime.UTC, ), "platform_id": platform_id, "platform_type": platform_type, diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 451f054f62..1b8179f074 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,16 +1,16 @@ import uuid from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import TypedDict from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint class TimestampMixin(SQLModel): - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(UTC), + sa_column_kwargs={"onupdate": lambda: datetime.now(UTC)}, ) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index f496e19d59..e356e85aa3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -2,7 +2,7 @@ import threading import typing as T from collections.abc import Awaitable, Callable -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from sqlalchemy import CursorResult, Row from sqlalchemy.ext.asyncio import AsyncSession @@ -633,7 +633,7 @@ async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None: """Get an active API key by hash (not revoked, not expired).""" async with self.get_db() as session: session: AsyncSession - now = datetime.now(timezone.utc) + now = datetime.now(UTC) query = select(ApiKey).where( ApiKey.key_hash == key_hash, col(ApiKey.revoked_at).is_(None), @@ -650,7 +650,7 @@ async def touch_api_key(self, key_id: str) -> None: await session.execute( update(ApiKey) .where(col(ApiKey.key_id) == key_id) - .values(last_used_at=datetime.now(timezone.utc)), + .values(last_used_at=datetime.now(UTC)), ) async def revoke_api_key(self, key_id: str) -> bool: @@ -661,7 +661,7 @@ async def revoke_api_key(self, key_id: str) -> bool: query = ( update(ApiKey) .where(col(ApiKey.key_id) == key_id) - .values(revoked_at=datetime.now(timezone.utc)) + .values(revoked_at=datetime.now(UTC)) ) result = T.cast(CursorResult, await session.execute(query)) return result.rowcount > 0 @@ -1534,7 +1534,7 @@ async def update_platform_session( async with self.get_db() as session: session: AsyncSession async with session.begin(): - values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)} if display_name is not None: values["display_name"] = display_name @@ -1622,7 +1622,7 @@ async def update_chatui_project( async with self.get_db() as session: session: AsyncSession async with session.begin(): - values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)} if title is not None: values["title"] = title if emoji is not None: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index 42fbd23dfe..5aa897e79b 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -28,12 +28,17 @@ async def check_token_expired(self, file_token: str) -> bool: await self._cleanup_expired_tokens() return file_token not in self.staged_files - async def register_file(self, file_path: str, timeout: float | None = None) -> str: + async def register_file( + self, + file_path: str, + timeout_seconds: float | None = None, + **kwargs, + ) -> str: """向令牌服务注册一个文件。 Args: file_path(str): 文件路径 - timeout(float): 超时时间,单位秒(可选) + timeout_seconds(float): 超时时间,单位秒(可选) Returns: str: 一个单次令牌 @@ -58,15 +63,18 @@ async def register_file(self, file_path: str, timeout: float | None = None) -> s async with self.lock: await self._cleanup_expired_tokens() + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = float(legacy_timeout) - if not os.path.exists(local_path): + if not await asyncio.to_thread(os.path.exists, local_path): raise FileNotFoundError( f"文件不存在: {local_path} (原始输入: {file_path})", ) file_token = str(uuid.uuid4()) expire_time = time.time() + ( - timeout if timeout is not None else self.default_timeout + timeout_seconds if timeout_seconds is not None else self.default_timeout ) # 存储转换后的真实路径 self.staged_files[file_token] = (local_path, expire_time) @@ -93,6 +101,6 @@ async def handle_file(self, file_token: str) -> str: raise KeyError(f"无效或过期的文件 token: {file_token}") file_path, _ = self.staged_files.pop(file_token) - if not os.path.exists(file_path): + if not await asyncio.to_thread(os.path.exists, file_path): raise FileNotFoundError(f"文件不存在: {file_path}") return file_path diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index da919a384a..10277a926b 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint @@ -40,10 +40,10 @@ class KnowledgeBase(BaseKBModel, table=True): top_k_dense: int | None = Field(default=50, nullable=True) top_k_sparse: int | None = Field(default=50, nullable=True) top_m_final: int | None = Field(default=5, nullable=True) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(UTC), + sa_column_kwargs={"onupdate": datetime.now(UTC)}, ) doc_count: int = Field(default=0, nullable=False) chunk_count: int = Field(default=0, nullable=False) @@ -83,10 +83,10 @@ class KBDocument(BaseKBModel, table=True): file_path: str = Field(max_length=512, nullable=False) chunk_count: int = Field(default=0, nullable=False) media_count: int = Field(default=0, nullable=False) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(UTC), + sa_column_kwargs={"onupdate": datetime.now(UTC)}, ) @@ -117,4 +117,4 @@ class KBMedia(BaseKBModel, table=True): file_path: str = Field(max_length=512, nullable=False) file_size: int = Field(nullable=False) mime_type: str = Field(max_length=100, nullable=False) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 15265c38d1..038e997424 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -27,7 +27,8 @@ import os import sys import uuid -from enum import Enum +from enum import StrEnum +from pathlib import Path if sys.version_info >= (3, 14): from pydantic import BaseModel @@ -39,7 +40,7 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 -class ComponentType(str, Enum): +class ComponentType(StrEnum): # Basic Segment Types Plain = "Plain" # plain text message Image = "Image" # image @@ -158,18 +159,17 @@ async def convert_to_file_path(self) -> str: return self.file[8:] if self.file.startswith("http"): file_path = await download_image_by_url(self.file) - return os.path.abspath(file_path) + return await asyncio.to_thread(os.path.abspath, file_path) if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) file_path = os.path.join( get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" ) - with open(file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(file_path) - if os.path.exists(self.file): - return os.path.abspath(self.file) + await asyncio.to_thread(Path(file_path).write_bytes, image_bytes) + return await asyncio.to_thread(os.path.abspath, file_path) + if await asyncio.to_thread(os.path.exists, self.file): + return await asyncio.to_thread(os.path.abspath, self.file) raise Exception(f"not a valid file: {self.file}") async def convert_to_base64(self) -> str: @@ -183,14 +183,14 @@ async def convert_to_base64(self) -> str: if not self.file: raise Exception(f"not a valid file: {self.file}") if self.file.startswith("file:///"): - bs64_data = file_to_base64(self.file[8:]) + bs64_data = await file_to_base64(self.file[8:]) elif self.file.startswith("http"): file_path = await download_image_by_url(self.file) - bs64_data = file_to_base64(file_path) + bs64_data = await file_to_base64(file_path) elif self.file.startswith("base64://"): bs64_data = self.file - elif os.path.exists(self.file): - bs64_data = file_to_base64(self.file) + elif await asyncio.to_thread(os.path.exists, self.file): + bs64_data = await file_to_base64(self.file) else: raise Exception(f"not a valid file: {self.file}") bs64_data = bs64_data.removeprefix("base64://") @@ -256,11 +256,11 @@ async def convert_to_file_path(self) -> str: get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}" ) await download_file(url, video_file_path) - if os.path.exists(video_file_path): - return os.path.abspath(video_file_path) + if await asyncio.to_thread(os.path.exists, video_file_path): + return await asyncio.to_thread(os.path.abspath, video_file_path) raise Exception(f"download failed: {url}") - if os.path.exists(url): - return os.path.abspath(url) + if await asyncio.to_thread(os.path.exists, url): + return await asyncio.to_thread(os.path.abspath, url) raise Exception(f"not a valid file: {url}") async def register_to_file_service(self) -> str: @@ -449,18 +449,17 @@ async def convert_to_file_path(self) -> str: return url[8:] if url.startswith("http"): image_file_path = await download_image_by_url(url) - return os.path.abspath(image_file_path) + return await asyncio.to_thread(os.path.abspath, image_file_path) if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) image_file_path = os.path.join( get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg" ) - with open(image_file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(image_file_path) - if os.path.exists(url): - return os.path.abspath(url) + await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes) + return await asyncio.to_thread(os.path.abspath, image_file_path) + if await asyncio.to_thread(os.path.exists, url): + return await asyncio.to_thread(os.path.abspath, url) raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: @@ -475,14 +474,14 @@ async def convert_to_base64(self) -> str: if not url: raise ValueError("No valid file or URL provided") if url.startswith("file:///"): - bs64_data = file_to_base64(url[8:]) + bs64_data = await file_to_base64(url[8:]) elif url.startswith("http"): image_file_path = await download_image_by_url(url) - bs64_data = file_to_base64(image_file_path) + bs64_data = await file_to_base64(image_file_path) elif url.startswith("base64://"): bs64_data = url - elif os.path.exists(url): - bs64_data = file_to_base64(url) + elif await asyncio.to_thread(os.path.exists, url): + bs64_data = await file_to_base64(url) else: raise Exception(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") @@ -735,8 +734,8 @@ async def get_file(self, allow_return_url: bool = False) -> str: ): path = path[1:] - if os.path.exists(path): - return os.path.abspath(path) + if await asyncio.to_thread(os.path.exists, path): + return await asyncio.to_thread(os.path.abspath, path) if self.url: await self._download_file() @@ -751,7 +750,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: and path[2] == ":" ): path = path[1:] - return os.path.abspath(path) + return await asyncio.to_thread(os.path.abspath, path) return "" @@ -767,7 +766,7 @@ async def _download_file(self) -> None: filename = f"fileseg_{uuid.uuid4().hex}" file_path = os.path.join(download_dir, filename) await download_file(self.url, file_path) - self.file_ = os.path.abspath(file_path) + self.file_ = await asyncio.to_thread(os.path.abspath, file_path) async def register_to_file_service(self) -> str: """将文件注册到文件服务。 diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 2d9b45cc19..e823aac9d0 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -254,7 +254,7 @@ async def download_ding_file( "robotCode": robot_code, } temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}" async with ( aiohttp.ClientSession() as session, @@ -412,7 +412,7 @@ async def upload_media(self, file_path: str, media_type: str) -> str: form = aiohttp.FormData() form.add_field( "media", - media_file_path.read_bytes(), + await asyncio.to_thread(media_file_path.read_bytes), filename=media_file_path.name, content_type="application/octet-stream", ) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index ebd32c471a..36ee5710b4 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,15 +1,10 @@ -import sys from collections.abc import Awaitable, Callable +from typing import override import discord from astrbot import logger -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - # Discord Bot客户端 class DiscordBotClient(discord.Bot): diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 7657962a11..40be87a633 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,7 +1,6 @@ import asyncio import re -import sys -from typing import Any, cast +from typing import Any, cast, override import discord from discord.abc import GuildChannel, Messageable, PrivateChannel @@ -27,11 +26,6 @@ from .client import DiscordBotClient from .discord_platform_event import DiscordPlatformEvent -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - # 注册平台适配器 @register_platform_adapter( diff --git a/astrbot/core/platform/sources/kook/kook_adapter.py b/astrbot/core/platform/sources/kook/kook_adapter.py index 1124c6841d..b7d047291e 100644 --- a/astrbot/core/platform/sources/kook/kook_adapter.py +++ b/astrbot/core/platform/sources/kook/kook_adapter.py @@ -130,7 +130,7 @@ async def _main_loop(self): await asyncio.wait_for( self.client.wait_until_closed(), timeout=1.0 ) - except asyncio.TimeoutError: + except TimeoutError: # 正常超时,继续下一轮 while 检查 continue diff --git a/astrbot/core/platform/sources/kook/kook_client.py b/astrbot/core/platform/sources/kook/kook_client.py index 9a452a9c3f..34078e2ac2 100644 --- a/astrbot/core/platform/sources/kook/kook_client.py +++ b/astrbot/core/platform/sources/kook/kook_client.py @@ -171,7 +171,7 @@ async def listen(self): # 处理不同类型的信令 await self._handle_signal(data) - except asyncio.TimeoutError: + except TimeoutError: # 超时检查,继续循环 continue except websockets.exceptions.ConnectionClosed: @@ -362,12 +362,14 @@ async def upload_asset(self, file_url: str | None) -> str: b64_str = file_url.removeprefix("base64://") bytes_data = base64.b64decode(b64_str) - elif file_url.startswith("file://") or os.path.exists(file_url): + elif file_url.startswith("file://") or await asyncio.to_thread( + os.path.exists, file_url + ): file_url = file_url.removeprefix("file:///") file_url = file_url.removeprefix("file://") try: - target_path = Path(file_url).resolve() + target_path = await asyncio.to_thread(Path(file_url).resolve) except Exception as exp: logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"') raise FileNotFoundError( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index be1c81c26e..6b500dc5ca 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -429,7 +429,7 @@ async def _download_file_resource_to_temp( suffix = Path(file_name).suffix if file_name else default_suffix temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) temp_path = ( temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}" ) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 92e3a32b9e..a513f45005 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,8 +1,10 @@ +import asyncio import base64 import json import os import uuid from io import BytesIO +from pathlib import Path import lark_oapi as lark from lark_oapi.api.im.v1 import ( @@ -136,7 +138,7 @@ async def _upload_lark_file( Returns: 成功返回file_key,失败返回None """ - if not path or not os.path.exists(path): + if not path or not await asyncio.to_thread(os.path.exists, path): logger.error(f"[Lark] 文件不存在: {path}") return None @@ -145,36 +147,32 @@ async def _upload_lark_file( return None try: - with open(path, "rb") as file_obj: - body_builder = ( - CreateFileRequestBody.builder() - .file_type(file_type) - .file_name(os.path.basename(path)) - .file(file_obj) - ) - if duration is not None: - body_builder.duration(duration) + file_obj = BytesIO(await asyncio.to_thread(Path(path).read_bytes)) + body_builder = ( + CreateFileRequestBody.builder() + .file_type(file_type) + .file_name(os.path.basename(path)) + .file(file_obj) + ) + if duration is not None: + body_builder.duration(duration) - request = ( - CreateFileRequest.builder() - .request_body(body_builder.build()) - .build() - ) - response = await lark_client.im.v1.file.acreate(request) + request = ( + CreateFileRequest.builder().request_body(body_builder.build()).build() + ) + response = await lark_client.im.v1.file.acreate(request) - if not response.success(): - logger.error( - f"[Lark] 无法上传文件({response.code}): {response.msg}" - ) - return None + if not response.success(): + logger.error(f"[Lark] 无法上传文件({response.code}): {response.msg}") + return None - if response.data is None: - logger.error("[Lark] 上传文件成功但未返回数据(data is None)") - return None + if response.data is None: + logger.error("[Lark] 上传文件成功但未返回数据(data is None)") + return None - file_key = response.data.file_key - logger.debug(f"[Lark] 文件上传成功: {file_key}") - return file_key + file_key = response.data.file_key + logger.debug(f"[Lark] 文件上传成功: {file_key}") + return file_key except Exception as e: logger.error(f"[Lark] 无法打开或上传文件: {e}") @@ -207,8 +205,9 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l temp_dir, f"lark_image_{uuid.uuid4().hex[:8]}.jpg", ) - with open(file_path, "wb") as f: - f.write(BytesIO(image_data).getvalue()) + await asyncio.to_thread( + Path(file_path).write_bytes, BytesIO(image_data).getvalue() + ) else: file_path = comp.file if comp.file else "" @@ -217,7 +216,9 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l logger.error("[Lark] 图片路径为空,无法上传") continue try: - image_file = open(file_path, "rb") + image_file = BytesIO( + await asyncio.to_thread(Path(file_path).read_bytes) + ) except Exception as e: logger.error(f"[Lark] 无法打开图片文件: {e}") continue @@ -412,7 +413,9 @@ async def _send_audio_message( logger.error(f"[Lark] 无法获取音频文件路径: {e}") return - if not original_audio_path or not os.path.exists(original_audio_path): + if not original_audio_path or not await asyncio.to_thread( + os.path.exists, original_audio_path + ): logger.error(f"[Lark] 音频文件不存在: {original_audio_path}") return @@ -442,7 +445,9 @@ async def _send_audio_message( ) # 清理转换后的临时音频文件 - if converted_audio_path and os.path.exists(converted_audio_path): + if converted_audio_path and await asyncio.to_thread( + os.path.exists, converted_audio_path + ): try: os.remove(converted_audio_path) logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}") @@ -485,7 +490,9 @@ async def _send_media_message( logger.error(f"[Lark] 无法获取视频文件路径: {e}") return - if not original_video_path or not os.path.exists(original_video_path): + if not original_video_path or not await asyncio.to_thread( + os.path.exists, original_video_path + ): logger.error(f"[Lark] 视频文件不存在: {original_video_path}") return @@ -515,7 +522,9 @@ async def _send_media_message( ) # 清理转换后的临时视频文件 - if converted_video_path and os.path.exists(converted_video_path): + if converted_video_path and await asyncio.to_thread( + os.path.exists, converted_video_path + ): try: os.remove(converted_video_path) logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}") diff --git a/astrbot/core/platform/sources/line/line_event.py b/astrbot/core/platform/sources/line/line_event.py index 8b82ad1820..a16bdd18f8 100644 --- a/astrbot/core/platform/sources/line/line_event.py +++ b/astrbot/core/platform/sources/line/line_event.py @@ -161,7 +161,7 @@ async def _resolve_video_preview_url(segment: Video) -> str: try: video_path = await segment.convert_to_file_path() temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg" process = await asyncio.create_subprocess_exec( @@ -201,8 +201,8 @@ async def _resolve_file_url(segment: File) -> str: async def _resolve_file_size(segment: File) -> int: try: file_path = await segment.get_file(allow_return_url=False) - if file_path and os.path.exists(file_path): - return int(os.path.getsize(file_path)) + if file_path and await asyncio.to_thread(os.path.exists, file_path): + return int(await asyncio.to_thread(os.path.getsize, file_path)) except Exception as e: logger.debug("[LINE] resolve file size failed: %s", e) return 0 diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index fd61c3e506..e1169decb4 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -499,7 +499,8 @@ async def _upload_comp(comp) -> object | None: # 清理临时文件 if local_path and isinstance(local_path, str): data_temp = get_astrbot_temp_path() - if local_path.startswith(data_temp) and os.path.exists( + if local_path.startswith(data_temp) and await asyncio.to_thread( + os.path.exists, local_path, ): try: diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 3e5eb9a90e..64728a5616 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,6 +3,7 @@ import random import uuid from collections.abc import Awaitable, Callable +from pathlib import Path from typing import Any, NoReturn try: @@ -555,22 +556,19 @@ async def upload_file( form.add_field("folderId", str(folder_id)) try: - f = open(file_path, "rb") + file_bytes = await asyncio.to_thread(Path(file_path).read_bytes) except FileNotFoundError as e: logger.error(f"[Misskey API] 本地文件不存在: {file_path}") raise APIError(f"File not found: {file_path}") from e - try: - form.add_field("file", f, filename=filename) - async with self.session.post(url, data=form) as resp: - result = await self._process_response(resp, "drive/files/create") - file_id = FileIDExtractor.extract_file_id(result) - logger.debug( - f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", - ) - return {"id": file_id, "raw": result} - finally: - f.close() + form.add_field("file", file_bytes, filename=filename) + async with self.session.post(url, data=form) as resp: + result = await self._process_response(resp, "drive/files/create") + file_id = FileIDExtractor.extract_file_id(result) + logger.debug( + f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", + ) + return {"id": file_id, "raw": result} except aiohttp.ClientError as e: logger.error(f"[Misskey API] 文件上传网络错误: {e}") raise APIConnectionError(f"Upload failed: {e}") from e diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 868ec8a657..55050a8219 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -339,7 +339,7 @@ async def upload_group_and_c2c_record( payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} # 处理文件数据 - if os.path.exists(file_source): + if await asyncio.to_thread(os.path.exists, file_source): # 读取本地文件 async with aiofiles.open(file_source, "rb") as f: file_content = await f.read() @@ -421,15 +421,15 @@ async def _parse_to_qqofficial(message: MessageChain): plain_text += i.text elif isinstance(i, Image) and not image_base64: if i.file and i.file.startswith("file:///"): - image_base64 = file_to_base64(i.file[8:]) + image_base64 = await file_to_base64(i.file[8:]) image_file_path = i.file[8:] elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) - image_base64 = file_to_base64(image_file_path) + image_base64 = await file_to_base64(image_file_path) elif i.file and i.file.startswith("base64://"): image_base64 = i.file elif i.file: - image_base64 = file_to_base64(i.file) + image_base64 = await file_to_base64(i.file) else: raise ValueError("Unsupported image file format") image_base64 = image_base64.removeprefix("base64://") diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 2dd72bd0ca..76e3f9d985 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -1,9 +1,8 @@ import asyncio import os import re -import sys import uuid -from typing import cast +from typing import cast, override from apscheduler.schedulers.asyncio import AsyncIOScheduler from telegram import BotCommand, Update @@ -33,11 +32,6 @@ from .tg_event import TelegramPlatformEvent -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - @register_platform_adapter("telegram", "telegram 适配器") class TelegramPlatformAdapter(Platform): diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py index 43072ec1c8..3a1371e723 100644 --- a/astrbot/core/platform/sources/webchat/message_parts_helper.py +++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py @@ -1,3 +1,4 @@ +import asyncio import json import mimetypes import shutil @@ -139,13 +140,15 @@ async def parse_webchat_message_parts( continue file_path = Path(str(path)) - if verify_media_path_exists and not file_path.exists(): + if verify_media_path_exists and not await asyncio.to_thread(file_path.exists): if strict: raise ValueError(f"file not found: {file_path!s}") continue file_path_str = ( - str(file_path.resolve()) if verify_media_path_exists else str(file_path) + str(await asyncio.to_thread(file_path.resolve)) + if verify_media_path_exists + else str(file_path) ) has_content = True if part_type == "image": @@ -366,7 +369,7 @@ async def message_chain_to_storage_message_parts( attachments_dir: str | Path, ) -> list[dict]: target_dir = Path(attachments_dir) - target_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True) parts: list[dict] = [] for comp in message_chain.chain: @@ -442,7 +445,9 @@ async def _copy_file_to_attachment_part( display_name: str | None = None, ) -> dict | None: src_path = Path(file_path) - if not src_path.exists() or not src_path.is_file(): + if not await asyncio.to_thread(src_path.exists) or not await asyncio.to_thread( + src_path.is_file + ): return None suffix = src_path.suffix diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index b7da864aae..aacb0e12da 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,8 +1,10 @@ +import asyncio import base64 import json import os import shutil import uuid +from pathlib import Path from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -80,8 +82,9 @@ async def _send( filename = f"{str(uuid.uuid4())}.jpg" path = os.path.join(attachments_dir, filename) image_base64 = await comp.convert_to_base64() - with open(path, "wb") as f: - f.write(base64.b64decode(image_base64)) + await asyncio.to_thread( + Path(path).write_bytes, base64.b64decode(image_base64) + ) data = f"[IMAGE]{filename}" await web_chat_back_queue.put( { @@ -96,8 +99,9 @@ async def _send( filename = f"{str(uuid.uuid4())}.wav" path = os.path.join(attachments_dir, filename) record_base64 = await comp.convert_to_base64() - with open(path, "wb") as f: - f.write(base64.b64decode(record_base64)) + await asyncio.to_thread( + Path(path).write_bytes, base64.b64decode(record_base64) + ) data = f"[RECORD]{filename}" await web_chat_back_queue.put( { diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 6647db89f0..b77a0da3e4 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -1,9 +1,9 @@ import asyncio import os -import sys import uuid from collections.abc import Awaitable, Callable -from typing import Any, cast +from pathlib import Path +from typing import Any, cast, override import quart from requests import Response @@ -33,11 +33,6 @@ from .wecom_kf import WeChatKF from .wecom_kf_message import WeChatKFMessage -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: @@ -346,8 +341,7 @@ async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) try: path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav") @@ -402,8 +396,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) abm.message = [Image(file=path, url=path)] elif msgtype == "voice": media_id = msg.get("voice", {}).get("media_id", "") @@ -415,8 +408,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) try: path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 7aee26e47f..83a91a872b 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -12,6 +12,13 @@ from .wecom_kf_message import WeChatKFMessage +def _upload_media_from_path( + client: WeChatClient, media_type: str, file_path: str +) -> dict: + with open(file_path, "rb") as f: + return client.media.upload(media_type, f) + + class WecomPlatformEvent(AstrMessageEvent): def __init__( self, @@ -100,45 +107,52 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - with open(img_path, "rb") as f: + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "image", + img_path, + ) + except Exception as e: + logger.error(f"微信客服上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传图片失败: {e}"), + ) + return + logger.debug(f"微信客服上传图片返回: {response}") + kf_message_api.send_image( + user_id, + self.get_self_id(), + response["media_id"], + ) + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: try: - response = self.client.media.upload("image", f) + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "voice", + record_path_amr, + ) except Exception as e: - logger.error(f"微信客服上传图片失败: {e}") + logger.error(f"微信客服上传语音失败: {e}") await self.send( - MessageChain().message(f"微信客服上传图片失败: {e}"), + MessageChain().message(f"微信客服上传语音失败: {e}"), ) return - logger.debug(f"微信客服上传图片返回: {response}") - kf_message_api.send_image( + logger.info(f"微信客服上传语音返回: {response}") + kf_message_api.send_voice( user_id, self.get_self_id(), response["media_id"], ) - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"微信客服上传语音失败: {e}") - await self.send( - MessageChain().message( - f"微信客服上传语音失败: {e}" - ), - ) - return - logger.info(f"微信客服上传语音返回: {response}") - kf_message_api.send_voice( - user_id, - self.get_self_id(), - response["media_id"], - ) finally: - if record_path_amr != record_path and os.path.exists( + if record_path_amr != record_path and await asyncio.to_thread( + os.path.exists, record_path_amr, ): try: @@ -148,39 +162,47 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, File): file_path = await comp.get_file() - with open(file_path, "rb") as f: - try: - response = self.client.media.upload("file", f) - except Exception as e: - logger.error(f"微信客服上传文件失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传文件失败: {e}"), - ) - return - logger.debug(f"微信客服上传文件返回: {response}") - kf_message_api.send_file( - user_id, - self.get_self_id(), - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "file", + file_path, + ) + except Exception as e: + logger.error(f"微信客服上传文件失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传文件失败: {e}"), ) + return + logger.debug(f"微信客服上传文件返回: {response}") + kf_message_api.send_file( + user_id, + self.get_self_id(), + response["media_id"], + ) elif isinstance(comp, Video): video_path = await comp.convert_to_file_path() - with open(video_path, "rb") as f: - try: - response = self.client.media.upload("video", f) - except Exception as e: - logger.error(f"微信客服上传视频失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传视频失败: {e}"), - ) - return - logger.debug(f"微信客服上传视频返回: {response}") - kf_message_api.send_video( - user_id, - self.get_self_id(), - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "video", + video_path, ) + except Exception as e: + logger.error(f"微信客服上传视频失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传视频失败: {e}"), + ) + return + logger.debug(f"微信客服上传视频返回: {response}") + kf_message_api.send_video( + user_id, + self.get_self_id(), + response["media_id"], + ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") else: @@ -199,45 +221,52 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - with open(img_path, "rb") as f: + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "image", + img_path, + ) + except Exception as e: + logger.error(f"企业微信上传图片失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传图片失败: {e}"), + ) + return + logger.debug(f"企业微信上传图片返回: {response}") + self.client.message.send_image( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: try: - response = self.client.media.upload("image", f) + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "voice", + record_path_amr, + ) except Exception as e: - logger.error(f"企业微信上传图片失败: {e}") + logger.error(f"企业微信上传语音失败: {e}") await self.send( - MessageChain().message(f"企业微信上传图片失败: {e}"), + MessageChain().message(f"企业微信上传语音失败: {e}"), ) return - logger.debug(f"企业微信上传图片返回: {response}") - self.client.message.send_image( + logger.info(f"企业微信上传语音返回: {response}") + self.client.message.send_voice( message_obj.self_id, message_obj.session_id, response["media_id"], ) - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"企业微信上传语音失败: {e}") - await self.send( - MessageChain().message( - f"企业微信上传语音失败: {e}" - ), - ) - return - logger.info(f"企业微信上传语音返回: {response}") - self.client.message.send_voice( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) finally: - if record_path_amr != record_path and os.path.exists( + if record_path_amr != record_path and await asyncio.to_thread( + os.path.exists, record_path_amr, ): try: @@ -247,39 +276,47 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, File): file_path = await comp.get_file() - with open(file_path, "rb") as f: - try: - response = self.client.media.upload("file", f) - except Exception as e: - logger.error(f"企业微信上传文件失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传文件失败: {e}"), - ) - return - logger.debug(f"企业微信上传文件返回: {response}") - self.client.message.send_file( - message_obj.self_id, - message_obj.session_id, - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "file", + file_path, ) + except Exception as e: + logger.error(f"企业微信上传文件失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传文件失败: {e}"), + ) + return + logger.debug(f"企业微信上传文件返回: {response}") + self.client.message.send_file( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) elif isinstance(comp, Video): video_path = await comp.convert_to_file_path() - with open(video_path, "rb") as f: - try: - response = self.client.media.upload("video", f) - except Exception as e: - logger.error(f"企业微信上传视频失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传视频失败: {e}"), - ) - return - logger.debug(f"企业微信上传视频返回: {response}") - self.client.message.send_video( - message_obj.self_id, - message_obj.session_id, - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "video", + video_path, + ) + except Exception as e: + logger.error(f"企业微信上传视频失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传视频失败: {e}"), ) + return + logger.debug(f"企业微信上传视频返回: {response}") + self.client.message.send_video( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index f7cbe380d4..6dbfda7b41 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -2,7 +2,6 @@ 提供常量定义、工具函数和辅助方法 """ -import asyncio import base64 import hashlib import secrets @@ -174,7 +173,7 @@ async def process_encrypted_image( response.raise_for_status() encrypted_data = await response.read() logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: + except (TimeoutError, aiohttp.ClientError) as e: error_msg = f"下载图片失败: {e!s}" logger.error(error_msg) return False, error_msg diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py index 6f42f264b9..c305411d4e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import base64 import hashlib import mimetypes @@ -103,7 +104,9 @@ async def send_image_base64(self, image_base64: str) -> None: async def upload_media( self, file_path: Path, media_type: Literal["file", "voice"] ) -> str: - if not file_path.exists() or not file_path.is_file(): + if not await asyncio.to_thread(file_path.exists) or not await asyncio.to_thread( + file_path.is_file + ): raise WecomAIBotWebhookError(f"文件不存在: {file_path}") content_type = ( @@ -112,7 +115,7 @@ async def upload_media( form = aiohttp.FormData() form.add_field( "media", - file_path.read_bytes(), + await asyncio.to_thread(file_path.read_bytes), filename=file_path.name, content_type=content_type, ) diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index c01355974a..59f8ebd8c7 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -1,10 +1,10 @@ import asyncio import os -import sys import time import uuid from collections.abc import Callable, Coroutine -from typing import Any, cast +from pathlib import Path +from typing import Any, cast, override import quart from requests import Response @@ -32,11 +32,6 @@ from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class WeixinOfficialAccountServer: def __init__( @@ -379,7 +374,7 @@ async def callback(msg: BaseMessage): ) # wait for 180s logger.debug(f"Got future result: {result}") return result - except asyncio.TimeoutError: + except TimeoutError: logger.info(f"callback 处理消息超时: message_id={msg.id}") return create_reply("处理消息超时,请稍后再试。", msg) except Exception as e: @@ -468,8 +463,7 @@ async def convert_message( ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) try: path_wav = os.path.join( diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index ae536593c5..0797e4dee1 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -12,6 +12,13 @@ from astrbot.core.utils.media_utils import convert_audio_to_amr +def _upload_media_from_path( + client: WeChatClient, media_type: str, file_path: str +) -> dict: + with open(file_path, "rb") as f: + return client.media.upload(media_type, f) + + class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): def __init__( self, @@ -101,24 +108,63 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - with open(img_path, "rb") as f: + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "image", + img_path, + ) + except Exception as e: + logger.error(f"微信公众平台上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信公众平台上传图片失败: {e}"), + ) + return + logger.debug(f"微信公众平台上传图片返回: {response}") + + if active_send_mode: + self.client.message.send_image( + message_obj.sender.user_id, + response["media_id"], + ) + else: + reply = ImageReply( + media_id=response["media_id"], + message=cast(dict, self.message_obj.raw_message)["message"], + ) + xml = reply.render() + future = cast(dict, self.message_obj.raw_message)["future"] + assert isinstance(future, asyncio.Future) + future.set_result(xml) + + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: try: - response = self.client.media.upload("image", f) + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "voice", + record_path_amr, + ) except Exception as e: - logger.error(f"微信公众平台上传图片失败: {e}") + logger.error(f"微信公众平台上传语音失败: {e}") await self.send( - MessageChain().message(f"微信公众平台上传图片失败: {e}"), + MessageChain().message(f"微信公众平台上传语音失败: {e}"), ) return - logger.debug(f"微信公众平台上传图片返回: {response}") + logger.info(f"微信公众平台上传语音返回: {response}") if active_send_mode: - self.client.message.send_image( + self.client.message.send_voice( message_obj.sender.user_id, response["media_id"], ) else: - reply = ImageReply( + reply = VoiceReply( media_id=response["media_id"], message=cast(dict, self.message_obj.raw_message)["message"], ) @@ -126,44 +172,9 @@ async def send(self, message: MessageChain) -> None: future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) - - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"微信公众平台上传语音失败: {e}") - await self.send( - MessageChain().message( - f"微信公众平台上传语音失败: {e}" - ), - ) - return - logger.info(f"微信公众平台上传语音返回: {response}") - - if active_send_mode: - self.client.message.send_voice( - message_obj.sender.user_id, - response["media_id"], - ) - else: - reply = VoiceReply( - media_id=response["media_id"], - message=cast(dict, self.message_obj.raw_message)[ - "message" - ], - ) - xml = reply.render() - future = cast(dict, self.message_obj.raw_message)["future"] - assert isinstance(future, asyncio.Future) - future.set_result(xml) finally: - if record_path_amr != record_path and os.path.exists( - record_path_amr + if record_path_amr != record_path and await asyncio.to_thread( + os.path.exists, record_path_amr ): try: os.remove(record_path_amr) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947d..aea04645d0 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio import base64 import enum import json from dataclasses import dataclass, field +from pathlib import Path from typing import Any from anthropic.types import Message as AnthropicMessage @@ -218,9 +220,10 @@ async def _encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + image_bs64 = base64.b64encode( + await asyncio.to_thread(Path(image_url).read_bytes) + ).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 return "" diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 068c63c5ad..22e9a0766c 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -8,6 +8,7 @@ import urllib.parse from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass +from pathlib import Path from types import MappingProxyType from typing import Any @@ -198,7 +199,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return True, "" return False, f"HTTP {response.status}: {response.reason}" - except asyncio.TimeoutError: + except TimeoutError: return False, f"连接超时: {timeout}秒" except Exception as e: return False, f"{e!s}" @@ -373,15 +374,24 @@ async def init_mcp_clients( data_dir = get_astrbot_data_path() mcp_json_file = os.path.join(data_dir, "mcp_server.json") - if not os.path.exists(mcp_json_file): + if not await asyncio.to_thread(os.path.exists, mcp_json_file): # 配置文件不存在错误处理 - with open(mcp_json_file, "w", encoding="utf-8") as f: - json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + config_text = json.dumps(DEFAULT_MCP_CONFIG, ensure_ascii=False, indent=4) + await asyncio.to_thread( + Path(mcp_json_file).write_text, + config_text, + encoding="utf-8", + ) logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return MCPInitSummary(total=0, success=0, failed=[]) - with open(mcp_json_file, encoding="utf-8") as f: - mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"] + mcp_json_content = await asyncio.to_thread( + Path(mcp_json_file).read_text, + encoding="utf-8", + ) + mcp_server_json_obj: dict[str, dict] = json.loads(mcp_json_content)[ + "mcpServers" + ] init_timeout = self._init_timeout_default timeout_display = f"{init_timeout:g}" @@ -451,7 +461,7 @@ async def _start_mcp_server( cfg: dict, *, shutdown_event: asyncio.Event | None = None, - timeout: float, + timeout_seconds: float, ) -> None: """Initialize MCP server with timeout and register task/event together. @@ -461,7 +471,7 @@ async def _start_mcp_server( async with self._runtime_lock: if name in self._mcp_server_runtime or name in self._mcp_starting: logger.warning( - f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。" + f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout_seconds:g})。" ) self._log_safe_mcp_debug_config(cfg) return @@ -474,11 +484,11 @@ async def _start_mcp_server( try: mcp_client = await asyncio.wait_for( self._init_mcp_client(name, cfg), - timeout=timeout, + timeout=timeout_seconds, ) - except asyncio.TimeoutError as exc: + except TimeoutError as exc: raise MCPInitTimeoutError( - f"MCP 服务 {name} 初始化超时({timeout:g} 秒)" + f"MCP 服务 {name} 初始化超时({timeout_seconds:g} 秒)" ) from exc except Exception: logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) @@ -511,7 +521,7 @@ async def lifecycle() -> None: async def _shutdown_runtimes( self, runtimes: list[_MCPServerRuntime], - timeout: float, + timeout_seconds: float, *, strict: bool = True, ) -> list[str]: @@ -530,9 +540,9 @@ async def _shutdown_runtimes( try: results = await asyncio.wait_for( asyncio.gather(*lifecycle_tasks, return_exceptions=True), - timeout=timeout, + timeout=timeout_seconds, ) - except asyncio.TimeoutError: + except TimeoutError: pending_names = [ runtime.name for runtime in runtimes @@ -543,10 +553,10 @@ async def _shutdown_runtimes( task.cancel() await asyncio.gather(*lifecycle_tasks, return_exceptions=True) if strict: - raise MCPShutdownTimeoutError(pending_names, timeout) + raise MCPShutdownTimeoutError(pending_names, timeout_seconds) logger.warning( "MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s", - f"{timeout:g}", + f"{timeout_seconds:g}", ", ".join(pending_names), ) return pending_names @@ -657,7 +667,8 @@ async def enable_mcp_server( name: str, config: dict, shutdown_event: asyncio.Event | None = None, - timeout: float | int | str | None = None, + timeout_seconds: float | int | str | None = None, + **kwargs: Any, ) -> None: """Enable a new MCP server and initialize it. @@ -665,18 +676,22 @@ async def enable_mcp_server( name: The name of the MCP server. config: Configuration for the MCP server. shutdown_event: Event to signal when the MCP client should shut down. - timeout: Timeout in seconds for initialization. + timeout_seconds: Timeout in seconds for initialization. Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout). Raises: MCPInitTimeoutError: If initialization does not complete within timeout. Exception: If there is an error during initialization. """ - if timeout is None: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = legacy_timeout + + if timeout_seconds is None: timeout_value = self._enable_timeout_default else: timeout_value = _resolve_timeout( - timeout=timeout, + timeout=timeout_seconds, env_name=ENABLE_MCP_TIMEOUT_ENV, default=self._enable_timeout_default, ) @@ -684,36 +699,45 @@ async def enable_mcp_server( name=name, cfg=config, shutdown_event=shutdown_event, - timeout=timeout_value, + timeout_seconds=timeout_value, ) async def disable_mcp_server( self, name: str | None = None, - timeout: float = 10, + timeout_seconds: float = 10, + **kwargs: Any, ) -> None: """Disable an MCP server by its name. Args: name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. - timeout (int): Timeout. + timeout_seconds (int): Timeout. Raises: MCPShutdownTimeoutError: If shutdown does not complete within timeout. Only raised when disabling a specific server (name is not None). """ + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = float(legacy_timeout) + if name: async with self._runtime_lock: runtime = self._mcp_server_runtime.get(name) if runtime is None: return - await self._shutdown_runtimes([runtime], timeout, strict=True) + await self._shutdown_runtimes( + [runtime], timeout_seconds=timeout_seconds, strict=True + ) else: async with self._runtime_lock: runtimes = list(self._mcp_server_runtime.values()) - await self._shutdown_runtimes(runtimes, timeout, strict=False) + await self._shutdown_runtimes( + runtimes, timeout_seconds=timeout_seconds, strict=False + ) def _warn_on_timeout_mismatch( self, diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 901efd0052..08c5254858 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,8 @@ import asyncio import os from collections.abc import AsyncGenerator -from typing import TypeAlias, Union +from pathlib import Path +from typing import Any from astrbot.core.agent.message import ContentPart, Message from astrbot.core.agent.tool import ToolSet @@ -15,13 +16,9 @@ from astrbot.core.provider.register import provider_cls_map from astrbot.core.utils.astrbot_path import get_astrbot_path -Providers: TypeAlias = Union[ - "Provider", - "STTProvider", - "TTSProvider", - "EmbeddingProvider", - "RerankProvider", -] +type Providers = ( + "Provider" | "STTProvider" | "TTSProvider" | "EmbeddingProvider" | "RerankProvider" +) class AbstractProvider(abc.ABC): @@ -188,10 +185,13 @@ def _ensure_message_to_dicts( return dicts - async def test(self, timeout: float = 45.0) -> None: + async def test(self, timeout_seconds: float = 45.0, **kwargs: Any) -> None: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = float(legacy_timeout) await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), - timeout=timeout, + timeout=timeout_seconds, ) @@ -268,8 +268,9 @@ async def get_audio_stream( # 调用原有的 get_audio 方法获取音频文件路径 audio_path = await self.get_audio(accumulated_text) # 读取音频文件内容 - with open(audio_path, "rb") as f: - audio_data = f.read() + audio_data = await asyncio.to_thread( + Path(audio_path).read_bytes + ) await audio_queue.put((accumulated_text, audio_data)) except Exception: # 出错时也要发送 None 结束标记 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index ec3c395a46..7f7a51859f 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,6 +1,8 @@ +import asyncio import base64 import json from collections.abc import AsyncGenerator +from pathlib import Path import anthropic import httpx @@ -637,11 +639,10 @@ async def encode_image_bs64(self, image_url: str) -> tuple[str, str]: except Exception: mime_type = "image/jpeg" return f"data:{mime_type};base64,{raw_base64}", mime_type - with open(image_url, "rb") as f: - image_bytes = f.read() - mime_type = self._detect_image_mime_type(image_bytes) - image_bs64 = base64.b64encode(image_bytes).decode("utf-8") - return f"data:{mime_type};base64,{image_bs64}", mime_type + image_bytes = await asyncio.to_thread(Path(image_url).read_bytes) + mime_type = self._detect_image_mime_type(image_bytes) + image_bs64 = base64.b64encode(image_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_bs64}", mime_type return "", "image/jpeg" def get_current_key(self) -> str: diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 9b6816859f..bd12f37b91 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,6 +3,7 @@ import logging import os import uuid +from pathlib import Path import aiohttp import dashscope @@ -59,8 +60,7 @@ async def get_audio(self, text: str) -> str: ) path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}") - with open(path, "wb") as f: - f.write(audio_bytes) + await asyncio.to_thread(Path(path).write_bytes, audio_bytes) return path def _call_qwen_tts(self, model: str, text: str): @@ -129,7 +129,7 @@ async def _download_audio_from_url(self, url: str) -> bytes | None: ) as response, ): return await response.read() - except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e: + except (TimeoutError, aiohttp.ClientError, OSError) as e: logging.exception(f"Failed to download audio from URL {url}: {e}") return None diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 503bd275b4..147c925ecf 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -1,126 +1,129 @@ -import asyncio -import os -import subprocess -import uuid - -import edge_tts - -from astrbot.core import logger -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - -""" -edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 -``` -pip install edge_tts -``` -Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot -""" - - -@register_provider_adapter( - "edge_tts", - "Microsoft Edge TTS", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderEdgeTTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - - # 设置默认语音,如果没有指定则使用中文小萱 - self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") - self.rate = provider_config.get("rate") - self.volume = provider_config.get("volume") - self.pitch = provider_config.get("pitch") - self.timeout = provider_config.get("timeout", 30) - - self.proxy = os.getenv("https_proxy", None) - - self.set_model("edge_tts") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") - wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") - - # 构建 Edge TTS 参数 - kwargs = {"text": text, "voice": self.voice} - if self.rate: - kwargs["rate"] = self.rate - if self.volume: - kwargs["volume"] = self.volume - if self.pitch: - kwargs["pitch"] = self.pitch - - try: - communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) - await communicate.save(mp3_path) - - try: - from pyffmpeg import FFmpeg - - ff = FFmpeg() - ff.convert(input_file=mp3_path, output_file=wav_path) - except Exception as e: - logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") - # use ffmpeg command line - - # 使用ffmpeg将MP3转换为标准WAV格式 - p = await asyncio.create_subprocess_exec( - "ffmpeg", - "-y", # 覆盖输出文件 - "-i", - mp3_path, # 输入文件 - "-acodec", - "pcm_s16le", # 16位PCM编码 - "-ar", - "24000", # 采样率24kHz (适合微信语音) - "-ac", - "1", # 单声道 - "-af", - "apad=pad_dur=2", # 确保输出时长准确 - "-fflags", - "+genpts", # 强制生成时间戳 - "-hide_banner", # 隐藏版本信息 - wav_path, # 输出文件 - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # 等待进程完成并获取输出 - stdout, stderr = await p.communicate() - logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") - logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") - logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") - - os.remove(mp3_path) - if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: - return wav_path - logger.error("生成的WAV文件不存在或为空") - raise RuntimeError("生成的WAV文件不存在或为空") - - except subprocess.CalledProcessError as e: - logger.error( - f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", - ) - try: - if os.path.exists(mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"FFmpeg 转换失败: {e!s}") - - except Exception as e: - logger.error(f"音频生成失败: {e!s}") - try: - if os.path.exists(mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"音频生成失败: {e!s}") +import asyncio +import os +import subprocess +import uuid + +import edge_tts + +from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + +""" +edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 +``` +pip install edge_tts +``` +Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot +""" + + +@register_provider_adapter( + "edge_tts", + "Microsoft Edge TTS", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderEdgeTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + + # 设置默认语音,如果没有指定则使用中文小萱 + self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") + self.rate = provider_config.get("rate") + self.volume = provider_config.get("volume") + self.pitch = provider_config.get("pitch") + self.timeout = provider_config.get("timeout", 30) + + self.proxy = os.getenv("https_proxy", None) + + self.set_model("edge_tts") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") + wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") + + # 构建 Edge TTS 参数 + kwargs = {"text": text, "voice": self.voice} + if self.rate: + kwargs["rate"] = self.rate + if self.volume: + kwargs["volume"] = self.volume + if self.pitch: + kwargs["pitch"] = self.pitch + + try: + communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) + await communicate.save(mp3_path) + + try: + from pyffmpeg import FFmpeg + + ff = FFmpeg() + ff.convert(input_file=mp3_path, output_file=wav_path) + except Exception as e: + logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") + # use ffmpeg command line + + # 使用ffmpeg将MP3转换为标准WAV格式 + p = await asyncio.create_subprocess_exec( + "ffmpeg", + "-y", # 覆盖输出文件 + "-i", + mp3_path, # 输入文件 + "-acodec", + "pcm_s16le", # 16位PCM编码 + "-ar", + "24000", # 采样率24kHz (适合微信语音) + "-ac", + "1", # 单声道 + "-af", + "apad=pad_dur=2", # 确保输出时长准确 + "-fflags", + "+genpts", # 强制生成时间戳 + "-hide_banner", # 隐藏版本信息 + wav_path, # 输出文件 + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # 等待进程完成并获取输出 + stdout, stderr = await p.communicate() + logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") + logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") + logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") + + os.remove(mp3_path) + if ( + await asyncio.to_thread(os.path.exists, wav_path) + and await asyncio.to_thread(os.path.getsize, wav_path) > 0 + ): + return wav_path + logger.error("生成的WAV文件不存在或为空") + raise RuntimeError("生成的WAV文件不存在或为空") + + except subprocess.CalledProcessError as e: + logger.error( + f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", + ) + try: + if await asyncio.to_thread(os.path.exists, mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"FFmpeg 转换失败: {e!s}") + + except Exception as e: + logger.error(f"音频生成失败: {e!s}") + try: + if await asyncio.to_thread(os.path.exists, mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"音频生成失败: {e!s}") diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 35945b7b6f..c1b62ef399 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -1,6 +1,8 @@ +import asyncio import os import re import uuid +from pathlib import Path from typing import Annotated, Literal import ormsgpack @@ -159,9 +161,10 @@ async def get_audio(self, text: str) -> str: if response.status_code == 200 and response.headers.get( "content-type", "" ).startswith("audio/"): - with open(path, "wb") as f: - async for chunk in response.aiter_bytes(): - f.write(chunk) + audio_data = bytearray() + async for chunk in response.aiter_bytes(): + audio_data.extend(chunk) + await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data)) return path error_bytes = await response.aread() error_text = error_bytes.decode("utf-8", errors="replace")[:1024] diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9557f3dbcd..25e5c6e3a7 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,6 +4,7 @@ import logging import random from collections.abc import AsyncGenerator +from pathlib import Path from typing import cast from google import genai @@ -924,9 +925,10 @@ async def encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + image_bs64 = base64.b64encode( + await asyncio.to_thread(Path(image_url).read_bytes) + ).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 async def terminate(self) -> None: if self.client: diff --git a/astrbot/core/provider/sources/genie_tts.py b/astrbot/core/provider/sources/genie_tts.py index 8f9b6d91d7..62b4b3f81d 100644 --- a/astrbot/core/provider/sources/genie_tts.py +++ b/astrbot/core/provider/sources/genie_tts.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from pathlib import Path from astrbot.core import logger from astrbot.core.provider.entities import ProviderType @@ -72,7 +73,7 @@ def _generate(save_path: str) -> None: try: await loop.run_in_executor(None, _generate, path) - if os.path.exists(path): + if await asyncio.to_thread(os.path.exists, path): return path raise RuntimeError("Genie TTS did not save to file.") @@ -109,9 +110,8 @@ def _generate(save_path: str, t: str) -> None: await loop.run_in_executor(None, _generate, path, text) - if os.path.exists(path): - with open(path, "rb") as f: - audio_data = f.read() + if await asyncio.to_thread(os.path.exists, path): + audio_data = await asyncio.to_thread(Path(path).read_bytes) # Put (text, bytes) into queue so frontend can display text await audio_queue.put((text, audio_data)) diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index fc8bccea84..a9ebfe9a66 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from pathlib import Path import aiohttp @@ -129,8 +130,7 @@ async def get_audio(self, text: str) -> str: result = await self._make_request(endpoint, params) if isinstance(result, bytes): - with open(path, "wb") as f: - f.write(result) + await asyncio.to_thread(Path(path).write_bytes, result) return path raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index 425e801f46..f92485b722 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,59 +1,62 @@ -import os -import urllib.parse -import uuid - -import aiohttp - -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - - -@register_provider_adapter( - "gsvi_tts_api", - "GSVI TTS API", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderGSVITTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") - self.api_base = self.api_base.removesuffix("/") - self.character = provider_config.get("character") - self.emotion = provider_config.get("emotion") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") - params = {"text": text} - - if self.character: - params["character"] = self.character - if self.emotion: - params["emotion"] = self.emotion - - query_parts = [] - for key, value in params.items(): - encoded_value = urllib.parse.quote(str(value)) - query_parts.append(f"{key}={encoded_value}") - - url = f"{self.api_base}/tts?{'&'.join(query_parts)}" - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - with open(path, "wb") as f: - f.write(await response.read()) - else: - error_text = await response.text() - raise Exception( - f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", - ) - - return path +import asyncio +import os +import urllib.parse +import uuid +from pathlib import Path + +import aiohttp + +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "gsvi_tts_api", + "GSVI TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderGSVITTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") + self.api_base = self.api_base.removesuffix("/") + self.character = provider_config.get("character") + self.emotion = provider_config.get("emotion") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") + params = {"text": text} + + if self.character: + params["character"] = self.character + if self.emotion: + params["emotion"] = self.emotion + + query_parts = [] + for key, value in params.items(): + encoded_value = urllib.parse.quote(str(value)) + query_parts.append(f"{key}={encoded_value}") + + url = f"{self.api_base}/tts?{'&'.join(query_parts)}" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + await asyncio.to_thread( + Path(path).write_bytes, await response.read() + ) + else: + error_text = await response.text() + raise Exception( + f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", + ) + + return path diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 69860111cf..ad2e345367 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -1,7 +1,9 @@ +import asyncio import json import os import uuid from collections.abc import AsyncIterator +from pathlib import Path import aiohttp @@ -155,8 +157,7 @@ async def get_audio(self, text: str) -> str: audio = await self._audio_play(audio_stream) # 结果保存至文件 - with open(path, "wb") as file: - file.write(audio) + await asyncio.to_thread(Path(path).write_bytes, audio) return path diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index adee24073d..3f0c007b36 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -5,6 +5,7 @@ import random import re from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any import httpx @@ -949,9 +950,10 @@ async def encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + image_bs64 = base64.b64encode( + await asyncio.to_thread(Path(image_url).read_bytes) + ).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 async def terminate(self): if self.client: diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 217b189251..35ac1d5a8c 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,5 +1,7 @@ +import asyncio import os import uuid +from pathlib import Path import httpx from openai import NOT_GIVEN, AsyncOpenAI @@ -54,9 +56,10 @@ async def get_audio(self, text: str) -> str: response_format="wav", input=text, ) as response: - with open(path, "wb") as f: - async for chunk in response.iter_bytes(chunk_size=1024): - f.write(chunk) + audio_data = bytearray() + async for chunk in response.iter_bytes(chunk_size=1024): + audio_data.extend(chunk) + await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data)) return path async def terminate(self): diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index af6c0f631e..b776657968 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -53,14 +53,12 @@ async def initialize(self) -> None: async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) return str(temp_dir / timestamp) async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" - with open(file_path, "rb") as f: - file_header = f.read(8) - + file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] if silk_header in file_header: return True return False @@ -76,7 +74,7 @@ async def get_text(self, audio_url: str) -> str: await download_file(audio_url, path) audio_url = path - if not os.path.isfile(audio_url): + if not await asyncio.to_thread(os.path.isfile, audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith((".amr", ".silk")) or is_tencent: diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py index 349815907d..5082200717 100644 --- a/astrbot/core/provider/sources/volcengine_tts.py +++ b/astrbot/core/provider/sources/volcengine_tts.py @@ -4,6 +4,7 @@ import os import traceback import uuid +from pathlib import Path import aiohttp @@ -100,10 +101,9 @@ async def get_audio(self, text: str) -> str: f"volcengine_tts_{uuid.uuid4()}.mp3", ) - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, - lambda: open(file_path, "wb").write(audio_data), + await asyncio.to_thread( + Path(file_path).write_bytes, + audio_data, ) return file_path diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 386da063db..00c87075db 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,5 +1,7 @@ +import asyncio import os import uuid +from pathlib import Path from openai import NOT_GIVEN, AsyncOpenAI @@ -44,8 +46,7 @@ async def _get_audio_format(self, file_path) -> str | None: amr_header = b"#!AMR" try: - with open(file_path, "rb") as f: - file_header = f.read(8) + file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] except FileNotFoundError: return None @@ -73,7 +74,7 @@ async def get_text(self, audio_url: str) -> str: await download_file(audio_url, path) audio_url = path - if not os.path.exists(audio_url): + if not await asyncio.to_thread(os.path.exists, audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: @@ -100,13 +101,14 @@ async def get_text(self, audio_url: str) -> str: audio_url = output_path + audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes) result = await self.client.audio.transcriptions.create( model=self.model_name, - file=("audio.wav", open(audio_url, "rb")), + file=("audio.wav", audio_bytes), ) # remove temp file - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(audio_url) except Exception as e: diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 678deb9481..d85c84f9be 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from pathlib import Path from typing import cast import whisper @@ -42,9 +43,7 @@ async def initialize(self) -> None: async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" - with open(file_path, "rb") as f: - file_header = f.read(8) - + file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] if silk_header in file_header: return True return False @@ -66,7 +65,7 @@ async def get_text(self, audio_url: str) -> str: await download_file(audio_url, path) audio_url = path - if not os.path.exists(audio_url): + if not await asyncio.to_thread(os.path.exists, audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 0a22e456ed..7b6068dd85 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -1,5 +1,7 @@ +import asyncio import os import uuid +from pathlib import Path import aiohttp from xinference_client.client.restful.async_restful_client import ( @@ -102,9 +104,8 @@ async def get_text(self, audio_url: str) -> str: f"Failed to download audio from {audio_url}, status: {resp.status}", ) return "" - elif os.path.exists(audio_url): - with open(audio_url, "rb") as f: - audio_bytes = f.read() + elif await asyncio.to_thread(os.path.exists, audio_url): + audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes) else: logger.error(f"File not found: {audio_url}") return "" @@ -143,8 +144,7 @@ async def get_text(self, audio_url: str) -> str: ) temp_files.extend([input_path, output_path]) - with open(input_path, "wb") as f: - f.write(audio_bytes) + await asyncio.to_thread(Path(input_path).write_bytes, audio_bytes) if conversion_type == "silk": logger.info("Converting silk to wav ...") @@ -153,8 +153,7 @@ async def get_text(self, audio_url: str) -> str: logger.info("Converting amr to wav ...") await convert_to_pcm_wav(input_path, output_path) - with open(output_path, "rb") as f: - audio_bytes = f.read() + audio_bytes = await asyncio.to_thread(Path(output_path).read_bytes) # 4. Transcribe # 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来 @@ -199,7 +198,7 @@ async def get_text(self, audio_url: str) -> str: # 5. Cleanup for temp_file in temp_files: try: - if os.path.exists(temp_file): + if await asyncio.to_thread(os.path.exists, temp_file): os.remove(temp_file) logger.debug(f"Removed temporary file: {temp_file}") except Exception as e: diff --git a/astrbot/core/skills/neo_skill_sync.py b/astrbot/core/skills/neo_skill_sync.py index 5fe2b7832d..2bb4c50f8b 100644 --- a/astrbot/core/skills/neo_skill_sync.py +++ b/astrbot/core/skills/neo_skill_sync.py @@ -5,7 +5,7 @@ import os import re from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -19,7 +19,7 @@ def _now_iso() -> str: - return datetime.now(timezone.utc).isoformat() + return datetime.now(UTC).isoformat() def _to_jsonable(model_like: Any) -> dict[str, Any]: diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index d15876526d..a24ddac9ea 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -7,7 +7,7 @@ import tempfile import zipfile from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path, PurePosixPath from astrbot.core.utils.astrbot_path import ( @@ -175,7 +175,7 @@ def _load_sandbox_skills_cache(self) -> dict: def _save_sandbox_skills_cache(self, cache: dict) -> None: cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION - cache["updated_at"] = datetime.now(timezone.utc).isoformat() + cache["updated_at"] = datetime.now(UTC).isoformat() with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f: json.dump(cache, f, ensure_ascii=False, indent=2) diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index d28ac726ae..f6afc08e1c 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -3,7 +3,7 @@ import enum from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeVar, overload +from typing import Any, Literal, TypeVar, overload from .filter import HandlerFilter from .star import star_map @@ -11,7 +11,7 @@ T = TypeVar("T", bound="StarHandlerMetadata") -class StarHandlerRegistry(Generic[T]): +class StarHandlerRegistry[T: "StarHandlerMetadata"]: def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] @@ -227,7 +227,7 @@ class EventType(enum.Enum): @dataclass -class StarHandlerMetadata(Generic[H]): +class StarHandlerMetadata[H: Callable[..., Any]]: """描述一个 Star 所注册的某一个 Handler。""" event_type: EventType diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 68c58fdae5..c5fa63bee7 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -8,6 +8,7 @@ import os import sys import traceback +from pathlib import Path from types import ModuleType import yaml @@ -188,7 +189,7 @@ async def _check_plugin_dept_update( 如果 target_plugin 为 None,则检查所有插件的依赖 """ plugin_dir = self.plugin_store_path - if not os.path.exists(plugin_dir): + if not await asyncio.to_thread(os.path.exists, plugin_dir): return False to_update = [] if target_plugin: @@ -198,7 +199,9 @@ async def _check_plugin_dept_update( to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) - if os.path.exists(os.path.join(plugin_path, "requirements.txt")): + if await asyncio.to_thread( + os.path.exists, os.path.join(plugin_path, "requirements.txt") + ): pth = os.path.join(plugin_path, "requirements.txt") logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}") try: @@ -217,7 +220,7 @@ async def _import_plugin_with_dependency_recovery( try: return __import__(path, fromlist=[module_str]) except (ModuleNotFoundError, ImportError) as import_exc: - if os.path.exists(requirements_path): + if await asyncio.to_thread(os.path.exists, requirements_path): try: logger.info( f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}" @@ -651,16 +654,19 @@ async def load( plugin_dir_path, self.conf_schema_fname, ) - if os.path.exists(plugin_schema_path): + if await asyncio.to_thread(os.path.exists, plugin_schema_path): # 加载插件配置 - with open(plugin_schema_path, encoding="utf-8") as f: - plugin_config = AstrBotConfig( - config_path=os.path.join( - self.plugin_config_path, - f"{root_dir_name}_config.json", - ), - schema=json.loads(f.read()), - ) + plugin_schema_text = await asyncio.to_thread( + Path(plugin_schema_path).read_text, + encoding="utf-8", + ) + plugin_config = AstrBotConfig( + config_path=os.path.join( + self.plugin_config_path, + f"{root_dir_name}_config.json", + ), + schema=json.loads(plugin_schema_text), + ) logo_path = os.path.join(plugin_dir_path, self.logo_fname) if path in star_map: @@ -836,7 +842,7 @@ async def load( metadata.activated = False # Plugin logo path - if os.path.exists(logo_path): + if await asyncio.to_thread(os.path.exists, logo_path): metadata.logo_path = logo_path assert metadata.module_path, f"插件 {metadata.name} 模块路径为空" @@ -955,7 +961,7 @@ async def _cleanup_failed_plugin_install( except Exception: logger.warning(traceback.format_exc()) - if os.path.exists(plugin_path): + if await asyncio.to_thread(os.path.exists, plugin_path): try: remove_dir(plugin_path) logger.warning(f"已清理安装失败的插件目录: {plugin_path}") @@ -968,7 +974,7 @@ async def _cleanup_failed_plugin_install( self.plugin_config_path, f"{dir_name}_config.json", ) - if os.path.exists(plugin_config_path): + if await asyncio.to_thread(os.path.exists, plugin_config_path): try: os.remove(plugin_config_path) logger.warning(f"已清理安装失败插件配置: {plugin_config_path}") @@ -1100,13 +1106,14 @@ async def install_plugin( # Extract README.md content if exists readme_content = None readme_path = os.path.join(plugin_path, "README.md") - if not os.path.exists(readme_path): + if not await asyncio.to_thread(os.path.exists, readme_path): readme_path = os.path.join(plugin_path, "readme.md") - if os.path.exists(readme_path): + if await asyncio.to_thread(os.path.exists, readme_path): try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = await asyncio.to_thread( + Path(readme_path).read_text, encoding="utf-8" + ) except Exception as e: logger.warning( f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}", @@ -1211,7 +1218,7 @@ async def uninstall_failed_plugin( self._cleanup_plugin_state(dir_name) plugin_path = os.path.join(self.plugin_store_path, dir_name) - if os.path.exists(plugin_path): + if await asyncio.to_thread(os.path.exists, plugin_path): try: remove_dir(plugin_path) except Exception as e: @@ -1498,13 +1505,14 @@ async def install_plugin_from_file( # Extract README.md content if exists readme_content = None readme_path = os.path.join(desti_dir, "README.md") - if not os.path.exists(readme_path): + if not await asyncio.to_thread(os.path.exists, readme_path): readme_path = os.path.join(desti_dir, "readme.md") - if os.path.exists(readme_path): + if await asyncio.to_thread(os.path.exists, readme_path): try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = await asyncio.to_thread( + Path(readme_path).read_text, encoding="utf-8" + ) except Exception as e: logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}") diff --git a/astrbot/core/utils/datetime_utils.py b/astrbot/core/utils/datetime_utils.py index 97b8196dde..431c9cd50c 100644 --- a/astrbot/core/utils/datetime_utils.py +++ b/astrbot/core/utils/datetime_utils.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime def normalize_datetime_utc(dt: datetime | None) -> datetime | None: @@ -9,8 +9,8 @@ def normalize_datetime_utc(dt: datetime | None) -> datetime | None: if dt is None: return None if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: - return dt.replace(tzinfo=timezone.utc) - return dt.astimezone(timezone.utc) + return dt.replace(tzinfo=UTC) + return dt.astimezone(UTC) def to_utc_isoformat(dt: datetime | None) -> str | None: diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b565926749..d169dd32b6 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,3 +1,4 @@ +import asyncio import base64 import logging import os @@ -58,8 +59,7 @@ def save_temp_img(img: Image.Image | bytes) -> str: if isinstance(img, Image.Image): img.save(p) else: - with open(p, "wb") as f: - f.write(img) + Path(p).write_bytes(img) return p @@ -83,15 +83,13 @@ async def download_image_by_url( async with session.post(url, json=post_data) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path else: async with session.get(url) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证(仅在证书验证失败时作为fallback) @@ -109,15 +107,13 @@ async def download_image_by_url( async with session.post(url, json=post_data, ssl=ssl_context) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path else: async with session.get(url, ssl=ssl_context) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path except Exception as e: raise e @@ -142,12 +138,13 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, "wb") as f: + file_obj = await asyncio.to_thread(Path(path).open, "wb") + try: while True: chunk = await resp.content.read(8192) if not chunk: break - f.write(chunk) + await asyncio.to_thread(file_obj.write, chunk) downloaded_size += len(chunk) if show_progress: elapsed_time = ( @@ -160,6 +157,8 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end="", ) + finally: + await asyncio.to_thread(file_obj.close) except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证(仅在证书验证失败时作为fallback) logger.warning( @@ -181,12 +180,13 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, "wb") as f: + file_obj = await asyncio.to_thread(Path(path).open, "wb") + try: while True: chunk = await resp.content.read(8192) if not chunk: break - f.write(chunk) + await asyncio.to_thread(file_obj.write, chunk) downloaded_size += len(chunk) if show_progress: elapsed_time = time.time() - start_time @@ -195,14 +195,15 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end="", ) + finally: + await asyncio.to_thread(file_obj.close) if show_progress: print() -def file_to_base64(file_path: str) -> str: - with open(file_path, "rb") as f: - data_bytes = f.read() - base64_str = base64.b64encode(data_bytes).decode() +async def file_to_base64(file_path: str) -> str: + data_bytes = await asyncio.to_thread(Path(file_path).read_bytes) + base64_str = base64.b64encode(data_bytes).decode() return "base64://" + base64_str @@ -221,17 +222,18 @@ def get_local_ip_addresses(): async def get_dashboard_version(): # First check user data directory (manually updated / downloaded dashboard). dist_dir = os.path.join(get_astrbot_data_path(), "dist") - if not os.path.exists(dist_dir): + if not await asyncio.to_thread(os.path.exists, dist_dir): # Fall back to the dist bundled inside the installed wheel. _bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist" - if _bundled.exists(): + if await asyncio.to_thread(_bundled.exists): dist_dir = str(_bundled) - if os.path.exists(dist_dir): + if await asyncio.to_thread(os.path.exists, dist_dir): version_file = os.path.join(dist_dir, "assets", "version") - if os.path.exists(version_file): - with open(version_file, encoding="utf-8") as f: - v = f.read().strip() - return v + if await asyncio.to_thread(os.path.exists, version_file): + v = ( + await asyncio.to_thread(Path(version_file).read_text, encoding="utf-8") + ).strip() + return v return None @@ -244,9 +246,12 @@ async def download_dashboard( ) -> None: """下载管理面板文件""" if path is None: - zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip" + zip_path = ( + await asyncio.to_thread(Path(get_astrbot_data_path()).absolute) + / "dashboard.zip" + ) else: - zip_path = Path(path).absolute() + zip_path = await asyncio.to_thread(Path(path).absolute) if latest or len(str(version)) != 40: ver_name = "latest" if latest else version diff --git a/astrbot/core/utils/media_utils.py b/astrbot/core/utils/media_utils.py index 8d833514fb..7ecebcad43 100644 --- a/astrbot/core/utils/media_utils.py +++ b/astrbot/core/utils/media_utils.py @@ -108,7 +108,7 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) if process.returncode != 0: # 清理可能已生成但无效的临时文件 - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) logger.debug( @@ -183,7 +183,7 @@ async def convert_video_format( if process.returncode != 0: # 清理可能已生成但无效的临时文件 - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) logger.debug( @@ -231,7 +231,7 @@ async def convert_audio_format( if output_path is None: temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}") args = ["ffmpeg", "-y", "-i", audio_path] @@ -249,7 +249,7 @@ async def convert_audio_format( ) _, stderr = await process.communicate() if process.returncode != 0: - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) except OSError as e: @@ -287,7 +287,7 @@ async def extract_video_cover( """从视频中提取封面图(JPG)。""" if output_path is None: temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg") try: @@ -306,7 +306,7 @@ async def extract_video_cover( ) _, stderr = await process.communicate() if process.returncode != 0: - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) except OSError as e: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index b327a61843..a6c62c4952 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -71,11 +71,11 @@ def keep(self, timeout: float = 0, reset_timeout=False) -> None: asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: float) -> None: + async def _holding(self, event: asyncio.Event, timeout_seconds: float) -> None: """等待事件结束或超时""" try: - await asyncio.wait_for(event.wait(), timeout) - except asyncio.TimeoutError: + await asyncio.wait_for(event.wait(), timeout_seconds) + except TimeoutError: if not self.future.done(): self.future.set_exception(TimeoutError("等待超时")) except asyncio.CancelledError: @@ -124,14 +124,14 @@ def __init__( async def register_wait( self, handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], - timeout: int = 30, + timeout_seconds: int = 30, ) -> Any: """等待外部输入并处理""" self.handler = handler USER_SESSIONS[self.session_id] = self # 开始一个会话保持事件 - self.session_controller.keep(timeout, reset_timeout=True) + self.session_controller.keep(timeout_seconds, reset_timeout=True) try: return await self.session_controller.future diff --git a/astrbot/core/utils/temp_dir_cleaner.py b/astrbot/core/utils/temp_dir_cleaner.py index c0c0600982..668ee45135 100644 --- a/astrbot/core/utils/temp_dir_cleaner.py +++ b/astrbot/core/utils/temp_dir_cleaner.py @@ -141,7 +141,7 @@ async def run(self) -> None: self._stop_event.wait(), timeout=self.CHECK_INTERVAL_SECONDS, ) - except asyncio.TimeoutError: + except TimeoutError: continue logger.info("TempDirCleaner stopped.") diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index f342484bdb..1abd6d1c0a 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -5,6 +5,7 @@ import tempfile import wave from io import BytesIO +from pathlib import Path from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path @@ -13,19 +14,18 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: import pysilk - with open(silk_path, "rb") as f: - input_data = f.read() - if input_data.startswith(b"\x02"): - input_data = input_data[1:] - input_io = BytesIO(input_data) - output_io = BytesIO() - pysilk.decode(input_io, output_io, 24000) - output_io.seek(0) - with wave.open(output_path, "wb") as wav: - wav.setnchannels(1) - wav.setsampwidth(2) - wav.setframerate(24000) - wav.writeframes(output_io.read()) + input_data = await asyncio.to_thread(Path(silk_path).read_bytes) + if input_data.startswith(b"\x02"): + input_data = input_data[1:] + input_io = BytesIO(input_data) + output_io = BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + with wave.open(output_path, "wb") as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(output_io.read()) return output_path @@ -97,7 +97,10 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: logger.debug(f"[FFmpeg] stderr: {stderr.decode().strip()}") logger.info(f"[FFmpeg] return code: {p.returncode}") - if os.path.exists(output_path) and os.path.getsize(output_path) > 0: + if ( + await asyncio.to_thread(os.path.exists, output_path) + and await asyncio.to_thread(os.path.getsize, output_path) > 0 + ): return output_path raise RuntimeError("生成的WAV文件不存在或为空") @@ -156,13 +159,12 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]: tencent=True, ) - with open(silk_path, "rb") as f: - silk_bytes = await asyncio.to_thread(f.read) - silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") + silk_bytes = await asyncio.to_thread(Path(silk_path).read_bytes) + silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") return silk_b64, duration # 已是秒 finally: - if os.path.exists(wav_path) and wav_path != audio_path: + if await asyncio.to_thread(os.path.exists, wav_path) and wav_path != audio_path: os.remove(wav_path) - if os.path.exists(silk_path): + if await asyncio.to_thread(os.path.exists, silk_path): os.remove(silk_path) diff --git a/astrbot/dashboard/routes/api_key.py b/astrbot/dashboard/routes/api_key.py index 4b957fe8ea..6d89de910c 100644 --- a/astrbot/dashboard/routes/api_key.py +++ b/astrbot/dashboard/routes/api_key.py @@ -1,6 +1,6 @@ import hashlib import secrets -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from quart import g, request @@ -59,7 +59,7 @@ def _serialize_api_key(key) -> dict: "expires_at": ApiKeyRoute._serialize_datetime(key.expires_at), "revoked_at": ApiKeyRoute._serialize_datetime(key.revoked_at), "is_revoked": key.revoked_at is not None, - "is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)), + "is_expired": bool(expires_at and expires_at < datetime.now(UTC)), } async def list_api_keys(self): @@ -98,9 +98,7 @@ async def create_api_key(self): return ( Response().error("expires_in_days must be greater than 0").__dict__ ) - expires_at = datetime.now(timezone.utc) + timedelta( - days=expires_in_days_int - ) + expires_at = datetime.now(UTC) + timedelta(days=expires_in_days_int) raw_key = f"abk_{secrets.token_urlsafe(32)}" key_hash = self._hash_key(raw_key) diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 40db1f60bd..f9bdc51d8f 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -82,7 +82,7 @@ async def edit_account(self): def generate_jwt(self, username): payload = { "username": username, - "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7), + "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=7), } jwt_token = self.config["dashboard"].get("jwt_secret", None) if not jwt_token: diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index 952806beb7..674bbbfdda 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -32,6 +32,18 @@ UPLOAD_EXPIRE_SECONDS = 3600 # 上传会话过期时间(1小时) +def _merge_backup_chunks(output_path: str, chunk_dir: str, total: int) -> None: + with open(output_path, "wb") as outfile: + for i in range(total): + chunk_path = os.path.join(chunk_dir, f"{i}.part") + with open(chunk_path, "rb") as chunk_file: + while True: + data_block = chunk_file.read(8192) + if not data_block: + break + outfile.write(data_block) + + def secure_filename(filename: str) -> str: """清洗文件名,移除路径遍历字符和危险字符 @@ -240,7 +252,7 @@ async def _cleanup_upload_session(self, upload_id: str) -> None: if upload_id in self.upload_sessions: session = self.upload_sessions[upload_id] chunk_dir = session.get("chunk_dir") - if chunk_dir and os.path.exists(chunk_dir): + if chunk_dir and await asyncio.to_thread(os.path.exists, chunk_dir): try: shutil.rmtree(chunk_dir) except Exception as e: @@ -283,7 +295,9 @@ async def list_backups(self): page_size = request.args.get("page_size", 20, type=int) # 确保备份目录存在 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + Path(self.backup_dir).mkdir, parents=True, exist_ok=True + ) # 获取所有备份文件 backup_files = [] @@ -293,7 +307,7 @@ async def list_backups(self): continue file_path = os.path.join(self.backup_dir, filename) - if not os.path.isfile(file_path): + if not await asyncio.to_thread(os.path.isfile, file_path): continue # 读取 manifest.json 获取备份信息 @@ -403,7 +417,7 @@ async def _background_export_task(self, task_id: str) -> None: result={ "filename": os.path.basename(zip_path), "path": zip_path, - "size": os.path.getsize(zip_path), + "size": await asyncio.to_thread(os.path.getsize, zip_path), }, ) except Exception as e: @@ -437,7 +451,9 @@ async def upload_backup(self): unique_filename = generate_unique_filename(safe_filename) # 保存上传的文件 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + Path(self.backup_dir).mkdir, parents=True, exist_ok=True + ) zip_path = os.path.join(self.backup_dir, unique_filename) await file.save(zip_path) @@ -451,7 +467,7 @@ async def upload_backup(self): { "filename": unique_filename, "original_filename": file.filename, - "size": os.path.getsize(zip_path), + "size": await asyncio.to_thread(os.path.getsize, zip_path), } ) .__dict__ @@ -499,7 +515,7 @@ async def upload_init(self): # 创建分片存储目录 chunk_dir = os.path.join(self.chunks_dir, upload_id) - Path(chunk_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(Path(chunk_dir).mkdir, parents=True, exist_ok=True) # 清洗文件名 safe_filename = secure_filename(filename) @@ -685,22 +701,20 @@ async def upload_complete(self): chunk_dir = session["chunk_dir"] filename = session["filename"] - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + Path(self.backup_dir).mkdir, parents=True, exist_ok=True + ) output_path = os.path.join(self.backup_dir, filename) try: - with open(output_path, "wb") as outfile: - for i in range(total): - chunk_path = os.path.join(chunk_dir, f"{i}.part") - with open(chunk_path, "rb") as chunk_file: - # 分块读取,避免内存溢出 - while True: - data_block = chunk_file.read(8192) - if not data_block: - break - outfile.write(data_block) - - file_size = os.path.getsize(output_path) + await asyncio.to_thread( + _merge_backup_chunks, + output_path, + chunk_dir, + total, + ) + + file_size = await asyncio.to_thread(os.path.getsize, output_path) # 标记备份为上传来源(修改 manifest.json 中的 origin 字段) self._mark_backup_as_uploaded(output_path) @@ -725,7 +739,7 @@ async def upload_complete(self): ) except Exception as e: # 如果合并失败,删除不完整的文件 - if os.path.exists(output_path): + if await asyncio.to_thread(os.path.exists, output_path): os.remove(output_path) raise e @@ -787,7 +801,7 @@ async def check_backup(self): return Response().error("无效的文件名").__dict__ zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): + if not await asyncio.to_thread(os.path.exists, zip_path): return Response().error(f"备份文件不存在: {filename}").__dict__ # 获取知识库管理器(用于构造 importer) @@ -841,7 +855,7 @@ async def import_backup(self): return Response().error("无效的文件名").__dict__ zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): + if not await asyncio.to_thread(os.path.exists, zip_path): return Response().error(f"备份文件不存在: {filename}").__dict__ # 生成任务ID @@ -988,7 +1002,7 @@ async def download_backup(self): return Response().error("无效的文件名").__dict__ file_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(file_path): + if not await asyncio.to_thread(os.path.exists, file_path): return Response().error("备份文件不存在").__dict__ return await send_file( @@ -1019,7 +1033,7 @@ async def delete_backup(self): return Response().error("无效的文件名").__dict__ file_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(file_path): + if not await asyncio.to_thread(os.path.exists, file_path): return Response().error("备份文件不存在").__dict__ os.remove(file_path) @@ -1067,12 +1081,12 @@ async def rename_backup(self): # 检查原文件是否存在 old_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(old_path): + if not await asyncio.to_thread(os.path.exists, old_path): return Response().error("备份文件不存在").__dict__ # 检查新文件名是否已存在 new_path = os.path.join(self.backup_dir, new_filename) - if os.path.exists(new_path): + if await asyncio.to_thread(os.path.exists, new_path): return Response().error(f"文件名 '{new_filename}' 已存在").__dict__ # 执行重命名 diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a914f3cbf0..f76aa7f9f1 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -80,17 +80,23 @@ async def get_file(self): try: file_path = os.path.join(self.attachments_dir, os.path.basename(filename)) - real_file_path = os.path.realpath(file_path) - real_imgs_dir = os.path.realpath(self.attachments_dir) + real_file_path = await asyncio.to_thread(os.path.realpath, file_path) + real_imgs_dir = await asyncio.to_thread( + os.path.realpath, self.attachments_dir + ) - if not os.path.exists(real_file_path): + if not await asyncio.to_thread(os.path.exists, real_file_path): # try legacy file_path = os.path.join( self.legacy_img_dir, os.path.basename(filename) ) - if os.path.exists(file_path): - real_file_path = os.path.realpath(file_path) - real_imgs_dir = os.path.realpath(self.legacy_img_dir) + if await asyncio.to_thread(os.path.exists, file_path): + real_file_path = await asyncio.to_thread( + os.path.realpath, file_path + ) + real_imgs_dir = await asyncio.to_thread( + os.path.realpath, self.legacy_img_dir + ) if not real_file_path.startswith(real_imgs_dir): return Response().error("Invalid file path").__dict__ @@ -117,7 +123,7 @@ async def get_attachment(self): return Response().error("Attachment not found").__dict__ file_path = attachment.path - real_file_path = os.path.realpath(file_path) + real_file_path = await asyncio.to_thread(os.path.realpath, file_path) return await send_file(real_file_path, mimetype=attachment.mime_type) @@ -344,7 +350,7 @@ async def stream(): while True: try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: + except TimeoutError: continue except asyncio.CancelledError: logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") @@ -652,7 +658,7 @@ async def _delete_attachments(self, attachment_ids: list[str]) -> None: try: attachments = await self.db.get_attachments(attachment_ids) for attachment in attachments: - if not os.path.exists(attachment.path): + if not await asyncio.to_thread(os.path.exists, attachment.path): continue try: os.remove(attachment.path) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 823d0fb9dd..1ed80a218d 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1103,7 +1103,9 @@ async def upload_config_file(self): if not files: return Response().error("No files uploaded").__dict__ - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) + storage_root_path = await asyncio.to_thread( + Path(get_astrbot_plugin_data_path()).resolve, strict=False + ) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1179,7 +1181,9 @@ async def delete_config_file(self): if not md: return Response().error(f"Plugin {name} not found").__dict__ - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) + storage_root_path = await asyncio.to_thread( + Path(get_astrbot_plugin_data_path()).resolve, strict=False + ) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1207,7 +1211,9 @@ async def get_config_file_list(self): if not meta or meta.get("type") != "file": return Response().error("Config item not found or not file type").__dict__ - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) + storage_root_path = await asyncio.to_thread( + Path(get_astrbot_plugin_data_path()).resolve, strict=False + ) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1375,7 +1381,7 @@ async def _register_platform_logo(self, platform, platform_default_tmpl) -> None logo_file_path = os.path.join(plugin_dir, platform.logo_path) # 检查文件是否存在并注册令牌 - if os.path.exists(logo_file_path): + if await asyncio.to_thread(os.path.exists, logo_file_path): logo_token = await file_token_service.register_file( logo_file_path, timeout=3600, diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index f0ac5d43d0..d06414c016 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -729,7 +729,7 @@ async def upload_document(self): ) finally: # 清理临时文件 - if os.path.exists(temp_file_path): + if await asyncio.to_thread(os.path.exists, temp_file_path): os.remove(temp_file_path) # 获取知识库 diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 8d0af938d0..58398d24c7 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -86,7 +86,7 @@ async def end_speaking(self, stamp: str) -> tuple[str | None, float]: self.temp_audio_path = audio_path logger.info( - f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes" + f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {await asyncio.to_thread(os.path.getsize, audio_path)} bytes" ) return audio_path, time.time() - start_time @@ -491,7 +491,7 @@ async def _handle_chat_message( try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: + except TimeoutError: continue if not result: @@ -790,7 +790,7 @@ async def _process_audio( try: result = await asyncio.wait_for(back_queue.get(), timeout=0.5) - except asyncio.TimeoutError: + except TimeoutError: continue if not result: diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 9a736b1763..763d05db03 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -369,7 +369,7 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: while True: try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: + except TimeoutError: continue if not result: diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index bb7769926a..f3d1d69eed 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -6,6 +6,7 @@ import traceback from dataclasses import dataclass from datetime import datetime +from pathlib import Path import aiohttp import certifi @@ -738,19 +739,20 @@ async def get_plugin_readme(self): plugin_obj.root_dir_name, ) - if not os.path.isdir(plugin_dir): + if not await asyncio.to_thread(os.path.isdir, plugin_dir): logger.warning(f"无法找到插件目录: {plugin_dir}") return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ readme_path = os.path.join(plugin_dir, "README.md") - if not os.path.isfile(readme_path): + if not await asyncio.to_thread(os.path.isfile, readme_path): logger.warning(f"插件 {plugin_name} 没有README文件") return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = await asyncio.to_thread( + Path(readme_path).read_text, encoding="utf-8" + ) return ( Response() @@ -799,7 +801,7 @@ async def get_plugin_changelog(self): plugin_obj.root_dir_name, ) - if not os.path.isdir(plugin_dir): + if not await asyncio.to_thread(os.path.isdir, plugin_dir): logger.warning(f"无法找到插件目录: {plugin_dir}") return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ @@ -807,10 +809,11 @@ async def get_plugin_changelog(self): changelog_names = ["CHANGELOG.md", "changelog.md", "CHANGELOG", "changelog"] for name in changelog_names: changelog_path = os.path.join(plugin_dir, name) - if os.path.isfile(changelog_path): + if await asyncio.to_thread(os.path.isfile, changelog_path): try: - with open(changelog_path, encoding="utf-8") as f: - changelog_content = f.read() + changelog_content = await asyncio.to_thread( + Path(changelog_path).read_text, encoding="utf-8" + ) return ( Response() .ok({"content": changelog_content}, "成功获取更新日志") diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index adad49615f..b003e2010d 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -1,3 +1,4 @@ +import asyncio import os import re import shutil @@ -182,7 +183,7 @@ async def upload_skill(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ finally: - if temp_path and os.path.exists(temp_path): + if temp_path and await asyncio.to_thread(os.path.exists, temp_path): try: os.remove(temp_path) except Exception: diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 532238ac7a..238b6aa4cc 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,3 +1,4 @@ +import asyncio import os import re import threading @@ -214,13 +215,17 @@ async def get_changelog(self): changelog_path = os.path.join(changelogs_dir, filename) # 规范化路径,防止符号链接攻击 - changelog_path = os.path.realpath(changelog_path) - changelogs_dir = os.path.realpath(changelogs_dir) + changelog_path = await asyncio.to_thread(os.path.realpath, changelog_path) + changelogs_dir = await asyncio.to_thread(os.path.realpath, changelogs_dir) # 验证最终路径在预期的 changelogs 目录内(防止路径遍历) # 确保规范化后的路径以 changelogs_dir 开头,且是目录内的文件 - changelog_path_normalized = os.path.normpath(changelog_path) - changelogs_dir_normalized = os.path.normpath(changelogs_dir) + changelog_path_normalized = await asyncio.to_thread( + os.path.normpath, changelog_path + ) + changelogs_dir_normalized = await asyncio.to_thread( + os.path.normpath, changelogs_dir + ) # 检查路径是否在预期目录内(必须是目录的子文件,不能是目录本身) expected_prefix = changelogs_dir_normalized + os.sep @@ -230,21 +235,22 @@ async def get_changelog(self): ) return Response().error("Invalid version format").__dict__ - if not os.path.exists(changelog_path): + if not await asyncio.to_thread(os.path.exists, changelog_path): return ( Response() .error(f"Changelog for version {version} not found") .__dict__ ) - if not os.path.isfile(changelog_path): + if not await asyncio.to_thread(os.path.isfile, changelog_path): return ( Response() .error(f"Changelog for version {version} not found") .__dict__ ) - with open(changelog_path, encoding="utf-8") as f: - content = f.read() + content = await asyncio.to_thread( + Path(changelog_path).read_text, encoding="utf-8" + ) return Response().ok({"content": content, "version": version}).__dict__ except Exception as e: @@ -257,7 +263,7 @@ async def list_changelog_versions(self): project_path = get_astrbot_path() changelogs_dir = os.path.join(project_path, "changelogs") - if not os.path.exists(changelogs_dir): + if not await asyncio.to_thread(os.path.exists, changelogs_dir): return Response().ok({"versions": []}).__dict__ versions = [] diff --git a/main.py b/main.py index 36c46fca33..b8c42d78ea 100644 --- a/main.py +++ b/main.py @@ -69,13 +69,13 @@ async def check_dashboard_files(webui_dir: str | None = None): """下载管理面板文件""" # 指定webui目录 if webui_dir: - if os.path.exists(webui_dir): + if await asyncio.to_thread(os.path.exists, webui_dir): logger.info(f"使用指定的 WebUI 目录: {webui_dir}") return webui_dir logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。") data_dist_path = os.path.join(get_astrbot_data_path(), "dist") - if os.path.exists(data_dist_path): + if await asyncio.to_thread(os.path.exists, data_dist_path): v = await get_dashboard_version() if v is not None: # 存在文件 diff --git a/pyproject.toml b/pyproject.toml index d981c24708..c818d48c70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ astrbot = "astrbot.cli.__main__:cli" [tool.ruff] exclude = ["astrbot/core/utils/t2i/local_strategy.py", "astrbot/api/all.py", "tests"] line-length = 88 -target-version = "py310" +target-version = "py312" [tool.ruff.lint] select = [ @@ -99,13 +99,11 @@ ignore = [ "F403", "F405", "E501", - "ASYNC230", # TODO: handle ASYNC230 in AstrBot - "ASYNC240", # TODO: handle ASYNC240 in AstrBot ] [tool.pyright] typeCheckingMode = "basic" -pythonVersion = "3.10" +pythonVersion = "3.12" reportMissingTypeStubs = false reportMissingImports = false include = ["astrbot"] diff --git a/tests/test_skill_manager_sandbox_cache.py b/tests/test_skill_manager_sandbox_cache.py index 88923ec10b..5707148c6d 100644 --- a/tests/test_skill_manager_sandbox_cache.py +++ b/tests/test_skill_manager_sandbox_cache.py @@ -2,6 +2,8 @@ from pathlib import Path +import pytest + from astrbot.core.skills.skill_manager import SkillManager @@ -56,7 +58,7 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path) assert by_name["custom-local"].description == "local description" assert by_name["custom-local"].path == "skills/custom-local/SKILL.md" assert by_name["python-sandbox"].description == "ship built-in" - assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md" + assert by_name["python-sandbox"].path == "/workspace/skills/python-sandbox/SKILL.md" def test_sandbox_cached_skill_respects_active_and_display_path( @@ -98,7 +100,8 @@ def test_sandbox_cached_skill_respects_active_and_display_path( assert len(all_skills) == 1 assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md" - mgr.set_skill_active("browser-automation", False) + with pytest.raises(PermissionError): + mgr.set_skill_active("browser-automation", False) active_skills = mgr.list_skills(runtime="sandbox", active_only=True) - assert active_skills == [] - + assert len(active_skills) == 1 + assert active_skills[0].name == "browser-automation" diff --git a/tests/unit/test_io_file_to_base64.py b/tests/unit/test_io_file_to_base64.py new file mode 100644 index 0000000000..b490ffed86 --- /dev/null +++ b/tests/unit/test_io_file_to_base64.py @@ -0,0 +1,16 @@ +import base64 + +import pytest + +from astrbot.core.utils.io import file_to_base64 + + +@pytest.mark.asyncio +async def test_file_to_base64_reads_file_async(tmp_path): + sample_file = tmp_path / "sample.bin" + sample_file.write_bytes(b"astrbot") + + result = await file_to_base64(str(sample_file)) + + expected = "base64://" + base64.b64encode(b"astrbot").decode() + assert result == expected From 9cfd9e9fd2d8d28af23c86f377b191bc18fa669a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A8=E3=82=A4=E3=82=AB=E3=82=AF?= <62183434+zouyonghe@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:26:34 +0900 Subject: [PATCH 2/4] feat: optimize async io performance and benchmark coverage (#5737) * docs: align deployment sections across multilingual readmes * docs: normalize deployment punctuation and AUR guidance * docs: fix french and russian deployment wording * perf: optimize async io hot paths and extend benchmarks * fix: address async io review feedback * fix: address follow-up async io review comments * fix: align base64 io error handling in message components * fix: harden attachment export ids and tune io chunking * fix: preserve best-effort attachment export and batch writes * test: expand path conversion and helper coverage --- astrbot/core/backup/exporter.py | 51 +++- astrbot/core/message/components.py | 63 +++-- astrbot/core/utils/io.py | 98 ++++--- tests/fixtures/helpers.py | 20 ++ tests/performance/test_benchmarks.py | 268 ++++++++++++++++++++ tests/test_backup.py | 98 +++++++ tests/unit/test_fixture_helpers.py | 43 ++++ tests/unit/test_io_download_file.py | 71 ++++++ tests/unit/test_message_components_paths.py | 178 +++++++++++++ 9 files changed, 824 insertions(+), 66 deletions(-) create mode 100644 tests/performance/test_benchmarks.py create mode 100644 tests/unit/test_fixture_helpers.py create mode 100644 tests/unit/test_io_download_file.py create mode 100644 tests/unit/test_message_components_paths.py diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index d65ac7a843..5658bf23a2 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -161,10 +161,10 @@ async def export_all( # 3. 导出配置文件 if progress_callback: await progress_callback("config", 0, 100, "正在导出配置文件...") - if await asyncio.to_thread(os.path.exists, self.config_path): - config_content = await asyncio.to_thread( - Path(self.config_path).read_text, encoding="utf-8" - ) + config_content = await asyncio.to_thread( + self._read_text_if_exists, self.config_path + ) + if config_content is not None: zf.writestr("config/cmd_config.json", config_content) self._add_checksum("config/cmd_config.json", config_content) if progress_callback: @@ -361,17 +361,44 @@ async def _export_attachments( self, zf: zipfile.ZipFile, attachments: list[dict] ) -> None: """导出附件文件""" + await asyncio.to_thread(self._export_attachments_sync, zf, attachments) + + def _export_attachments_sync( + self, zf: zipfile.ZipFile, attachments: list[dict] + ) -> None: + """在单个线程中批量导出附件,减少高频线程切换。""" for attachment in attachments: + file_path = attachment.get("path", "") + attachment_id = attachment.get("attachment_id") try: - file_path = attachment.get("path", "") - if file_path and await asyncio.to_thread(os.path.exists, file_path): - # 使用 attachment_id 作为文件名 - attachment_id = attachment.get("attachment_id", "") - ext = os.path.splitext(file_path)[1] - archive_path = f"files/attachments/{attachment_id}{ext}" - zf.write(file_path, archive_path) + if not file_path: + continue + if not attachment_id: + logger.warning( + f"跳过附件导出:attachment_id 为空 (path={file_path})" + ) + continue + # 使用 attachment_id 作为文件名 + ext = os.path.splitext(file_path)[1] + archive_path = f"files/attachments/{attachment_id}{ext}" + zf.write(file_path, archive_path) + except FileNotFoundError: + # 和旧逻辑保持一致:缺失文件直接跳过。 + continue + except OSError as e: + logger.warning( + f"导出附件失败 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}" + ) except Exception as e: - logger.warning(f"导出附件失败: {e}") + logger.warning( + f"导出附件时发生非预期错误,已跳过 (path={file_path}, attachment_id={attachment_id or 'unknown'}): {e}" + ) + + def _read_text_if_exists(self, file_path: str) -> str | None: + """Read text file when it exists in a single synchronous call.""" + if not os.path.exists(file_path): + return None + return Path(file_path).read_text(encoding="utf-8") def _model_to_dict(self, record: Any) -> dict: """将 SQLModel 实例转换为字典 diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 038e997424..901dcd2ffb 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -40,6 +40,16 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 +def _absolute_path(path: str) -> str: + return os.path.abspath(path) + + +def _absolute_path_if_exists(path: str | None) -> str | None: + if not path or not os.path.exists(path): + return None + return os.path.abspath(path) + + class ComponentType(StrEnum): # Basic Segment Types Plain = "Plain" # plain text message @@ -159,7 +169,7 @@ async def convert_to_file_path(self) -> str: return self.file[8:] if self.file.startswith("http"): file_path = await download_image_by_url(self.file) - return await asyncio.to_thread(os.path.abspath, file_path) + return await asyncio.to_thread(_absolute_path, file_path) if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) @@ -167,9 +177,10 @@ async def convert_to_file_path(self) -> str: get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" ) await asyncio.to_thread(Path(file_path).write_bytes, image_bytes) - return await asyncio.to_thread(os.path.abspath, file_path) - if await asyncio.to_thread(os.path.exists, self.file): - return await asyncio.to_thread(os.path.abspath, self.file) + return await asyncio.to_thread(_absolute_path, file_path) + local_path = await asyncio.to_thread(_absolute_path_if_exists, self.file) + if local_path: + return local_path raise Exception(f"not a valid file: {self.file}") async def convert_to_base64(self) -> str: @@ -189,10 +200,11 @@ async def convert_to_base64(self) -> str: bs64_data = await file_to_base64(file_path) elif self.file.startswith("base64://"): bs64_data = self.file - elif await asyncio.to_thread(os.path.exists, self.file): - bs64_data = await file_to_base64(self.file) else: - raise Exception(f"not a valid file: {self.file}") + try: + bs64_data = await file_to_base64(self.file) + except OSError as exc: + raise Exception(f"not a valid file: {self.file}") from exc bs64_data = bs64_data.removeprefix("base64://") return bs64_data @@ -256,11 +268,15 @@ async def convert_to_file_path(self) -> str: get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}" ) await download_file(url, video_file_path) - if await asyncio.to_thread(os.path.exists, video_file_path): - return await asyncio.to_thread(os.path.abspath, video_file_path) + local_path = await asyncio.to_thread( + _absolute_path_if_exists, video_file_path + ) + if local_path: + return local_path raise Exception(f"download failed: {url}") - if await asyncio.to_thread(os.path.exists, url): - return await asyncio.to_thread(os.path.abspath, url) + local_path = await asyncio.to_thread(_absolute_path_if_exists, url) + if local_path: + return local_path raise Exception(f"not a valid file: {url}") async def register_to_file_service(self) -> str: @@ -449,7 +465,7 @@ async def convert_to_file_path(self) -> str: return url[8:] if url.startswith("http"): image_file_path = await download_image_by_url(url) - return await asyncio.to_thread(os.path.abspath, image_file_path) + return await asyncio.to_thread(_absolute_path, image_file_path) if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) @@ -457,9 +473,10 @@ async def convert_to_file_path(self) -> str: get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg" ) await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes) - return await asyncio.to_thread(os.path.abspath, image_file_path) - if await asyncio.to_thread(os.path.exists, url): - return await asyncio.to_thread(os.path.abspath, url) + return await asyncio.to_thread(_absolute_path, image_file_path) + local_path = await asyncio.to_thread(_absolute_path_if_exists, url) + if local_path: + return local_path raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: @@ -480,10 +497,11 @@ async def convert_to_base64(self) -> str: bs64_data = await file_to_base64(image_file_path) elif url.startswith("base64://"): bs64_data = url - elif await asyncio.to_thread(os.path.exists, url): - bs64_data = await file_to_base64(url) else: - raise Exception(f"not a valid file: {url}") + try: + bs64_data = await file_to_base64(url) + except OSError as exc: + raise Exception(f"not a valid file: {url}") from exc bs64_data = bs64_data.removeprefix("base64://") return bs64_data @@ -734,8 +752,9 @@ async def get_file(self, allow_return_url: bool = False) -> str: ): path = path[1:] - if await asyncio.to_thread(os.path.exists, path): - return await asyncio.to_thread(os.path.abspath, path) + local_path = await asyncio.to_thread(_absolute_path_if_exists, path) + if local_path: + return local_path if self.url: await self._download_file() @@ -750,7 +769,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: and path[2] == ":" ): path = path[1:] - return await asyncio.to_thread(os.path.abspath, path) + return await asyncio.to_thread(_absolute_path, path) return "" @@ -766,7 +785,7 @@ async def _download_file(self) -> None: filename = f"fileseg_{uuid.uuid4().hex}" file_path = os.path.join(download_dir, filename) await download_file(self.url, file_path) - self.file_ = await asyncio.to_thread(os.path.abspath, file_path) + self.file_ = await asyncio.to_thread(_absolute_path, file_path) async def register_to_file_service(self) -> str: """将文件注册到文件服务。 diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index d169dd32b6..d37e4fbb1b 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -9,6 +9,7 @@ import uuid import zipfile from pathlib import Path +from typing import BinaryIO import aiohttp import certifi @@ -18,6 +19,8 @@ from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path logger = logging.getLogger("astrbot") +_DOWNLOAD_READ_CHUNK_SIZE = 64 * 1024 +_DOWNLOAD_FLUSH_THRESHOLD = 256 * 1024 def on_error(func, path, exc_info) -> None: @@ -134,29 +137,18 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non if resp.status != 200: raise Exception(f"下载文件失败: {resp.status}") total_size = int(resp.headers.get("content-length", 0)) - downloaded_size = 0 start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") file_obj = await asyncio.to_thread(Path(path).open, "wb") try: - while True: - chunk = await resp.content.read(8192) - if not chunk: - break - await asyncio.to_thread(file_obj.write, chunk) - downloaded_size += len(chunk) - if show_progress: - elapsed_time = ( - time.time() - start_time - if time.time() - start_time > 0 - else 1 - ) - speed = downloaded_size / 1024 / elapsed_time # KB/s - print( - f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", - end="", - ) + await _stream_to_file( + resp.content, + file_obj, + total_size=total_size, + start_time=start_time, + show_progress=show_progress, + ) finally: await asyncio.to_thread(file_obj.close) except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): @@ -176,31 +168,73 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non async with aiohttp.ClientSession() as session: async with session.get(url, ssl=ssl_context, timeout=120) as resp: total_size = int(resp.headers.get("content-length", 0)) - downloaded_size = 0 start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") file_obj = await asyncio.to_thread(Path(path).open, "wb") try: - while True: - chunk = await resp.content.read(8192) - if not chunk: - break - await asyncio.to_thread(file_obj.write, chunk) - downloaded_size += len(chunk) - if show_progress: - elapsed_time = time.time() - start_time - speed = downloaded_size / 1024 / elapsed_time # KB/s - print( - f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", - end="", - ) + await _stream_to_file( + resp.content, + file_obj, + total_size=total_size, + start_time=start_time, + show_progress=show_progress, + ) finally: await asyncio.to_thread(file_obj.close) if show_progress: print() +async def _stream_to_file( + stream: aiohttp.StreamReader, + file_obj: BinaryIO, + *, + total_size: int, + start_time: float, + show_progress: bool, +) -> None: + """Stream HTTP response into file with buffered thread-offloaded writes.""" + downloaded_size = 0 + known_total = total_size if total_size > 0 else None + buffered = bytearray() + + try: + while True: + chunk = await stream.read(_DOWNLOAD_READ_CHUNK_SIZE) + if not chunk: + break + + buffered.extend(chunk) + downloaded_size += len(chunk) + + if len(buffered) >= _DOWNLOAD_FLUSH_THRESHOLD: + await asyncio.to_thread(file_obj.write, bytes(buffered)) + buffered.clear() + + if show_progress: + _print_download_progress(downloaded_size, known_total, start_time) + finally: + if buffered: + # Ensure buffered data is flushed even on cancellation. + await asyncio.shield(asyncio.to_thread(file_obj.write, bytes(buffered))) + + +def _print_download_progress( + downloaded_size: int, total_size: int | None, start_time: float +) -> None: + elapsed_time = max(time.time() - start_time, 1e-6) + speed = downloaded_size / 1024 / elapsed_time # KB/s + + if total_size: + percent = downloaded_size / total_size + msg = f"\r下载进度: {percent:.2%} 速度: {speed:.2f} KB/s" + else: + msg = f"\r已下载: {downloaded_size} 字节 速度: {speed:.2f} KB/s" + + print(msg, end="") + + async def file_to_base64(file_path: str) -> str: data_bytes = await asyncio.to_thread(Path(file_path).read_bytes) base64_str = base64.b64encode(data_bytes).decode() diff --git a/tests/fixtures/helpers.py b/tests/fixtures/helpers.py index 26edb761c5..5e948363ba 100644 --- a/tests/fixtures/helpers.py +++ b/tests/fixtures/helpers.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable +from urllib.parse import urlparse from unittest.mock import AsyncMock, MagicMock from astrbot.core.message.components import BaseMessageComponent @@ -24,6 +25,25 @@ def __await__(self): return None +def get_bound_tcp_port(site: Any) -> int: + """Resolve the bound aiohttp TCP site port for tests. + + We prefer the public ``site.name`` first. Some aiohttp test setups with + ephemeral ports may not expose a usable port there, so we fall back to + ``site._server.sockets`` as a test-only compatibility path. + """ + parsed = urlparse(getattr(site, "name", "")) + if parsed.port is not None and parsed.port > 0: + return parsed.port + + server = getattr(site, "_server", None) + sockets = getattr(server, "sockets", None) if server else None + if sockets: + return sockets[0].getsockname()[1] + + raise RuntimeError("Unable to resolve bound TCP port from aiohttp site") + + # ============================================================ # 平台配置工厂 # ============================================================ diff --git a/tests/performance/test_benchmarks.py b/tests/performance/test_benchmarks.py new file mode 100644 index 0000000000..f956951dd2 --- /dev/null +++ b/tests/performance/test_benchmarks.py @@ -0,0 +1,268 @@ +"""Performance benchmark tests for core AstrBot execution paths. + +Run with: + uv run pytest tests/performance/test_benchmarks.py -q -s + +Optional output: + ASTRBOT_BENCHMARK_OUTPUT=/tmp/astrbot_benchmark.json +""" + +from __future__ import annotations + +import asyncio +import json +import math +import os +import time +import zipfile +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Awaitable, Callable +from unittest.mock import MagicMock + +import pytest +from aiohttp import web + +from astrbot.core.backup.exporter import AstrBotExporter +from astrbot.core.message.components import File, Image, Record +from astrbot.core.utils.io import download_file, file_to_base64 +from tests.fixtures.helpers import get_bound_tcp_port + + +@dataclass(slots=True) +class BenchmarkResult: + name: str + iterations: int + warmup: int + min_ms: float + max_ms: float + mean_ms: float + p50_ms: float + p95_ms: float + ops_per_sec: float + + +def _percentile(values: list[float], q: float) -> float: + if not values: + return 0.0 + sorted_values = sorted(values) + if len(sorted_values) == 1: + return sorted_values[0] + rank = (len(sorted_values) - 1) * q + lower = math.floor(rank) + upper = math.ceil(rank) + if lower == upper: + return sorted_values[lower] + weight = rank - lower + return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight + + +async def run_async_benchmark( + name: str, + func: Callable[[], Awaitable[None]], + *, + iterations: int, + warmup: int = 5, +) -> BenchmarkResult: + for _ in range(warmup): + await func() + + samples_ms: list[float] = [] + for _ in range(iterations): + start_ns = time.perf_counter_ns() + await func() + elapsed_ms = (time.perf_counter_ns() - start_ns) / 1_000_000 + samples_ms.append(elapsed_ms) + + mean_ms = sum(samples_ms) / len(samples_ms) + return BenchmarkResult( + name=name, + iterations=iterations, + warmup=warmup, + min_ms=min(samples_ms), + max_ms=max(samples_ms), + mean_ms=mean_ms, + p50_ms=_percentile(samples_ms, 0.50), + p95_ms=_percentile(samples_ms, 0.95), + ops_per_sec=1000 / mean_ms if mean_ms > 0 else 0.0, + ) + + +def _print_report(results: list[BenchmarkResult]) -> None: + print("\nAstrBot Benchmark Report") + print("-" * 84) + print( + f"{'case':35} {'iters':>7} {'mean(ms)':>10} {'p50(ms)':>10} " + f"{'p95(ms)':>10} {'ops/s':>10}" + ) + print("-" * 84) + for result in results: + print( + f"{result.name:35} {result.iterations:7d} " + f"{result.mean_ms:10.4f} {result.p50_ms:10.4f} " + f"{result.p95_ms:10.4f} {result.ops_per_sec:10.1f}" + ) + + +def _scaled_iterations(value: int) -> int: + scale = int(os.environ.get("ASTRBOT_BENCHMARK_SCALE", "1")) + return max(1, value * scale) + + +@pytest.mark.asyncio +@pytest.mark.slow +async def test_core_performance_benchmarks(tmp_path: Path) -> None: + """Measure representative performance paths across core modules.""" + data = os.urandom(256 * 1024) + + payload_path = tmp_path / "payload.bin" + payload_path.write_bytes(data) + + image = Image.fromFileSystem(str(payload_path)) + record = Record.fromFileSystem(str(payload_path)) + file_component = File(name="payload.bin", file=str(payload_path)) + exists_path = tmp_path / "exists_target.txt" + exists_path.write_text("ok", encoding="utf-8") + + attachments_dir = tmp_path / "attachments" + attachments_dir.mkdir() + attachments: list[dict[str, str]] = [] + attachments_with_missing: list[dict[str, str]] = [] + for i in range(64): + file_path = attachments_dir / f"attachment_{i}.bin" + file_path.write_bytes(data[:2048]) + attachments.append({"attachment_id": f"att_{i}", "path": str(file_path)}) + if i % 4 == 0: + missing_path = attachments_dir / f"missing_{i}.bin" + attachments_with_missing.append( + {"attachment_id": f"att_missing_{i}", "path": str(missing_path)} + ) + attachments_with_missing.append( + {"attachment_id": f"att_existing_{i}", "path": str(file_path)} + ) + + exporter = AstrBotExporter(main_db=MagicMock()) + zip_path = tmp_path / "attachments_bench.zip" + micro_batch = 32 + download_target = tmp_path / "download_target.bin" + download_payload = os.urandom(512 * 1024) + + async def handle_download(_request): + return web.Response(body=download_payload) + + app = web.Application() + app.router.add_get("/download.bin", handle_download) + runner = web.AppRunner(app, access_log=None) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + port = get_bound_tcp_port(site) + download_url = f"http://127.0.0.1:{port}/download.bin" + + async def bench_file_to_base64() -> None: + await file_to_base64(str(payload_path)) + + async def bench_image_convert_to_base64() -> None: + await image.convert_to_base64() + + async def bench_record_convert_to_base64() -> None: + await record.convert_to_base64() + + async def bench_image_convert_to_file_path() -> None: + for _ in range(micro_batch): + await image.convert_to_file_path() + + async def bench_file_component_get_file() -> None: + await file_component.get_file() + + async def bench_to_thread_exists() -> None: + await asyncio.to_thread(exists_path.exists) + + async def bench_export_attachments_existing() -> None: + if zip_path.exists(): + zip_path.unlink() + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + await exporter._export_attachments(zf, attachments) + zip_path.unlink(missing_ok=True) + + async def bench_export_attachments_with_missing() -> None: + if zip_path.exists(): + zip_path.unlink() + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + await exporter._export_attachments(zf, attachments_with_missing) + zip_path.unlink(missing_ok=True) + + async def bench_download_file_local_http() -> None: + await download_file(download_url, str(download_target)) + download_target.unlink(missing_ok=True) + + try: + results = [ + await run_async_benchmark( + "utils.io.file_to_base64(256KB)", + bench_file_to_base64, + iterations=_scaled_iterations(120), + ), + await run_async_benchmark( + "components.Image.convert_to_base64", + bench_image_convert_to_base64, + iterations=_scaled_iterations(120), + ), + await run_async_benchmark( + "components.Record.convert_to_base64", + bench_record_convert_to_base64, + iterations=_scaled_iterations(120), + ), + await run_async_benchmark( + f"components.Image.convert_to_file_path(x{micro_batch})", + bench_image_convert_to_file_path, + iterations=_scaled_iterations(140), + ), + await run_async_benchmark( + "components.File.get_file(local)", + bench_file_component_get_file, + iterations=_scaled_iterations(140), + ), + await run_async_benchmark( + "asyncio.to_thread(Path.exists)", + bench_to_thread_exists, + iterations=_scaled_iterations(240), + ), + await run_async_benchmark( + "backup.exporter._export_attachments(existing)", + bench_export_attachments_existing, + iterations=_scaled_iterations(20), + warmup=2, + ), + await run_async_benchmark( + "backup.exporter._export_attachments(mixed)", + bench_export_attachments_with_missing, + iterations=_scaled_iterations(20), + warmup=2, + ), + await run_async_benchmark( + "utils.io.download_file(local_http_512KB)", + bench_download_file_local_http, + iterations=_scaled_iterations(12), + warmup=2, + ), + ] + finally: + await runner.cleanup() + + _print_report(results) + + output_path = os.environ.get("ASTRBOT_BENCHMARK_OUTPUT") + if output_path: + Path(output_path).write_text( + json.dumps([asdict(result) for result in results], indent=2), + encoding="utf-8", + ) + + # Keep assertions broad: benchmarks are for measurement, not strict gating. + assert len(results) == 9 + for result in results: + assert result.iterations > 0 + assert result.mean_ms > 0 + assert result.max_ms >= result.min_ms + assert result.p95_ms >= result.p50_ms diff --git a/tests/test_backup.py b/tests/test_backup.py index cf3c4d9494..c4049d4b1f 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -172,6 +172,15 @@ def test_add_checksum(self): assert "test.json" in exporter._checksums assert exporter._checksums["test.json"].startswith("sha256:") + def test_read_text_if_exists(self, tmp_path): + """测试 _read_text_if_exists 行为。""" + exporter = AstrBotExporter(main_db=MagicMock()) + file_path = tmp_path / "config.json" + file_path.write_text('{"k":"v"}', encoding="utf-8") + + assert exporter._read_text_if_exists(str(file_path)) == '{"k":"v"}' + assert exporter._read_text_if_exists(str(tmp_path / "missing.json")) is None + def test_generate_manifest(self, mock_main_db, mock_kb_manager): """测试生成清单""" exporter = AstrBotExporter( @@ -240,6 +249,95 @@ async def test_export_all_creates_zip( assert "databases/main_db.json" in namelist assert "config/cmd_config.json" in namelist + @pytest.mark.asyncio + async def test_export_attachments_exports_existing_and_skips_missing( + self, mock_main_db, tmp_path + ): + """测试附件导出:存在文件写入 ZIP,不存在文件跳过。""" + exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None) + + existing_file = tmp_path / "exists.txt" + existing_file.write_text("hello", encoding="utf-8") + missing_file = tmp_path / "missing.txt" + zip_path = tmp_path / "attachments.zip" + + attachments = [ + {"attachment_id": "att_ok", "path": str(existing_file)}, + {"attachment_id": "att_missing", "path": str(missing_file)}, + ] + + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + await exporter._export_attachments(zf, attachments) + + with zipfile.ZipFile(zip_path, "r") as zf: + namelist = zf.namelist() + + assert "files/attachments/att_ok.txt" in namelist + assert "files/attachments/att_missing.txt" not in namelist + + @pytest.mark.asyncio + async def test_export_attachments_skips_empty_attachment_id( + self, mock_main_db, tmp_path + ): + """测试附件导出:attachment_id 为空时跳过,避免覆盖冲突。""" + exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None) + + file_a = tmp_path / "a.txt" + file_b = tmp_path / "b.txt" + file_a.write_text("a", encoding="utf-8") + file_b.write_text("b", encoding="utf-8") + zip_path = tmp_path / "attachments_empty_id.zip" + + attachments = [ + {"attachment_id": "", "path": str(file_a)}, + {"path": str(file_b)}, + {"attachment_id": "att_ok", "path": str(file_a)}, + ] + + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + await exporter._export_attachments(zf, attachments) + + with zipfile.ZipFile(zip_path, "r") as zf: + namelist = zf.namelist() + + assert "files/attachments/att_ok.txt" in namelist + assert "files/attachments/.txt" not in namelist + + @pytest.mark.asyncio + async def test_export_attachments_keeps_best_effort_on_unexpected_write_error( + self, mock_main_db, tmp_path + ): + """测试附件导出:单个非 OSError 写入异常不会中断后续附件导出。""" + exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None) + + file_a = tmp_path / "a.txt" + file_b = tmp_path / "b.txt" + file_a.write_text("a", encoding="utf-8") + file_b.write_text("b", encoding="utf-8") + zip_path = tmp_path / "attachments_best_effort.zip" + + attachments = [ + {"attachment_id": "att_boom", "path": str(file_a)}, + {"attachment_id": "att_ok", "path": str(file_b)}, + ] + + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + original_write = zf.write + + def flaky_write(filename, arcname=None, *args, **kwargs): + if arcname == "files/attachments/att_boom.txt": + raise RuntimeError("boom") + return original_write(filename, arcname, *args, **kwargs) + + with patch.object(zf, "write", side_effect=flaky_write): + await exporter._export_attachments(zf, attachments) + + with zipfile.ZipFile(zip_path, "r") as zf: + namelist = zf.namelist() + + assert "files/attachments/att_boom.txt" not in namelist + assert "files/attachments/att_ok.txt" in namelist + class TestAstrBotImporter: """AstrBotImporter 类测试""" diff --git a/tests/unit/test_fixture_helpers.py b/tests/unit/test_fixture_helpers.py new file mode 100644 index 0000000000..7ed3fa9eef --- /dev/null +++ b/tests/unit/test_fixture_helpers.py @@ -0,0 +1,43 @@ +import pytest + +from tests.fixtures.helpers import get_bound_tcp_port + + +class _DummySiteNoAttrs: + pass + + +class _DummySocket: + def __init__(self, port: int) -> None: + self._port = port + + def getsockname(self): + return ("127.0.0.1", self._port) + + +class _DummyServer: + def __init__(self, port: int) -> None: + self.sockets = [_DummySocket(port)] + + +class _DummySiteWithName: + def __init__(self, port: int) -> None: + self.name = f"http://localhost:{port}" + + +class _DummySiteWithServer: + def __init__(self, port: int) -> None: + self._server = _DummyServer(port) + + +def test_get_bound_tcp_port_raises_on_unresolvable_site(): + with pytest.raises(RuntimeError, match="Unable to resolve bound TCP port"): + get_bound_tcp_port(_DummySiteNoAttrs()) + + +def test_get_bound_tcp_port_uses_name_port_when_available(): + assert get_bound_tcp_port(_DummySiteWithName(8081)) == 8081 + + +def test_get_bound_tcp_port_falls_back_to_server_sockets(): + assert get_bound_tcp_port(_DummySiteWithServer(9092)) == 9092 diff --git a/tests/unit/test_io_download_file.py b/tests/unit/test_io_download_file.py new file mode 100644 index 0000000000..7ccc2e403e --- /dev/null +++ b/tests/unit/test_io_download_file.py @@ -0,0 +1,71 @@ +import pytest +from aiohttp import web + +from astrbot.core.utils import io as io_module +from astrbot.core.utils.io import _stream_to_file, download_file +from tests.fixtures.helpers import get_bound_tcp_port + + +@pytest.mark.asyncio +async def test_download_file_downloads_content(tmp_path): + payload = b"astrbot-download-payload" * 256 + + async def handle(_request): + return web.Response(body=payload) + + app = web.Application() + app.router.add_get("/file.bin", handle) + runner = web.AppRunner(app, access_log=None) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + + try: + port = get_bound_tcp_port(site) + url = f"http://127.0.0.1:{port}/file.bin" + + out = tmp_path / "downloaded.bin" + await download_file(url, str(out)) + + assert out.read_bytes() == payload + finally: + await runner.cleanup() + + +class _DummyStream: + def __init__(self, chunks: list[bytes]) -> None: + self._chunks = chunks + + async def read(self, _size: int) -> bytes: + if not self._chunks: + return b"" + return self._chunks.pop(0) + + +class _RecordingFile: + def __init__(self) -> None: + self.writes: list[bytes] = [] + + def write(self, data: bytes) -> int: + self.writes.append(data) + return len(data) + + +@pytest.mark.asyncio +async def test_stream_to_file_batches_multiple_chunks_per_write(monkeypatch): + monkeypatch.setattr(io_module, "_DOWNLOAD_READ_CHUNK_SIZE", 4) + monkeypatch.setattr(io_module, "_DOWNLOAD_FLUSH_THRESHOLD", 10) + + stream = _DummyStream([b"aaaa", b"bbbb", b"cccc"]) + file_obj = _RecordingFile() + + await _stream_to_file( + stream, + file_obj, + total_size=12, + start_time=0.0, + show_progress=False, + ) + + assert len(file_obj.writes) == 1 + assert file_obj.writes[0] == b"aaaabbbbcccc" diff --git a/tests/unit/test_message_components_paths.py b/tests/unit/test_message_components_paths.py new file mode 100644 index 0000000000..c8bbc43f53 --- /dev/null +++ b/tests/unit/test_message_components_paths.py @@ -0,0 +1,178 @@ +import base64 +import os +from pathlib import Path + +import pytest +from aiohttp import web + +from astrbot.core.message import components as components_module +from astrbot.core.message.components import File, Image, Record +from tests.fixtures.helpers import get_bound_tcp_port + + +@pytest.mark.asyncio +async def test_image_convert_to_file_path_returns_absolute_path(tmp_path): + file_path = tmp_path / "img.bin" + file_path.write_bytes(b"img") + + image = Image(file=str(file_path)) + resolved = await image.convert_to_file_path() + + assert resolved == os.path.abspath(str(file_path)) + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_returns_absolute_path(tmp_path): + file_path = tmp_path / "record.bin" + file_path.write_bytes(b"record") + + record = Record(file=str(file_path)) + resolved = await record.convert_to_file_path() + + assert resolved == os.path.abspath(str(file_path)) + + +@pytest.mark.asyncio +async def test_file_component_get_file_returns_absolute_path(tmp_path): + file_path = tmp_path / "file.bin" + file_path.write_bytes(b"file") + + file_component = File(name="file.bin", file=str(file_path)) + resolved = await file_component.get_file() + + assert resolved == os.path.abspath(str(file_path)) + + +@pytest.mark.asyncio +async def test_image_convert_to_base64_raises_on_missing_file(tmp_path): + image = Image(file=str(tmp_path / "missing.bin")) + + with pytest.raises(Exception, match="not a valid file"): + await image.convert_to_base64() + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_raises_on_missing_file(tmp_path): + record = Record(file=str(tmp_path / "missing.bin")) + + with pytest.raises(Exception, match="not a valid file"): + await record.convert_to_base64() + + +@pytest.mark.asyncio +async def test_image_convert_to_base64_reads_existing_local_file(tmp_path): + raw = b"image-bytes" + file_path = tmp_path / "exists_image.bin" + file_path.write_bytes(raw) + + image = Image(file=str(file_path)) + encoded = await image.convert_to_base64() + + assert base64.b64decode(encoded) == raw + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_reads_existing_local_file(tmp_path): + raw = b"record-bytes" + file_path = tmp_path / "exists_record.bin" + file_path.write_bytes(raw) + + record = Record(file=str(file_path)) + encoded = await record.convert_to_base64() + + assert base64.b64decode(encoded) == raw + + +@pytest.mark.asyncio +async def test_image_convert_to_base64_maps_permission_error(monkeypatch): + async def _raise_permission_error(_path: str) -> str: + raise PermissionError("permission denied") + + monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error) + + image = Image(file="/tmp/forbidden-image") + with pytest.raises(Exception, match="not a valid file"): + await image.convert_to_base64() + + +@pytest.mark.asyncio +async def test_record_convert_to_base64_maps_permission_error(monkeypatch): + async def _raise_permission_error(_path: str) -> str: + raise PermissionError("permission denied") + + monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error) + + record = Record(file="/tmp/forbidden-record") + with pytest.raises(Exception, match="not a valid file"): + await record.convert_to_base64() + + +@pytest.mark.asyncio +async def test_image_convert_to_file_path_from_base64_creates_absolute_file(): + payload = b"image-base64-payload" + image = Image(file=f"base64://{base64.b64encode(payload).decode()}") + + resolved = await image.convert_to_file_path() + resolved_path = Path(resolved) + + assert resolved_path.is_absolute() + assert resolved_path.exists() + assert resolved_path.read_bytes() == payload + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_from_base64_creates_absolute_file(): + payload = b"record-base64-payload" + record = Record(file=f"base64://{base64.b64encode(payload).decode()}") + + resolved = await record.convert_to_file_path() + resolved_path = Path(resolved) + + assert resolved_path.is_absolute() + assert resolved_path.exists() + assert resolved_path.read_bytes() == payload + + +async def _serve_payload(payload: bytes, route: str): + async def handle(_request): + return web.Response(body=payload) + + app = web.Application() + app.router.add_get(route, handle) + runner = web.AppRunner(app, access_log=None) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + return runner, get_bound_tcp_port(site) + + +@pytest.mark.asyncio +async def test_image_convert_to_file_path_from_http_creates_absolute_file(): + payload = b"image-http-payload" + runner, port = await _serve_payload(payload, "/img.bin") + try: + image = Image(file=f"http://127.0.0.1:{port}/img.bin") + resolved = await image.convert_to_file_path() + resolved_path = Path(resolved) + + assert resolved_path.is_absolute() + assert resolved_path.exists() + assert resolved_path.read_bytes() == payload + finally: + await runner.cleanup() + + +@pytest.mark.asyncio +async def test_record_convert_to_file_path_from_http_creates_absolute_file(): + payload = b"record-http-payload" + runner, port = await _serve_payload(payload, "/record.bin") + try: + record = Record(file=f"http://127.0.0.1:{port}/record.bin") + resolved = await record.convert_to_file_path() + resolved_path = Path(resolved) + + assert resolved_path.is_absolute() + assert resolved_path.exists() + assert resolved_path.read_bytes() == payload + finally: + await runner.cleanup() From 516d9e26eba24ca07f5a97d4e0dfd1d9ce538034 Mon Sep 17 00:00:00 2001 From: shuiping233 <1944680304@qq.com> Date: Tue, 3 Mar 2026 16:23:13 +0800 Subject: [PATCH 3/4] =?UTF-8?q?refactor:=20=E7=BB=99kook=E9=80=82=E9=85=8D?= =?UTF-8?q?=E5=99=A8=E6=B7=BB=E5=8A=A0kook=E4=BA=8B=E4=BB=B6=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 6 - .../platform/sources/kook/kook_adapter.py | 217 ++++++---- .../core/platform/sources/kook/kook_client.py | 159 +++++--- .../core/platform/sources/kook/kook_config.py | 2 - .../core/platform/sources/kook/kook_event.py | 5 +- .../core/platform/sources/kook/kook_types.py | 374 +++++++++++++++--- .../en-US/features/config-metadata.json | 5 - .../zh-CN/features/config-metadata.json | 5 - tests/test_kook/data/kook_card_data.json | 54 +-- .../data/kook_ws_event_group_message.json | 119 ++++++ tests/test_kook/data/kook_ws_event_hello.json | 8 + .../kook_ws_event_message_with_card_1.json | 72 ++++ .../kook_ws_event_message_with_card_2.json | 79 ++++ tests/test_kook/data/kook_ws_event_ping.json | 4 + tests/test_kook/data/kook_ws_event_pong.json | 3 + .../data/kook_ws_event_private_message.json | 64 +++ .../kook_ws_event_private_system_message.json | 31 ++ .../data/kook_ws_event_reconnect_err.json | 7 + .../test_kook/data/kook_ws_event_resume.json | 4 + .../data/kook_ws_event_resume_ack.json | 6 + tests/test_kook/shared.py | 3 +- tests/test_kook/test_kook_event.py | 59 +-- tests/test_kook/test_kook_types.py | 43 +- 23 files changed, 1036 insertions(+), 293 deletions(-) create mode 100644 tests/test_kook/data/kook_ws_event_group_message.json create mode 100644 tests/test_kook/data/kook_ws_event_hello.json create mode 100644 tests/test_kook/data/kook_ws_event_message_with_card_1.json create mode 100644 tests/test_kook/data/kook_ws_event_message_with_card_2.json create mode 100644 tests/test_kook/data/kook_ws_event_ping.json create mode 100644 tests/test_kook/data/kook_ws_event_pong.json create mode 100644 tests/test_kook/data/kook_ws_event_private_message.json create mode 100644 tests/test_kook/data/kook_ws_event_private_system_message.json create mode 100644 tests/test_kook/data/kook_ws_event_reconnect_err.json create mode 100644 tests/test_kook/data/kook_ws_event_resume.json create mode 100644 tests/test_kook/data/kook_ws_event_resume_ack.json diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index cbadb5c18f..ec6bc423b4 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -454,7 +454,6 @@ class ChatProviderTemplate(TypedDict): "type": "kook", "enable": False, "kook_bot_token": "", - "kook_bot_nickname": "", "kook_reconnect_delay": 1, "kook_max_reconnect_delay": 60, "kook_max_retry_delay": 60, @@ -809,11 +808,6 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。", }, - "kook_bot_nickname": { - "description": "Bot Nickname", - "type": "string", - "hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。", - }, "kook_reconnect_delay": { "description": "重连延迟", "type": "int", diff --git a/astrbot/core/platform/sources/kook/kook_adapter.py b/astrbot/core/platform/sources/kook/kook_adapter.py index b7d047291e..52dd4fcde5 100644 --- a/astrbot/core/platform/sources/kook/kook_adapter.py +++ b/astrbot/core/platform/sources/kook/kook_adapter.py @@ -13,11 +13,28 @@ PlatformMetadata, register_platform_adapter, ) +from astrbot.core.message.components import File, Record, Video from astrbot.core.platform.astr_message_event import MessageSesion from .kook_client import KookClient from .kook_config import KookConfig from .kook_event import KookEvent +from .kook_types import ( + ContainerModule, + FileModule, + HeaderModule, + ImageGroupModule, + KmarkdownElement, + KookCardMessageContainer, + KookChannelType, + KookMessageEventData, + KookMessageType, + KookModuleType, + PlainTextElement, + SectionModule, +) + +KOOK_AT_SELECTOR_REGEX = re.compile(r"\(met\)([^()]+)\(met\)") @register_platform_adapter( @@ -57,35 +74,26 @@ def meta(self) -> PlatformMetadata: name="kook", description="KOOK 适配器", id=self.kook_config.id ) - def _should_ignore_event_by_bot_nickname(self, payload: dict) -> bool: - bot_nickname = self.kook_config.bot_nickname.strip() - if not bot_nickname: - return False - - author = payload.get("extra", {}).get("author", {}) - if not isinstance(author, dict): - return False - - author_nickname = author.get("nickname") or author.get("username") or "" - if not isinstance(author_nickname, str): - author_nickname = str(author_nickname) - - return author_nickname.strip().casefold() == bot_nickname.casefold() - - async def _on_received(self, data: dict): - logger.debug(f"KOOK 收到数据: {data}") - if "d" in data and data["s"] == 0: - payload = data["d"] - event_type = payload.get("type") - # 支持type=9(文本)和type=10(卡片) - if event_type in (9, 10): - if self._should_ignore_event_by_bot_nickname(payload): - return - try: - abm = await self.convert_message(payload) - await self.handle_msg(abm) - except Exception as e: - logger.error(f"[KOOK] 消息处理异常: {e}") + def _should_ignore_event_by_bot_nickname(self, author_id: str) -> bool: + return self.client.bot_id == author_id + + async def _on_received(self, event: KookMessageEventData): + logger.debug( + f'[KOOK] 收到来自"{event.channel_type.name}"渠道的消息, 消息类型为: {event.type.name}({event.type.value})' + ) + event_type = event.type + if event_type in (KookMessageType.KMARKDOWN, KookMessageType.CARD): + if self._should_ignore_event_by_bot_nickname(event.author_id): + logger.debug("[KOOK] 收到来自机器人自身的消息, 忽略此消息") + return + try: + abm = await self.convert_message(event) + await self.handle_msg(abm) + except Exception as e: + logger.error(f"[KOOK] 消息处理异常: {e}") + elif event_type == KookMessageType.SYSTEM: + logger.debug(f'[KOOK] 消息为系统通知, 通知类型为: "{event.extra.type}"') + logger.debug(f"[KOOK] 原始消息数据: {event.to_json()}") async def run(self): """主运行循环""" @@ -184,18 +192,26 @@ async def _cleanup(self): logger.info("[KOOK] 资源清理完成") def _parse_kmarkdown_text_message( - self, data: dict, self_id: str + self, data: KookMessageEventData, self_id: str ) -> tuple[list, str]: - kmarkdown = data.get("extra", {}).get("kmarkdown", {}) - content = data.get("content") or "" - raw_content = kmarkdown.get("raw_content") or content + kmarkdown = data.extra.kmarkdown + content = data.content or "" + if kmarkdown is None: + logger.error( + f'[KOOK] 无法转换"{KookMessageType.KMARKDOWN.name}"消息, 消息中找不到kmarkdown字段' + ) + logger.error(f"[KOOK] 原始消息内容: {data.to_json()}") + return [], "" + + raw_content = kmarkdown.raw_content or content if not isinstance(content, str): content = str(content) if not isinstance(raw_content, str): raw_content = str(raw_content) + # TODO 后面的pydantic类型替换,以后再来探索吧 :( mention_name_map: dict[str, str] = {} - mention_part = kmarkdown.get("mention_part", []) + mention_part = kmarkdown.mention_part if isinstance(mention_part, list): for item in mention_part: if not isinstance(item, dict): @@ -207,7 +223,7 @@ def _parse_kmarkdown_text_message( components = [] cursor = 0 - for match in re.finditer(r"\(met\)([^()]+)\(met\)", content): + for match in KOOK_AT_SELECTOR_REGEX.finditer(content): if match.start() > cursor: plain_text = content[cursor : match.start()] if plain_text: @@ -254,77 +270,109 @@ def _parse_kmarkdown_text_message( return components, message_str - def _parse_card_message(self, data: dict) -> tuple[list, str]: - content = data.get("content", "[]") + def _parse_card_message(self, data: KookMessageEventData) -> tuple[list, str]: + content = data.content if not isinstance(content, str): content = str(content) - card_list = json.loads(content) + + card_list = KookCardMessageContainer.from_dict(json.loads(content)) text_parts: list[str] = [] images: list[str] = [] + files: list[tuple[KookModuleType, str, str]] = [] for card in card_list: - if not isinstance(card, dict): - continue - for module in card.get("modules", []): - if not isinstance(module, dict): - continue + for module in card.modules: + match module: + case SectionModule(): + if content := self._handle_section_text(module): + text_parts.append(content) - module_type = module.get("type") - if module_type == "section": - section_text = module.get("text", {}).get("content", "") - if section_text: - text_parts.append(str(section_text)) - continue + case ContainerModule() | ImageGroupModule(): + urls = self._handle_image_group(module) + images.extend(urls) + text_parts.append(" [image]" * len(urls)) - if module_type != "container": - continue + case HeaderModule(): + text_parts.append(module.text.content) - for element in module.get("elements", []): - if not isinstance(element, dict): - continue - if element.get("type") != "image": - continue + case FileModule(): + files.append((module.type, module.title, module.src)) + text_parts.append(f" [{module.type.value}]") - image_src = element.get("src") - if not isinstance(image_src, str): - logger.warning( - f'[KOOK] 处理卡片中的图片时发生错误,图片url "{image_src}" 应该为str类型, 而不是 "{type(image_src)}" ' - ) - continue - if not image_src.startswith(("http://", "https://")): - logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}") - continue - images.append(image_src) + case _: + logger.debug(f"[KOOK] 跳过或未处理模块: {module.type}") text = "".join(text_parts) message = [] + if text: + for search in KOOK_AT_SELECTOR_REGEX.finditer(text): + search_text = search.group(1).strip() + if search_text == "all": + message.append(AtAll()) + continue + message.append(At(qq=search_text)) + text = text.replace(f"(met){search_text}(met)", "") + message.append(Plain(text=text)) + for img_url in images: message.append(Image(file=img_url)) + for file in files: + file_type = file[0] + file_name = file[1] + file_url = file[2] + if file_type == KookModuleType.FILE: + message.append(File(name=file_name, file=file_url)) + elif file_type == KookModuleType.VIDEO: + message.append(Video(file=file_url)) + elif file_type == KookModuleType.AUDIO: + message.append(Record(file=file_url)) + else: + logger.warning(f"[KOOK] 跳过未知文件类型: {file_type.name}") + return message, text - async def convert_message(self, data: dict) -> AstrBotMessage: + def _handle_section_text(self, module: SectionModule) -> str: + """专门处理 Section 里的文本提取""" + if isinstance(module.text, (KmarkdownElement, PlainTextElement)): + return module.text.content or "" + return "" + + def _handle_image_group( + self, module: ContainerModule | ImageGroupModule + ) -> list[str]: + """专门处理图片组/容器里的合法 URL 提取""" + valid_urls = [] + for el in module.elements: + image_src = el.src + if not el.src.startswith(("http://", "https://")): + logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}") + continue + valid_urls.append(el.src) + return valid_urls + + async def convert_message(self, data: KookMessageEventData) -> AstrBotMessage: abm = AstrBotMessage() - abm.raw_message = data + abm.raw_message = data.to_dict() abm.self_id = self.client.bot_id - channel_type = data.get("channel_type") - author_id = data.get("author_id", "unknown") + channel_type = data.channel_type + author_id = data.author_id # channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction match channel_type: - case "GROUP": - session_id = data.get("target_id") or "unknown" + case KookChannelType.GROUP: + session_id = data.target_id or "unknown" abm.type = MessageType.GROUP_MESSAGE abm.group_id = session_id abm.session_id = session_id - case "PERSON": + case KookChannelType.PERSON: abm.type = MessageType.FRIEND_MESSAGE abm.group_id = "" - abm.session_id = data.get("author_id", "unknown") - case "BROADCAST": - session_id = data.get("target_id") or "unknown" + abm.session_id = data.author_id or "unknown" + case KookChannelType.BROADCAST: + session_id = data.target_id or "unknown" abm.type = MessageType.OTHER_MESSAGE abm.group_id = session_id abm.session_id = session_id @@ -333,28 +381,25 @@ async def convert_message(self, data: dict) -> AstrBotMessage: abm.sender = MessageMember( user_id=author_id, - nickname=data.get("extra", {}).get("author", {}).get("username", ""), + nickname=data.extra.author.username if data.extra.author else "unknown", ) - abm.message_id = data.get("msg_id", "unknown") + abm.message_id = data.msg_id or "unknown" - # 普通文本消息 - if data.get("type") == 9: - message, message_str = self._parse_kmarkdown_text_message( - data, str(abm.self_id) - ) + if data.type == KookMessageType.KMARKDOWN: + message, message_str = self._parse_kmarkdown_text_message(data, abm.self_id) abm.message = message abm.message_str = message_str - # 卡片消息 - elif data.get("type") == 10: + elif data.type == KookMessageType.CARD: try: abm.message, abm.message_str = self._parse_card_message(data) except Exception as exp: logger.error(f"[KOOK] 卡片消息解析失败: {exp}") + logger.error(f"[KOOK] 原始消息内容: {data.to_json()}") abm.message_str = "[卡片消息解析失败]" abm.message = [Plain(text="[卡片消息解析失败]")] else: - logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.get("type")}"') + logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.type.name}"') abm.message_str = "[不支持的消息类型]" abm.message = [Plain(text="[不支持的消息类型]")] diff --git a/astrbot/core/platform/sources/kook/kook_client.py b/astrbot/core/platform/sources/kook/kook_client.py index 34078e2ac2..fee7e08f1a 100644 --- a/astrbot/core/platform/sources/kook/kook_client.py +++ b/astrbot/core/platform/sources/kook/kook_client.py @@ -1,6 +1,5 @@ import asyncio import base64 -import json import os import random import time @@ -9,13 +8,23 @@ import aiofiles import aiohttp +import pydantic import websockets from astrbot import logger from astrbot.core.platform.message_type import MessageType from .kook_config import KookConfig -from .kook_types import KookApiPaths, KookMessageType +from .kook_types import ( + KookApiPaths, + KookGatewayIndexResponse, + KookHelloEventData, + KookMessageSignal, + KookMessageType, + KookResumeAckEventData, + KookUserMeResponse, + KookWebsocketEvent, +) class KookClient: @@ -23,7 +32,8 @@ def __init__(self, config: KookConfig, event_callback): # 数据字段 self.config = config self._bot_id = "" - self._bot_name = "" + self._bot_username = "" + self._bot_nickname = "" # 资源字段 self._http_client = aiohttp.ClientSession( @@ -48,37 +58,50 @@ def bot_id(self): return self._bot_id @property - def bot_name(self): - return self._bot_name + def bot_nickname(self): + return self._bot_nickname - async def get_bot_info(self) -> str: - """获取机器人账号ID""" + @property + def bot_username(self): + return self._bot_username + + async def get_bot_info(self) -> None: + """获取机器人账号信息""" url = KookApiPaths.USER_ME try: async with self._http_client.get(url) as resp: if resp.status != 200: - logger.error(f"[KOOK] 获取机器人账号ID失败,状态码: {resp.status}") - return "" + logger.error( + f"[KOOK] 获取机器人账号信息失败,状态码: {resp.status} , {await resp.text()}" + ) + return + try: + resp_content = KookUserMeResponse.from_dict(await resp.json()) + except pydantic.ValidationError as e: + logger.error( + f"[KOOK] 获取机器人账号信息失败, 响应数据格式错误: \n{e}" + ) + logger.error(f"[KOOK] 响应内容: {await resp.text()}") + return - data = await resp.json() - if data.get("code") != 0: - logger.error(f"[KOOK] 获取机器人账号ID失败: {data}") - return "" + if not resp_content.success(): + logger.error( + f"[KOOK] 获取机器人账号信息失败: {resp_content.model_dump_json()}" + ) + return - bot_id: str = data["data"]["id"] + bot_id: str = resp_content.data.id self._bot_id = bot_id logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}") - bot_name: str = data["data"]["nickname"] or data["data"]["username"] - self._bot_name = bot_name - logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_name}") + self._bot_nickname = resp_content.data.nickname + self._bot_username = resp_content.data.username + logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_nickname}") - return bot_id except Exception as e: - logger.error(f"[KOOK] 获取机器人账号ID异常: {e}") - return "" + logger.error(f"[KOOK] 获取机器人账号信息异常: {e}") - async def get_gateway_url(self, resume=False, sn=0, session_id=None): + async def get_gateway_url(self, resume=False, sn=0, session_id=None) -> str | None: """获取网关连接地址""" url = KookApiPaths.GATEWAY_INDEX @@ -96,14 +119,20 @@ async def get_gateway_url(self, resume=False, sn=0, session_id=None): logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}") return None - data = await resp.json() - if data.get("code") != 0: - logger.error(f"[KOOK] 获取gateway失败: {data}") + resp_content = KookGatewayIndexResponse.from_dict(await resp.json()) + if not resp_content.success(): + logger.error(f"[KOOK] 获取gateway失败: {resp_content}") return None - gateway_url: str = data["data"]["url"] + gateway_url: str = resp_content.data.url logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}") return gateway_url + + except pydantic.ValidationError as e: + logger.error(f"[KOOK] 获取gateway失败, 响应数据格式错误: \n{e}") + logger.error(f"[KOOK] 原始响应内容: {await resp.text()}") + return None + except Exception as e: logger.error(f"[KOOK] 获取gateway异常: {e}") return None @@ -156,7 +185,11 @@ async def listen(self): try: while self.running: try: - msg = await asyncio.wait_for(self.ws.recv(), timeout=10) # type: ignore + if self.ws is None: + logger.error("[KOOK] WebSocket 对象丢失,结束监听流程。") + break + + msg = await asyncio.wait_for(self.ws.recv(), timeout=10) if isinstance(msg, bytes): try: @@ -166,10 +199,15 @@ async def listen(self): continue msg = msg.decode("utf-8") - data = json.loads(msg) + event = KookWebsocketEvent.from_json(msg) # 处理不同类型的信令 - await self._handle_signal(data) + await self._handle_signal(event) + + except pydantic.ValidationError as e: + logger.error(f"[KOOK] 解析WebSocket事件数据格式失败: \n{e}") + logger.error(f"[KOOK] 原始响应内容: {msg}") + continue except TimeoutError: # 超时检查,继续循环 @@ -187,38 +225,41 @@ async def listen(self): self.running = False self._stop_event.set() - async def _handle_signal(self, data): + async def _handle_signal(self, event: KookWebsocketEvent): """处理不同类型的信令""" - signal_type = data.get("s") + data = event.data - if signal_type == 0: # 事件消息 - # 更新消息序号 - if "sn" in data: - self.last_sn = data["sn"] - await self.event_callback(data) + match event.signal: + case KookMessageSignal.MESSAGE: + if event.sn is not None: + self.last_sn = event.sn + await self.event_callback(data) - elif signal_type == 1: # HELLO握手 - await self._handle_hello(data) + case KookMessageSignal.HELLO: + assert isinstance(data, KookHelloEventData) + await self._handle_hello(data) - elif signal_type == 3: # PONG心跳响应 - await self._handle_pong(data) + case KookMessageSignal.RESUME_ACK: + assert isinstance(data, KookResumeAckEventData) + await self._handle_resume_ack(data) - elif signal_type == 5: # RECONNECT重连指令 - await self._handle_reconnect(data) + case KookMessageSignal.PONG: + await self._handle_pong() - elif signal_type == 6: # RESUME ACK - await self._handle_resume_ack(data) + case KookMessageSignal.RECONNECT: + await self._handle_reconnect() - else: - logger.debug(f"[KOOK] 未处理的信令类型: {signal_type}") + case _: + logger.debug( + f"[KOOK] 未处理的信令类型: {event.signal.name}({event.signal.value})" + ) - async def _handle_hello(self, data): + async def _handle_hello(self, data: KookHelloEventData): """处理HELLO握手""" - hello_data = data.get("d", {}) - code = hello_data.get("code", 0) + code = data.code if code == 0: - self.session_id = hello_data.get("session_id") + self.session_id = data.session_id logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}") # TODO 重置重连延迟 # self.reconnect_delay = 1 @@ -228,12 +269,12 @@ async def _handle_hello(self, data): logger.error("[KOOK] Token已过期,需要重新获取") self.running = False - async def _handle_pong(self, data): + async def _handle_pong(self): """处理PONG心跳响应""" self.last_heartbeat_time = time.time() self.heartbeat_failed_count = 0 - async def _handle_reconnect(self, data): + async def _handle_reconnect(self): """处理重连指令""" logger.warning("[KOOK] 收到重连指令") # 清空本地状态 @@ -241,10 +282,9 @@ async def _handle_reconnect(self, data): self.session_id = None self.running = False - async def _handle_resume_ack(self, data): + async def _handle_resume_ack(self, data: KookResumeAckEventData): """处理RESUME确认""" - resume_data = data.get("d", {}) - self.session_id = resume_data.get("session_id") + self.session_id = data.session_id logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}") async def _heartbeat_loop(self): @@ -292,9 +332,16 @@ async def _heartbeat_loop(self): async def _send_ping(self): """发送心跳PING""" + if self.ws is None: + logger.warning("[KOOK] 尚未连接kook WebSocket服务器, 跳过发送心跳包流程") + return try: - ping_data = {"s": 2, "sn": self.last_sn} - await self.ws.send(json.dumps(ping_data)) # type: ignore + ping_data = KookWebsocketEvent( + signal=KookMessageSignal.PING, + data=None, + sn=self.last_sn, + ) + await self.ws.send(ping_data.to_json()) except Exception as e: logger.error(f"[KOOK] 发送心跳失败: {e}") diff --git a/astrbot/core/platform/sources/kook/kook_config.py b/astrbot/core/platform/sources/kook/kook_config.py index 21f2547b03..0b9d180a29 100644 --- a/astrbot/core/platform/sources/kook/kook_config.py +++ b/astrbot/core/platform/sources/kook/kook_config.py @@ -9,7 +9,6 @@ class KookConfig: # 基础配置 token: str - bot_nickname: str = "" enable: bool = False id: str = "kook" @@ -41,7 +40,6 @@ def from_dict(cls, config_dict: dict) -> "KookConfig": # id=config_dict.get("id", "kook"), enable=config_dict.get("enable", False), token=config_dict.get("kook_bot_token", ""), - bot_nickname=config_dict.get("kook_bot_nickname", ""), reconnect_delay=config_dict.get( "kook_reconnect_delay", KookConfig.reconnect_delay, diff --git a/astrbot/core/platform/sources/kook/kook_event.py b/astrbot/core/platform/sources/kook/kook_event.py index 12f72a9790..884d066d8d 100644 --- a/astrbot/core/platform/sources/kook/kook_event.py +++ b/astrbot/core/platform/sources/kook/kook_event.py @@ -27,6 +27,7 @@ KookCardMessage, KookCardMessageContainer, KookMessageType, + KookModuleType, OrderMessage, ) @@ -111,7 +112,7 @@ async def handle_audio(index: int, f_item: Record): KookCardMessage( modules=[ FileModule( - type="audio", + type=KookModuleType.AUDIO, title=title, src=url, ) @@ -182,7 +183,7 @@ async def send(self, message: MessageChain): if item.reply_id: reply_id = item.reply_id if not item.text: - logger.debug(f'[Kook] 跳过空消息,类型为"{item.type}"') + logger.debug(f'[Kook] 跳过空消息,类型为"{item.type.name}"') continue try: await self.client.send_text( diff --git a/astrbot/core/platform/sources/kook/kook_types.py b/astrbot/core/platform/sources/kook/kook_types.py index dd18ac00f1..5efaf2a14c 100644 --- a/astrbot/core/platform/sources/kook/kook_types.py +++ b/astrbot/core/platform/sources/kook/kook_types.py @@ -1,10 +1,8 @@ import json -from dataclasses import field -from enum import IntEnum -from typing import Literal +from enum import Enum, IntEnum +from typing import Annotated, Any, Literal -from pydantic import BaseModel, ConfigDict -from pydantic.dataclasses import dataclass +from pydantic import BaseModel, ConfigDict, Field, model_validator class KookApiPaths: @@ -25,8 +23,9 @@ class KookApiPaths: DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create" -# 定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction class KookMessageType(IntEnum): + """定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction""" + TEXT = 1 IMAGE = 2 VIDEO = 3 @@ -37,6 +36,26 @@ class KookMessageType(IntEnum): SYSTEM = 255 +class KookModuleType(str, Enum): + PLAIN_TEXT = "plain-text" + KMARKDOWN = "kmarkdown" + IMAGE = "image" + BUTTON = "button" + HEADER = "header" + SECTION = "section" + IMAGE_GROUP = "image-group" + CONTAINER = "container" + ACTION_GROUP = "action-group" + CONTEXT = "context" + DIVIDER = "divider" + FILE = "file" + AUDIO = "audio" + VIDEO = "video" + COUNTDOWN = "countdown" + INVITE = "invite" + CARD = "card" + + ThemeType = Literal[ "primary", "success", "danger", "warning", "info", "secondary", "none", "invisible" ] @@ -48,43 +67,81 @@ class KookMessageType(IntEnum): CountdownMode = Literal["day", "hour", "second"] -class KookCardColor(str): - """16 进制色值""" +class KookBaseDataClass(BaseModel): + model_config = ConfigDict( + extra="allow", + arbitrary_types_allowed=True, + populate_by_name=True, + ) + + @classmethod + def from_dict(cls, raw_data: dict): + return cls.model_validate(raw_data) + + @classmethod + def from_json(cls, raw_data: str | bytes | bytearray): + return cls.model_validate_json(raw_data) + + def to_dict( + self, + mode: Literal["json", "python"] | str = "python", + by_alias=True, + exclude_none=True, + exclude_unset=False, + ) -> dict: + return self.model_dump( + by_alias=by_alias, + exclude_none=exclude_none, + mode=mode, + exclude_unset=exclude_unset, + ) + + def to_json( + self, + indent: int | None = None, + ensure_ascii=False, + by_alias=True, + exclude_none=True, + exclude_unset=False, + ) -> str: + return self.model_dump_json( + indent=indent, + ensure_ascii=ensure_ascii, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + ) -class KookCardModelBase: +class KookCardModelBase(KookBaseDataClass): """卡片模块基类""" type: str -@dataclass class PlainTextElement(KookCardModelBase): content: str - type: str = "plain-text" + type: Literal[KookModuleType.PLAIN_TEXT] = KookModuleType.PLAIN_TEXT emoji: bool = True -@dataclass class KmarkdownElement(KookCardModelBase): content: str - type: str = "kmarkdown" + type: Literal[KookModuleType.KMARKDOWN] = KookModuleType.KMARKDOWN -@dataclass class ImageElement(KookCardModelBase): src: str - type: str = "image" + type: Literal[KookModuleType.IMAGE] = KookModuleType.IMAGE alt: str = "" size: SizeType = "lg" circle: bool = False fallbackUrl: str | None = None -@dataclass class ButtonElement(KookCardModelBase): text: str - type: str = "button" + type: Literal[KookModuleType.BUTTON] = KookModuleType.BUTTON theme: ThemeType = "primary" value: str = "" """当为 link 时,会跳转到 value 代表的链接; @@ -96,93 +153,88 @@ class ButtonElement(KookCardModelBase): AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str -@dataclass class ParagraphStructure(KookCardModelBase): fields: list[PlainTextElement | KmarkdownElement] - type: str = "paragraph" + type: Literal["paragraph"] = "paragraph" cols: int = 1 """范围是 1-3 , 移动端忽略此参数""" -@dataclass class HeaderModule(KookCardModelBase): text: PlainTextElement - type: str = "header" + type: Literal[KookModuleType.HEADER] = KookModuleType.HEADER -@dataclass class SectionModule(KookCardModelBase): text: PlainTextElement | KmarkdownElement | ParagraphStructure - type: str = "section" + type: Literal[KookModuleType.SECTION] = KookModuleType.SECTION mode: SectionMode = "left" accessory: ImageElement | ButtonElement | None = None -@dataclass class ImageGroupModule(KookCardModelBase): """1 到多张图片的组合""" elements: list[ImageElement] - type: str = "image-group" + type: Literal[KookModuleType.IMAGE_GROUP] = KookModuleType.IMAGE_GROUP -@dataclass class ContainerModule(KookCardModelBase): """1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。""" elements: list[ImageElement] - type: str = "container" + type: Literal[KookModuleType.CONTAINER] = KookModuleType.CONTAINER -@dataclass class ActionGroupModule(KookCardModelBase): + """用来放按钮的模块""" + elements: list[ButtonElement] - type: str = "action-group" + type: Literal[KookModuleType.ACTION_GROUP] = KookModuleType.ACTION_GROUP -@dataclass class ContextModule(KookCardModelBase): elements: list[PlainTextElement | KmarkdownElement | ImageElement] """最多包含10个元素""" - type: str = "context" + type: Literal[KookModuleType.CONTEXT] = KookModuleType.CONTEXT -@dataclass class DividerModule(KookCardModelBase): - type: str = "divider" + """展示分割线用的""" + + type: Literal[KookModuleType.DIVIDER] = KookModuleType.DIVIDER -@dataclass class FileModule(KookCardModelBase): src: str title: str = "" - type: Literal["file", "audio", "video"] = "file" + type: Literal[KookModuleType.FILE, KookModuleType.AUDIO, KookModuleType.VIDEO] = ( + KookModuleType.FILE + ) cover: str | None = None """cover 仅音频有效, 是音频的封面图""" -@dataclass class CountdownModule(KookCardModelBase): """startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。""" endTime: int """毫秒时间戳""" - type: str = "countdown" + type: Literal[KookModuleType.COUNTDOWN] = KookModuleType.COUNTDOWN startTime: int | None = None """毫秒时间戳, 仅当mode为second才有这个字段""" mode: CountdownMode = "day" """mode 主要是倒计时的样式""" -@dataclass class InviteModule(KookCardModelBase): code: str """邀请链接或者邀请码""" - type: str = "invite" + type: Literal[KookModuleType.INVITE] = KookModuleType.INVITE # 所有模块的联合类型 -AnyModule = ( +AnyModule = Annotated[ HeaderModule | SectionModule | ImageGroupModule @@ -192,34 +244,29 @@ class InviteModule(KookCardModelBase): | DividerModule | FileModule | CountdownModule - | InviteModule -) + | InviteModule, + Field(discriminator="type"), +] -class KookCardMessage(BaseModel): +class KookCardMessage(KookBaseDataClass): """卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage 此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表** 若要发送卡片消息,请使用KookCardMessageContainer """ model_config = ConfigDict(arbitrary_types_allowed=True) - type: str = "card" + type: Literal[KookModuleType.CARD] = KookModuleType.CARD theme: ThemeType | None = None size: SizeType | None = None - color: KookCardColor | None = None - modules: list[AnyModule] = field(default_factory=list) + color: str | None = None + """16 进制色值""" + modules: list[AnyModule] = Field(default_factory=list) """单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50""" def add_module(self, module: AnyModule): self.modules.append(module) - def to_dict(self, exclude_none: bool = True): - """exclude_none:去掉值为 None 字段,保留结构""" - return self.model_dump(exclude_none=exclude_none) - - def to_json(self, indent: int | None = None, ensure_ascii: bool = True): - return json.dumps(self.to_dict(), indent=indent, ensure_ascii=ensure_ascii) - class KookCardMessageContainer(list[KookCardMessage]): """卡片消息容器(列表),此类型可以直接to_json后发送出去""" @@ -232,10 +279,227 @@ def to_json(self, indent: int | None = None, ensure_ascii: bool = True) -> str: [i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii ) + @classmethod + def from_dict(cls, raw_data: list[dict[str, Any]]): + return cls(KookCardMessage.from_dict(item) for item in raw_data) + -@dataclass -class OrderMessage: +class OrderMessage(BaseModel): index: int text: str type: KookMessageType reply_id: str | int = "" + + +class KookMessageSignal(IntEnum): + """KOOK WebSocket 信令类型 + ws文档: https://developer.kookapp.cn/doc/websocket""" # noqa: W291 + + MESSAGE = 0 + """server->client 消息(s包含聊天和通知消息)""" + HELLO = 1 + """server->client 客户端连接 ws 时, 服务端返回握手结果""" + PING = 2 + """client->server 心跳,ping""" + PONG = 3 + """server->client 心跳,pong""" + RESUME = 4 + """client->server resume, 恢复会话""" + RECONNECT = 5 + """server->client reconnect, 要求客户端断开当前连接重新连接""" + RESUME_ACK = 6 + """server->client resume ack""" + + +class KookChannelType(str, Enum): + GROUP = "GROUP" + PERSON = "PERSON" + BROADCAST = "BROADCAST" + + +class KookAuthor(KookBaseDataClass): + id: str + username: str + identify_num: str + nickname: str + bot: bool + online: bool + avatar: str | None = None + vip_avatar: str | None = None + status: int + roles: list[int] = Field(default_factory=list) + + +class KookKMarkdown(KookBaseDataClass): + raw_content: str + mention_part: list[Any] = Field(default_factory=list) + mention_role_part: list[Any] = Field(default_factory=list) + + +class KookExtra(KookBaseDataClass): + type: int | str + code: str | None = None + body: dict[str, Any] | None = None + author: KookAuthor | None = None + kmarkdown: KookKMarkdown | None = None + last_msg_content: str | None = None + mention: list[str] = Field(default_factory=list) + mention_all: bool = False + mention_here: bool = False + + +class KookMessageEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.MESSAGE] = Field( + KookMessageSignal.MESSAGE, exclude=True + ) + """only for type hint""" + + channel_type: KookChannelType + type: KookMessageType + target_id: str + author_id: str + content: str | dict[str, Any] + msg_id: str + msg_timestamp: int + nonce: str + from_type: int + extra: KookExtra + + +class KookHelloEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.HELLO] = Field( + KookMessageSignal.HELLO, exclude=True + ) + """only for type hint""" + + code: int + session_id: str + + +class KookPingEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.PING] = Field( + KookMessageSignal.PING, exclude=True + ) + """only for type hint""" + + +class KookPongEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.PONG] = Field( + KookMessageSignal.PONG, exclude=True + ) + """only for type hint""" + + +class KookResumeEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.RESUME] = Field( + KookMessageSignal.RESUME, exclude=True + ) + """only for type hint""" + + +class KookReconnectEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.RECONNECT] = Field( + KookMessageSignal.RECONNECT, exclude=True + ) + """only for type hint""" + + code: int + err: str + + +class KookResumeAckEventData(KookBaseDataClass): + signal: Literal[KookMessageSignal.RESUME_ACK] = Field( + KookMessageSignal.RESUME_ACK, exclude=True + ) + """only for type hint""" + + session_id: str + + +class KookWebsocketEvent(KookBaseDataClass): + """KOOK WebSocket 原始推送结构""" + + signal: KookMessageSignal = Field( + ..., validation_alias="s", serialization_alias="s" + ) + """信令类型""" + data: Annotated[ + KookMessageEventData + | KookHelloEventData + | KookPingEventData + | KookPongEventData + | KookResumeEventData + | KookReconnectEventData + | KookResumeAckEventData + | None, + Field(discriminator="signal"), + ] = Field(None, validation_alias="d", serialization_alias="d") + """数据事件主体,对应原字段是'd'""" + sn: int | None = None + """消息序号 , 用来确定消息顺序和ws重连时使用 + 详见ws连接流程文档: https://developer.kookapp.cn/doc/websocket#%E8%BF%9E%E6%8E%A5%E6%B5%81%E7%A8%8B""" # noqa: W291 + + @model_validator(mode="before") + @classmethod + def _inject_signal_into_data(cls, data: Any) -> Any: + """在解析前,把外层的 s 同步到内层的 d 中,供 discriminator 使用""" + if isinstance(data, dict): + s_value = data.get("s") + d_value = data.get("d") + if s_value is not None and isinstance(d_value, dict): + d_value["signal"] = s_value + return data + + +class KookUserTag(KookBaseDataClass): + color: str + bg_color: str + text: str + + +class KookApiResponseBase(KookBaseDataClass): + code: int + message: str + data: Any + + def success(self) -> bool: + return self.code == 0 + + +class KookUserMeData(KookBaseDataClass): + """USER_ME 接口返回的 'data' 字段主体""" + + id: str + username: str + identify_num: str + nickname: str + bot: bool + online: bool + status: int + bot_status: int + avatar: str + vip_avatar: str | None = None + banner: str | None = None + roles: list[Any] = Field(default_factory=list) + is_vip: bool + vip_amp: bool + wealth_level: int + mobile_verified: bool + client_id: str + tag_info: KookUserTag | None = None + + +class KookUserMeResponse(KookApiResponseBase): + """USER_ME 完整响应结构""" + + data: KookUserMeData + + +class KookGatewayIndexData(KookBaseDataClass): + url: str + + +class KookGatewayIndexResponse(KookApiResponseBase): + """USER_ME 完整响应结构""" + + data: KookGatewayIndexData diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index a143678c23..b69225b078 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -590,11 +590,6 @@ "type": "string", "hint": "Required. The Bot Token obtained from the KOOK Developer Platform." }, - "kook_bot_nickname": { - "description": "Bot Nickname", - "type": "string", - "hint": "Optional. If the sender nickname matches this value, the message will be ignored to prevent broadcast storms." - }, "kook_reconnect_delay": { "description": "Reconnect Delay", "type": "int", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 015ce3082c..e5ffa96df7 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -593,11 +593,6 @@ "type": "string", "hint": "必填项。从 KOOK 开发者平台获取的机器人 Token" }, - "kook_bot_nickname": { - "description": "Bot Nickname", - "type": "string", - "hint": "可选项。若发送者昵称与此值一致,将忽略该消息。" - }, "kook_reconnect_delay": { "description": "重连延迟", "type": "int", diff --git a/tests/test_kook/data/kook_card_data.json b/tests/test_kook/data/kook_card_data.json index f19bb40800..a142318e46 100644 --- a/tests/test_kook/data/kook_card_data.json +++ b/tests/test_kook/data/kook_card_data.json @@ -4,97 +4,97 @@ "size": "lg", "modules": [ { + "type": "header", "text": { - "content": "test1", "type": "plain-text", + "content": "test1", "emoji": true - }, - "type": "header" + } }, { + "type": "section", "text": { - "content": "test2", - "type": "kmarkdown" + "type": "kmarkdown", + "content": "test2" }, - "type": "section", "mode": "left" }, { "type": "divider" }, { + "type": "section", "text": { + "type": "paragraph", "fields": [ { - "content": "test3", - "type": "kmarkdown" + "type": "kmarkdown", + "content": "test3" }, { - "content": "**test4**", - "type": "kmarkdown" + "type": "kmarkdown", + "content": "**test4**" } ], - "type": "paragraph", "cols": 2 }, - "type": "section", "mode": "left" }, { + "type": "image-group", "elements": [ { - "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", "type": "image", + "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", "alt": "", "size": "lg", "circle": false } - ], - "type": "image-group" + ] }, { + "type": "file", "src": "https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", - "title": "test5", - "type": "file" + "title": "test5" }, { - "endTime": 1772343427360, "type": "countdown", + "endTime": 1772343427360, "startTime": 1772343378259, "mode": "second" }, { + "type": "action-group", "elements": [ { - "text": "点我测试回调", "type": "button", + "text": "点我测试回调", "theme": "primary", "value": "btn_clicked", "click": "return-val" }, { - "text": "访问官网", "type": "button", + "text": "访问官网", "theme": "danger", "value": "https://www.kookapp.cn", "click": "link" } - ], - "type": "action-group" + ] }, { + "type": "context", "elements": [ { - "content": "test6", "type": "plain-text", + "content": "test6", "emoji": true } - ], - "type": "context" + ] }, { - "code": "test7", - "type": "invite" + "type": "invite", + "code": "test7" } ] } \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_group_message.json b/tests/test_kook/data/kook_ws_event_group_message.json new file mode 100644 index 0000000000..dcab6e901c --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_group_message.json @@ -0,0 +1,119 @@ +{ + "s": 0, + "d": { + "channel_type": "GROUP", + "type": 9, + "target_id": "2732467349811313213", + "author_id": "7324688132731983", + "content": "done!", + "extra": { + "quote": { + "id": "69a788adb0cfb9ece50eae1c", + "rong_id": "7baef72c-0cd7-49ad-9592-1615236136cb", + "type": 9, + "content": "/am 1", + "interact_res": null, + "create_at": 1772587180973, + "author": { + "id": "2701973210937821093781", + "username": "some_username", + "identify_num": "4198", + "online": true, + "os": "Websocket", + "status": 1, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "some_username", + "roles": [ + 63724577 + ], + "is_vip": false, + "vip_amp": false, + "bot": false, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "decorations_id_map": null, + "mobile_verified": true, + "is_sys": false, + "joined_at": 1772259607000, + "active_time": 1772587181304 + }, + "can_jump": true, + "preview_content": null, + "kmarkdown": { + "mention_part": [], + "mention_role_part": [], + "channel_part": [], + "item_part": [] + } + }, + "type": 9, + "code": "", + "guild_id": "273902183210983210983", + "guild_type": 0, + "channel_name": "聊天大厅", + "author": { + "id": "7324688132731983", + "username": "Bot_Test", + "identify_num": "9561", + "online": true, + "os": "Websocket", + "status": 0, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "Bot_Test", + "roles": [ + 63725384 + ], + "is_vip": false, + "vip_amp": false, + "bot": true, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "bot_status": 0, + "tag_info": { + "color": "#0096FF", + "bg_color": "#0096FF33", + "text": "机器人" + }, + "is_sys": false, + "client_id": "sAdiIHoGhdSFUOA", + "verified": false + }, + "visible_only": "", + "mention": [], + "mention_no_at": [], + "mention_all": false, + "mention_roles": [], + "mention_here": false, + "nav_channels": [], + "kmarkdown": { + "raw_content": "done!", + "mention_part": [], + "mention_role_part": [], + "channel_part": [], + "spl": [] + }, + "emoji": [], + "preview_content": "", + "channel_type": 1, + "last_msg_content": "Bot_Test:done!", + "send_msg_device": 0 + }, + "msg_id": "c51a8761-63bv-5l2a-5681-0ac16e140a1b", + "msg_timestamp": 1772587182234, + "nonce": "", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 3 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_hello.json b/tests/test_kook/data/kook_ws_event_hello.json new file mode 100644 index 0000000000..a6ab68d984 --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_hello.json @@ -0,0 +1,8 @@ +{ + "s": 1, + "d": { + "sessionId": "67d7d497-2b10-4849-9c2c-dda2fe58ed60", + "session_id": "67d7d497-2b10-4849-9c2c-dda2fe58ed60", + "code": 0 + } +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_message_with_card_1.json b/tests/test_kook/data/kook_ws_event_message_with_card_1.json new file mode 100644 index 0000000000..d4456651e5 --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_message_with_card_1.json @@ -0,0 +1,72 @@ +{ + "s": 0, + "d": { + "channel_type": "PERSON", + "type": 10, + "target_id": "2732467349811313213", + "author_id": "7324688132731983", + "content": "[{\"theme\":\"primary\",\"color\":\"\",\"size\":\"lg\",\"expand\":false,\"modules\":[{\"type\":\"audio\",\"cover\":\"\",\"duration\":0,\"title\":\"dancing_shot5.wav\",\"src\":\"https:\\/\\/img.kookapp.cn\\/attachments\\/2026-03\\/03\\/69a6841c3125d.wav\",\"external\":false,\"size\":443414,\"canDownload\":true,\"elements\":[]}],\"type\":\"card\"}]", + "extra": { + "type": 10, + "code": "1738914789hd8fd91098he809h19y491", + "author": { + "id": "7324688132731983", + "username": "Bot_Test", + "identify_num": "9561", + "online": true, + "os": "Websocket", + "status": 0, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "Bot_Test", + "roles": [], + "is_vip": false, + "vip_amp": false, + "bot": true, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "bot_status": 0, + "tag_info": { + "color": "#0096FF", + "bg_color": "#0096FF33", + "text": "机器人" + }, + "is_sys": false, + "client_id": "u109u3108h8ds0qsdaHUIOS", + "verified": false + }, + "visible_only": "", + "mention": [], + "mention_no_at": [], + "mention_all": false, + "mention_roles": [], + "mention_here": false, + "nav_channels": [], + "emoji": [], + "kmarkdown": { + "raw_content": "[音频]dancing_shot5.wav", + "mention_part": [], + "mention_role_part": [], + "channel_part": [] + }, + "editable": false, + "preview_content": "[音频]dancing_shot5.wav", + "preview_content_search": "[音频]dancing_shot5.wav", + "last_msg_content": "[音频]dancing_shot5.wav", + "send_msg_device": 0 + }, + "msg_id": "82c0b042-79b4-4066-a0f4-6c7a95c74e67", + "msg_timestamp": 1772587223043, + "nonce": "", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 5 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_message_with_card_2.json b/tests/test_kook/data/kook_ws_event_message_with_card_2.json new file mode 100644 index 0000000000..fd122391e3 --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_message_with_card_2.json @@ -0,0 +1,79 @@ +{ + "s": 0, + "d": { + "channel_type": "GROUP", + "type": 10, + "target_id": "2723723449021809", + "author_id": "1237198731983", + "content": "[{\"theme\":\"invisible\",\"color\":\"\",\"size\":\"lg\",\"expand\":false,\"modules\":[{\"type\":\"section\",\"mode\":\"left\",\"accessory\":null,\"text\":{\"type\":\"kmarkdown\",\"content\":\"(met)(met) (met)all(met) #hello \\\\*\\\\*world\\\\*\\\\* \",\"elements\":[]},\"elements\":[]},{\"type\":\"audio\",\"cover\":\"\",\"duration\":0,\"title\":\"dancing_shot5.wav\",\"src\":\"https:\\/\\/img.kookapp.cn\\/attachments\\/2026-03\\/03\\/69a6841c3125d.wav\",\"external\":false,\"size\":443414,\"canDownload\":true,\"elements\":[]},{\"type\":\"section\",\"mode\":\"left\",\"accessory\":null,\"text\":{\"type\":\"kmarkdown\",\"content\":\"\\n😆 \",\"elements\":[]},\"elements\":[]}],\"type\":\"card\"}]", + "msg_id": "ec4046e9-ea43-4907-9fc3-8c6d0bd4ec56", + "msg_timestamp": 1772600762056, + "nonce": "sy8f91y248yda", + "from_type": 1, + "extra": { + "type": 10, + "code": "", + "author": { + "id": "1237198731983", + "username": "some_username", + "identify_num": "4198", + "nickname": "some_username", + "bot": false, + "online": true, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "status": 1, + "roles": [ + 12783219731984 + ], + "os": "Websocket", + "banner": "", + "is_vip": false, + "vip_amp": false, + "nameplate": [], + "wealth_level": 0, + "is_sys": false + }, + "kmarkdown": { + "raw_content": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "mention_part": [ + { + "id": "", + "username": "Bot_Test", + "full_name": "Bot_Test#9561", + "avatar": "https://example.com", + "wealth_level": 0 + } + ], + "mention_role_part": [], + "channel_part": [] + }, + "last_msg_content": "some_username:@Bot_Test @ 全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "mention": [ + "" + ], + "mention_all": true, + "mention_here": false, + "guild_id": "28321098321093", + "guild_type": 0, + "channel_name": "聊天大厅", + "visible_only": "", + "mention_no_at": [], + "mention_roles": [], + "nav_channels": [], + "emoji": [], + "editable": true, + "preview_content": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "preview_content_search": "@Bot_Test @全体成员 #hello **world**[音频]dancing_shot5.wav😆", + "channel_type": 1, + "send_msg_device": 0 + } + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 5 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_ping.json b/tests/test_kook/data/kook_ws_event_ping.json new file mode 100644 index 0000000000..1b4e8e7cfd --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_ping.json @@ -0,0 +1,4 @@ +{ + "s": 2, + "sn": 0 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_pong.json b/tests/test_kook/data/kook_ws_event_pong.json new file mode 100644 index 0000000000..da07a35c6c --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_pong.json @@ -0,0 +1,3 @@ +{ + "s": 3 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_private_message.json b/tests/test_kook/data/kook_ws_event_private_message.json new file mode 100644 index 0000000000..13b0180282 --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_private_message.json @@ -0,0 +1,64 @@ +{ + "s": 0, + "d": { + "channel_type": "PERSON", + "type": 9, + "target_id": "7324688132731983", + "author_id": "2732467349811313213", + "content": "/help", + "extra": { + "type": 9, + "code": "1738914789hd8fd91098he809h19y491", + "author": { + "id": "2732467349811313213", + "username": "shuiping233", + "identify_num": "4198", + "online": true, + "os": "Websocket", + "status": 1, + "avatar": "https://example.com", + "vip_avatar": "https://example.com", + "banner": "", + "nickname": "shuiping233", + "roles": [], + "is_vip": false, + "vip_amp": false, + "bot": false, + "nameplate": [], + "kpm_vip": null, + "wealth_level": 0, + "decorations_id_map": null, + "is_sys": false + }, + "visible_only": "", + "mention": [], + "mention_no_at": [], + "mention_all": false, + "mention_roles": [], + "mention_here": false, + "nav_channels": [], + "kmarkdown": { + "raw_content": "/help", + "mention_part": [], + "mention_role_part": [], + "channel_part": [], + "spl": [] + }, + "emoji": [], + "preview_content": "", + "last_msg_content": "/help", + "send_msg_device": 0 + }, + "msg_id": "b0f57b9e-2cd4-4e07-8f0e-9c1ecfeaa837", + "msg_timestamp": 1772587358662, + "nonce": "6AwzUe5YjgyC8pAfxcLGjewL", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 19 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_private_system_message.json b/tests/test_kook/data/kook_ws_event_private_system_message.json new file mode 100644 index 0000000000..1a60adc4af --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_private_system_message.json @@ -0,0 +1,31 @@ +{ + "s": 0, + "d": { + "channel_type": "PERSON", + "type": 255, + "target_id": "7324688132731983", + "author_id": "1", + "content": "[系统消息]", + "extra": { + "type": "guild_member_offline", + "body": { + "user_id": "2732467349811313213", + "event_time": 1772589748914, + "guilds": [ + "78941897317309873120973" + ] + } + }, + "msg_id": "e91b4451-75ce-47bd-bda6-e4498ed8d30d", + "msg_timestamp": 1772589748933, + "nonce": "", + "from_type": 1 + }, + "extra": { + "verifyToken": "kW4FH_ASHio1hosd", + "encryptKey": "", + "callbackUrl": "", + "intent": 255 + }, + "sn": 1 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_reconnect_err.json b/tests/test_kook/data/kook_ws_event_reconnect_err.json new file mode 100644 index 0000000000..5346680f2e --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_reconnect_err.json @@ -0,0 +1,7 @@ +{ + "s": 5, + "d": { + "code": 40108, + "err": "Invalid SN" + } +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_resume.json b/tests/test_kook/data/kook_ws_event_resume.json new file mode 100644 index 0000000000..427f4ca2a9 --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_resume.json @@ -0,0 +1,4 @@ +{ + "s": 4, + "sn": 100 +} \ No newline at end of file diff --git a/tests/test_kook/data/kook_ws_event_resume_ack.json b/tests/test_kook/data/kook_ws_event_resume_ack.json new file mode 100644 index 0000000000..da8edab146 --- /dev/null +++ b/tests/test_kook/data/kook_ws_event_resume_ack.json @@ -0,0 +1,6 @@ +{ + "s": 6, + "d": { + "session_id": "xxxx-xxxxxx-xxx-xxx" + } +} \ No newline at end of file diff --git a/tests/test_kook/shared.py b/tests/test_kook/shared.py index 5c5c9da86c..f5ef18b8b8 100644 --- a/tests/test_kook/shared.py +++ b/tests/test_kook/shared.py @@ -1,4 +1,5 @@ from pathlib import Path -TEST_DATA_DIR = Path(__file__).parent / "data" +CURRENT_DIR = Path(__file__).parent +TEST_DATA_DIR = CURRENT_DIR / "data" diff --git a/tests/test_kook/test_kook_event.py b/tests/test_kook/test_kook_event.py index 253839506e..5fe73a510a 100644 --- a/tests/test_kook/test_kook_event.py +++ b/tests/test_kook/test_kook_event.py @@ -60,7 +60,7 @@ def mock_astrbot_message(): Image("test image"), "test image", OrderMessage( - 1, + index=1, text="test image", type=KookMessageType.IMAGE, ), @@ -70,7 +70,7 @@ def mock_astrbot_message(): Video("test video"), "test video", OrderMessage( - 1, + index=1, text="test video", type=KookMessageType.VIDEO, ), @@ -80,7 +80,7 @@ def mock_astrbot_message(): mock_file_message("test file"), "test file", OrderMessage( - 1, + index=1, text="test file", type=KookMessageType.FILE, ), @@ -90,8 +90,8 @@ def mock_astrbot_message(): mock_record_message("./tests/file.wav"), "./tests/file.wav", OrderMessage( - 1, - text='[{"type": "card", "modules": [{"src": "./tests/file.wav", "title": "./tests/file.wav", "type": "audio"}]}]', + index=1, + text='[{"type": "card", "modules": [{"type": "audio", "src": "./tests/file.wav", "title": "./tests/file.wav"}]}]', type=KookMessageType.CARD, ), None, @@ -100,7 +100,7 @@ def mock_astrbot_message(): Plain("test plain"), "test plain", OrderMessage( - 1, + index=1, text="test plain", type=KookMessageType.KMARKDOWN, ), @@ -110,7 +110,7 @@ def mock_astrbot_message(): At(qq="test at"), "test at", OrderMessage( - 1, + index=1, text="(met)test at(met)", type=KookMessageType.KMARKDOWN, ), @@ -120,7 +120,7 @@ def mock_astrbot_message(): AtAll(qq="all"), "test atAll", OrderMessage( - 1, + index=1, text="(met)all(met)", type=KookMessageType.KMARKDOWN, ), @@ -130,7 +130,7 @@ def mock_astrbot_message(): Reply(id="test reply"), "test reply", OrderMessage( - 1, + index=1, text="", type=KookMessageType.KMARKDOWN, reply_id="test reply", @@ -141,7 +141,7 @@ def mock_astrbot_message(): Json(data={"test": "json"}), "test json", OrderMessage( - 1, + index=1, text='[{"test": "json"}]', type=KookMessageType.CARD, ), @@ -159,7 +159,7 @@ async def test_kook_event_warp_message( input_message: BaseMessageComponent, upload_asset_return: str, expected_output: OrderMessage, - expected_error: type[Exception] | None, + expected_error: type[BaseException] | None, ): client = await mock_kook_client( upload_asset_return, @@ -185,39 +185,4 @@ async def test_kook_event_warp_message( result = await event._wrap_message(1, input_message) assert result == expected_output - - -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "message_chain,send_text_expected_output,expected_error", -# [ -# ( -# MessageChain( -# chain=[ -# Image(file="test image"), -# Plain(text="test plain"), -# ], -# ), -# "" -# ), -# ], -# ) -# async def test_kook_event_send(): -# client = await mock_kook_client( -# "", -# "", -# ) - -# event = KookEvent( -# "", -# mock_astrbot_message(), -# PlatformMetadata( -# name="test", -# id="test", -# description="test", -# ), -# "", -# client, -# ) - -# await event.send(message=mock_astrbot_message()) + \ No newline at end of file diff --git a/tests/test_kook/test_kook_types.py b/tests/test_kook/test_kook_types.py index 760e36c596..85c39622c1 100644 --- a/tests/test_kook/test_kook_types.py +++ b/tests/test_kook/test_kook_types.py @@ -16,6 +16,9 @@ InviteModule, KmarkdownElement, KookCardMessage, + KookMessageSignal, + KookModuleType, + KookWebsocketEvent, ParagraphStructure, PlainTextElement, SectionModule, @@ -77,7 +80,7 @@ def test_all_kook_card_type(): FileModule( src="https://img.kookapp.cn/attachments/2023-01/05/63b645851ff19.svg", title="test5", - type="file", + type=KookModuleType.FILE, ), CountdownModule( endTime=1772343427360, @@ -105,3 +108,41 @@ def test_all_kook_card_type(): ], ).to_json(indent=4, ensure_ascii=False) assert json_output == expect_json_data + +@pytest.mark.parametrize( + "expected_json_data_filename", + [ + ("kook_ws_event_group_message.json"), + ("kook_ws_event_hello.json"), + ("kook_ws_event_message_with_card_1.json"), + ("kook_ws_event_message_with_card_2.json"), + ("kook_ws_event_ping.json"), + ("kook_ws_event_pong.json"), + ("kook_ws_event_private_message.json"), + ("kook_ws_event_private_system_message.json"), + ("kook_ws_event_reconnect_err.json"), + ("kook_ws_event_resume_ack.json"), + ("kook_ws_event_resume.json"), + + ], +) +def test_websocket_event_type_parse(expected_json_data_filename:str): + expected_json_data_str =(TEST_DATA_DIR / expected_json_data_filename).read_text(encoding="utf-8") + event = KookWebsocketEvent.from_json( + expected_json_data_str, + ) + event_dict = event.to_dict(mode="json",exclude_unset=True,exclude_none=False) + assert event_dict == json.loads(expected_json_data_str) + + +def test_websocket_event_create(): + ping_data = KookWebsocketEvent( + signal=KookMessageSignal.PING, + data=None, + sn=0, + ) + assert ping_data.to_dict(mode="json")== { + "s": KookMessageSignal.PING.value, + "sn": 0, + } + \ No newline at end of file From 4b8b76cd466a18f60a7bcd24f394a726021ca1b8 Mon Sep 17 00:00:00 2001 From: shuiping233 <1944680304@qq.com> Date: Wed, 4 Mar 2026 16:05:01 +0800 Subject: [PATCH 4/4] =?UTF-8?q?format:=20=E4=BD=BF=E7=94=A8StrEnum?= =?UTF-8?q?=E6=9B=BF=E6=8D=A2kook=E9=80=82=E9=85=8D=E5=99=A8=E4=B8=AD?= =?UTF-8?q?=E7=9A=84(str,enum)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/platform/sources/kook/kook_types.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/astrbot/core/platform/sources/kook/kook_types.py b/astrbot/core/platform/sources/kook/kook_types.py index 5efaf2a14c..7256fbbd4a 100644 --- a/astrbot/core/platform/sources/kook/kook_types.py +++ b/astrbot/core/platform/sources/kook/kook_types.py @@ -1,5 +1,5 @@ import json -from enum import Enum, IntEnum +from enum import IntEnum, StrEnum from typing import Annotated, Any, Literal from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -36,7 +36,7 @@ class KookMessageType(IntEnum): SYSTEM = 255 -class KookModuleType(str, Enum): +class KookModuleType(StrEnum): PLAIN_TEXT = "plain-text" KMARKDOWN = "kmarkdown" IMAGE = "image" @@ -311,7 +311,7 @@ class KookMessageSignal(IntEnum): """server->client resume ack""" -class KookChannelType(str, Enum): +class KookChannelType(StrEnum): GROUP = "GROUP" PERSON = "PERSON" BROADCAST = "BROADCAST"