diff --git a/astrbot/api/all.py b/astrbot/api/all.py index df3e1170fb..fe226b5afc 100644 --- a/astrbot/api/all.py +++ b/astrbot/api/all.py @@ -10,7 +10,6 @@ CommandResult, EventResultType, ) -from astrbot.core.platform import AstrMessageEvent # star register from astrbot.core.star.register import ( @@ -31,8 +30,9 @@ from astrbot.core.star.register import ( register_star as register, # 注册插件(Star) ) -from astrbot.core.star import Context, Star +from astrbot.core.star.base import Star from astrbot.core.star.config import * +from astrbot.core.star.context import Context # provider diff --git a/astrbot/api/star/__init__.py b/astrbot/api/star/__init__.py index 63db07a727..914e2ab301 100644 --- a/astrbot/api/star/__init__.py +++ b/astrbot/api/star/__init__.py @@ -1,7 +1,9 @@ -from astrbot.core.star import Context, Star, StarTools +from astrbot.core.star.base import Star from astrbot.core.star.config import * +from astrbot.core.star.context import Context from astrbot.core.star.register import ( register_star as register, # 注册插件(Star) ) +from astrbot.core.star.star_tools import StarTools __all__ = ["Context", "Star", "StarTools", "register"] diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 0f51a29c05..8dec64a859 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -846,6 +846,7 @@ def _apply_sandbox_tools( ) -> None: if req.func_tool is None: req.func_tool = ToolSet() + req.system_prompt = req.system_prompt or "" booter = config.sandbox_cfg.get("booter", "shipyard_neo") if booter == "shipyard": ep = config.sandbox_cfg.get("shipyard_endpoint", "") diff --git a/astrbot/core/cron/__init__.py b/astrbot/core/cron/__init__.py index b685075411..94a0771ff9 100644 --- a/astrbot/core/cron/__init__.py +++ b/astrbot/core/cron/__init__.py @@ -1,3 +1,22 @@ -from .manager import CronJobManager +"""Cron package exports. + +Keep `CronJobManager` import-compatible while avoiding hard import failure when +`apscheduler` is partially mocked in test environments. +""" + +try: + from .manager import CronJobManager +except ModuleNotFoundError as exc: + if not (exc.name and exc.name.startswith("apscheduler")): + raise + + _IMPORT_ERROR = exc + + class CronJobManager: + def __init__(self, *args, **kwargs) -> None: + raise ModuleNotFoundError( + "CronJobManager requires a complete `apscheduler` installation." + ) from _IMPORT_ERROR + __all__ = ["CronJobManager"] diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index cbaec66c7c..7c8b871537 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -12,16 +12,24 @@ from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.zip_updator import ReleaseInfo from astrbot.dashboard.server import AstrBotDashboard from tests.fixtures.helpers import ( MockPluginBuilder, - MockPluginConfig, create_mock_updater_install, create_mock_updater_update, ) +RUN_ONLINE_UPDATE_CHECK = os.environ.get( + "ASTRBOT_RUN_ONLINE_UPDATE_CHECK", "" +).lower() in { + "1", + "true", + "yes", +} -@pytest_asyncio.fixture(scope="module") + +@pytest_asyncio.fixture(scope="module", loop_scope="module") async def core_lifecycle_td(tmp_path_factory): """Creates and initializes a core lifecycle instance with a temporary database.""" tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db" @@ -51,7 +59,7 @@ def app(core_lifecycle_td: AstrBotCoreLifecycle): return server.app -@pytest_asyncio.fixture(scope="module") +@pytest_asyncio.fixture(scope="module", loop_scope="module") async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): """Handles login and returns an authenticated header.""" test_client = app.test_client() @@ -108,16 +116,14 @@ async def test_plugins( core_lifecycle_td: AstrBotCoreLifecycle, monkeypatch, ): - """测试插件 API 端点,使用 Mock 避免真实网络调用。""" + """Test plugin APIs with mocked updater behavior.""" test_client = app.test_client() - # 已经安装的插件 response = await test_client.get("/api/plugin/get", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "ok" - # 插件市场 response = await test_client.get( "/api/plugin/market_list", headers=authenticated_header, @@ -126,31 +132,24 @@ async def test_plugins( data = await response.get_json() assert data["status"] == "ok" - # 使用 MockPluginBuilder 创建测试插件 plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path builder = MockPluginBuilder(plugin_store_path) - # 定义测试插件 test_plugin_name = "test_mock_plugin" test_repo_url = f"https://github.com/test/{test_plugin_name}" - # 创建 Mock 函数 mock_install = create_mock_updater_install( builder, repo_to_plugin={test_repo_url: test_plugin_name}, ) mock_update = create_mock_updater_update(builder) - # 设置 Mock monkeypatch.setattr( core_lifecycle_td.plugin_manager.updator, "install", mock_install ) - monkeypatch.setattr( - core_lifecycle_td.plugin_manager.updator, "update", mock_update - ) + monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update) try: - # 插件安装 response = await test_client.post( "/api/plugin/install", json={"url": test_repo_url}, @@ -158,13 +157,13 @@ async def test_plugins( ) assert response.status_code == 200 data = await response.get_json() - assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}" + assert data["status"] == "ok", ( + f"install failed: {data.get('message', 'unknown error')}" + ) - # 验证插件已注册 exists = any(md.name == test_plugin_name for md in star_registry) - assert exists is True, f"插件 {test_plugin_name} 未成功载入" + assert exists is True, f"plugin {test_plugin_name} was not loaded" - # 插件更新 response = await test_client.post( "/api/plugin/update", json={"name": test_plugin_name}, @@ -174,11 +173,9 @@ async def test_plugins( data = await response.get_json() assert data["status"] == "ok" - # 验证更新标记文件 plugin_dir = builder.get_plugin_path(test_plugin_name) assert (plugin_dir / ".updated").exists() - # 插件卸载 response = await test_client.post( "/api/plugin/uninstall", json={"name": test_plugin_name}, @@ -188,16 +185,14 @@ async def test_plugins( data = await response.get_json() assert data["status"] == "ok" - # 验证插件已卸载 exists = any(md.name == test_plugin_name for md in star_registry) - assert exists is False, f"插件 {test_plugin_name} 未成功卸载" + assert exists is False, f"plugin {test_plugin_name} was not unloaded" exists = any( test_plugin_name in md.handler_module_path for md in star_handlers_registry ) - assert exists is False, f"插件 {test_plugin_name} handler 未成功清理" + assert exists is False, f"plugin {test_plugin_name} handlers were not cleaned" finally: - # 清理测试插件 builder.cleanup(test_plugin_name) @@ -230,41 +225,141 @@ async def test_commands_api(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_check_update( +async def test_check_update_success_no_new_version( app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle, monkeypatch, ): - """测试检查更新 API,使用 Mock 避免真实网络调用。""" + async def mock_get_dashboard_version(): + return "v-test-dashboard" + + async def mock_check_update(*args, **kwargs): # noqa: ARG001 + return None + + monkeypatch.setattr( + "astrbot.dashboard.routes.update.get_dashboard_version", + mock_get_dashboard_version, + ) + monkeypatch.setattr( + core_lifecycle_td.astrbot_updator, + "check_update", + mock_check_update, + ) test_client = app.test_client() - # Mock 更新检查和网络请求 - async def mock_check_update(*args, **kwargs): - """Mock 更新检查,返回无新版本。""" - return None # None 表示没有新版本 + response = await test_client.get("/api/update/check", headers=authenticated_header) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "success" + assert { + "version", + "has_new_version", + "dashboard_version", + "dashboard_has_new_version", + }.issubset(data["data"]) + assert data["data"]["has_new_version"] is False + assert data["data"]["dashboard_version"] == "v-test-dashboard" - async def mock_get_dashboard_version(*args, **kwargs): - """Mock Dashboard 版本获取。""" - from astrbot.core.config.default import VERSION - return f"v{VERSION}" # 返回当前版本 +@pytest.mark.asyncio +async def test_check_update_success_has_new_version( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + async def mock_get_dashboard_version(): + return "v-test-dashboard" + + async def mock_check_update(*args, **kwargs): # noqa: ARG001 + return ReleaseInfo( + version="v999.0.0", + published_at="2026-01-01", + body="test release", + ) + monkeypatch.setattr( + "astrbot.dashboard.routes.update.get_dashboard_version", + mock_get_dashboard_version, + ) monkeypatch.setattr( core_lifecycle_td.astrbot_updator, "check_update", mock_check_update, ) + + test_client = app.test_client() + response = await test_client.get("/api/update/check", headers=authenticated_header) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "success" + assert { + "version", + "has_new_version", + "dashboard_version", + "dashboard_has_new_version", + }.issubset(data["data"]) + assert data["data"]["has_new_version"] is True + assert data["data"]["dashboard_version"] == "v-test-dashboard" + + +@pytest.mark.asyncio +async def test_check_update_error_when_updator_raises( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + async def mock_get_dashboard_version(): + return "v-test-dashboard" + + async def mock_check_update(*args, **kwargs): # noqa: ARG001 + raise RuntimeError("mock update check failure") + monkeypatch.setattr( "astrbot.dashboard.routes.update.get_dashboard_version", mock_get_dashboard_version, ) + monkeypatch.setattr( + core_lifecycle_td.astrbot_updator, + "check_update", + mock_check_update, + ) + test_client = app.test_client() response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() - assert data["status"] == "success" - assert data["data"]["has_new_version"] is False + assert data["status"] == "error" + assert isinstance(data["message"], str) + assert data["message"] + + +@pytest.mark.asyncio +@pytest.mark.integration +@pytest.mark.slow +@pytest.mark.skipif( + not RUN_ONLINE_UPDATE_CHECK, + reason="Set ASTRBOT_RUN_ONLINE_UPDATE_CHECK=1 to run online update check test.", +) +async def test_check_update_online_optional(app: Quart, authenticated_header: dict): + """Optional online smoke test for the real update-check request path.""" + test_client = app.test_client() + response = await test_client.get("/api/update/check", headers=authenticated_header) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] in {"success", "error"} + assert "message" in data + assert "data" in data + + if data["status"] == "success": + assert { + "version", + "has_new_version", + "dashboard_version", + "dashboard_has_new_version", + }.issubset(data["data"]) @pytest.mark.asyncio diff --git a/tests/test_kb_import.py b/tests/test_kb_import.py index 8ad40f5406..9e5e5995bb 100644 --- a/tests/test_kb_import.py +++ b/tests/test_kb_import.py @@ -13,7 +13,7 @@ from astrbot.dashboard.server import AstrBotDashboard -@pytest_asyncio.fixture(scope="module") +@pytest_asyncio.fixture(scope="module", loop_scope="module") async def core_lifecycle_td(tmp_path_factory): """Creates and initializes a core lifecycle instance with a temporary database.""" tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db" @@ -24,7 +24,8 @@ async def core_lifecycle_td(tmp_path_factory): # Mock kb_manager and kb_helper kb_manager = MagicMock() - kb_helper = AsyncMock(spec=KBHelper) + kb_helper = MagicMock(spec=KBHelper) + kb_helper.upload_document = AsyncMock() # Configure get_kb to be an async mock that returns kb_helper kb_manager.get_kb = AsyncMock(return_value=kb_helper) @@ -64,7 +65,7 @@ def app(core_lifecycle_td: AstrBotCoreLifecycle): return server.app -@pytest_asyncio.fixture(scope="module") +@pytest_asyncio.fixture(scope="module", loop_scope="module") async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): """Handles login and returns an authenticated header.""" test_client = app.test_client() @@ -129,11 +130,11 @@ async def test_import_documents( assert result["failed_count"] == 0 # Verify kb_helper.upload_document was called correctly - kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id") - assert kb_helper.upload_document.call_count == 2 + kb_helper_mock = await core_lifecycle_td.kb_manager.get_kb("test_kb_id") + assert kb_helper_mock.upload_document.call_count == 2 # Check first call arguments - call_args_list = kb_helper.upload_document.call_args_list + call_args_list = kb_helper_mock.upload_document.call_args_list # First document args1, kwargs1 = call_args_list[0] diff --git a/tests/unit/test_api_compat_smoke.py b/tests/unit/test_api_compat_smoke.py new file mode 100644 index 0000000000..7057ec06f0 --- /dev/null +++ b/tests/unit/test_api_compat_smoke.py @@ -0,0 +1,86 @@ +"""Smoke tests for astrbot.api backward compatibility.""" + +import importlib +import sys + + +def test_api_exports_smoke(): + """astrbot.api should expose expected public symbols.""" + import astrbot.api as api + + for name in [ + "AstrBotConfig", + "BaseFunctionToolExecutor", + "FunctionTool", + "ToolSet", + "agent", + "llm_tool", + "logger", + "html_renderer", + "sp", + ]: + assert hasattr(api, name), f"Missing export: {name}" + + assert callable(api.agent) + assert callable(api.llm_tool) + + +def test_api_event_and_platform_map_to_core(): + """api facade classes should remain mapped to core implementations.""" + from astrbot.api import event as api_event + from astrbot.api import platform as api_platform + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform import ( + AstrBotMessage, + AstrMessageEvent, + MessageMember, + MessageType, + Platform, + PlatformMetadata, + ) + from astrbot.core.platform.register import register_platform_adapter + + assert api_event.AstrMessageEvent is AstrMessageEvent + assert api_event.MessageChain is MessageChain + + assert api_platform.AstrBotMessage is AstrBotMessage + assert api_platform.AstrMessageEvent is AstrMessageEvent + assert api_platform.MessageMember is MessageMember + assert api_platform.MessageType is MessageType + assert api_platform.Platform is Platform + assert api_platform.PlatformMetadata is PlatformMetadata + assert api_platform.register_platform_adapter is register_platform_adapter + + +def test_api_message_components_smoke(): + """message_components facade should stay import-compatible.""" + from astrbot.api.message_components import File, Image, Plain + + plain = Plain("hello") + image = Image(file="https://example.com/a.jpg", url="https://example.com/a.jpg") + file_seg = File(file="https://example.com/a.txt", name="a.txt") + + assert plain.text == "hello" + assert image.file == "https://example.com/a.jpg" + assert file_seg.name == "a.txt" + + +def test_api_eagerly_imports_star_register(monkeypatch): + """Importing astrbot.api should expose direct aliases from star.register.""" + monkeypatch.delitem(sys.modules, "astrbot.core.star.register", raising=False) + + api = importlib.import_module("astrbot.api") + importlib.reload(api) + register_mod = importlib.import_module("astrbot.core.star.register") + + assert "astrbot.core.star.register" in sys.modules + assert api.agent is register_mod.register_agent + assert api.llm_tool is register_mod.register_llm_tool + + +def test_api_agent_and_llm_tool_are_callable_aliases(): + """agent/llm_tool should remain callable after direct aliasing.""" + import astrbot.api as api + + assert callable(api.agent) + assert callable(api.llm_tool) diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 3ce974f419..866a646a80 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -523,6 +523,8 @@ async def test_ensure_skills(self, mock_event, mock_context): mock_skill = MagicMock() mock_skill.name = "test_skill" mock_skill.to_prompt.return_value = "Skill description" + mock_skill.description = "Skill description" + mock_skill.path = "data/skills/test_skill/SKILL.md" mock_context.persona_manager.personas_v3 = [] mock_context.persona_manager.resolve_selected_persona = AsyncMock( return_value=(None, None, None, False) diff --git a/tests/unit/test_fixture_plugin_usage.py b/tests/unit/test_fixture_plugin_usage.py new file mode 100644 index 0000000000..656e1562a3 --- /dev/null +++ b/tests/unit/test_fixture_plugin_usage.py @@ -0,0 +1,58 @@ +import subprocess +import sys +from pathlib import Path + +import pytest + +from tests.fixtures import get_fixture_path + + +def test_fixture_plugin_files_exist(): + plugin_file = get_fixture_path("plugins/fixture_plugin.py") + metadata_file = get_fixture_path("plugins/metadata.yaml") + + assert plugin_file.exists() + assert metadata_file.exists() + + +@pytest.mark.slow +def test_fixture_plugin_can_be_imported_in_isolated_process(): + plugin_file = get_fixture_path("plugins/fixture_plugin.py") + repo_root = Path(__file__).resolve().parents[2] + + script = "\n".join( + [ + "import importlib.util", + f'plugin_file = r"{plugin_file}"', + "spec = importlib.util.spec_from_file_location('fixture_test_plugin', plugin_file)", + "assert spec is not None", + "assert spec.loader is not None", + "module = importlib.util.module_from_spec(spec)", + "spec.loader.exec_module(module)", + "plugin_cls = getattr(module, 'TestPlugin', None)", + "assert plugin_cls is not None", + "assert hasattr(plugin_cls, 'test_command')", + "assert hasattr(plugin_cls, 'test_llm_tool')", + "assert hasattr(plugin_cls, 'test_regex_handler')", + ], + ) + + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + cwd=repo_root, + check=False, + ) + + if result.returncode != 0: + stderr_text = (result.stderr or "").strip() + if stderr_text: + raise AssertionError( + "Fixture plugin import failed with stderr output.\n" + f"stderr:\n{stderr_text}\n\nstdout:\n{result.stdout}" + ) + raise AssertionError( + "Fixture plugin import failed with non-zero return code " + f"{result.returncode}, but stderr is empty.\nstdout:\n{result.stdout}" + ) diff --git a/tests/unit/test_skipped_items_runtime.py b/tests/unit/test_skipped_items_runtime.py new file mode 100644 index 0000000000..667999671e --- /dev/null +++ b/tests/unit/test_skipped_items_runtime.py @@ -0,0 +1,773 @@ +"""Runtime coverage for scenarios previously represented by skipped adapter tests. + +These tests run in isolated Python subprocesses and install lightweight SDK stubs +so we can execute critical adapter paths without changing existing skipped tests. +""" + +from __future__ import annotations + +import subprocess +import sys +import textwrap +from pathlib import Path + + +def _run_python(code: str) -> subprocess.CompletedProcess[str]: + repo_root = Path(__file__).resolve().parents[2] + return subprocess.run( + [sys.executable, "-c", textwrap.dedent(code)], + cwd=repo_root, + capture_output=True, + text=True, + check=False, + ) + + +def _assert_ok(code: str) -> None: + proc = _run_python(code) + assert proc.returncode == 0, ( + f"Subprocess test failed.\nstdout:\n{proc.stdout}\nstderr:\n{proc.stderr}\n" + ) + + +def test_platform_manager_cycle_and_helpers_work() -> None: + _assert_ok( + """ + import asyncio + + from astrbot.core.platform.manager import PlatformManager + + + class DummyConfig(dict): + def save_config(self): + self["_saved"] = True + + + cfg = DummyConfig({"platform": [], "platform_settings": {}}) + manager = PlatformManager(cfg, asyncio.Queue()) + assert manager._is_valid_platform_id("platform_1") + assert not manager._is_valid_platform_id("bad:id") + assert manager._sanitize_platform_id("bad:id!x") == ("bad_id_x", True) + assert manager._sanitize_platform_id("ok") == ("ok", False) + stats = manager.get_all_stats() + assert stats["summary"]["total"] == 0 + """ + ) + + +def test_slack_adapter_smoke_without_external_sdk() -> None: + _assert_ok( + """ + import asyncio + import types + import sys + + quart = types.ModuleType("quart") + + class Quart: + def __init__(self, *args, **kwargs): + pass + + def route(self, *args, **kwargs): + def deco(fn): + return fn + return deco + + async def run_task(self, *args, **kwargs): + return None + + class Response: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + quart.Quart = Quart + quart.Response = Response + quart.request = types.SimpleNamespace() + sys.modules["quart"] = quart + + slack_sdk = types.ModuleType("slack_sdk") + sys.modules["slack_sdk"] = slack_sdk + sys.modules["slack_sdk.socket_mode"] = types.ModuleType("slack_sdk.socket_mode") + + req_mod = types.ModuleType("slack_sdk.socket_mode.request") + class SocketModeRequest: + def __init__(self): + self.type = "events_api" + self.payload = {} + self.envelope_id = "env" + req_mod.SocketModeRequest = SocketModeRequest + sys.modules["slack_sdk.socket_mode.request"] = req_mod + + aiohttp_mod = types.ModuleType("slack_sdk.socket_mode.aiohttp") + class SocketModeClient: + def __init__(self, *args, **kwargs): + self.socket_mode_request_listeners = [] + async def connect(self): + return None + async def disconnect(self): + return None + async def close(self): + return None + async def send_socket_mode_response(self, response): + return None + aiohttp_mod.SocketModeClient = SocketModeClient + sys.modules["slack_sdk.socket_mode.aiohttp"] = aiohttp_mod + + async_client_mod = types.ModuleType("slack_sdk.socket_mode.async_client") + async_client_mod.AsyncBaseSocketModeClient = object + sys.modules["slack_sdk.socket_mode.async_client"] = async_client_mod + + resp_mod = types.ModuleType("slack_sdk.socket_mode.response") + class SocketModeResponse: + def __init__(self, envelope_id): + self.envelope_id = envelope_id + resp_mod.SocketModeResponse = SocketModeResponse + sys.modules["slack_sdk.socket_mode.response"] = resp_mod + + sys.modules["slack_sdk.web"] = types.ModuleType("slack_sdk.web") + web_async_mod = types.ModuleType("slack_sdk.web.async_client") + class AsyncWebClient: + def __init__(self, *args, **kwargs): + pass + async def auth_test(self): + return {"user_id": "U1"} + async def users_info(self, user): + return {"user": {"name": "user", "real_name": "User"}} + async def conversations_info(self, channel): + return {"channel": {"is_im": False, "name": "general"}} + async def chat_postMessage(self, **kwargs): + return {"ok": True} + web_async_mod.AsyncWebClient = AsyncWebClient + sys.modules["slack_sdk.web.async_client"] = web_async_mod + + from astrbot.core.platform.sources.slack.slack_adapter import SlackAdapter + + adapter = SlackAdapter( + { + "id": "slack_test", + "bot_token": "xoxb-test", + "app_token": "xapp-test", + "slack_connection_mode": "socket", + }, + {}, + asyncio.Queue(), + ) + assert adapter.meta().name == "slack" + + try: + SlackAdapter({"id": "bad"}, {}, asyncio.Queue()) + raise AssertionError("Expected ValueError for missing bot_token") + except ValueError: + pass + """ + ) + + +def test_wecom_adapter_smoke_without_external_sdk() -> None: + _assert_ok( + """ + import asyncio + import types + import sys + + optionaldict_mod = types.ModuleType("optionaldict") + class optionaldict(dict): + pass + optionaldict_mod.optionaldict = optionaldict + sys.modules["optionaldict"] = optionaldict_mod + + quart = types.ModuleType("quart") + class Quart: + def __init__(self, *args, **kwargs): + pass + def add_url_rule(self, *args, **kwargs): + return None + async def run_task(self, *args, **kwargs): + return None + async def shutdown(self): + return None + quart.Quart = Quart + quart.request = types.SimpleNamespace() + sys.modules["quart"] = quart + + wechatpy = types.ModuleType("wechatpy") + enterprise = types.ModuleType("wechatpy.enterprise") + crypto_mod = types.ModuleType("wechatpy.enterprise.crypto") + enterprise_messages = types.ModuleType("wechatpy.enterprise.messages") + exceptions_mod = types.ModuleType("wechatpy.exceptions") + messages_mod = types.ModuleType("wechatpy.messages") + client_mod = types.ModuleType("wechatpy.client") + client_api_mod = types.ModuleType("wechatpy.client.api") + client_base_mod = types.ModuleType("wechatpy.client.api.base") + + class BaseWeChatAPI: + def _post(self, *args, **kwargs): + return {} + def _get(self, *args, **kwargs): + return {} + client_base_mod.BaseWeChatAPI = BaseWeChatAPI + + class InvalidSignatureException(Exception): + pass + exceptions_mod.InvalidSignatureException = InvalidSignatureException + + class BaseMessage: + type = "text" + messages_mod.BaseMessage = BaseMessage + + class TextMessage(BaseMessage): + def __init__(self, content="hello"): + self.type = "text" + self.content = content + self.agent = "agent_1" + self.source = "user_1" + self.id = "msg_1" + self.time = 1700000000 + + class ImageMessage(BaseMessage): + def __init__(self): + self.type = "image" + self.image = "https://example.com/a.jpg" + self.agent = "agent_1" + self.source = "user_1" + self.id = "msg_2" + self.time = 1700000000 + + class VoiceMessage(BaseMessage): + def __init__(self): + self.type = "voice" + self.media_id = "media_1" + self.agent = "agent_1" + self.source = "user_1" + self.id = "msg_3" + self.time = 1700000000 + + enterprise_messages.TextMessage = TextMessage + enterprise_messages.ImageMessage = ImageMessage + enterprise_messages.VoiceMessage = VoiceMessage + + class WeChatCrypto: + def __init__(self, *args, **kwargs): + pass + def check_signature(self, *args, **kwargs): + return "ok" + def decrypt_message(self, *args, **kwargs): + return "" + crypto_mod.WeChatCrypto = WeChatCrypto + + class WeChatClient: + def __init__(self, *args, **kwargs): + self.message = types.SimpleNamespace( + send_text=lambda *a, **k: {"errcode": 0}, + send_image=lambda *a, **k: {"errcode": 0}, + send_voice=lambda *a, **k: {"errcode": 0}, + send_file=lambda *a, **k: {"errcode": 0}, + ) + self.media = types.SimpleNamespace( + download=lambda media_id: types.SimpleNamespace(content=b"voice"), + upload=lambda *a, **k: {"media_id": "m1"}, + ) + enterprise.WeChatClient = WeChatClient + enterprise.parse_message = lambda xml: TextMessage("xml") + + wechatpy.enterprise = enterprise + wechatpy.exceptions = exceptions_mod + wechatpy.messages = messages_mod + wechatpy.client = client_mod + client_mod.api = client_api_mod + client_api_mod.base = client_base_mod + + sys.modules["wechatpy"] = wechatpy + sys.modules["wechatpy.enterprise"] = enterprise + sys.modules["wechatpy.enterprise.crypto"] = crypto_mod + sys.modules["wechatpy.enterprise.messages"] = enterprise_messages + sys.modules["wechatpy.exceptions"] = exceptions_mod + sys.modules["wechatpy.messages"] = messages_mod + sys.modules["wechatpy.client"] = client_mod + sys.modules["wechatpy.client.api"] = client_api_mod + sys.modules["wechatpy.client.api.base"] = client_base_mod + + from astrbot.core.platform.sources.wecom.wecom_adapter import WecomPlatformAdapter + + queue = asyncio.Queue() + adapter = WecomPlatformAdapter( + { + "id": "wecom_test", + "corpid": "corp", + "secret": "sec", + "token": "token", + "encoding_aes_key": "x" * 43, + "port": "8080", + "callback_server_host": "0.0.0.0", + }, + {}, + queue, + ) + assert adapter.meta().name == "wecom" + asyncio.run(adapter.convert_message(TextMessage("hello"))) + assert queue.qsize() == 1 + """ + ) + + +def test_lark_adapter_smoke_without_external_sdk() -> None: + _assert_ok( + """ + import asyncio + import types + import sys + + lark = types.ModuleType("lark_oapi") + lark.FEISHU_DOMAIN = "https://open.feishu.cn" + lark.LogLevel = types.SimpleNamespace(ERROR="ERROR") + + class DispatcherBuilder: + def register_p2_im_message_receive_v1(self, callback): + return self + def build(self): + return object() + + class EventDispatcherHandler: + @staticmethod + def builder(*args, **kwargs): + return DispatcherBuilder() + lark.EventDispatcherHandler = EventDispatcherHandler + + class WSClient: + def __init__(self, *args, **kwargs): + pass + async def _connect(self): + return None + async def _disconnect(self): + return None + lark.ws = types.SimpleNamespace(Client=WSClient) + + class APIBuilder: + def app_id(self, *args, **kwargs): + return self + def app_secret(self, *args, **kwargs): + return self + def log_level(self, *args, **kwargs): + return self + def domain(self, *args, **kwargs): + return self + def build(self): + return types.SimpleNamespace(im=types.SimpleNamespace(v1=types.SimpleNamespace())) + + class Client: + @staticmethod + def builder(): + return APIBuilder() + lark.Client = Client + + lark.im = types.SimpleNamespace(v1=types.SimpleNamespace(P2ImMessageReceiveV1=object)) + + sys.modules["lark_oapi"] = lark + sys.modules["lark_oapi.api"] = types.ModuleType("lark_oapi.api") + sys.modules["lark_oapi.api.im"] = types.ModuleType("lark_oapi.api.im") + + v1_mod = types.ModuleType("lark_oapi.api.im.v1") + + class BuilderObj: + def __getattr__(self, name): + def method(*args, **kwargs): + return self + return method + def build(self): + return object() + + class Req: + @staticmethod + def builder(): + return BuilderObj() + + v1_mod.GetMessageRequest = Req + v1_mod.GetMessageResourceRequest = Req + v1_mod.CreateFileRequest = Req + v1_mod.CreateFileRequestBody = Req + v1_mod.CreateImageRequest = Req + v1_mod.CreateImageRequestBody = Req + v1_mod.CreateMessageReactionRequest = Req + v1_mod.CreateMessageReactionRequestBody = Req + v1_mod.ReplyMessageRequest = Req + v1_mod.ReplyMessageRequestBody = Req + v1_mod.CreateMessageRequest = Req + v1_mod.CreateMessageRequestBody = Req + v1_mod.Emoji = object + sys.modules["lark_oapi.api.im.v1"] = v1_mod + + proc_mod = types.ModuleType("lark_oapi.api.im.v1.processor") + class P2ImMessageReceiveV1Processor: + def __init__(self, cb): + self.cb = cb + def type(self): + return lambda data: data + def do(self, data): + return None + proc_mod.P2ImMessageReceiveV1Processor = P2ImMessageReceiveV1Processor + sys.modules["lark_oapi.api.im.v1.processor"] = proc_mod + + from astrbot.api.message_components import Plain + from astrbot.core.platform.sources.lark.lark_adapter import LarkPlatformAdapter + + adapter = LarkPlatformAdapter( + { + "id": "lark_test", + "app_id": "appid", + "app_secret": "secret", + "lark_connection_mode": "socket", + "lark_bot_name": "astrbot", + }, + {}, + asyncio.Queue(), + ) + assert adapter.meta().name == "lark" + assert adapter._build_message_str_from_components([Plain("hello")]) == "hello" + assert adapter._is_duplicate_event("event_1") is False + assert adapter._is_duplicate_event("event_1") is True + """ + ) + + +def test_dingtalk_adapter_smoke_without_external_sdk() -> None: + _assert_ok( + """ + import asyncio + import types + import sys + + dingtalk = types.ModuleType("dingtalk_stream") + + class EventHandler: + pass + + class EventMessage: + pass + + class AckMessage: + STATUS_OK = "OK" + + class Credential: + def __init__(self, *args, **kwargs): + pass + + class DingTalkStreamClient: + def __init__(self, *args, **kwargs): + self.websocket = None + def register_all_event_handler(self, *args, **kwargs): + return None + def register_callback_handler(self, *args, **kwargs): + return None + async def start(self): + return None + def get_access_token(self): + return "token" + + class ChatbotHandler: + pass + + class CallbackMessage: + pass + + class ChatbotMessage: + TOPIC = "/v1.0/chatbot/messages" + @staticmethod + def from_dict(data): + return types.SimpleNamespace( + create_at=0, + conversation_type="1", + sender_id="sender", + sender_nick="nick", + chatbot_user_id="bot", + message_id="msg", + at_users=[], + conversation_id="conv", + message_type="text", + text=types.SimpleNamespace(content="hello"), + sender_staff_id="staff", + robot_code="robot", + ) + + dingtalk.EventHandler = EventHandler + dingtalk.EventMessage = EventMessage + dingtalk.AckMessage = AckMessage + dingtalk.Credential = Credential + dingtalk.DingTalkStreamClient = DingTalkStreamClient + dingtalk.ChatbotHandler = ChatbotHandler + dingtalk.CallbackMessage = CallbackMessage + dingtalk.ChatbotMessage = ChatbotMessage + dingtalk.RichTextContent = object + + sys.modules["dingtalk_stream"] = dingtalk + + from astrbot.api.message_components import Plain + from astrbot.api.platform import MessageType + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.astr_message_event import MessageSesion + from astrbot.core.platform.sources.dingtalk.dingtalk_adapter import ( + DingtalkPlatformAdapter, + ) + + adapter = DingtalkPlatformAdapter( + { + "id": "ding_test", + "client_id": "client", + "client_secret": "secret", + }, + {}, + asyncio.Queue(), + ) + assert adapter._id_to_sid("$:LWCP_v1:$abc") == "abc" + + called = {"ok": False} + + async def fake_send_by_session(session, chain): + called["ok"] = True + + adapter.send_by_session = fake_send_by_session + session = MessageSesion( + platform_name="dingtalk", + message_type=MessageType.FRIEND_MESSAGE, + session_id="user_1", + ) + asyncio.run(adapter.send_with_sesison(session, MessageChain([Plain("ping")]))) + assert called["ok"] is True + """ + ) + + +def test_other_adapters_runtime_imports() -> None: + _assert_ok( + """ + from astrbot.core.platform.sources.qqofficial_webhook.qo_webhook_server import ( + QQOfficialWebhook, + ) + from astrbot.core.platform.sources.wecom_ai_bot.wecomai_webhook import ( + WecomAIBotWebhookClient, + ) + from astrbot.core.platform.sources.line.line_adapter import LinePlatformAdapter + from astrbot.core.platform.sources.satori.satori_adapter import ( + SatoriPlatformAdapter, + ) + from astrbot.core.platform.sources.misskey.misskey_adapter import ( + MisskeyPlatformAdapter, + ) + + assert QQOfficialWebhook is not None + assert WecomAIBotWebhookClient is not None + assert LinePlatformAdapter is not None + assert SatoriPlatformAdapter is not None + assert MisskeyPlatformAdapter is not None + """ + ) + + +def test_line_satori_misskey_adapter_basic_init() -> None: + _assert_ok( + """ + import asyncio + + from astrbot.core.platform.sources.line.line_adapter import LinePlatformAdapter + from astrbot.core.platform.sources.misskey.misskey_adapter import ( + MisskeyPlatformAdapter, + ) + from astrbot.core.platform.sources.satori.satori_adapter import ( + SatoriPlatformAdapter, + ) + + queue = asyncio.Queue() + + line_adapter = LinePlatformAdapter( + { + "id": "line_test", + "channel_access_token": "token", + "channel_secret": "secret", + }, + {}, + queue, + ) + assert line_adapter.meta().name == "line" + + satori_adapter = SatoriPlatformAdapter( + {"id": "satori_test"}, + {}, + queue, + ) + assert satori_adapter.meta().name == "satori" + + misskey_adapter = MisskeyPlatformAdapter( + {"id": "misskey_test"}, + {}, + queue, + ) + assert misskey_adapter.meta().name == "misskey" + """ + ) + + +def test_wecom_ai_bot_webhook_client_basic() -> None: + _assert_ok( + """ + from astrbot.core.platform.sources.wecom_ai_bot.wecomai_webhook import ( + WecomAIBotWebhookClient, + ) + + client = WecomAIBotWebhookClient( + "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test_key" + ) + assert client._build_upload_url("file").startswith( + "https://qyapi.weixin.qq.com/cgi-bin/webhook/upload_media?" + ) + """ + ) + + +def test_weixin_official_account_adapter_with_stubbed_wechatpy() -> None: + _assert_ok( + """ + import asyncio + import types + import sys + + quart = types.ModuleType("quart") + class Quart: + def __init__(self, *args, **kwargs): + pass + def add_url_rule(self, *args, **kwargs): + return None + async def run_task(self, *args, **kwargs): + return None + async def shutdown(self): + return None + quart.Quart = Quart + quart.request = types.SimpleNamespace() + sys.modules["quart"] = quart + + wechatpy = types.ModuleType("wechatpy") + wechatpy.__path__ = [] + crypto_mod = types.ModuleType("wechatpy.crypto") + exceptions_mod = types.ModuleType("wechatpy.exceptions") + messages_mod = types.ModuleType("wechatpy.messages") + replies_mod = types.ModuleType("wechatpy.replies") + utils_mod = types.ModuleType("wechatpy.utils") + + class InvalidSignatureException(Exception): + pass + exceptions_mod.InvalidSignatureException = InvalidSignatureException + + class WeChatCrypto: + def __init__(self, *args, **kwargs): + pass + def check_signature(self, *args, **kwargs): + return "ok" + def decrypt_message(self, *args, **kwargs): + return "" + def encrypt_message(self, xml, nonce, ts): + return xml + crypto_mod.WeChatCrypto = WeChatCrypto + + class BaseMessage: + type = "text" + source = "user_1" + id = "msg_1" + time = 1700000000 + + class TextMessage(BaseMessage): + def __init__(self, content="hello"): + self.type = "text" + self.content = content + self.source = "user_1" + self.id = "msg_1" + self.time = 1700000000 + self.target = "bot_1" + + class ImageMessage(BaseMessage): + def __init__(self): + self.type = "image" + self.image = "https://example.com/a.jpg" + self.source = "user_1" + self.id = "msg_2" + self.time = 1700000000 + self.target = "bot_1" + + class VoiceMessage(BaseMessage): + def __init__(self): + self.type = "voice" + self.media_id = "media_1" + self.source = "user_1" + self.id = "msg_3" + self.time = 1700000000 + self.target = "bot_1" + + messages_mod.BaseMessage = BaseMessage + messages_mod.TextMessage = TextMessage + messages_mod.ImageMessage = ImageMessage + messages_mod.VoiceMessage = VoiceMessage + + class ImageReply: + def __init__(self, *args, **kwargs): + pass + def render(self): + return "image" + + class VoiceReply: + def __init__(self, *args, **kwargs): + pass + def render(self): + return "voice" + + replies_mod.ImageReply = ImageReply + replies_mod.VoiceReply = VoiceReply + + class WeChatClient: + def __init__(self, *args, **kwargs): + self.message = types.SimpleNamespace( + send_text=lambda *a, **k: {"errcode": 0}, + send_image=lambda *a, **k: {"errcode": 0}, + send_voice=lambda *a, **k: {"errcode": 0}, + send_file=lambda *a, **k: {"errcode": 0}, + ) + self.media = types.SimpleNamespace( + download=lambda media_id: types.SimpleNamespace(content=b"voice"), + upload=lambda *a, **k: {"media_id": "m1"}, + ) + wechatpy.WeChatClient = WeChatClient + wechatpy.create_reply = lambda text, msg: text + wechatpy.parse_message = lambda xml: TextMessage("xml") + + utils_mod.check_signature = lambda *args, **kwargs: True + + wechatpy.crypto = crypto_mod + wechatpy.exceptions = exceptions_mod + wechatpy.messages = messages_mod + wechatpy.replies = replies_mod + wechatpy.utils = utils_mod + sys.modules["wechatpy"] = wechatpy + sys.modules["wechatpy.crypto"] = crypto_mod + sys.modules["wechatpy.exceptions"] = exceptions_mod + sys.modules["wechatpy.messages"] = messages_mod + sys.modules["wechatpy.replies"] = replies_mod + sys.modules["wechatpy.utils"] = utils_mod + + from astrbot.core.platform.sources.weixin_official_account.weixin_offacc_adapter import ( + WeixinOfficialAccountPlatformAdapter, + ) + + queue = asyncio.Queue() + adapter = WeixinOfficialAccountPlatformAdapter( + { + "id": "wxoa_test", + "appid": "appid", + "secret": "secret", + "token": "token", + "encoding_aes_key": "x" * 43, + "port": "8081", + "callback_server_host": "0.0.0.0", + }, + {}, + queue, + ) + assert adapter.meta().name == "weixin_official_account" + """ + )