diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 5aeef1eff0..8bccae959a 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -46,6 +46,8 @@ jobs: include: - language: python build-mode: none + - language: javascript-typescript + build-mode: none # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' # Use `c-cpp` to analyze code written in C, C++ or both # Use 'java-kotlin' to analyze code written in Java, Kotlin or both diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index f0019ee7e6..bd7beceb1c 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -23,12 +23,13 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 + with: + python-version: "3.12" - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install pytest pytest-asyncio pytest-cov - pip install --editable . + python -m pip install --upgrade pip uv + uv sync --group dev - name: Run tests run: | @@ -37,7 +38,7 @@ jobs: mkdir -p data/temp export TESTING=true export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }} - pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG + uv run pytest --cov=astrbot -v -o log_cli=true -o log_level=DEBUG - name: Upload results to Codecov uses: codecov/codecov-action@v5 diff --git a/.github/workflows/dashboard_ci.yml b/.github/workflows/dashboard_ci.yml index 46d2fea735..921fc49922 100644 --- a/.github/workflows/dashboard_ci.yml +++ b/.github/workflows/dashboard_ci.yml @@ -13,18 +13,23 @@ jobs: - name: Checkout repository uses: actions/checkout@v6 + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.28.2 + - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '24.13.0' + cache: "pnpm" + cache-dependency-path: dashboard/pnpm-lock.yaml - - name: npm install, build + - name: Install and build run: | - cd dashboard - npm install pnpm -g - pnpm install - pnpm i --save-dev @types/markdown-it - pnpm run build + pnpm --dir dashboard install --frozen-lockfile + pnpm --dir dashboard run typecheck + pnpm --dir dashboard run build - name: Inject Commit SHA id: get_sha diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 18c8d49269..6300a65a03 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -25,6 +25,18 @@ jobs: fetch-depth: 1 fetch-tag: true + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.28.2 + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '24.13.0' + cache: "pnpm" + cache-dependency-path: dashboard/pnpm-lock.yaml + - name: Check for new commits today if: github.event_name == 'schedule' id: check-commits @@ -46,12 +58,10 @@ jobs: - name: Build Dashboard run: | - cd dashboard - npm install - npm run build - mkdir -p dist/assets - echo $(git rev-parse HEAD) > dist/assets/version - cd .. + pnpm --dir dashboard install --frozen-lockfile + pnpm --dir dashboard run build + mkdir -p dashboard/dist/assets + echo $(git rev-parse HEAD) > dashboard/dist/assets/version mkdir -p data cp -r dashboard/dist data/ @@ -123,6 +133,18 @@ jobs: fetch-depth: 1 fetch-tag: true + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 10.28.2 + + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '24.13.0' + cache: "pnpm" + cache-dependency-path: dashboard/pnpm-lock.yaml + - name: Get latest tag (only on manual trigger) id: get-latest-tag if: github.event_name == 'workflow_dispatch' @@ -153,12 +175,10 @@ jobs: - name: Build Dashboard run: | - cd dashboard - npm install - npm run build - mkdir -p dist/assets - echo $(git rev-parse HEAD) > dist/assets/version - cd .. + pnpm --dir dashboard install --frozen-lockfile + pnpm --dir dashboard run build + mkdir -p dashboard/dist/assets + echo $(git rev-parse HEAD) > dashboard/dist/assets/version mkdir -p data cp -r dashboard/dist data/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d434d04950..c561f01733 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,6 +18,29 @@ permissions: contents: write jobs: + verify-core: + name: Verify Core Quality Gate + runs-on: ubuntu-24.04 + steps: + - name: Checkout repository + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ inputs.ref || github.ref }} + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + + - name: Install uv + shell: bash + run: python -m pip install uv + + - name: Run local PR gate checks + shell: bash + run: make pr-test-neo + build-dashboard: name: Build Dashboard runs-on: ubuntu-24.04 @@ -85,7 +108,8 @@ jobs: VERSION_TAG: ${{ steps.tag.outputs.tag }} shell: bash run: | - curl https://rclone.org/install.sh | sudo bash + sudo apt-get update + sudo apt-get install -y rclone mkdir -p ~/.config/rclone cat < ~/.config/rclone/rclone.conf @@ -106,6 +130,7 @@ jobs: name: Publish GitHub Release runs-on: ubuntu-24.04 needs: + - verify-core - build-dashboard steps: - name: Checkout repository @@ -195,7 +220,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.12" - name: Install uv shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8611e26984..4c2a126e9d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.14.1 + rev: v0.15.1 hooks: # Run the linter. - id: ruff-check @@ -22,4 +22,4 @@ repos: rev: v3.21.0 hooks: - id: pyupgrade - args: [--py310-plus] + args: [--py312-plus] diff --git a/Dockerfile b/Dockerfile index 992060d6ea..544c6d6ced 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,10 +13,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ bash \ ffmpeg \ curl \ - gnupg \ git \ - && curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \ - && apt-get install -y --no-install-recommends nodejs \ + nodejs \ + npm \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* diff --git a/README.md b/README.md index 4d9bf68c09..a000062577 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_fr.md b/README_fr.md index 5614f6bec3..15c9f91aa4 100644 --- a/README_fr.md +++ b/README_fr.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_ja.md b/README_ja.md index 56d369a425..550fea6fa2 100644 --- a/README_ja.md +++ b/README_ja.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_ru.md b/README_ru.md index c4d3b2fd0e..99617a998e 100644 --- a/README_ru.md +++ b/README_ru.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_zh-TW.md b/README_zh-TW.md index 8d1ea9b92c..8873bdf185 100644 --- a/README_zh-TW.md +++ b/README_zh-TW.md @@ -19,7 +19,7 @@
-python +python zread Docker pull diff --git a/README_zh.md b/README_zh.md index 1c3b40d0dd..39f324613e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -17,7 +17,7 @@
-python +python zread Docker pull diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index c06dda3500..2764935d02 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -1,6 +1,6 @@ import shutil import tempfile -from enum import Enum +from enum import StrEnum from io import BytesIO from pathlib import Path from zipfile import ZipFile @@ -12,7 +12,7 @@ from .version_comparator import VersionComparator -class PluginStatus(str, Enum): +class PluginStatus(StrEnum): INSTALLED = "installed" NEED_UPDATE = "needs-update" NOT_INSTALLED = "not-installed" diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py index d6e2e7cb41..f4606b6da5 100644 --- a/astrbot/core/agent/agent.py +++ b/astrbot/core/agent/agent.py @@ -1,13 +1,12 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import Any from .hooks import BaseAgentRunHooks -from .run_context import TContext from .tool import FunctionTool @dataclass -class Agent(Generic[TContext]): +class Agent[TContext]: name: str instructions: str | None = None tools: list[str | FunctionTool] | None = None diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 8475009d3f..01fc5159c6 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,11 +1,8 @@ -from typing import Generic - from .agent import Agent -from .run_context import TContext from .tool import FunctionTool -class HandoffTool(FunctionTool, Generic[TContext]): +class HandoffTool[TContext](FunctionTool): """Handoff tool for delegating tasks to another agent.""" def __init__( diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 74ca6335b3..451a957539 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -1,14 +1,12 @@ -from typing import Generic - import mcp from astrbot.core.agent.tool import FunctionTool from astrbot.core.provider.entities import LLMResponse -from .run_context import ContextWrapper, TContext +from .run_context import ContextWrapper -class BaseAgentRunHooks(Generic[TContext]): +class BaseAgentRunHooks[TContext]: async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 18f4d47e04..5c4c19fade 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,7 +2,6 @@ import logging from contextlib import AsyncExitStack from datetime import timedelta -from typing import Generic from tenacity import ( before_sleep_log, @@ -16,7 +15,6 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.utils.log_pipe import LogPipe -from .run_context import TContext from .tool import FunctionTool try: @@ -101,7 +99,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return True, "" return False, f"HTTP {response.status}: {response.reason}" - except asyncio.TimeoutError: + except TimeoutError: return False, f"Connection timeout: {timeout} seconds" except Exception as e: return False, f"{e!s}" @@ -360,7 +358,7 @@ async def cleanup(self) -> None: self.running_event.set() -class MCPTool(FunctionTool, Generic[TContext]): +class MCPTool[TContext](FunctionTool): """A function tool that calls an MCP service.""" def __init__( diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 21e7964335..e1e3ff8e39 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -7,7 +7,7 @@ from ..hooks import BaseAgentRunHooks from ..response import AgentResponse -from ..run_context import ContextWrapper, TContext +from ..run_context import ContextWrapper class AgentState(Enum): @@ -19,7 +19,7 @@ class AgentState(Enum): ERROR = auto() # Error state -class BaseAgentRunner(T.Generic[TContext]): +class BaseAgentRunner[TContext]: @abc.abstractmethod async def reset( self, diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py index a8300bb711..0d7fab2070 100644 --- a/astrbot/core/agent/runners/coze/coze_agent_runner.py +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -1,7 +1,7 @@ import base64 import json -import sys import typing as T +from typing import override import astrbot.core.message.components as Comp from astrbot import logger @@ -18,11 +18,6 @@ from ..base import AgentResponse, AgentState, BaseAgentRunner from .coze_api_client import CozeAPIClient -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class CozeAgentRunner(BaseAgentRunner[TContext]): """Coze Agent Runner""" @@ -251,7 +246,7 @@ async def _execute_coze_request(self): conversation_id=conversation_id, auto_save_history=self.auto_save_history, stream=True, - timeout=self.timeout, + timeout_seconds=self.timeout, ): event_type = chunk.get("event") data = chunk.get("data", {}) diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index f5799dfbb7..03dbe64cc3 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -2,6 +2,7 @@ import io import json from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any import aiohttp @@ -90,7 +91,7 @@ async def upload_file( logger.debug(f"[Coze] 图片上传成功,file_id: {file_id}") return file_id - except asyncio.TimeoutError: + except TimeoutError: logger.error("文件上传超时") raise Exception("文件上传超时") except Exception as e: @@ -128,7 +129,7 @@ async def chat_messages( conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, - timeout: float = 120, + timeout_seconds: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: """发送聊天消息并返回流式响应 @@ -139,7 +140,7 @@ async def chat_messages( conversation_id: 会话ID auto_save_history: 是否自动保存历史 stream: 是否流式响应 - timeout: 超时时间 + timeout_seconds: 超时时间 """ session = await self._ensure_session() @@ -166,7 +167,7 @@ async def chat_messages( url, json=payload, params=params, - timeout=aiohttp.ClientTimeout(total=timeout), + timeout=aiohttp.ClientTimeout(total=timeout_seconds), ) as response: if response.status == 401: raise Exception("Coze API 认证失败,请检查 API Key 是否正确") @@ -203,8 +204,8 @@ async def chat_messages( except json.JSONDecodeError: event_data = {"content": data_str} - except asyncio.TimeoutError: - raise Exception(f"Coze API 流式请求超时 ({timeout}秒)") + except TimeoutError: + raise Exception(f"Coze API 流式请求超时 ({timeout_seconds}秒)") except Exception as e: raise Exception(f"Coze API 流式请求失败: {e!s}") @@ -236,7 +237,7 @@ async def clear_context(self, conversation_id: str): except json.JSONDecodeError: raise Exception("Coze API 返回非JSON格式") - except asyncio.TimeoutError: + except TimeoutError: raise Exception("Coze API 请求超时") except aiohttp.ClientError as e: raise Exception(f"Coze API 请求失败: {e!s}") @@ -294,8 +295,7 @@ async def test_coze_api_client() -> None: client = CozeAPIClient(api_key=api_key) try: - with open("README.md", "rb") as f: - file_data = f.read() + file_data = await asyncio.to_thread(Path("README.md").read_bytes) file_id = await client.upload_file(file_data) print(f"Uploaded file_id: {file_id}") async for event in client.chat_messages( diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 1aaf6e3b9c..080e627d52 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -2,9 +2,9 @@ import functools import queue import re -import sys import threading import typing as T +from typing import override from dashscope import Application from dashscope.app.application_response import ApplicationResponse @@ -22,11 +22,6 @@ from ...run_context import ContextWrapper, TContext from ..base import AgentResponse, AgentState, BaseAgentRunner -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class DashscopeAgentRunner(BaseAgentRunner[TContext]): """Dashscope Agent Runner""" diff --git a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py index 50ec7c8262..9e4a114719 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py @@ -1,10 +1,10 @@ import asyncio import hashlib import json -import sys import typing as T from collections import deque from dataclasses import dataclass, field +from typing import override from uuid import uuid4 import astrbot.core.message.components as Comp @@ -40,11 +40,6 @@ get_message_id, ) -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class DeerFlowAgentRunner(BaseAgentRunner[TContext]): """DeerFlow Agent Runner via LangGraph HTTP API.""" @@ -378,7 +373,9 @@ async def _ensure_thread_id(self, session_id: str) -> str: if thread_id: return thread_id - thread = await self.api_client.create_thread(timeout=min(30, self.timeout)) + thread = await self.api_client.create_thread( + timeout_seconds=min(30, self.timeout) + ) thread_id = thread.get("thread_id", "") if not thread_id: raise Exception( @@ -639,7 +636,7 @@ async def _execute_deerflow_request(self): async for event in self.api_client.stream_run( thread_id=thread_id, payload=payload, - timeout=self.timeout, + timeout_seconds=self.timeout, ): event_type = event.get("event") data = event.get("data") @@ -666,7 +663,7 @@ async def _execute_deerflow_request(self): if event_type == "end": break - except (asyncio.TimeoutError, TimeoutError): + except TimeoutError: logger.warning( "DeerFlow stream timed out after %ss for thread_id=%s; returning partial result.", self.timeout, diff --git a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py index 37a23f2432..4ae9432e09 100644 --- a/astrbot/core/agent/runners/deerflow/deerflow_api_client.py +++ b/astrbot/core/agent/runners/deerflow/deerflow_api_client.py @@ -139,7 +139,7 @@ async def __aexit__( ) -> None: await self.close() - async def create_thread(self, timeout: float = 20) -> dict[str, Any]: + async def create_thread(self, timeout_seconds: float = 20) -> dict[str, Any]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads" payload = {"metadata": {}} @@ -147,7 +147,7 @@ async def create_thread(self, timeout: float = 20) -> dict[str, Any]: url, json=payload, headers=self.headers, - timeout=timeout, + timeout=timeout_seconds, proxy=self.proxy, ) as resp: if resp.status not in (200, 201): @@ -161,7 +161,7 @@ async def stream_run( self, thread_id: str, payload: dict[str, Any], - timeout: float = 120, + timeout_seconds: float = 120, ) -> AsyncGenerator[dict[str, Any], None]: session = self._get_session() url = f"{self.api_base}/api/langgraph/threads/{thread_id}/runs/stream" @@ -183,9 +183,9 @@ async def stream_run( # Use socket read timeout so active heartbeats/chunks can keep the stream alive. stream_timeout = ClientTimeout( total=None, - connect=min(timeout, 30), - sock_connect=min(timeout, 30), - sock_read=timeout, + connect=min(timeout_seconds, 30), + sock_connect=min(timeout_seconds, 30), + sock_read=timeout_seconds, ) async with session.post( url, diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index 93f8d3570d..1630ebf08e 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -1,7 +1,7 @@ import base64 import os -import sys import typing as T +from typing import override import astrbot.core.message.components as Comp from astrbot.core import logger, sp @@ -19,11 +19,6 @@ from ..base import AgentResponse, AgentState, BaseAgentRunner from .dify_api_client import DifyAPIClient -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class DifyAgentRunner(BaseAgentRunner[TContext]): """Dify Agent Runner""" @@ -176,7 +171,7 @@ async def _execute_dify_request(self): user=session_id, conversation_id=conversation_id, files=files_payload, - timeout=self.timeout, + timeout_seconds=self.timeout, ): logger.debug(f"dify resp chunk: {chunk}") if chunk["event"] == "message" or chunk["event"] == "agent_message": @@ -216,7 +211,7 @@ async def _execute_dify_request(self): }, user=session_id, files=files_payload, - timeout=self.timeout, + timeout_seconds=self.timeout, ): logger.debug(f"dify workflow resp chunk: {chunk}") match chunk["event"]: diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index 26da6dfe9a..db7b923fcd 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -1,6 +1,8 @@ +import asyncio import codecs import json from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any from aiohttp import ClientResponse, ClientSession, FormData @@ -47,20 +49,20 @@ async def chat_messages( response_mode: str = "streaming", conversation_id: str = "", files: list[dict[str, Any]] | None = None, - timeout: float = 60, + timeout_seconds: float = 60, ) -> AsyncGenerator[dict[str, Any], None]: if files is None: files = [] url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") - payload.pop("timeout") + payload.pop("timeout_seconds") logger.info(f"chat_messages payload: {payload}") async with self.session.post( url, json=payload, headers=self.headers, - timeout=timeout, + timeout=timeout_seconds, ) as resp: if resp.status != 200: text = await resp.text() @@ -76,20 +78,20 @@ async def workflow_run( user: str, response_mode: str = "streaming", files: list[dict[str, Any]] | None = None, - timeout: float = 60, + timeout_seconds: float = 60, ): if files is None: files = [] url = f"{self.api_base}/workflows/run" payload = locals() payload.pop("self") - payload.pop("timeout") + payload.pop("timeout_seconds") logger.info(f"workflow_run payload: {payload}") async with self.session.post( url, json=payload, headers=self.headers, - timeout=timeout, + timeout=timeout_seconds, ) as resp: if resp.status != 200: text = await resp.text() @@ -134,14 +136,13 @@ async def file_upload( # 使用文件路径 import os - with open(file_path, "rb") as f: - file_content = f.read() - form.add_field( - "file", - file_content, - filename=os.path.basename(file_path), - content_type=mime_type or "application/octet-stream", - ) + file_content = await asyncio.to_thread(Path(file_path).read_bytes) + form.add_field( + "file", + file_content, + filename=os.path.basename(file_path), + content_type=mime_type or "application/octet-stream", + ) else: raise ValueError("file_path 和 file_data 不能同时为 None") diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 743b280070..cc231be693 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,10 +1,10 @@ import asyncio import copy -import sys import time import traceback import typing as T from dataclasses import dataclass, field +from typing import override from mcp.types import ( BlobResourceContents, @@ -44,11 +44,6 @@ from ..tool_executor import BaseFunctionToolExecutor from .base import AgentResponse, AgentState, BaseAgentRunner -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - @dataclass(slots=True) class _HandleFunctionToolsResult: diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index c2536708e6..98f354ae41 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,6 @@ import copy from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Any, Generic +from typing import Any import jsonschema import mcp @@ -10,7 +10,7 @@ from astrbot.core.message.message_event_result import MessageEventResult -from .run_context import ContextWrapper, TContext +from .run_context import ContextWrapper ParametersType = dict[str, Any] ToolExecResult = str | mcp.types.CallToolResult @@ -38,7 +38,7 @@ def validate_parameters(self) -> "ToolSchema": @dataclass -class FunctionTool(ToolSchema, Generic[TContext]): +class FunctionTool[TContext](ToolSchema): """A callable tool, for function calling.""" handler: ( diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 2704119d4f..8708fd97d2 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,13 +1,13 @@ from collections.abc import AsyncGenerator -from typing import Any, Generic +from typing import Any import mcp -from .run_context import ContextWrapper, TContext +from .run_context import ContextWrapper from .tool import FunctionTool -class BaseFunctionToolExecutor(Generic[TContext]): +class BaseFunctionToolExecutor[TContext]: @classmethod async def execute( cls, diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index dd65f92e69..f7178fe3bf 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -3,6 +3,7 @@ import time import traceback from collections.abc import AsyncGenerator +from pathlib import Path from astrbot.core import logger from astrbot.core.agent.message import Message @@ -509,8 +510,7 @@ async def _simulated_stream_tts( audio_path = await tts_provider.get_audio(text) if audio_path: - with open(audio_path, "rb") as f: - audio_data = f.read() + audio_data = await asyncio.to_thread(Path(audio_path).read_bytes) await audio_queue.put((text, audio_data)) except Exception as e: logger.error( diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 0dc8b9eeb7..be51ced720 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -625,7 +625,7 @@ async def _execute_local( exc_info=True, ) yield None - except asyncio.TimeoutError: + except TimeoutError: raise Exception( f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.", ) diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 2e0d8b0aa7..933346d1fe 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -1,3 +1,4 @@ +import asyncio import base64 import json import os @@ -241,7 +242,7 @@ async def _resolve_path_from_sandbox( bool: indicates whether the file was downloaded from sandbox. """ - if os.path.exists(path): + if await asyncio.to_thread(os.path.exists, path): return path, False # Try to check if the file exists in the sandbox diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index a922375998..d65ac7a843 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -4,11 +4,12 @@ 导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。 """ +import asyncio import hashlib import json import os import zipfile -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -83,7 +84,7 @@ async def export_all( output_dir = get_astrbot_backups_path() # 确保输出目录存在 - Path(output_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(Path(output_dir).mkdir, parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") zip_filename = f"astrbot_backup_{timestamp}.zip" @@ -160,9 +161,10 @@ async def export_all( # 3. 导出配置文件 if progress_callback: await progress_callback("config", 0, 100, "正在导出配置文件...") - if os.path.exists(self.config_path): - with open(self.config_path, encoding="utf-8") as f: - config_content = f.read() + if await asyncio.to_thread(os.path.exists, self.config_path): + config_content = await asyncio.to_thread( + Path(self.config_path).read_text, encoding="utf-8" + ) zf.writestr("config/cmd_config.json", config_content) self._add_checksum("config/cmd_config.json", config_content) if progress_callback: @@ -199,7 +201,7 @@ async def export_all( except Exception as e: logger.error(f"备份导出失败: {e}") # 清理失败的文件 - if os.path.exists(zip_path): + if await asyncio.to_thread(os.path.exists, zip_path): os.remove(zip_path) raise @@ -317,7 +319,7 @@ async def _export_directories( for dir_name, dir_path in backup_directories.items(): full_path = Path(dir_path) - if not full_path.exists(): + if not await asyncio.to_thread(full_path.exists): logger.debug(f"目录不存在,跳过: {full_path}") continue @@ -362,7 +364,7 @@ async def _export_attachments( for attachment in attachments: try: file_path = attachment.get("path", "") - if file_path and os.path.exists(file_path): + if file_path and await asyncio.to_thread(os.path.exists, file_path): # 使用 attachment_id 作为文件名 attachment_id = attachment.get("attachment_id", "") ext = os.path.splitext(file_path)[1] @@ -446,7 +448,7 @@ def _generate_manifest( manifest = { "version": BACKUP_MANIFEST_VERSION, "astrbot_version": VERSION, - "exported_at": datetime.now(timezone.utc).isoformat(), + "exported_at": datetime.now(UTC).isoformat(), "origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传 "schema_version": { "main_db": "v4", diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index b51c7d9560..5362ab3cbf 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -7,12 +7,13 @@ - 版本匹配时也需要用户确认 """ +import asyncio import json import os import shutil import zipfile from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -364,7 +365,7 @@ async def import_all( """ result = ImportResult() - if not os.path.exists(zip_path): + if not await asyncio.to_thread(os.path.exists, zip_path): result.add_error(f"备份文件不存在: {zip_path}") return result @@ -446,12 +447,13 @@ async def import_all( try: config_content = zf.read("config/cmd_config.json") # 备份现有配置 - if os.path.exists(self.config_path): + if await asyncio.to_thread(os.path.exists, self.config_path): backup_path = f"{self.config_path}.bak" shutil.copy2(self.config_path, backup_path) - with open(self.config_path, "wb") as f: - f.write(config_content) + await asyncio.to_thread( + Path(self.config_path).write_bytes, config_content + ) result.imported_files["config"] = 1 except Exception as e: result.add_warning(f"导入配置文件失败: {e}") @@ -675,9 +677,9 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: if isinstance(value, datetime): dt = value if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) + dt = dt.replace(tzinfo=UTC) else: - dt = dt.astimezone(timezone.utc) + dt = dt.astimezone(UTC) return dt.isoformat() if isinstance(value, str): timestamp = value.strip() @@ -688,9 +690,9 @@ def _normalize_platform_stats_timestamp(self, value: Any) -> str | None: try: dt = datetime.fromisoformat(timestamp) if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) + dt = dt.replace(tzinfo=UTC) else: - dt = dt.astimezone(timezone.utc) + dt = dt.astimezone(UTC) return dt.isoformat() except ValueError: return None @@ -753,8 +755,8 @@ async def _import_knowledge_bases( if faiss_path in zf.namelist(): try: target_path = kb_dir / "index.faiss" - with zf.open(faiss_path) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(faiss_path) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) except Exception as e: result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}") @@ -766,8 +768,8 @@ async def _import_knowledge_bases( rel_path = name[len(media_prefix) :] target_path = kb_dir / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(name) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) except Exception as e: result.add_warning(f"导入媒体文件 {name} 失败: {e}") @@ -828,8 +830,8 @@ async def _import_attachments( target_path = attachments_dir / os.path.basename(name) target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(name) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) count += 1 except Exception as e: logger.warning(f"导入附件 {name} 失败: {e}") @@ -885,15 +887,15 @@ async def _import_directories( continue # 备份现有目录(如果存在) - if target_dir.exists(): + if await asyncio.to_thread(target_dir.exists): backup_path = Path(f"{target_dir}.bak") - if backup_path.exists(): + if await asyncio.to_thread(backup_path.exists): shutil.rmtree(backup_path) shutil.move(str(target_dir), str(backup_path)) logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}") # 创建目标目录 - target_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True) # 解压文件 for name in dir_files: @@ -906,8 +908,8 @@ async def _import_directories( target_path = target_dir / rel_path target_path.parent.mkdir(parents=True, exist_ok=True) - with zf.open(name) as src, open(target_path, "wb") as dst: - dst.write(src.read()) + with zf.open(name) as src: + await asyncio.to_thread(target_path.write_bytes, src.read()) file_count += 1 except Exception as e: result.add_warning(f"导入文件 {name} 失败: {e}") diff --git a/astrbot/core/computer/booters/bay_manager.py b/astrbot/core/computer/booters/bay_manager.py index 24fa379e82..da429eb652 100644 --- a/astrbot/core/computer/booters/bay_manager.py +++ b/astrbot/core/computer/booters/bay_manager.py @@ -118,10 +118,10 @@ async def ensure_running(self) -> str: return f"http://127.0.0.1:{self._host_port}" - async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: + async def wait_healthy(self, timeout_seconds: int = HEALTH_TIMEOUT_S) -> None: """Block until Bay's ``/health`` endpoint returns 200.""" url = f"http://127.0.0.1:{self._host_port}/health" - deadline = asyncio.get_event_loop().time() + timeout + deadline = asyncio.get_event_loop().time() + timeout_seconds last_error: str = "" async with aiohttp.ClientSession() as session: @@ -140,7 +140,7 @@ async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: await asyncio.sleep(HEALTH_POLL_INTERVAL_S) raise TimeoutError( - f"Bay did not become healthy within {timeout}s (last error: {last_error})" + f"Bay did not become healthy within {timeout_seconds}s (last error: {last_error})" ) async def read_credentials(self) -> str: diff --git a/astrbot/core/computer/booters/boxlite.py b/astrbot/core/computer/booters/boxlite.py index 70064fdd48..337f5a68e2 100644 --- a/astrbot/core/computer/booters/boxlite.py +++ b/astrbot/core/computer/booters/boxlite.py @@ -1,5 +1,6 @@ import asyncio import random +from pathlib import Path from typing import Any import aiohttp @@ -46,8 +47,7 @@ async def upload_file(self, path: str, remote_path: str) -> dict: try: # Read file content - with open(path, "rb") as f: - file_content = f.read() + file_content = await asyncio.to_thread(Path(path).read_bytes) # Create multipart form data data = aiohttp.FormData() @@ -88,7 +88,7 @@ async def upload_file(self, path: str, remote_path: str) -> dict: "error": f"Connection error: {str(e)}", "message": "File upload failed", } - except asyncio.TimeoutError: + except TimeoutError: return { "success": False, "error": "File upload timeout", diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index a80ef0da28..011ac45f42 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -59,7 +59,7 @@ async def exec( command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout: int | None = 30, + timeout_seconds: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: @@ -87,7 +87,7 @@ def _run() -> dict[str, Any]: shell=shell, cwd=working_dir, env=run_env, - timeout=timeout, + timeout=timeout_seconds, capture_output=True, text=True, ) @@ -106,14 +106,14 @@ async def exec( self, code: str, kernel_id: str | None = None, - timeout: int = 30, + timeout_seconds: int = 30, silent: bool = False, ) -> dict[str, Any]: def _run() -> dict[str, Any]: try: result = subprocess.run( [os.environ.get("PYTHON", sys.executable), "-c", code], - timeout=timeout, + timeout=timeout_seconds, capture_output=True, text=True, ) diff --git a/astrbot/core/computer/booters/shipyard_neo.py b/astrbot/core/computer/booters/shipyard_neo.py index 6304696ad2..6c6f62bb5f 100644 --- a/astrbot/core/computer/booters/shipyard_neo.py +++ b/astrbot/core/computer/booters/shipyard_neo.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import os import shlex +from pathlib import Path from typing import Any, cast from astrbot.api import logger @@ -33,11 +35,11 @@ async def exec( self, code: str, kernel_id: str | None = None, - timeout: int = 30, + timeout_seconds: int = 30, silent: bool = False, ) -> dict[str, Any]: _ = kernel_id # Bay runtime does not expose kernel_id in current SDK. - result = await self._sandbox.python.exec(code, timeout=timeout) + result = await self._sandbox.python.exec(code, timeout=timeout_seconds) payload = _maybe_model_dump(result) output_text = payload.get("output", "") or "" @@ -75,7 +77,7 @@ async def exec( command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout: int | None = 30, + timeout_seconds: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: @@ -99,7 +101,7 @@ async def exec( result = await self._sandbox.shell.exec( run_command, - timeout=timeout or 30, + timeout=timeout_seconds or 30, cwd=cwd, ) payload = _maybe_model_dump(result) @@ -192,7 +194,7 @@ def __init__(self, sandbox: Any) -> None: async def exec( self, cmd: str, - timeout: int = 30, + timeout_seconds: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, @@ -200,7 +202,7 @@ async def exec( ) -> dict[str, Any]: result = await self._sandbox.browser.exec( cmd, - timeout=timeout, + timeout=timeout_seconds, description=description, tags=tags, learn=learn, @@ -211,7 +213,7 @@ async def exec( async def exec_batch( self, commands: list[str], - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, @@ -220,7 +222,7 @@ async def exec_batch( ) -> dict[str, Any]: result = await self._sandbox.browser.exec_batch( commands, - timeout=timeout, + timeout=timeout_seconds, stop_on_error=stop_on_error, description=description, tags=tags, @@ -232,7 +234,7 @@ async def exec_batch( async def run_skill( self, skill_key: str, - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, @@ -240,7 +242,7 @@ async def run_skill( ) -> dict[str, Any]: result = await self._sandbox.browser.run_skill( skill_key=skill_key, - timeout=timeout, + timeout=timeout_seconds, stop_on_error=stop_on_error, include_trace=include_trace, description=description, @@ -468,8 +470,7 @@ def browser(self) -> BrowserComponent: async def upload_file(self, path: str, file_name: str) -> dict: if self._sandbox is None: raise RuntimeError("ShipyardNeoBooter is not initialized.") - with open(path, "rb") as f: - content = f.read() + content = await asyncio.to_thread(Path(path).read_bytes) remote_path = file_name.lstrip("/") await self._sandbox.filesystem.upload(remote_path, content) logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path) @@ -486,8 +487,7 @@ async def download_file(self, remote_path: str, local_path: str) -> None: local_dir = os.path.dirname(local_path) if local_dir: os.makedirs(local_dir, exist_ok=True) - with open(local_path, "wb") as f: - f.write(cast(bytes, content)) + await asyncio.to_thread(Path(local_path).write_bytes, cast(bytes, content)) logger.info( "[Computer] File downloaded from Neo sandbox: %s -> %s", remote_path, diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index aa10d125e7..1adaeae08c 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -1,3 +1,4 @@ +import asyncio import json import os import shutil @@ -372,12 +373,12 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: splitting into `apply` and `scan` phases. """ skills_root = Path(get_astrbot_skills_path()) - if not skills_root.is_dir(): + if not await asyncio.to_thread(skills_root.is_dir): return local_skill_dirs = _list_local_skill_dirs(skills_root) temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) zip_base = temp_dir / "skills_bundle" zip_path = zip_base.with_suffix(".zip") diff --git a/astrbot/core/computer/olayer/browser.py b/astrbot/core/computer/olayer/browser.py index aa69f4501d..5bc40a4462 100644 --- a/astrbot/core/computer/olayer/browser.py +++ b/astrbot/core/computer/olayer/browser.py @@ -11,7 +11,7 @@ class BrowserComponent(Protocol): async def exec( self, cmd: str, - timeout: int = 30, + timeout_seconds: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, @@ -23,7 +23,7 @@ async def exec( async def exec_batch( self, commands: list[str], - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, @@ -36,7 +36,7 @@ async def exec_batch( async def run_skill( self, skill_key: str, - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, diff --git a/astrbot/core/computer/olayer/python.py b/astrbot/core/computer/olayer/python.py index 6255041463..09bf497db4 100644 --- a/astrbot/core/computer/olayer/python.py +++ b/astrbot/core/computer/olayer/python.py @@ -12,7 +12,7 @@ async def exec( self, code: str, kernel_id: str | None = None, - timeout: int = 30, + timeout_seconds: int = 30, silent: bool = False, ) -> dict[str, Any]: """Execute Python code""" diff --git a/astrbot/core/computer/olayer/shell.py b/astrbot/core/computer/olayer/shell.py index df2263b65a..67d9f95efd 100644 --- a/astrbot/core/computer/olayer/shell.py +++ b/astrbot/core/computer/olayer/shell.py @@ -13,7 +13,7 @@ async def exec( command: str, cwd: str | None = None, env: dict[str, str] | None = None, - timeout: int | None = 30, + timeout_seconds: int | None = 30, shell: bool = True, background: bool = False, ) -> dict[str, Any]: diff --git a/astrbot/core/computer/tools/browser.py b/astrbot/core/computer/tools/browser.py index 70061ac313..80a9be11a2 100644 --- a/astrbot/core/computer/tools/browser.py +++ b/astrbot/core/computer/tools/browser.py @@ -71,19 +71,23 @@ async def call( self, context: ContextWrapper[AstrAgentContext], cmd: str, - timeout: int = 30, + timeout_seconds: int = 30, description: str | None = None, tags: str | None = None, learn: bool = False, include_trace: bool = False, + **kwargs: Any, ) -> ToolExecResult: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.exec( cmd=cmd, - timeout=timeout, + timeout_seconds=timeout_seconds, description=description, tags=tags, learn=learn, @@ -133,20 +137,24 @@ async def call( self, context: ContextWrapper[AstrAgentContext], commands: list[str], - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, description: str | None = None, tags: str | None = None, learn: bool = False, include_trace: bool = False, + **kwargs: Any, ) -> ToolExecResult: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.exec_batch( commands=commands, - timeout=timeout, + timeout_seconds=timeout_seconds, stop_on_error=stop_on_error, description=description, tags=tags, @@ -181,19 +189,23 @@ async def call( self, context: ContextWrapper[AstrAgentContext], skill_key: str, - timeout: int = 60, + timeout_seconds: int = 60, stop_on_error: bool = True, include_trace: bool = False, description: str | None = None, tags: str | None = None, + **kwargs: Any, ) -> ToolExecResult: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = int(legacy_timeout) if err := _ensure_admin(context): return err try: browser = await _get_browser_component(context) result = await browser.run_skill( skill_key=skill_key, - timeout=timeout, + timeout_seconds=timeout_seconds, stop_on_error=stop_on_error, include_trace=include_trace, description=description, diff --git a/astrbot/core/computer/tools/fs.py b/astrbot/core/computer/tools/fs.py index 31b7f3f513..d50025f4d4 100644 --- a/astrbot/core/computer/tools/fs.py +++ b/astrbot/core/computer/tools/fs.py @@ -1,3 +1,4 @@ +import asyncio import os import uuid from dataclasses import dataclass, field @@ -111,10 +112,10 @@ async def call( ) try: # Check if file exists - if not os.path.exists(local_path): + if not await asyncio.to_thread(os.path.exists, local_path): return f"Error: File does not exist: {local_path}" - if not os.path.isfile(local_path): + if not await asyncio.to_thread(os.path.isfile, local_path): return f"Error: Path is not a file: {local_path}" # Use basename if sandbox_filename is not provided diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index d12878be3e..211514f7a2 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -1,7 +1,7 @@ import asyncio import json from collections.abc import Awaitable, Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from zoneinfo import ZoneInfo @@ -192,7 +192,7 @@ async def _run_job(self, job_id: str) -> None: job = await self.db.get_cron_job(job_id) if not job or not job.enabled: return - start_time = datetime.now(timezone.utc) + start_time = datetime.now(UTC) await self.db.update_cron_job( job_id, status="running", last_run_at=start_time, last_error=None ) diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index d7bca30678..47ecadf040 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -1,3 +1,4 @@ +import asyncio import os from astrbot.api import logger, sp @@ -22,7 +23,7 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: data_dir = get_astrbot_data_path() data_v3_db = os.path.join(data_dir, "data_v3.db") - if not os.path.exists(data_v3_db): + if not await asyncio.to_thread(os.path.exists, data_v3_db): return False migration_done = await db_helper.get_preference( "global", diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 727d97b29b..d7a57a6d61 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -106,8 +106,8 @@ async def migration_platform_table( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) secs_from_2023_4_10_to_now = ( - datetime.datetime.now(datetime.timezone.utc) - - datetime.datetime(2023, 4, 10, tzinfo=datetime.timezone.utc) + datetime.datetime.now(datetime.UTC) + - datetime.datetime(2023, 4, 10, tzinfo=datetime.UTC) ).total_seconds() offset_sec = int(secs_from_2023_4_10_to_now) logger.info(f"迁移旧平台数据,offset_sec: {offset_sec} 秒。") @@ -162,7 +162,7 @@ async def migration_platform_table( { "timestamp": datetime.datetime.fromtimestamp( bucket_end, - tz=datetime.timezone.utc, + tz=datetime.UTC, ), "platform_id": platform_id, "platform_type": platform_type, diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 451f054f62..1b8179f074 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -1,16 +1,16 @@ import uuid from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import TypedDict from sqlmodel import JSON, Field, SQLModel, Text, UniqueConstraint class TimestampMixin(SQLModel): - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(UTC), + sa_column_kwargs={"onupdate": lambda: datetime.now(UTC)}, ) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index f496e19d59..e356e85aa3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -2,7 +2,7 @@ import threading import typing as T from collections.abc import Awaitable, Callable -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from sqlalchemy import CursorResult, Row from sqlalchemy.ext.asyncio import AsyncSession @@ -633,7 +633,7 @@ async def get_active_api_key_by_hash(self, key_hash: str) -> ApiKey | None: """Get an active API key by hash (not revoked, not expired).""" async with self.get_db() as session: session: AsyncSession - now = datetime.now(timezone.utc) + now = datetime.now(UTC) query = select(ApiKey).where( ApiKey.key_hash == key_hash, col(ApiKey.revoked_at).is_(None), @@ -650,7 +650,7 @@ async def touch_api_key(self, key_id: str) -> None: await session.execute( update(ApiKey) .where(col(ApiKey.key_id) == key_id) - .values(last_used_at=datetime.now(timezone.utc)), + .values(last_used_at=datetime.now(UTC)), ) async def revoke_api_key(self, key_id: str) -> bool: @@ -661,7 +661,7 @@ async def revoke_api_key(self, key_id: str) -> bool: query = ( update(ApiKey) .where(col(ApiKey.key_id) == key_id) - .values(revoked_at=datetime.now(timezone.utc)) + .values(revoked_at=datetime.now(UTC)) ) result = T.cast(CursorResult, await session.execute(query)) return result.rowcount > 0 @@ -1534,7 +1534,7 @@ async def update_platform_session( async with self.get_db() as session: session: AsyncSession async with session.begin(): - values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)} if display_name is not None: values["display_name"] = display_name @@ -1622,7 +1622,7 @@ async def update_chatui_project( async with self.get_db() as session: session: AsyncSession async with session.begin(): - values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + values: dict[str, T.Any] = {"updated_at": datetime.now(UTC)} if title is not None: values["title"] = title if emoji is not None: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index 42fbd23dfe..5aa897e79b 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -28,12 +28,17 @@ async def check_token_expired(self, file_token: str) -> bool: await self._cleanup_expired_tokens() return file_token not in self.staged_files - async def register_file(self, file_path: str, timeout: float | None = None) -> str: + async def register_file( + self, + file_path: str, + timeout_seconds: float | None = None, + **kwargs, + ) -> str: """向令牌服务注册一个文件。 Args: file_path(str): 文件路径 - timeout(float): 超时时间,单位秒(可选) + timeout_seconds(float): 超时时间,单位秒(可选) Returns: str: 一个单次令牌 @@ -58,15 +63,18 @@ async def register_file(self, file_path: str, timeout: float | None = None) -> s async with self.lock: await self._cleanup_expired_tokens() + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = float(legacy_timeout) - if not os.path.exists(local_path): + if not await asyncio.to_thread(os.path.exists, local_path): raise FileNotFoundError( f"文件不存在: {local_path} (原始输入: {file_path})", ) file_token = str(uuid.uuid4()) expire_time = time.time() + ( - timeout if timeout is not None else self.default_timeout + timeout_seconds if timeout_seconds is not None else self.default_timeout ) # 存储转换后的真实路径 self.staged_files[file_token] = (local_path, expire_time) @@ -93,6 +101,6 @@ async def handle_file(self, file_token: str) -> str: raise KeyError(f"无效或过期的文件 token: {file_token}") file_path, _ = self.staged_files.pop(file_token) - if not os.path.exists(file_path): + if not await asyncio.to_thread(os.path.exists, file_path): raise FileNotFoundError(f"文件不存在: {file_path}") return file_path diff --git a/astrbot/core/knowledge_base/models.py b/astrbot/core/knowledge_base/models.py index da919a384a..10277a926b 100644 --- a/astrbot/core/knowledge_base/models.py +++ b/astrbot/core/knowledge_base/models.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime from sqlmodel import Field, MetaData, SQLModel, Text, UniqueConstraint @@ -40,10 +40,10 @@ class KnowledgeBase(BaseKBModel, table=True): top_k_dense: int | None = Field(default=50, nullable=True) top_k_sparse: int | None = Field(default=50, nullable=True) top_m_final: int | None = Field(default=5, nullable=True) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(UTC), + sa_column_kwargs={"onupdate": datetime.now(UTC)}, ) doc_count: int = Field(default=0, nullable=False) chunk_count: int = Field(default=0, nullable=False) @@ -83,10 +83,10 @@ class KBDocument(BaseKBModel, table=True): file_path: str = Field(max_length=512, nullable=False) chunk_count: int = Field(default=0, nullable=False) media_count: int = Field(default=0, nullable=False) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), - sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(UTC), + sa_column_kwargs={"onupdate": datetime.now(UTC)}, ) @@ -117,4 +117,4 @@ class KBMedia(BaseKBModel, table=True): file_path: str = Field(max_length=512, nullable=False) file_size: int = Field(nullable=False) mime_type: str = Field(max_length=100, nullable=False) - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 15265c38d1..038e997424 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -27,7 +27,8 @@ import os import sys import uuid -from enum import Enum +from enum import StrEnum +from pathlib import Path if sys.version_info >= (3, 14): from pydantic import BaseModel @@ -39,7 +40,7 @@ from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 -class ComponentType(str, Enum): +class ComponentType(StrEnum): # Basic Segment Types Plain = "Plain" # plain text message Image = "Image" # image @@ -158,18 +159,17 @@ async def convert_to_file_path(self) -> str: return self.file[8:] if self.file.startswith("http"): file_path = await download_image_by_url(self.file) - return os.path.abspath(file_path) + return await asyncio.to_thread(os.path.abspath, file_path) if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) file_path = os.path.join( get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg" ) - with open(file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(file_path) - if os.path.exists(self.file): - return os.path.abspath(self.file) + await asyncio.to_thread(Path(file_path).write_bytes, image_bytes) + return await asyncio.to_thread(os.path.abspath, file_path) + if await asyncio.to_thread(os.path.exists, self.file): + return await asyncio.to_thread(os.path.abspath, self.file) raise Exception(f"not a valid file: {self.file}") async def convert_to_base64(self) -> str: @@ -183,14 +183,14 @@ async def convert_to_base64(self) -> str: if not self.file: raise Exception(f"not a valid file: {self.file}") if self.file.startswith("file:///"): - bs64_data = file_to_base64(self.file[8:]) + bs64_data = await file_to_base64(self.file[8:]) elif self.file.startswith("http"): file_path = await download_image_by_url(self.file) - bs64_data = file_to_base64(file_path) + bs64_data = await file_to_base64(file_path) elif self.file.startswith("base64://"): bs64_data = self.file - elif os.path.exists(self.file): - bs64_data = file_to_base64(self.file) + elif await asyncio.to_thread(os.path.exists, self.file): + bs64_data = await file_to_base64(self.file) else: raise Exception(f"not a valid file: {self.file}") bs64_data = bs64_data.removeprefix("base64://") @@ -256,11 +256,11 @@ async def convert_to_file_path(self) -> str: get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}" ) await download_file(url, video_file_path) - if os.path.exists(video_file_path): - return os.path.abspath(video_file_path) + if await asyncio.to_thread(os.path.exists, video_file_path): + return await asyncio.to_thread(os.path.abspath, video_file_path) raise Exception(f"download failed: {url}") - if os.path.exists(url): - return os.path.abspath(url) + if await asyncio.to_thread(os.path.exists, url): + return await asyncio.to_thread(os.path.abspath, url) raise Exception(f"not a valid file: {url}") async def register_to_file_service(self) -> str: @@ -449,18 +449,17 @@ async def convert_to_file_path(self) -> str: return url[8:] if url.startswith("http"): image_file_path = await download_image_by_url(url) - return os.path.abspath(image_file_path) + return await asyncio.to_thread(os.path.abspath, image_file_path) if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) image_file_path = os.path.join( get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg" ) - with open(image_file_path, "wb") as f: - f.write(image_bytes) - return os.path.abspath(image_file_path) - if os.path.exists(url): - return os.path.abspath(url) + await asyncio.to_thread(Path(image_file_path).write_bytes, image_bytes) + return await asyncio.to_thread(os.path.abspath, image_file_path) + if await asyncio.to_thread(os.path.exists, url): + return await asyncio.to_thread(os.path.abspath, url) raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: @@ -475,14 +474,14 @@ async def convert_to_base64(self) -> str: if not url: raise ValueError("No valid file or URL provided") if url.startswith("file:///"): - bs64_data = file_to_base64(url[8:]) + bs64_data = await file_to_base64(url[8:]) elif url.startswith("http"): image_file_path = await download_image_by_url(url) - bs64_data = file_to_base64(image_file_path) + bs64_data = await file_to_base64(image_file_path) elif url.startswith("base64://"): bs64_data = url - elif os.path.exists(url): - bs64_data = file_to_base64(url) + elif await asyncio.to_thread(os.path.exists, url): + bs64_data = await file_to_base64(url) else: raise Exception(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") @@ -735,8 +734,8 @@ async def get_file(self, allow_return_url: bool = False) -> str: ): path = path[1:] - if os.path.exists(path): - return os.path.abspath(path) + if await asyncio.to_thread(os.path.exists, path): + return await asyncio.to_thread(os.path.abspath, path) if self.url: await self._download_file() @@ -751,7 +750,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: and path[2] == ":" ): path = path[1:] - return os.path.abspath(path) + return await asyncio.to_thread(os.path.abspath, path) return "" @@ -767,7 +766,7 @@ async def _download_file(self) -> None: filename = f"fileseg_{uuid.uuid4().hex}" file_path = os.path.join(download_dir, filename) await download_file(self.url, file_path) - self.file_ = os.path.abspath(file_path) + self.file_ = await asyncio.to_thread(os.path.abspath, file_path) async def register_to_file_service(self) -> str: """将文件注册到文件服务。 diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 2d9b45cc19..e823aac9d0 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -254,7 +254,7 @@ async def download_ding_file( "robotCode": robot_code, } temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}" async with ( aiohttp.ClientSession() as session, @@ -412,7 +412,7 @@ async def upload_media(self, file_path: str, media_type: str) -> str: form = aiohttp.FormData() form.add_field( "media", - media_file_path.read_bytes(), + await asyncio.to_thread(media_file_path.read_bytes), filename=media_file_path.name, content_type="application/octet-stream", ) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index ebd32c471a..36ee5710b4 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,15 +1,10 @@ -import sys from collections.abc import Awaitable, Callable +from typing import override import discord from astrbot import logger -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - # Discord Bot客户端 class DiscordBotClient(discord.Bot): diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 7657962a11..40be87a633 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,7 +1,6 @@ import asyncio import re -import sys -from typing import Any, cast +from typing import Any, cast, override import discord from discord.abc import GuildChannel, Messageable, PrivateChannel @@ -27,11 +26,6 @@ from .client import DiscordBotClient from .discord_platform_event import DiscordPlatformEvent -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - # 注册平台适配器 @register_platform_adapter( diff --git a/astrbot/core/platform/sources/kook/kook_adapter.py b/astrbot/core/platform/sources/kook/kook_adapter.py index 1124c6841d..b7d047291e 100644 --- a/astrbot/core/platform/sources/kook/kook_adapter.py +++ b/astrbot/core/platform/sources/kook/kook_adapter.py @@ -130,7 +130,7 @@ async def _main_loop(self): await asyncio.wait_for( self.client.wait_until_closed(), timeout=1.0 ) - except asyncio.TimeoutError: + except TimeoutError: # 正常超时,继续下一轮 while 检查 continue diff --git a/astrbot/core/platform/sources/kook/kook_client.py b/astrbot/core/platform/sources/kook/kook_client.py index 9a452a9c3f..34078e2ac2 100644 --- a/astrbot/core/platform/sources/kook/kook_client.py +++ b/astrbot/core/platform/sources/kook/kook_client.py @@ -171,7 +171,7 @@ async def listen(self): # 处理不同类型的信令 await self._handle_signal(data) - except asyncio.TimeoutError: + except TimeoutError: # 超时检查,继续循环 continue except websockets.exceptions.ConnectionClosed: @@ -362,12 +362,14 @@ async def upload_asset(self, file_url: str | None) -> str: b64_str = file_url.removeprefix("base64://") bytes_data = base64.b64decode(b64_str) - elif file_url.startswith("file://") or os.path.exists(file_url): + elif file_url.startswith("file://") or await asyncio.to_thread( + os.path.exists, file_url + ): file_url = file_url.removeprefix("file:///") file_url = file_url.removeprefix("file://") try: - target_path = Path(file_url).resolve() + target_path = await asyncio.to_thread(Path(file_url).resolve) except Exception as exp: logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"') raise FileNotFoundError( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index be1c81c26e..6b500dc5ca 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -429,7 +429,7 @@ async def _download_file_resource_to_temp( suffix = Path(file_name).suffix if file_name else default_suffix temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) temp_path = ( temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}" ) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 92e3a32b9e..a513f45005 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,8 +1,10 @@ +import asyncio import base64 import json import os import uuid from io import BytesIO +from pathlib import Path import lark_oapi as lark from lark_oapi.api.im.v1 import ( @@ -136,7 +138,7 @@ async def _upload_lark_file( Returns: 成功返回file_key,失败返回None """ - if not path or not os.path.exists(path): + if not path or not await asyncio.to_thread(os.path.exists, path): logger.error(f"[Lark] 文件不存在: {path}") return None @@ -145,36 +147,32 @@ async def _upload_lark_file( return None try: - with open(path, "rb") as file_obj: - body_builder = ( - CreateFileRequestBody.builder() - .file_type(file_type) - .file_name(os.path.basename(path)) - .file(file_obj) - ) - if duration is not None: - body_builder.duration(duration) + file_obj = BytesIO(await asyncio.to_thread(Path(path).read_bytes)) + body_builder = ( + CreateFileRequestBody.builder() + .file_type(file_type) + .file_name(os.path.basename(path)) + .file(file_obj) + ) + if duration is not None: + body_builder.duration(duration) - request = ( - CreateFileRequest.builder() - .request_body(body_builder.build()) - .build() - ) - response = await lark_client.im.v1.file.acreate(request) + request = ( + CreateFileRequest.builder().request_body(body_builder.build()).build() + ) + response = await lark_client.im.v1.file.acreate(request) - if not response.success(): - logger.error( - f"[Lark] 无法上传文件({response.code}): {response.msg}" - ) - return None + if not response.success(): + logger.error(f"[Lark] 无法上传文件({response.code}): {response.msg}") + return None - if response.data is None: - logger.error("[Lark] 上传文件成功但未返回数据(data is None)") - return None + if response.data is None: + logger.error("[Lark] 上传文件成功但未返回数据(data is None)") + return None - file_key = response.data.file_key - logger.debug(f"[Lark] 文件上传成功: {file_key}") - return file_key + file_key = response.data.file_key + logger.debug(f"[Lark] 文件上传成功: {file_key}") + return file_key except Exception as e: logger.error(f"[Lark] 无法打开或上传文件: {e}") @@ -207,8 +205,9 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l temp_dir, f"lark_image_{uuid.uuid4().hex[:8]}.jpg", ) - with open(file_path, "wb") as f: - f.write(BytesIO(image_data).getvalue()) + await asyncio.to_thread( + Path(file_path).write_bytes, BytesIO(image_data).getvalue() + ) else: file_path = comp.file if comp.file else "" @@ -217,7 +216,9 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l logger.error("[Lark] 图片路径为空,无法上传") continue try: - image_file = open(file_path, "rb") + image_file = BytesIO( + await asyncio.to_thread(Path(file_path).read_bytes) + ) except Exception as e: logger.error(f"[Lark] 无法打开图片文件: {e}") continue @@ -412,7 +413,9 @@ async def _send_audio_message( logger.error(f"[Lark] 无法获取音频文件路径: {e}") return - if not original_audio_path or not os.path.exists(original_audio_path): + if not original_audio_path or not await asyncio.to_thread( + os.path.exists, original_audio_path + ): logger.error(f"[Lark] 音频文件不存在: {original_audio_path}") return @@ -442,7 +445,9 @@ async def _send_audio_message( ) # 清理转换后的临时音频文件 - if converted_audio_path and os.path.exists(converted_audio_path): + if converted_audio_path and await asyncio.to_thread( + os.path.exists, converted_audio_path + ): try: os.remove(converted_audio_path) logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}") @@ -485,7 +490,9 @@ async def _send_media_message( logger.error(f"[Lark] 无法获取视频文件路径: {e}") return - if not original_video_path or not os.path.exists(original_video_path): + if not original_video_path or not await asyncio.to_thread( + os.path.exists, original_video_path + ): logger.error(f"[Lark] 视频文件不存在: {original_video_path}") return @@ -515,7 +522,9 @@ async def _send_media_message( ) # 清理转换后的临时视频文件 - if converted_video_path and os.path.exists(converted_video_path): + if converted_video_path and await asyncio.to_thread( + os.path.exists, converted_video_path + ): try: os.remove(converted_video_path) logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}") diff --git a/astrbot/core/platform/sources/line/line_event.py b/astrbot/core/platform/sources/line/line_event.py index 8b82ad1820..a16bdd18f8 100644 --- a/astrbot/core/platform/sources/line/line_event.py +++ b/astrbot/core/platform/sources/line/line_event.py @@ -161,7 +161,7 @@ async def _resolve_video_preview_url(segment: Video) -> str: try: video_path = await segment.convert_to_file_path() temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg" process = await asyncio.create_subprocess_exec( @@ -201,8 +201,8 @@ async def _resolve_file_url(segment: File) -> str: async def _resolve_file_size(segment: File) -> int: try: file_path = await segment.get_file(allow_return_url=False) - if file_path and os.path.exists(file_path): - return int(os.path.getsize(file_path)) + if file_path and await asyncio.to_thread(os.path.exists, file_path): + return int(await asyncio.to_thread(os.path.getsize, file_path)) except Exception as e: logger.debug("[LINE] resolve file size failed: %s", e) return 0 diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index fd61c3e506..e1169decb4 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -499,7 +499,8 @@ async def _upload_comp(comp) -> object | None: # 清理临时文件 if local_path and isinstance(local_path, str): data_temp = get_astrbot_temp_path() - if local_path.startswith(data_temp) and os.path.exists( + if local_path.startswith(data_temp) and await asyncio.to_thread( + os.path.exists, local_path, ): try: diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 3e5eb9a90e..64728a5616 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,6 +3,7 @@ import random import uuid from collections.abc import Awaitable, Callable +from pathlib import Path from typing import Any, NoReturn try: @@ -555,22 +556,19 @@ async def upload_file( form.add_field("folderId", str(folder_id)) try: - f = open(file_path, "rb") + file_bytes = await asyncio.to_thread(Path(file_path).read_bytes) except FileNotFoundError as e: logger.error(f"[Misskey API] 本地文件不存在: {file_path}") raise APIError(f"File not found: {file_path}") from e - try: - form.add_field("file", f, filename=filename) - async with self.session.post(url, data=form) as resp: - result = await self._process_response(resp, "drive/files/create") - file_id = FileIDExtractor.extract_file_id(result) - logger.debug( - f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", - ) - return {"id": file_id, "raw": result} - finally: - f.close() + form.add_field("file", file_bytes, filename=filename) + async with self.session.post(url, data=form) as resp: + result = await self._process_response(resp, "drive/files/create") + file_id = FileIDExtractor.extract_file_id(result) + logger.debug( + f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}", + ) + return {"id": file_id, "raw": result} except aiohttp.ClientError as e: logger.error(f"[Misskey API] 文件上传网络错误: {e}") raise APIConnectionError(f"Upload failed: {e}") from e diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 868ec8a657..55050a8219 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -339,7 +339,7 @@ async def upload_group_and_c2c_record( payload = {"file_type": file_type, "srv_send_msg": srv_send_msg} # 处理文件数据 - if os.path.exists(file_source): + if await asyncio.to_thread(os.path.exists, file_source): # 读取本地文件 async with aiofiles.open(file_source, "rb") as f: file_content = await f.read() @@ -421,15 +421,15 @@ async def _parse_to_qqofficial(message: MessageChain): plain_text += i.text elif isinstance(i, Image) and not image_base64: if i.file and i.file.startswith("file:///"): - image_base64 = file_to_base64(i.file[8:]) + image_base64 = await file_to_base64(i.file[8:]) image_file_path = i.file[8:] elif i.file and i.file.startswith("http"): image_file_path = await download_image_by_url(i.file) - image_base64 = file_to_base64(image_file_path) + image_base64 = await file_to_base64(image_file_path) elif i.file and i.file.startswith("base64://"): image_base64 = i.file elif i.file: - image_base64 = file_to_base64(i.file) + image_base64 = await file_to_base64(i.file) else: raise ValueError("Unsupported image file format") image_base64 = image_base64.removeprefix("base64://") diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 2dd72bd0ca..76e3f9d985 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -1,9 +1,8 @@ import asyncio import os import re -import sys import uuid -from typing import cast +from typing import cast, override from apscheduler.schedulers.asyncio import AsyncIOScheduler from telegram import BotCommand, Update @@ -33,11 +32,6 @@ from .tg_event import TelegramPlatformEvent -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - @register_platform_adapter("telegram", "telegram 适配器") class TelegramPlatformAdapter(Platform): diff --git a/astrbot/core/platform/sources/webchat/message_parts_helper.py b/astrbot/core/platform/sources/webchat/message_parts_helper.py index 43072ec1c8..3a1371e723 100644 --- a/astrbot/core/platform/sources/webchat/message_parts_helper.py +++ b/astrbot/core/platform/sources/webchat/message_parts_helper.py @@ -1,3 +1,4 @@ +import asyncio import json import mimetypes import shutil @@ -139,13 +140,15 @@ async def parse_webchat_message_parts( continue file_path = Path(str(path)) - if verify_media_path_exists and not file_path.exists(): + if verify_media_path_exists and not await asyncio.to_thread(file_path.exists): if strict: raise ValueError(f"file not found: {file_path!s}") continue file_path_str = ( - str(file_path.resolve()) if verify_media_path_exists else str(file_path) + str(await asyncio.to_thread(file_path.resolve)) + if verify_media_path_exists + else str(file_path) ) has_content = True if part_type == "image": @@ -366,7 +369,7 @@ async def message_chain_to_storage_message_parts( attachments_dir: str | Path, ) -> list[dict]: target_dir = Path(attachments_dir) - target_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(target_dir.mkdir, parents=True, exist_ok=True) parts: list[dict] = [] for comp in message_chain.chain: @@ -442,7 +445,9 @@ async def _copy_file_to_attachment_part( display_name: str | None = None, ) -> dict | None: src_path = Path(file_path) - if not src_path.exists() or not src_path.is_file(): + if not await asyncio.to_thread(src_path.exists) or not await asyncio.to_thread( + src_path.is_file + ): return None suffix = src_path.suffix diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index b7da864aae..aacb0e12da 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,8 +1,10 @@ +import asyncio import base64 import json import os import shutil import uuid +from pathlib import Path from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -80,8 +82,9 @@ async def _send( filename = f"{str(uuid.uuid4())}.jpg" path = os.path.join(attachments_dir, filename) image_base64 = await comp.convert_to_base64() - with open(path, "wb") as f: - f.write(base64.b64decode(image_base64)) + await asyncio.to_thread( + Path(path).write_bytes, base64.b64decode(image_base64) + ) data = f"[IMAGE]{filename}" await web_chat_back_queue.put( { @@ -96,8 +99,9 @@ async def _send( filename = f"{str(uuid.uuid4())}.wav" path = os.path.join(attachments_dir, filename) record_base64 = await comp.convert_to_base64() - with open(path, "wb") as f: - f.write(base64.b64decode(record_base64)) + await asyncio.to_thread( + Path(path).write_bytes, base64.b64decode(record_base64) + ) data = f"[RECORD]{filename}" await web_chat_back_queue.put( { diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 6647db89f0..b77a0da3e4 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -1,9 +1,9 @@ import asyncio import os -import sys import uuid from collections.abc import Awaitable, Callable -from typing import Any, cast +from pathlib import Path +from typing import Any, cast, override import quart from requests import Response @@ -33,11 +33,6 @@ from .wecom_kf import WeChatKF from .wecom_kf_message import WeChatKFMessage -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: @@ -346,8 +341,7 @@ async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) try: path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav") @@ -402,8 +396,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) abm.message = [Image(file=path, url=path)] elif msgtype == "voice": media_id = msg.get("voice", {}).get("media_id", "") @@ -415,8 +408,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) try: path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 7aee26e47f..83a91a872b 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -12,6 +12,13 @@ from .wecom_kf_message import WeChatKFMessage +def _upload_media_from_path( + client: WeChatClient, media_type: str, file_path: str +) -> dict: + with open(file_path, "rb") as f: + return client.media.upload(media_type, f) + + class WecomPlatformEvent(AstrMessageEvent): def __init__( self, @@ -100,45 +107,52 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - with open(img_path, "rb") as f: + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "image", + img_path, + ) + except Exception as e: + logger.error(f"微信客服上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传图片失败: {e}"), + ) + return + logger.debug(f"微信客服上传图片返回: {response}") + kf_message_api.send_image( + user_id, + self.get_self_id(), + response["media_id"], + ) + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: try: - response = self.client.media.upload("image", f) + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "voice", + record_path_amr, + ) except Exception as e: - logger.error(f"微信客服上传图片失败: {e}") + logger.error(f"微信客服上传语音失败: {e}") await self.send( - MessageChain().message(f"微信客服上传图片失败: {e}"), + MessageChain().message(f"微信客服上传语音失败: {e}"), ) return - logger.debug(f"微信客服上传图片返回: {response}") - kf_message_api.send_image( + logger.info(f"微信客服上传语音返回: {response}") + kf_message_api.send_voice( user_id, self.get_self_id(), response["media_id"], ) - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"微信客服上传语音失败: {e}") - await self.send( - MessageChain().message( - f"微信客服上传语音失败: {e}" - ), - ) - return - logger.info(f"微信客服上传语音返回: {response}") - kf_message_api.send_voice( - user_id, - self.get_self_id(), - response["media_id"], - ) finally: - if record_path_amr != record_path and os.path.exists( + if record_path_amr != record_path and await asyncio.to_thread( + os.path.exists, record_path_amr, ): try: @@ -148,39 +162,47 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, File): file_path = await comp.get_file() - with open(file_path, "rb") as f: - try: - response = self.client.media.upload("file", f) - except Exception as e: - logger.error(f"微信客服上传文件失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传文件失败: {e}"), - ) - return - logger.debug(f"微信客服上传文件返回: {response}") - kf_message_api.send_file( - user_id, - self.get_self_id(), - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "file", + file_path, + ) + except Exception as e: + logger.error(f"微信客服上传文件失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传文件失败: {e}"), ) + return + logger.debug(f"微信客服上传文件返回: {response}") + kf_message_api.send_file( + user_id, + self.get_self_id(), + response["media_id"], + ) elif isinstance(comp, Video): video_path = await comp.convert_to_file_path() - with open(video_path, "rb") as f: - try: - response = self.client.media.upload("video", f) - except Exception as e: - logger.error(f"微信客服上传视频失败: {e}") - await self.send( - MessageChain().message(f"微信客服上传视频失败: {e}"), - ) - return - logger.debug(f"微信客服上传视频返回: {response}") - kf_message_api.send_video( - user_id, - self.get_self_id(), - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "video", + video_path, ) + except Exception as e: + logger.error(f"微信客服上传视频失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传视频失败: {e}"), + ) + return + logger.debug(f"微信客服上传视频返回: {response}") + kf_message_api.send_video( + user_id, + self.get_self_id(), + response["media_id"], + ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") else: @@ -199,45 +221,52 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - with open(img_path, "rb") as f: + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "image", + img_path, + ) + except Exception as e: + logger.error(f"企业微信上传图片失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传图片失败: {e}"), + ) + return + logger.debug(f"企业微信上传图片返回: {response}") + self.client.message.send_image( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: try: - response = self.client.media.upload("image", f) + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "voice", + record_path_amr, + ) except Exception as e: - logger.error(f"企业微信上传图片失败: {e}") + logger.error(f"企业微信上传语音失败: {e}") await self.send( - MessageChain().message(f"企业微信上传图片失败: {e}"), + MessageChain().message(f"企业微信上传语音失败: {e}"), ) return - logger.debug(f"企业微信上传图片返回: {response}") - self.client.message.send_image( + logger.info(f"企业微信上传语音返回: {response}") + self.client.message.send_voice( message_obj.self_id, message_obj.session_id, response["media_id"], ) - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"企业微信上传语音失败: {e}") - await self.send( - MessageChain().message( - f"企业微信上传语音失败: {e}" - ), - ) - return - logger.info(f"企业微信上传语音返回: {response}") - self.client.message.send_voice( - message_obj.self_id, - message_obj.session_id, - response["media_id"], - ) finally: - if record_path_amr != record_path and os.path.exists( + if record_path_amr != record_path and await asyncio.to_thread( + os.path.exists, record_path_amr, ): try: @@ -247,39 +276,47 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, File): file_path = await comp.get_file() - with open(file_path, "rb") as f: - try: - response = self.client.media.upload("file", f) - except Exception as e: - logger.error(f"企业微信上传文件失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传文件失败: {e}"), - ) - return - logger.debug(f"企业微信上传文件返回: {response}") - self.client.message.send_file( - message_obj.self_id, - message_obj.session_id, - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "file", + file_path, ) + except Exception as e: + logger.error(f"企业微信上传文件失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传文件失败: {e}"), + ) + return + logger.debug(f"企业微信上传文件返回: {response}") + self.client.message.send_file( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) elif isinstance(comp, Video): video_path = await comp.convert_to_file_path() - with open(video_path, "rb") as f: - try: - response = self.client.media.upload("video", f) - except Exception as e: - logger.error(f"企业微信上传视频失败: {e}") - await self.send( - MessageChain().message(f"企业微信上传视频失败: {e}"), - ) - return - logger.debug(f"企业微信上传视频返回: {response}") - self.client.message.send_video( - message_obj.self_id, - message_obj.session_id, - response["media_id"], + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "video", + video_path, + ) + except Exception as e: + logger.error(f"企业微信上传视频失败: {e}") + await self.send( + MessageChain().message(f"企业微信上传视频失败: {e}"), ) + return + logger.debug(f"企业微信上传视频返回: {response}") + self.client.message.send_video( + message_obj.self_id, + message_obj.session_id, + response["media_id"], + ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index f7cbe380d4..6dbfda7b41 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -2,7 +2,6 @@ 提供常量定义、工具函数和辅助方法 """ -import asyncio import base64 import hashlib import secrets @@ -174,7 +173,7 @@ async def process_encrypted_image( response.raise_for_status() encrypted_data = await response.read() logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: + except (TimeoutError, aiohttp.ClientError) as e: error_msg = f"下载图片失败: {e!s}" logger.error(error_msg) return False, error_msg diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py index 6f42f264b9..c305411d4e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_webhook.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import base64 import hashlib import mimetypes @@ -103,7 +104,9 @@ async def send_image_base64(self, image_base64: str) -> None: async def upload_media( self, file_path: Path, media_type: Literal["file", "voice"] ) -> str: - if not file_path.exists() or not file_path.is_file(): + if not await asyncio.to_thread(file_path.exists) or not await asyncio.to_thread( + file_path.is_file + ): raise WecomAIBotWebhookError(f"文件不存在: {file_path}") content_type = ( @@ -112,7 +115,7 @@ async def upload_media( form = aiohttp.FormData() form.add_field( "media", - file_path.read_bytes(), + await asyncio.to_thread(file_path.read_bytes), filename=file_path.name, content_type=content_type, ) diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index c01355974a..59f8ebd8c7 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -1,10 +1,10 @@ import asyncio import os -import sys import time import uuid from collections.abc import Callable, Coroutine -from typing import Any, cast +from pathlib import Path +from typing import Any, cast, override import quart from requests import Response @@ -32,11 +32,6 @@ from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent -if sys.version_info >= (3, 12): - from typing import override -else: - from typing_extensions import override - class WeixinOfficialAccountServer: def __init__( @@ -379,7 +374,7 @@ async def callback(msg: BaseMessage): ) # wait for 180s logger.debug(f"Got future result: {result}") return result - except asyncio.TimeoutError: + except TimeoutError: logger.info(f"callback 处理消息超时: message_id={msg.id}") return create_reply("处理消息超时,请稍后再试。", msg) except Exception as e: @@ -468,8 +463,7 @@ async def convert_message( ) temp_dir = get_astrbot_temp_path() path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr") - with open(path, "wb") as f: - f.write(resp.content) + await asyncio.to_thread(Path(path).write_bytes, resp.content) try: path_wav = os.path.join( diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index ae536593c5..0797e4dee1 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -12,6 +12,13 @@ from astrbot.core.utils.media_utils import convert_audio_to_amr +def _upload_media_from_path( + client: WeChatClient, media_type: str, file_path: str +) -> dict: + with open(file_path, "rb") as f: + return client.media.upload(media_type, f) + + class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): def __init__( self, @@ -101,24 +108,63 @@ async def send(self, message: MessageChain) -> None: elif isinstance(comp, Image): img_path = await comp.convert_to_file_path() - with open(img_path, "rb") as f: + try: + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "image", + img_path, + ) + except Exception as e: + logger.error(f"微信公众平台上传图片失败: {e}") + await self.send( + MessageChain().message(f"微信公众平台上传图片失败: {e}"), + ) + return + logger.debug(f"微信公众平台上传图片返回: {response}") + + if active_send_mode: + self.client.message.send_image( + message_obj.sender.user_id, + response["media_id"], + ) + else: + reply = ImageReply( + media_id=response["media_id"], + message=cast(dict, self.message_obj.raw_message)["message"], + ) + xml = reply.render() + future = cast(dict, self.message_obj.raw_message)["future"] + assert isinstance(future, asyncio.Future) + future.set_result(xml) + + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + record_path_amr = await convert_audio_to_amr(record_path) + + try: try: - response = self.client.media.upload("image", f) + response = await asyncio.to_thread( + _upload_media_from_path, + self.client, + "voice", + record_path_amr, + ) except Exception as e: - logger.error(f"微信公众平台上传图片失败: {e}") + logger.error(f"微信公众平台上传语音失败: {e}") await self.send( - MessageChain().message(f"微信公众平台上传图片失败: {e}"), + MessageChain().message(f"微信公众平台上传语音失败: {e}"), ) return - logger.debug(f"微信公众平台上传图片返回: {response}") + logger.info(f"微信公众平台上传语音返回: {response}") if active_send_mode: - self.client.message.send_image( + self.client.message.send_voice( message_obj.sender.user_id, response["media_id"], ) else: - reply = ImageReply( + reply = VoiceReply( media_id=response["media_id"], message=cast(dict, self.message_obj.raw_message)["message"], ) @@ -126,44 +172,9 @@ async def send(self, message: MessageChain) -> None: future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) - - elif isinstance(comp, Record): - record_path = await comp.convert_to_file_path() - record_path_amr = await convert_audio_to_amr(record_path) - - try: - with open(record_path_amr, "rb") as f: - try: - response = self.client.media.upload("voice", f) - except Exception as e: - logger.error(f"微信公众平台上传语音失败: {e}") - await self.send( - MessageChain().message( - f"微信公众平台上传语音失败: {e}" - ), - ) - return - logger.info(f"微信公众平台上传语音返回: {response}") - - if active_send_mode: - self.client.message.send_voice( - message_obj.sender.user_id, - response["media_id"], - ) - else: - reply = VoiceReply( - media_id=response["media_id"], - message=cast(dict, self.message_obj.raw_message)[ - "message" - ], - ) - xml = reply.render() - future = cast(dict, self.message_obj.raw_message)["future"] - assert isinstance(future, asyncio.Future) - future.set_result(xml) finally: - if record_path_amr != record_path and os.path.exists( - record_path_amr + if record_path_amr != record_path and await asyncio.to_thread( + os.path.exists, record_path_amr ): try: os.remove(record_path_amr) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947d..aea04645d0 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio import base64 import enum import json from dataclasses import dataclass, field +from pathlib import Path from typing import Any from anthropic.types import Message as AnthropicMessage @@ -218,9 +220,10 @@ async def _encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + image_bs64 = base64.b64encode( + await asyncio.to_thread(Path(image_url).read_bytes) + ).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 return "" diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 068c63c5ad..22e9a0766c 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -8,6 +8,7 @@ import urllib.parse from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping from dataclasses import dataclass +from pathlib import Path from types import MappingProxyType from typing import Any @@ -198,7 +199,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return True, "" return False, f"HTTP {response.status}: {response.reason}" - except asyncio.TimeoutError: + except TimeoutError: return False, f"连接超时: {timeout}秒" except Exception as e: return False, f"{e!s}" @@ -373,15 +374,24 @@ async def init_mcp_clients( data_dir = get_astrbot_data_path() mcp_json_file = os.path.join(data_dir, "mcp_server.json") - if not os.path.exists(mcp_json_file): + if not await asyncio.to_thread(os.path.exists, mcp_json_file): # 配置文件不存在错误处理 - with open(mcp_json_file, "w", encoding="utf-8") as f: - json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) + config_text = json.dumps(DEFAULT_MCP_CONFIG, ensure_ascii=False, indent=4) + await asyncio.to_thread( + Path(mcp_json_file).write_text, + config_text, + encoding="utf-8", + ) logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return MCPInitSummary(total=0, success=0, failed=[]) - with open(mcp_json_file, encoding="utf-8") as f: - mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"] + mcp_json_content = await asyncio.to_thread( + Path(mcp_json_file).read_text, + encoding="utf-8", + ) + mcp_server_json_obj: dict[str, dict] = json.loads(mcp_json_content)[ + "mcpServers" + ] init_timeout = self._init_timeout_default timeout_display = f"{init_timeout:g}" @@ -451,7 +461,7 @@ async def _start_mcp_server( cfg: dict, *, shutdown_event: asyncio.Event | None = None, - timeout: float, + timeout_seconds: float, ) -> None: """Initialize MCP server with timeout and register task/event together. @@ -461,7 +471,7 @@ async def _start_mcp_server( async with self._runtime_lock: if name in self._mcp_server_runtime or name in self._mcp_starting: logger.warning( - f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout:g})。" + f"MCP 服务 {name} 已在运行,忽略本次启用请求(timeout={timeout_seconds:g})。" ) self._log_safe_mcp_debug_config(cfg) return @@ -474,11 +484,11 @@ async def _start_mcp_server( try: mcp_client = await asyncio.wait_for( self._init_mcp_client(name, cfg), - timeout=timeout, + timeout=timeout_seconds, ) - except asyncio.TimeoutError as exc: + except TimeoutError as exc: raise MCPInitTimeoutError( - f"MCP 服务 {name} 初始化超时({timeout:g} 秒)" + f"MCP 服务 {name} 初始化超时({timeout_seconds:g} 秒)" ) from exc except Exception: logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) @@ -511,7 +521,7 @@ async def lifecycle() -> None: async def _shutdown_runtimes( self, runtimes: list[_MCPServerRuntime], - timeout: float, + timeout_seconds: float, *, strict: bool = True, ) -> list[str]: @@ -530,9 +540,9 @@ async def _shutdown_runtimes( try: results = await asyncio.wait_for( asyncio.gather(*lifecycle_tasks, return_exceptions=True), - timeout=timeout, + timeout=timeout_seconds, ) - except asyncio.TimeoutError: + except TimeoutError: pending_names = [ runtime.name for runtime in runtimes @@ -543,10 +553,10 @@ async def _shutdown_runtimes( task.cancel() await asyncio.gather(*lifecycle_tasks, return_exceptions=True) if strict: - raise MCPShutdownTimeoutError(pending_names, timeout) + raise MCPShutdownTimeoutError(pending_names, timeout_seconds) logger.warning( "MCP 服务关闭超时(%s 秒),以下服务未完全关闭:%s", - f"{timeout:g}", + f"{timeout_seconds:g}", ", ".join(pending_names), ) return pending_names @@ -657,7 +667,8 @@ async def enable_mcp_server( name: str, config: dict, shutdown_event: asyncio.Event | None = None, - timeout: float | int | str | None = None, + timeout_seconds: float | int | str | None = None, + **kwargs: Any, ) -> None: """Enable a new MCP server and initialize it. @@ -665,18 +676,22 @@ async def enable_mcp_server( name: The name of the MCP server. config: Configuration for the MCP server. shutdown_event: Event to signal when the MCP client should shut down. - timeout: Timeout in seconds for initialization. + timeout_seconds: Timeout in seconds for initialization. Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout). Raises: MCPInitTimeoutError: If initialization does not complete within timeout. Exception: If there is an error during initialization. """ - if timeout is None: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = legacy_timeout + + if timeout_seconds is None: timeout_value = self._enable_timeout_default else: timeout_value = _resolve_timeout( - timeout=timeout, + timeout=timeout_seconds, env_name=ENABLE_MCP_TIMEOUT_ENV, default=self._enable_timeout_default, ) @@ -684,36 +699,45 @@ async def enable_mcp_server( name=name, cfg=config, shutdown_event=shutdown_event, - timeout=timeout_value, + timeout_seconds=timeout_value, ) async def disable_mcp_server( self, name: str | None = None, - timeout: float = 10, + timeout_seconds: float = 10, + **kwargs: Any, ) -> None: """Disable an MCP server by its name. Args: name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. - timeout (int): Timeout. + timeout_seconds (int): Timeout. Raises: MCPShutdownTimeoutError: If shutdown does not complete within timeout. Only raised when disabling a specific server (name is not None). """ + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = float(legacy_timeout) + if name: async with self._runtime_lock: runtime = self._mcp_server_runtime.get(name) if runtime is None: return - await self._shutdown_runtimes([runtime], timeout, strict=True) + await self._shutdown_runtimes( + [runtime], timeout_seconds=timeout_seconds, strict=True + ) else: async with self._runtime_lock: runtimes = list(self._mcp_server_runtime.values()) - await self._shutdown_runtimes(runtimes, timeout, strict=False) + await self._shutdown_runtimes( + runtimes, timeout_seconds=timeout_seconds, strict=False + ) def _warn_on_timeout_mismatch( self, diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 901efd0052..08c5254858 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,8 @@ import asyncio import os from collections.abc import AsyncGenerator -from typing import TypeAlias, Union +from pathlib import Path +from typing import Any from astrbot.core.agent.message import ContentPart, Message from astrbot.core.agent.tool import ToolSet @@ -15,13 +16,9 @@ from astrbot.core.provider.register import provider_cls_map from astrbot.core.utils.astrbot_path import get_astrbot_path -Providers: TypeAlias = Union[ - "Provider", - "STTProvider", - "TTSProvider", - "EmbeddingProvider", - "RerankProvider", -] +type Providers = ( + "Provider" | "STTProvider" | "TTSProvider" | "EmbeddingProvider" | "RerankProvider" +) class AbstractProvider(abc.ABC): @@ -188,10 +185,13 @@ def _ensure_message_to_dicts( return dicts - async def test(self, timeout: float = 45.0) -> None: + async def test(self, timeout_seconds: float = 45.0, **kwargs: Any) -> None: + legacy_timeout = kwargs.pop("timeout", None) + if legacy_timeout is not None: + timeout_seconds = float(legacy_timeout) await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), - timeout=timeout, + timeout=timeout_seconds, ) @@ -268,8 +268,9 @@ async def get_audio_stream( # 调用原有的 get_audio 方法获取音频文件路径 audio_path = await self.get_audio(accumulated_text) # 读取音频文件内容 - with open(audio_path, "rb") as f: - audio_data = f.read() + audio_data = await asyncio.to_thread( + Path(audio_path).read_bytes + ) await audio_queue.put((accumulated_text, audio_data)) except Exception: # 出错时也要发送 None 结束标记 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index ec3c395a46..7f7a51859f 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,6 +1,8 @@ +import asyncio import base64 import json from collections.abc import AsyncGenerator +from pathlib import Path import anthropic import httpx @@ -637,11 +639,10 @@ async def encode_image_bs64(self, image_url: str) -> tuple[str, str]: except Exception: mime_type = "image/jpeg" return f"data:{mime_type};base64,{raw_base64}", mime_type - with open(image_url, "rb") as f: - image_bytes = f.read() - mime_type = self._detect_image_mime_type(image_bytes) - image_bs64 = base64.b64encode(image_bytes).decode("utf-8") - return f"data:{mime_type};base64,{image_bs64}", mime_type + image_bytes = await asyncio.to_thread(Path(image_url).read_bytes) + mime_type = self._detect_image_mime_type(image_bytes) + image_bs64 = base64.b64encode(image_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_bs64}", mime_type return "", "image/jpeg" def get_current_key(self) -> str: diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 9b6816859f..bd12f37b91 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,6 +3,7 @@ import logging import os import uuid +from pathlib import Path import aiohttp import dashscope @@ -59,8 +60,7 @@ async def get_audio(self, text: str) -> str: ) path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}{ext}") - with open(path, "wb") as f: - f.write(audio_bytes) + await asyncio.to_thread(Path(path).write_bytes, audio_bytes) return path def _call_qwen_tts(self, model: str, text: str): @@ -129,7 +129,7 @@ async def _download_audio_from_url(self, url: str) -> bytes | None: ) as response, ): return await response.read() - except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e: + except (TimeoutError, aiohttp.ClientError, OSError) as e: logging.exception(f"Failed to download audio from URL {url}: {e}") return None diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 503bd275b4..147c925ecf 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -1,126 +1,129 @@ -import asyncio -import os -import subprocess -import uuid - -import edge_tts - -from astrbot.core import logger -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - -""" -edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 -``` -pip install edge_tts -``` -Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot -""" - - -@register_provider_adapter( - "edge_tts", - "Microsoft Edge TTS", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderEdgeTTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - - # 设置默认语音,如果没有指定则使用中文小萱 - self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") - self.rate = provider_config.get("rate") - self.volume = provider_config.get("volume") - self.pitch = provider_config.get("pitch") - self.timeout = provider_config.get("timeout", 30) - - self.proxy = os.getenv("https_proxy", None) - - self.set_model("edge_tts") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") - wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") - - # 构建 Edge TTS 参数 - kwargs = {"text": text, "voice": self.voice} - if self.rate: - kwargs["rate"] = self.rate - if self.volume: - kwargs["volume"] = self.volume - if self.pitch: - kwargs["pitch"] = self.pitch - - try: - communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) - await communicate.save(mp3_path) - - try: - from pyffmpeg import FFmpeg - - ff = FFmpeg() - ff.convert(input_file=mp3_path, output_file=wav_path) - except Exception as e: - logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") - # use ffmpeg command line - - # 使用ffmpeg将MP3转换为标准WAV格式 - p = await asyncio.create_subprocess_exec( - "ffmpeg", - "-y", # 覆盖输出文件 - "-i", - mp3_path, # 输入文件 - "-acodec", - "pcm_s16le", # 16位PCM编码 - "-ar", - "24000", # 采样率24kHz (适合微信语音) - "-ac", - "1", # 单声道 - "-af", - "apad=pad_dur=2", # 确保输出时长准确 - "-fflags", - "+genpts", # 强制生成时间戳 - "-hide_banner", # 隐藏版本信息 - wav_path, # 输出文件 - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # 等待进程完成并获取输出 - stdout, stderr = await p.communicate() - logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") - logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") - logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") - - os.remove(mp3_path) - if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0: - return wav_path - logger.error("生成的WAV文件不存在或为空") - raise RuntimeError("生成的WAV文件不存在或为空") - - except subprocess.CalledProcessError as e: - logger.error( - f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", - ) - try: - if os.path.exists(mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"FFmpeg 转换失败: {e!s}") - - except Exception as e: - logger.error(f"音频生成失败: {e!s}") - try: - if os.path.exists(mp3_path): - os.remove(mp3_path) - except Exception: - pass - raise RuntimeError(f"音频生成失败: {e!s}") +import asyncio +import os +import subprocess +import uuid + +import edge_tts + +from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + +""" +edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 +``` +pip install edge_tts +``` +Windows 如果提示找不到指定文件,以管理员身份运行命令行窗口,然后再次运行 AstrBot +""" + + +@register_provider_adapter( + "edge_tts", + "Microsoft Edge TTS", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderEdgeTTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + + # 设置默认语音,如果没有指定则使用中文小萱 + self.voice = provider_config.get("edge-tts-voice", "zh-CN-XiaoxiaoNeural") + self.rate = provider_config.get("rate") + self.volume = provider_config.get("volume") + self.pitch = provider_config.get("pitch") + self.timeout = provider_config.get("timeout", 30) + + self.proxy = os.getenv("https_proxy", None) + + self.set_model("edge_tts") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") + wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") + + # 构建 Edge TTS 参数 + kwargs = {"text": text, "voice": self.voice} + if self.rate: + kwargs["rate"] = self.rate + if self.volume: + kwargs["volume"] = self.volume + if self.pitch: + kwargs["pitch"] = self.pitch + + try: + communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs) + await communicate.save(mp3_path) + + try: + from pyffmpeg import FFmpeg + + ff = FFmpeg() + ff.convert(input_file=mp3_path, output_file=wav_path) + except Exception as e: + logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") + # use ffmpeg command line + + # 使用ffmpeg将MP3转换为标准WAV格式 + p = await asyncio.create_subprocess_exec( + "ffmpeg", + "-y", # 覆盖输出文件 + "-i", + mp3_path, # 输入文件 + "-acodec", + "pcm_s16le", # 16位PCM编码 + "-ar", + "24000", # 采样率24kHz (适合微信语音) + "-ac", + "1", # 单声道 + "-af", + "apad=pad_dur=2", # 确保输出时长准确 + "-fflags", + "+genpts", # 强制生成时间戳 + "-hide_banner", # 隐藏版本信息 + wav_path, # 输出文件 + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # 等待进程完成并获取输出 + stdout, stderr = await p.communicate() + logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}") + logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}") + logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}") + + os.remove(mp3_path) + if ( + await asyncio.to_thread(os.path.exists, wav_path) + and await asyncio.to_thread(os.path.getsize, wav_path) > 0 + ): + return wav_path + logger.error("生成的WAV文件不存在或为空") + raise RuntimeError("生成的WAV文件不存在或为空") + + except subprocess.CalledProcessError as e: + logger.error( + f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}", + ) + try: + if await asyncio.to_thread(os.path.exists, mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"FFmpeg 转换失败: {e!s}") + + except Exception as e: + logger.error(f"音频生成失败: {e!s}") + try: + if await asyncio.to_thread(os.path.exists, mp3_path): + os.remove(mp3_path) + except Exception: + pass + raise RuntimeError(f"音频生成失败: {e!s}") diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 35945b7b6f..c1b62ef399 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -1,6 +1,8 @@ +import asyncio import os import re import uuid +from pathlib import Path from typing import Annotated, Literal import ormsgpack @@ -159,9 +161,10 @@ async def get_audio(self, text: str) -> str: if response.status_code == 200 and response.headers.get( "content-type", "" ).startswith("audio/"): - with open(path, "wb") as f: - async for chunk in response.aiter_bytes(): - f.write(chunk) + audio_data = bytearray() + async for chunk in response.aiter_bytes(): + audio_data.extend(chunk) + await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data)) return path error_bytes = await response.aread() error_text = error_bytes.decode("utf-8", errors="replace")[:1024] diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9557f3dbcd..25e5c6e3a7 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,6 +4,7 @@ import logging import random from collections.abc import AsyncGenerator +from pathlib import Path from typing import cast from google import genai @@ -924,9 +925,10 @@ async def encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + image_bs64 = base64.b64encode( + await asyncio.to_thread(Path(image_url).read_bytes) + ).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 async def terminate(self) -> None: if self.client: diff --git a/astrbot/core/provider/sources/genie_tts.py b/astrbot/core/provider/sources/genie_tts.py index 8f9b6d91d7..62b4b3f81d 100644 --- a/astrbot/core/provider/sources/genie_tts.py +++ b/astrbot/core/provider/sources/genie_tts.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from pathlib import Path from astrbot.core import logger from astrbot.core.provider.entities import ProviderType @@ -72,7 +73,7 @@ def _generate(save_path: str) -> None: try: await loop.run_in_executor(None, _generate, path) - if os.path.exists(path): + if await asyncio.to_thread(os.path.exists, path): return path raise RuntimeError("Genie TTS did not save to file.") @@ -109,9 +110,8 @@ def _generate(save_path: str, t: str) -> None: await loop.run_in_executor(None, _generate, path, text) - if os.path.exists(path): - with open(path, "rb") as f: - audio_data = f.read() + if await asyncio.to_thread(os.path.exists, path): + audio_data = await asyncio.to_thread(Path(path).read_bytes) # Put (text, bytes) into queue so frontend can display text await audio_queue.put((text, audio_data)) diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index fc8bccea84..a9ebfe9a66 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from pathlib import Path import aiohttp @@ -129,8 +130,7 @@ async def get_audio(self, text: str) -> str: result = await self._make_request(endpoint, params) if isinstance(result, bytes): - with open(path, "wb") as f: - f.write(result) + await asyncio.to_thread(Path(path).write_bytes, result) return path raise Exception(f"[GSV TTS] 合成失败,输入文本:{text},错误信息:{result}") diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index 425e801f46..f92485b722 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,59 +1,62 @@ -import os -import urllib.parse -import uuid - -import aiohttp - -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..entities import ProviderType -from ..provider import TTSProvider -from ..register import register_provider_adapter - - -@register_provider_adapter( - "gsvi_tts_api", - "GSVI TTS API", - provider_type=ProviderType.TEXT_TO_SPEECH, -) -class ProviderGSVITTS(TTSProvider): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - super().__init__(provider_config, provider_settings) - self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") - self.api_base = self.api_base.removesuffix("/") - self.character = provider_config.get("character") - self.emotion = provider_config.get("emotion") - - async def get_audio(self, text: str) -> str: - temp_dir = get_astrbot_temp_path() - path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") - params = {"text": text} - - if self.character: - params["character"] = self.character - if self.emotion: - params["emotion"] = self.emotion - - query_parts = [] - for key, value in params.items(): - encoded_value = urllib.parse.quote(str(value)) - query_parts.append(f"{key}={encoded_value}") - - url = f"{self.api_base}/tts?{'&'.join(query_parts)}" - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - with open(path, "wb") as f: - f.write(await response.read()) - else: - error_text = await response.text() - raise Exception( - f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", - ) - - return path +import asyncio +import os +import urllib.parse +import uuid +from pathlib import Path + +import aiohttp + +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..entities import ProviderType +from ..provider import TTSProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "gsvi_tts_api", + "GSVI TTS API", + provider_type=ProviderType.TEXT_TO_SPEECH, +) +class ProviderGSVITTS(TTSProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + self.api_base = provider_config.get("api_base", "http://127.0.0.1:5000") + self.api_base = self.api_base.removesuffix("/") + self.character = provider_config.get("character") + self.emotion = provider_config.get("emotion") + + async def get_audio(self, text: str) -> str: + temp_dir = get_astrbot_temp_path() + path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") + params = {"text": text} + + if self.character: + params["character"] = self.character + if self.emotion: + params["emotion"] = self.emotion + + query_parts = [] + for key, value in params.items(): + encoded_value = urllib.parse.quote(str(value)) + query_parts.append(f"{key}={encoded_value}") + + url = f"{self.api_base}/tts?{'&'.join(query_parts)}" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + await asyncio.to_thread( + Path(path).write_bytes, await response.read() + ) + else: + error_text = await response.text() + raise Exception( + f"GSVI TTS API 请求失败,状态码: {response.status},错误: {error_text}", + ) + + return path diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 69860111cf..ad2e345367 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -1,7 +1,9 @@ +import asyncio import json import os import uuid from collections.abc import AsyncIterator +from pathlib import Path import aiohttp @@ -155,8 +157,7 @@ async def get_audio(self, text: str) -> str: audio = await self._audio_play(audio_stream) # 结果保存至文件 - with open(path, "wb") as file: - file.write(audio) + await asyncio.to_thread(Path(path).write_bytes, audio) return path diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index adee24073d..3f0c007b36 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -5,6 +5,7 @@ import random import re from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any import httpx @@ -949,9 +950,10 @@ async def encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + image_bs64 = base64.b64encode( + await asyncio.to_thread(Path(image_url).read_bytes) + ).decode("utf-8") + return "data:image/jpeg;base64," + image_bs64 async def terminate(self): if self.client: diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 217b189251..35ac1d5a8c 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,5 +1,7 @@ +import asyncio import os import uuid +from pathlib import Path import httpx from openai import NOT_GIVEN, AsyncOpenAI @@ -54,9 +56,10 @@ async def get_audio(self, text: str) -> str: response_format="wav", input=text, ) as response: - with open(path, "wb") as f: - async for chunk in response.iter_bytes(chunk_size=1024): - f.write(chunk) + audio_data = bytearray() + async for chunk in response.iter_bytes(chunk_size=1024): + audio_data.extend(chunk) + await asyncio.to_thread(Path(path).write_bytes, bytes(audio_data)) return path async def terminate(self): diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index af6c0f631e..b776657968 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -53,14 +53,12 @@ async def initialize(self) -> None: async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) return str(temp_dir / timestamp) async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" - with open(file_path, "rb") as f: - file_header = f.read(8) - + file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] if silk_header in file_header: return True return False @@ -76,7 +74,7 @@ async def get_text(self, audio_url: str) -> str: await download_file(audio_url, path) audio_url = path - if not os.path.isfile(audio_url): + if not await asyncio.to_thread(os.path.isfile, audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith((".amr", ".silk")) or is_tencent: diff --git a/astrbot/core/provider/sources/volcengine_tts.py b/astrbot/core/provider/sources/volcengine_tts.py index 349815907d..5082200717 100644 --- a/astrbot/core/provider/sources/volcengine_tts.py +++ b/astrbot/core/provider/sources/volcengine_tts.py @@ -4,6 +4,7 @@ import os import traceback import uuid +from pathlib import Path import aiohttp @@ -100,10 +101,9 @@ async def get_audio(self, text: str) -> str: f"volcengine_tts_{uuid.uuid4()}.mp3", ) - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, - lambda: open(file_path, "wb").write(audio_data), + await asyncio.to_thread( + Path(file_path).write_bytes, + audio_data, ) return file_path diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 386da063db..00c87075db 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,5 +1,7 @@ +import asyncio import os import uuid +from pathlib import Path from openai import NOT_GIVEN, AsyncOpenAI @@ -44,8 +46,7 @@ async def _get_audio_format(self, file_path) -> str | None: amr_header = b"#!AMR" try: - with open(file_path, "rb") as f: - file_header = f.read(8) + file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] except FileNotFoundError: return None @@ -73,7 +74,7 @@ async def get_text(self, audio_url: str) -> str: await download_file(audio_url, path) audio_url = path - if not os.path.exists(audio_url): + if not await asyncio.to_thread(os.path.exists, audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: @@ -100,13 +101,14 @@ async def get_text(self, audio_url: str) -> str: audio_url = output_path + audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes) result = await self.client.audio.transcriptions.create( model=self.model_name, - file=("audio.wav", open(audio_url, "rb")), + file=("audio.wav", audio_bytes), ) # remove temp file - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(audio_url) except Exception as e: diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 678deb9481..d85c84f9be 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from pathlib import Path from typing import cast import whisper @@ -42,9 +43,7 @@ async def initialize(self) -> None: async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" - with open(file_path, "rb") as f: - file_header = f.read(8) - + file_header = (await asyncio.to_thread(Path(file_path).read_bytes))[:8] if silk_header in file_header: return True return False @@ -66,7 +65,7 @@ async def get_text(self, audio_url: str) -> str: await download_file(audio_url, path) audio_url = path - if not os.path.exists(audio_url): + if not await asyncio.to_thread(os.path.exists, audio_url): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 0a22e456ed..7b6068dd85 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -1,5 +1,7 @@ +import asyncio import os import uuid +from pathlib import Path import aiohttp from xinference_client.client.restful.async_restful_client import ( @@ -102,9 +104,8 @@ async def get_text(self, audio_url: str) -> str: f"Failed to download audio from {audio_url}, status: {resp.status}", ) return "" - elif os.path.exists(audio_url): - with open(audio_url, "rb") as f: - audio_bytes = f.read() + elif await asyncio.to_thread(os.path.exists, audio_url): + audio_bytes = await asyncio.to_thread(Path(audio_url).read_bytes) else: logger.error(f"File not found: {audio_url}") return "" @@ -143,8 +144,7 @@ async def get_text(self, audio_url: str) -> str: ) temp_files.extend([input_path, output_path]) - with open(input_path, "wb") as f: - f.write(audio_bytes) + await asyncio.to_thread(Path(input_path).write_bytes, audio_bytes) if conversion_type == "silk": logger.info("Converting silk to wav ...") @@ -153,8 +153,7 @@ async def get_text(self, audio_url: str) -> str: logger.info("Converting amr to wav ...") await convert_to_pcm_wav(input_path, output_path) - with open(output_path, "rb") as f: - audio_bytes = f.read() + audio_bytes = await asyncio.to_thread(Path(output_path).read_bytes) # 4. Transcribe # 官方asyncCLient的客户端似乎实现有点问题,这里直接用aiohttp实现openai标准兼容请求,提交issue等待官方修复后再改回来 @@ -199,7 +198,7 @@ async def get_text(self, audio_url: str) -> str: # 5. Cleanup for temp_file in temp_files: try: - if os.path.exists(temp_file): + if await asyncio.to_thread(os.path.exists, temp_file): os.remove(temp_file) logger.debug(f"Removed temporary file: {temp_file}") except Exception as e: diff --git a/astrbot/core/skills/neo_skill_sync.py b/astrbot/core/skills/neo_skill_sync.py index 5fe2b7832d..2bb4c50f8b 100644 --- a/astrbot/core/skills/neo_skill_sync.py +++ b/astrbot/core/skills/neo_skill_sync.py @@ -5,7 +5,7 @@ import os import re from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -19,7 +19,7 @@ def _now_iso() -> str: - return datetime.now(timezone.utc).isoformat() + return datetime.now(UTC).isoformat() def _to_jsonable(model_like: Any) -> dict[str, Any]: diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index d15876526d..a24ddac9ea 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -7,7 +7,7 @@ import tempfile import zipfile from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path, PurePosixPath from astrbot.core.utils.astrbot_path import ( @@ -175,7 +175,7 @@ def _load_sandbox_skills_cache(self) -> dict: def _save_sandbox_skills_cache(self, cache: dict) -> None: cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION - cache["updated_at"] = datetime.now(timezone.utc).isoformat() + cache["updated_at"] = datetime.now(UTC).isoformat() with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f: json.dump(cache, f, ensure_ascii=False, indent=2) diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index d28ac726ae..f6afc08e1c 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -3,7 +3,7 @@ import enum from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeVar, overload +from typing import Any, Literal, TypeVar, overload from .filter import HandlerFilter from .star import star_map @@ -11,7 +11,7 @@ T = TypeVar("T", bound="StarHandlerMetadata") -class StarHandlerRegistry(Generic[T]): +class StarHandlerRegistry[T: "StarHandlerMetadata"]: def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] @@ -227,7 +227,7 @@ class EventType(enum.Enum): @dataclass -class StarHandlerMetadata(Generic[H]): +class StarHandlerMetadata[H: Callable[..., Any]]: """描述一个 Star 所注册的某一个 Handler。""" event_type: EventType diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 68c58fdae5..c5fa63bee7 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -8,6 +8,7 @@ import os import sys import traceback +from pathlib import Path from types import ModuleType import yaml @@ -188,7 +189,7 @@ async def _check_plugin_dept_update( 如果 target_plugin 为 None,则检查所有插件的依赖 """ plugin_dir = self.plugin_store_path - if not os.path.exists(plugin_dir): + if not await asyncio.to_thread(os.path.exists, plugin_dir): return False to_update = [] if target_plugin: @@ -198,7 +199,9 @@ async def _check_plugin_dept_update( to_update.append(p.root_dir_name) for p in to_update: plugin_path = os.path.join(plugin_dir, p) - if os.path.exists(os.path.join(plugin_path, "requirements.txt")): + if await asyncio.to_thread( + os.path.exists, os.path.join(plugin_path, "requirements.txt") + ): pth = os.path.join(plugin_path, "requirements.txt") logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}") try: @@ -217,7 +220,7 @@ async def _import_plugin_with_dependency_recovery( try: return __import__(path, fromlist=[module_str]) except (ModuleNotFoundError, ImportError) as import_exc: - if os.path.exists(requirements_path): + if await asyncio.to_thread(os.path.exists, requirements_path): try: logger.info( f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}" @@ -651,16 +654,19 @@ async def load( plugin_dir_path, self.conf_schema_fname, ) - if os.path.exists(plugin_schema_path): + if await asyncio.to_thread(os.path.exists, plugin_schema_path): # 加载插件配置 - with open(plugin_schema_path, encoding="utf-8") as f: - plugin_config = AstrBotConfig( - config_path=os.path.join( - self.plugin_config_path, - f"{root_dir_name}_config.json", - ), - schema=json.loads(f.read()), - ) + plugin_schema_text = await asyncio.to_thread( + Path(plugin_schema_path).read_text, + encoding="utf-8", + ) + plugin_config = AstrBotConfig( + config_path=os.path.join( + self.plugin_config_path, + f"{root_dir_name}_config.json", + ), + schema=json.loads(plugin_schema_text), + ) logo_path = os.path.join(plugin_dir_path, self.logo_fname) if path in star_map: @@ -836,7 +842,7 @@ async def load( metadata.activated = False # Plugin logo path - if os.path.exists(logo_path): + if await asyncio.to_thread(os.path.exists, logo_path): metadata.logo_path = logo_path assert metadata.module_path, f"插件 {metadata.name} 模块路径为空" @@ -955,7 +961,7 @@ async def _cleanup_failed_plugin_install( except Exception: logger.warning(traceback.format_exc()) - if os.path.exists(plugin_path): + if await asyncio.to_thread(os.path.exists, plugin_path): try: remove_dir(plugin_path) logger.warning(f"已清理安装失败的插件目录: {plugin_path}") @@ -968,7 +974,7 @@ async def _cleanup_failed_plugin_install( self.plugin_config_path, f"{dir_name}_config.json", ) - if os.path.exists(plugin_config_path): + if await asyncio.to_thread(os.path.exists, plugin_config_path): try: os.remove(plugin_config_path) logger.warning(f"已清理安装失败插件配置: {plugin_config_path}") @@ -1100,13 +1106,14 @@ async def install_plugin( # Extract README.md content if exists readme_content = None readme_path = os.path.join(plugin_path, "README.md") - if not os.path.exists(readme_path): + if not await asyncio.to_thread(os.path.exists, readme_path): readme_path = os.path.join(plugin_path, "readme.md") - if os.path.exists(readme_path): + if await asyncio.to_thread(os.path.exists, readme_path): try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = await asyncio.to_thread( + Path(readme_path).read_text, encoding="utf-8" + ) except Exception as e: logger.warning( f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}", @@ -1211,7 +1218,7 @@ async def uninstall_failed_plugin( self._cleanup_plugin_state(dir_name) plugin_path = os.path.join(self.plugin_store_path, dir_name) - if os.path.exists(plugin_path): + if await asyncio.to_thread(os.path.exists, plugin_path): try: remove_dir(plugin_path) except Exception as e: @@ -1498,13 +1505,14 @@ async def install_plugin_from_file( # Extract README.md content if exists readme_content = None readme_path = os.path.join(desti_dir, "README.md") - if not os.path.exists(readme_path): + if not await asyncio.to_thread(os.path.exists, readme_path): readme_path = os.path.join(desti_dir, "readme.md") - if os.path.exists(readme_path): + if await asyncio.to_thread(os.path.exists, readme_path): try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = await asyncio.to_thread( + Path(readme_path).read_text, encoding="utf-8" + ) except Exception as e: logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}") diff --git a/astrbot/core/utils/datetime_utils.py b/astrbot/core/utils/datetime_utils.py index 97b8196dde..431c9cd50c 100644 --- a/astrbot/core/utils/datetime_utils.py +++ b/astrbot/core/utils/datetime_utils.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime def normalize_datetime_utc(dt: datetime | None) -> datetime | None: @@ -9,8 +9,8 @@ def normalize_datetime_utc(dt: datetime | None) -> datetime | None: if dt is None: return None if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: - return dt.replace(tzinfo=timezone.utc) - return dt.astimezone(timezone.utc) + return dt.replace(tzinfo=UTC) + return dt.astimezone(UTC) def to_utc_isoformat(dt: datetime | None) -> str | None: diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b565926749..d169dd32b6 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,3 +1,4 @@ +import asyncio import base64 import logging import os @@ -58,8 +59,7 @@ def save_temp_img(img: Image.Image | bytes) -> str: if isinstance(img, Image.Image): img.save(p) else: - with open(p, "wb") as f: - f.write(img) + Path(p).write_bytes(img) return p @@ -83,15 +83,13 @@ async def download_image_by_url( async with session.post(url, json=post_data) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path else: async with session.get(url) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证(仅在证书验证失败时作为fallback) @@ -109,15 +107,13 @@ async def download_image_by_url( async with session.post(url, json=post_data, ssl=ssl_context) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path else: async with session.get(url, ssl=ssl_context) as resp: if not path: return save_temp_img(await resp.read()) - with open(path, "wb") as f: - f.write(await resp.read()) + await asyncio.to_thread(Path(path).write_bytes, await resp.read()) return path except Exception as e: raise e @@ -142,12 +138,13 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, "wb") as f: + file_obj = await asyncio.to_thread(Path(path).open, "wb") + try: while True: chunk = await resp.content.read(8192) if not chunk: break - f.write(chunk) + await asyncio.to_thread(file_obj.write, chunk) downloaded_size += len(chunk) if show_progress: elapsed_time = ( @@ -160,6 +157,8 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end="", ) + finally: + await asyncio.to_thread(file_obj.close) except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError): # 关闭SSL验证(仅在证书验证失败时作为fallback) logger.warning( @@ -181,12 +180,13 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non start_time = time.time() if show_progress: print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}") - with open(path, "wb") as f: + file_obj = await asyncio.to_thread(Path(path).open, "wb") + try: while True: chunk = await resp.content.read(8192) if not chunk: break - f.write(chunk) + await asyncio.to_thread(file_obj.write, chunk) downloaded_size += len(chunk) if show_progress: elapsed_time = time.time() - start_time @@ -195,14 +195,15 @@ async def download_file(url: str, path: str, show_progress: bool = False) -> Non f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end="", ) + finally: + await asyncio.to_thread(file_obj.close) if show_progress: print() -def file_to_base64(file_path: str) -> str: - with open(file_path, "rb") as f: - data_bytes = f.read() - base64_str = base64.b64encode(data_bytes).decode() +async def file_to_base64(file_path: str) -> str: + data_bytes = await asyncio.to_thread(Path(file_path).read_bytes) + base64_str = base64.b64encode(data_bytes).decode() return "base64://" + base64_str @@ -221,17 +222,18 @@ def get_local_ip_addresses(): async def get_dashboard_version(): # First check user data directory (manually updated / downloaded dashboard). dist_dir = os.path.join(get_astrbot_data_path(), "dist") - if not os.path.exists(dist_dir): + if not await asyncio.to_thread(os.path.exists, dist_dir): # Fall back to the dist bundled inside the installed wheel. _bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist" - if _bundled.exists(): + if await asyncio.to_thread(_bundled.exists): dist_dir = str(_bundled) - if os.path.exists(dist_dir): + if await asyncio.to_thread(os.path.exists, dist_dir): version_file = os.path.join(dist_dir, "assets", "version") - if os.path.exists(version_file): - with open(version_file, encoding="utf-8") as f: - v = f.read().strip() - return v + if await asyncio.to_thread(os.path.exists, version_file): + v = ( + await asyncio.to_thread(Path(version_file).read_text, encoding="utf-8") + ).strip() + return v return None @@ -244,9 +246,12 @@ async def download_dashboard( ) -> None: """下载管理面板文件""" if path is None: - zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip" + zip_path = ( + await asyncio.to_thread(Path(get_astrbot_data_path()).absolute) + / "dashboard.zip" + ) else: - zip_path = Path(path).absolute() + zip_path = await asyncio.to_thread(Path(path).absolute) if latest or len(str(version)) != 40: ver_name = "latest" if latest else version diff --git a/astrbot/core/utils/media_utils.py b/astrbot/core/utils/media_utils.py index 8d833514fb..7ecebcad43 100644 --- a/astrbot/core/utils/media_utils.py +++ b/astrbot/core/utils/media_utils.py @@ -108,7 +108,7 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) if process.returncode != 0: # 清理可能已生成但无效的临时文件 - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) logger.debug( @@ -183,7 +183,7 @@ async def convert_video_format( if process.returncode != 0: # 清理可能已生成但无效的临时文件 - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) logger.debug( @@ -231,7 +231,7 @@ async def convert_audio_format( if output_path is None: temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}") args = ["ffmpeg", "-y", "-i", audio_path] @@ -249,7 +249,7 @@ async def convert_audio_format( ) _, stderr = await process.communicate() if process.returncode != 0: - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) except OSError as e: @@ -287,7 +287,7 @@ async def extract_video_cover( """从视频中提取封面图(JPG)。""" if output_path is None: temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(temp_dir.mkdir, parents=True, exist_ok=True) output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg") try: @@ -306,7 +306,7 @@ async def extract_video_cover( ) _, stderr = await process.communicate() if process.returncode != 0: - if output_path and os.path.exists(output_path): + if output_path and await asyncio.to_thread(os.path.exists, output_path): try: os.remove(output_path) except OSError as e: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index b327a61843..a6c62c4952 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -71,11 +71,11 @@ def keep(self, timeout: float = 0, reset_timeout=False) -> None: asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: float) -> None: + async def _holding(self, event: asyncio.Event, timeout_seconds: float) -> None: """等待事件结束或超时""" try: - await asyncio.wait_for(event.wait(), timeout) - except asyncio.TimeoutError: + await asyncio.wait_for(event.wait(), timeout_seconds) + except TimeoutError: if not self.future.done(): self.future.set_exception(TimeoutError("等待超时")) except asyncio.CancelledError: @@ -124,14 +124,14 @@ def __init__( async def register_wait( self, handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], - timeout: int = 30, + timeout_seconds: int = 30, ) -> Any: """等待外部输入并处理""" self.handler = handler USER_SESSIONS[self.session_id] = self # 开始一个会话保持事件 - self.session_controller.keep(timeout, reset_timeout=True) + self.session_controller.keep(timeout_seconds, reset_timeout=True) try: return await self.session_controller.future diff --git a/astrbot/core/utils/temp_dir_cleaner.py b/astrbot/core/utils/temp_dir_cleaner.py index c0c0600982..668ee45135 100644 --- a/astrbot/core/utils/temp_dir_cleaner.py +++ b/astrbot/core/utils/temp_dir_cleaner.py @@ -141,7 +141,7 @@ async def run(self) -> None: self._stop_event.wait(), timeout=self.CHECK_INTERVAL_SECONDS, ) - except asyncio.TimeoutError: + except TimeoutError: continue logger.info("TempDirCleaner stopped.") diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index f342484bdb..1abd6d1c0a 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -5,6 +5,7 @@ import tempfile import wave from io import BytesIO +from pathlib import Path from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path @@ -13,19 +14,18 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str: import pysilk - with open(silk_path, "rb") as f: - input_data = f.read() - if input_data.startswith(b"\x02"): - input_data = input_data[1:] - input_io = BytesIO(input_data) - output_io = BytesIO() - pysilk.decode(input_io, output_io, 24000) - output_io.seek(0) - with wave.open(output_path, "wb") as wav: - wav.setnchannels(1) - wav.setsampwidth(2) - wav.setframerate(24000) - wav.writeframes(output_io.read()) + input_data = await asyncio.to_thread(Path(silk_path).read_bytes) + if input_data.startswith(b"\x02"): + input_data = input_data[1:] + input_io = BytesIO(input_data) + output_io = BytesIO() + pysilk.decode(input_io, output_io, 24000) + output_io.seek(0) + with wave.open(output_path, "wb") as wav: + wav.setnchannels(1) + wav.setsampwidth(2) + wav.setframerate(24000) + wav.writeframes(output_io.read()) return output_path @@ -97,7 +97,10 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: logger.debug(f"[FFmpeg] stderr: {stderr.decode().strip()}") logger.info(f"[FFmpeg] return code: {p.returncode}") - if os.path.exists(output_path) and os.path.getsize(output_path) > 0: + if ( + await asyncio.to_thread(os.path.exists, output_path) + and await asyncio.to_thread(os.path.getsize, output_path) > 0 + ): return output_path raise RuntimeError("生成的WAV文件不存在或为空") @@ -156,13 +159,12 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]: tencent=True, ) - with open(silk_path, "rb") as f: - silk_bytes = await asyncio.to_thread(f.read) - silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") + silk_bytes = await asyncio.to_thread(Path(silk_path).read_bytes) + silk_b64 = base64.b64encode(silk_bytes).decode("utf-8") return silk_b64, duration # 已是秒 finally: - if os.path.exists(wav_path) and wav_path != audio_path: + if await asyncio.to_thread(os.path.exists, wav_path) and wav_path != audio_path: os.remove(wav_path) - if os.path.exists(silk_path): + if await asyncio.to_thread(os.path.exists, silk_path): os.remove(silk_path) diff --git a/astrbot/dashboard/routes/api_key.py b/astrbot/dashboard/routes/api_key.py index 4b957fe8ea..6d89de910c 100644 --- a/astrbot/dashboard/routes/api_key.py +++ b/astrbot/dashboard/routes/api_key.py @@ -1,6 +1,6 @@ import hashlib import secrets -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from quart import g, request @@ -59,7 +59,7 @@ def _serialize_api_key(key) -> dict: "expires_at": ApiKeyRoute._serialize_datetime(key.expires_at), "revoked_at": ApiKeyRoute._serialize_datetime(key.revoked_at), "is_revoked": key.revoked_at is not None, - "is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)), + "is_expired": bool(expires_at and expires_at < datetime.now(UTC)), } async def list_api_keys(self): @@ -98,9 +98,7 @@ async def create_api_key(self): return ( Response().error("expires_in_days must be greater than 0").__dict__ ) - expires_at = datetime.now(timezone.utc) + timedelta( - days=expires_in_days_int - ) + expires_at = datetime.now(UTC) + timedelta(days=expires_in_days_int) raw_key = f"abk_{secrets.token_urlsafe(32)}" key_hash = self._hash_key(raw_key) diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index 40db1f60bd..f9bdc51d8f 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -82,7 +82,7 @@ async def edit_account(self): def generate_jwt(self, username): payload = { "username": username, - "exp": datetime.datetime.utcnow() + datetime.timedelta(days=7), + "exp": datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=7), } jwt_token = self.config["dashboard"].get("jwt_secret", None) if not jwt_token: diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index 952806beb7..674bbbfdda 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -32,6 +32,18 @@ UPLOAD_EXPIRE_SECONDS = 3600 # 上传会话过期时间(1小时) +def _merge_backup_chunks(output_path: str, chunk_dir: str, total: int) -> None: + with open(output_path, "wb") as outfile: + for i in range(total): + chunk_path = os.path.join(chunk_dir, f"{i}.part") + with open(chunk_path, "rb") as chunk_file: + while True: + data_block = chunk_file.read(8192) + if not data_block: + break + outfile.write(data_block) + + def secure_filename(filename: str) -> str: """清洗文件名,移除路径遍历字符和危险字符 @@ -240,7 +252,7 @@ async def _cleanup_upload_session(self, upload_id: str) -> None: if upload_id in self.upload_sessions: session = self.upload_sessions[upload_id] chunk_dir = session.get("chunk_dir") - if chunk_dir and os.path.exists(chunk_dir): + if chunk_dir and await asyncio.to_thread(os.path.exists, chunk_dir): try: shutil.rmtree(chunk_dir) except Exception as e: @@ -283,7 +295,9 @@ async def list_backups(self): page_size = request.args.get("page_size", 20, type=int) # 确保备份目录存在 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + Path(self.backup_dir).mkdir, parents=True, exist_ok=True + ) # 获取所有备份文件 backup_files = [] @@ -293,7 +307,7 @@ async def list_backups(self): continue file_path = os.path.join(self.backup_dir, filename) - if not os.path.isfile(file_path): + if not await asyncio.to_thread(os.path.isfile, file_path): continue # 读取 manifest.json 获取备份信息 @@ -403,7 +417,7 @@ async def _background_export_task(self, task_id: str) -> None: result={ "filename": os.path.basename(zip_path), "path": zip_path, - "size": os.path.getsize(zip_path), + "size": await asyncio.to_thread(os.path.getsize, zip_path), }, ) except Exception as e: @@ -437,7 +451,9 @@ async def upload_backup(self): unique_filename = generate_unique_filename(safe_filename) # 保存上传的文件 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + Path(self.backup_dir).mkdir, parents=True, exist_ok=True + ) zip_path = os.path.join(self.backup_dir, unique_filename) await file.save(zip_path) @@ -451,7 +467,7 @@ async def upload_backup(self): { "filename": unique_filename, "original_filename": file.filename, - "size": os.path.getsize(zip_path), + "size": await asyncio.to_thread(os.path.getsize, zip_path), } ) .__dict__ @@ -499,7 +515,7 @@ async def upload_init(self): # 创建分片存储目录 chunk_dir = os.path.join(self.chunks_dir, upload_id) - Path(chunk_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(Path(chunk_dir).mkdir, parents=True, exist_ok=True) # 清洗文件名 safe_filename = secure_filename(filename) @@ -685,22 +701,20 @@ async def upload_complete(self): chunk_dir = session["chunk_dir"] filename = session["filename"] - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + Path(self.backup_dir).mkdir, parents=True, exist_ok=True + ) output_path = os.path.join(self.backup_dir, filename) try: - with open(output_path, "wb") as outfile: - for i in range(total): - chunk_path = os.path.join(chunk_dir, f"{i}.part") - with open(chunk_path, "rb") as chunk_file: - # 分块读取,避免内存溢出 - while True: - data_block = chunk_file.read(8192) - if not data_block: - break - outfile.write(data_block) - - file_size = os.path.getsize(output_path) + await asyncio.to_thread( + _merge_backup_chunks, + output_path, + chunk_dir, + total, + ) + + file_size = await asyncio.to_thread(os.path.getsize, output_path) # 标记备份为上传来源(修改 manifest.json 中的 origin 字段) self._mark_backup_as_uploaded(output_path) @@ -725,7 +739,7 @@ async def upload_complete(self): ) except Exception as e: # 如果合并失败,删除不完整的文件 - if os.path.exists(output_path): + if await asyncio.to_thread(os.path.exists, output_path): os.remove(output_path) raise e @@ -787,7 +801,7 @@ async def check_backup(self): return Response().error("无效的文件名").__dict__ zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): + if not await asyncio.to_thread(os.path.exists, zip_path): return Response().error(f"备份文件不存在: {filename}").__dict__ # 获取知识库管理器(用于构造 importer) @@ -841,7 +855,7 @@ async def import_backup(self): return Response().error("无效的文件名").__dict__ zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): + if not await asyncio.to_thread(os.path.exists, zip_path): return Response().error(f"备份文件不存在: {filename}").__dict__ # 生成任务ID @@ -988,7 +1002,7 @@ async def download_backup(self): return Response().error("无效的文件名").__dict__ file_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(file_path): + if not await asyncio.to_thread(os.path.exists, file_path): return Response().error("备份文件不存在").__dict__ return await send_file( @@ -1019,7 +1033,7 @@ async def delete_backup(self): return Response().error("无效的文件名").__dict__ file_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(file_path): + if not await asyncio.to_thread(os.path.exists, file_path): return Response().error("备份文件不存在").__dict__ os.remove(file_path) @@ -1067,12 +1081,12 @@ async def rename_backup(self): # 检查原文件是否存在 old_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(old_path): + if not await asyncio.to_thread(os.path.exists, old_path): return Response().error("备份文件不存在").__dict__ # 检查新文件名是否已存在 new_path = os.path.join(self.backup_dir, new_filename) - if os.path.exists(new_path): + if await asyncio.to_thread(os.path.exists, new_path): return Response().error(f"文件名 '{new_filename}' 已存在").__dict__ # 执行重命名 diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a914f3cbf0..f76aa7f9f1 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -80,17 +80,23 @@ async def get_file(self): try: file_path = os.path.join(self.attachments_dir, os.path.basename(filename)) - real_file_path = os.path.realpath(file_path) - real_imgs_dir = os.path.realpath(self.attachments_dir) + real_file_path = await asyncio.to_thread(os.path.realpath, file_path) + real_imgs_dir = await asyncio.to_thread( + os.path.realpath, self.attachments_dir + ) - if not os.path.exists(real_file_path): + if not await asyncio.to_thread(os.path.exists, real_file_path): # try legacy file_path = os.path.join( self.legacy_img_dir, os.path.basename(filename) ) - if os.path.exists(file_path): - real_file_path = os.path.realpath(file_path) - real_imgs_dir = os.path.realpath(self.legacy_img_dir) + if await asyncio.to_thread(os.path.exists, file_path): + real_file_path = await asyncio.to_thread( + os.path.realpath, file_path + ) + real_imgs_dir = await asyncio.to_thread( + os.path.realpath, self.legacy_img_dir + ) if not real_file_path.startswith(real_imgs_dir): return Response().error("Invalid file path").__dict__ @@ -117,7 +123,7 @@ async def get_attachment(self): return Response().error("Attachment not found").__dict__ file_path = attachment.path - real_file_path = os.path.realpath(file_path) + real_file_path = await asyncio.to_thread(os.path.realpath, file_path) return await send_file(real_file_path, mimetype=attachment.mime_type) @@ -344,7 +350,7 @@ async def stream(): while True: try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: + except TimeoutError: continue except asyncio.CancelledError: logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") @@ -652,7 +658,7 @@ async def _delete_attachments(self, attachment_ids: list[str]) -> None: try: attachments = await self.db.get_attachments(attachment_ids) for attachment in attachments: - if not os.path.exists(attachment.path): + if not await asyncio.to_thread(os.path.exists, attachment.path): continue try: os.remove(attachment.path) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 823d0fb9dd..1ed80a218d 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1103,7 +1103,9 @@ async def upload_config_file(self): if not files: return Response().error("No files uploaded").__dict__ - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) + storage_root_path = await asyncio.to_thread( + Path(get_astrbot_plugin_data_path()).resolve, strict=False + ) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1179,7 +1181,9 @@ async def delete_config_file(self): if not md: return Response().error(f"Plugin {name} not found").__dict__ - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) + storage_root_path = await asyncio.to_thread( + Path(get_astrbot_plugin_data_path()).resolve, strict=False + ) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1207,7 +1211,9 @@ async def get_config_file_list(self): if not meta or meta.get("type") != "file": return Response().error("Config item not found or not file type").__dict__ - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) + storage_root_path = await asyncio.to_thread( + Path(get_astrbot_plugin_data_path()).resolve, strict=False + ) plugin_root_path = (storage_root_path / name).resolve(strict=False) try: plugin_root_path.relative_to(storage_root_path) @@ -1375,7 +1381,7 @@ async def _register_platform_logo(self, platform, platform_default_tmpl) -> None logo_file_path = os.path.join(plugin_dir, platform.logo_path) # 检查文件是否存在并注册令牌 - if os.path.exists(logo_file_path): + if await asyncio.to_thread(os.path.exists, logo_file_path): logo_token = await file_token_service.register_file( logo_file_path, timeout=3600, diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index f0ac5d43d0..d06414c016 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -729,7 +729,7 @@ async def upload_document(self): ) finally: # 清理临时文件 - if os.path.exists(temp_file_path): + if await asyncio.to_thread(os.path.exists, temp_file_path): os.remove(temp_file_path) # 获取知识库 diff --git a/astrbot/dashboard/routes/live_chat.py b/astrbot/dashboard/routes/live_chat.py index 8d0af938d0..58398d24c7 100644 --- a/astrbot/dashboard/routes/live_chat.py +++ b/astrbot/dashboard/routes/live_chat.py @@ -86,7 +86,7 @@ async def end_speaking(self, stamp: str) -> tuple[str | None, float]: self.temp_audio_path = audio_path logger.info( - f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {os.path.getsize(audio_path)} bytes" + f"[Live Chat] 音频文件已保存: {audio_path}, 大小: {await asyncio.to_thread(os.path.getsize, audio_path)} bytes" ) return audio_path, time.time() - start_time @@ -491,7 +491,7 @@ async def _handle_chat_message( try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: + except TimeoutError: continue if not result: @@ -790,7 +790,7 @@ async def _process_audio( try: result = await asyncio.wait_for(back_queue.get(), timeout=0.5) - except asyncio.TimeoutError: + except TimeoutError: continue if not result: diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 9a736b1763..763d05db03 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -369,7 +369,7 @@ async def _handle_chat_ws_send(self, post_data: dict) -> None: while True: try: result = await asyncio.wait_for(back_queue.get(), timeout=1) - except asyncio.TimeoutError: + except TimeoutError: continue if not result: diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index bb7769926a..f3d1d69eed 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -6,6 +6,7 @@ import traceback from dataclasses import dataclass from datetime import datetime +from pathlib import Path import aiohttp import certifi @@ -738,19 +739,20 @@ async def get_plugin_readme(self): plugin_obj.root_dir_name, ) - if not os.path.isdir(plugin_dir): + if not await asyncio.to_thread(os.path.isdir, plugin_dir): logger.warning(f"无法找到插件目录: {plugin_dir}") return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ readme_path = os.path.join(plugin_dir, "README.md") - if not os.path.isfile(readme_path): + if not await asyncio.to_thread(os.path.isfile, readme_path): logger.warning(f"插件 {plugin_name} 没有README文件") return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() + readme_content = await asyncio.to_thread( + Path(readme_path).read_text, encoding="utf-8" + ) return ( Response() @@ -799,7 +801,7 @@ async def get_plugin_changelog(self): plugin_obj.root_dir_name, ) - if not os.path.isdir(plugin_dir): + if not await asyncio.to_thread(os.path.isdir, plugin_dir): logger.warning(f"无法找到插件目录: {plugin_dir}") return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ @@ -807,10 +809,11 @@ async def get_plugin_changelog(self): changelog_names = ["CHANGELOG.md", "changelog.md", "CHANGELOG", "changelog"] for name in changelog_names: changelog_path = os.path.join(plugin_dir, name) - if os.path.isfile(changelog_path): + if await asyncio.to_thread(os.path.isfile, changelog_path): try: - with open(changelog_path, encoding="utf-8") as f: - changelog_content = f.read() + changelog_content = await asyncio.to_thread( + Path(changelog_path).read_text, encoding="utf-8" + ) return ( Response() .ok({"content": changelog_content}, "成功获取更新日志") diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index adad49615f..b003e2010d 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -1,3 +1,4 @@ +import asyncio import os import re import shutil @@ -182,7 +183,7 @@ async def upload_skill(self): logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ finally: - if temp_path and os.path.exists(temp_path): + if temp_path and await asyncio.to_thread(os.path.exists, temp_path): try: os.remove(temp_path) except Exception: diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 532238ac7a..238b6aa4cc 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,3 +1,4 @@ +import asyncio import os import re import threading @@ -214,13 +215,17 @@ async def get_changelog(self): changelog_path = os.path.join(changelogs_dir, filename) # 规范化路径,防止符号链接攻击 - changelog_path = os.path.realpath(changelog_path) - changelogs_dir = os.path.realpath(changelogs_dir) + changelog_path = await asyncio.to_thread(os.path.realpath, changelog_path) + changelogs_dir = await asyncio.to_thread(os.path.realpath, changelogs_dir) # 验证最终路径在预期的 changelogs 目录内(防止路径遍历) # 确保规范化后的路径以 changelogs_dir 开头,且是目录内的文件 - changelog_path_normalized = os.path.normpath(changelog_path) - changelogs_dir_normalized = os.path.normpath(changelogs_dir) + changelog_path_normalized = await asyncio.to_thread( + os.path.normpath, changelog_path + ) + changelogs_dir_normalized = await asyncio.to_thread( + os.path.normpath, changelogs_dir + ) # 检查路径是否在预期目录内(必须是目录的子文件,不能是目录本身) expected_prefix = changelogs_dir_normalized + os.sep @@ -230,21 +235,22 @@ async def get_changelog(self): ) return Response().error("Invalid version format").__dict__ - if not os.path.exists(changelog_path): + if not await asyncio.to_thread(os.path.exists, changelog_path): return ( Response() .error(f"Changelog for version {version} not found") .__dict__ ) - if not os.path.isfile(changelog_path): + if not await asyncio.to_thread(os.path.isfile, changelog_path): return ( Response() .error(f"Changelog for version {version} not found") .__dict__ ) - with open(changelog_path, encoding="utf-8") as f: - content = f.read() + content = await asyncio.to_thread( + Path(changelog_path).read_text, encoding="utf-8" + ) return Response().ok({"content": content, "version": version}).__dict__ except Exception as e: @@ -257,7 +263,7 @@ async def list_changelog_versions(self): project_path = get_astrbot_path() changelogs_dir = os.path.join(project_path, "changelogs") - if not os.path.exists(changelogs_dir): + if not await asyncio.to_thread(os.path.exists, changelogs_dir): return Response().ok({"versions": []}).__dict__ versions = [] diff --git a/main.py b/main.py index 36c46fca33..b8c42d78ea 100644 --- a/main.py +++ b/main.py @@ -69,13 +69,13 @@ async def check_dashboard_files(webui_dir: str | None = None): """下载管理面板文件""" # 指定webui目录 if webui_dir: - if os.path.exists(webui_dir): + if await asyncio.to_thread(os.path.exists, webui_dir): logger.info(f"使用指定的 WebUI 目录: {webui_dir}") return webui_dir logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。") data_dist_path = os.path.join(get_astrbot_data_path(), "dist") - if os.path.exists(data_dist_path): + if await asyncio.to_thread(os.path.exists, data_dist_path): v = await get_dashboard_version() if v is not None: # 存在文件 diff --git a/pyproject.toml b/pyproject.toml index e57a0216c3..1f91891975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ astrbot = "astrbot.cli.__main__:cli" [tool.ruff] exclude = ["astrbot/core/utils/t2i/local_strategy.py", "astrbot/api/all.py", "tests"] line-length = 88 -target-version = "py310" +target-version = "py312" [tool.ruff.lint] select = [ @@ -99,13 +99,11 @@ ignore = [ "F403", "F405", "E501", - "ASYNC230", # TODO: handle ASYNC230 in AstrBot - "ASYNC240", # TODO: handle ASYNC240 in AstrBot ] [tool.pyright] typeCheckingMode = "basic" -pythonVersion = "3.10" +pythonVersion = "3.12" reportMissingTypeStubs = false reportMissingImports = false include = ["astrbot"] diff --git a/tests/test_skill_manager_sandbox_cache.py b/tests/test_skill_manager_sandbox_cache.py index 88923ec10b..5707148c6d 100644 --- a/tests/test_skill_manager_sandbox_cache.py +++ b/tests/test_skill_manager_sandbox_cache.py @@ -2,6 +2,8 @@ from pathlib import Path +import pytest + from astrbot.core.skills.skill_manager import SkillManager @@ -56,7 +58,7 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path) assert by_name["custom-local"].description == "local description" assert by_name["custom-local"].path == "skills/custom-local/SKILL.md" assert by_name["python-sandbox"].description == "ship built-in" - assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md" + assert by_name["python-sandbox"].path == "/workspace/skills/python-sandbox/SKILL.md" def test_sandbox_cached_skill_respects_active_and_display_path( @@ -98,7 +100,8 @@ def test_sandbox_cached_skill_respects_active_and_display_path( assert len(all_skills) == 1 assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md" - mgr.set_skill_active("browser-automation", False) + with pytest.raises(PermissionError): + mgr.set_skill_active("browser-automation", False) active_skills = mgr.list_skills(runtime="sandbox", active_only=True) - assert active_skills == [] - + assert len(active_skills) == 1 + assert active_skills[0].name == "browser-automation" diff --git a/tests/unit/test_io_file_to_base64.py b/tests/unit/test_io_file_to_base64.py new file mode 100644 index 0000000000..b490ffed86 --- /dev/null +++ b/tests/unit/test_io_file_to_base64.py @@ -0,0 +1,16 @@ +import base64 + +import pytest + +from astrbot.core.utils.io import file_to_base64 + + +@pytest.mark.asyncio +async def test_file_to_base64_reads_file_async(tmp_path): + sample_file = tmp_path / "sample.bin" + sample_file.write_bytes(b"astrbot") + + result = await file_to_base64(str(sample_file)) + + expected = "base64://" + base64.b64encode(b"astrbot").decode() + assert result == expected