Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions astrbot/api/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
CommandResult,
EventResultType,
)
from astrbot.core.platform import AstrMessageEvent

# star register
from astrbot.core.star.register import (
Expand All @@ -31,8 +30,9 @@
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
)
from astrbot.core.star import Context, Star
from astrbot.core.star.base import Star
from astrbot.core.star.config import *
from astrbot.core.star.context import Context


# provider
Expand Down
4 changes: 3 additions & 1 deletion astrbot/api/star/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.base import Star
from astrbot.core.star.config import *
from astrbot.core.star.context import Context
from astrbot.core.star.register import (
register_star as register, # 注册插件(Star)
)
from astrbot.core.star.star_tools import StarTools

__all__ = ["Context", "Star", "StarTools", "register"]
1 change: 1 addition & 0 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ def _apply_sandbox_tools(
) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
req.system_prompt = req.system_prompt or ""
booter = config.sandbox_cfg.get("booter", "shipyard_neo")
if booter == "shipyard":
ep = config.sandbox_cfg.get("shipyard_endpoint", "")
Expand Down
21 changes: 20 additions & 1 deletion astrbot/core/cron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
from .manager import CronJobManager
"""Cron package exports.

Keep `CronJobManager` import-compatible while avoiding hard import failure when
`apscheduler` is partially mocked in test environments.
"""

try:
from .manager import CronJobManager
except ModuleNotFoundError as exc:
if not (exc.name and exc.name.startswith("apscheduler")):
raise

_IMPORT_ERROR = exc

class CronJobManager:
def __init__(self, *args, **kwargs) -> None:
raise ModuleNotFoundError(
"CronJobManager requires a complete `apscheduler` installation."
) from _IMPORT_ERROR


__all__ = ["CronJobManager"]
167 changes: 131 additions & 36 deletions tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,24 @@
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.star import star_registry
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.zip_updator import ReleaseInfo
from astrbot.dashboard.server import AstrBotDashboard
from tests.fixtures.helpers import (
MockPluginBuilder,
MockPluginConfig,
create_mock_updater_install,
create_mock_updater_update,
)

RUN_ONLINE_UPDATE_CHECK = os.environ.get(
"ASTRBOT_RUN_ONLINE_UPDATE_CHECK", ""
).lower() in {
"1",
"true",
"yes",
}

@pytest_asyncio.fixture(scope="module")

@pytest_asyncio.fixture(scope="module", loop_scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_v3.db"
Expand Down Expand Up @@ -51,7 +59,7 @@ def app(core_lifecycle_td: AstrBotCoreLifecycle):
return server.app


@pytest_asyncio.fixture(scope="module")
@pytest_asyncio.fixture(scope="module", loop_scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
Expand Down Expand Up @@ -108,16 +116,14 @@ async def test_plugins(
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
"""测试插件 API 端点,使用 Mock 避免真实网络调用。"""
"""Test plugin APIs with mocked updater behavior."""
test_client = app.test_client()

# 已经安装的插件
response = await test_client.get("/api/plugin/get", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"

# 插件市场
response = await test_client.get(
"/api/plugin/market_list",
headers=authenticated_header,
Expand All @@ -126,45 +132,38 @@ async def test_plugins(
data = await response.get_json()
assert data["status"] == "ok"

# 使用 MockPluginBuilder 创建测试插件
plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path
builder = MockPluginBuilder(plugin_store_path)

# 定义测试插件
test_plugin_name = "test_mock_plugin"
test_repo_url = f"https://github.com/test/{test_plugin_name}"

# 创建 Mock 函数
mock_install = create_mock_updater_install(
builder,
repo_to_plugin={test_repo_url: test_plugin_name},
)
mock_update = create_mock_updater_update(builder)

# 设置 Mock
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "install", mock_install
)
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "update", mock_update
)
monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update)

try:
# 插件安装
response = await test_client.post(
"/api/plugin/install",
json={"url": test_repo_url},
headers=authenticated_header,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}"
assert data["status"] == "ok", (
f"install failed: {data.get('message', 'unknown error')}"
)

# 验证插件已注册
exists = any(md.name == test_plugin_name for md in star_registry)
assert exists is True, f"插件 {test_plugin_name} 未成功载入"
assert exists is True, f"plugin {test_plugin_name} was not loaded"

# 插件更新
response = await test_client.post(
"/api/plugin/update",
json={"name": test_plugin_name},
Expand All @@ -174,11 +173,9 @@ async def test_plugins(
data = await response.get_json()
assert data["status"] == "ok"

# 验证更新标记文件
plugin_dir = builder.get_plugin_path(test_plugin_name)
assert (plugin_dir / ".updated").exists()

# 插件卸载
response = await test_client.post(
"/api/plugin/uninstall",
json={"name": test_plugin_name},
Expand All @@ -188,16 +185,14 @@ async def test_plugins(
data = await response.get_json()
assert data["status"] == "ok"

# 验证插件已卸载
exists = any(md.name == test_plugin_name for md in star_registry)
assert exists is False, f"插件 {test_plugin_name} 未成功卸载"
assert exists is False, f"plugin {test_plugin_name} was not unloaded"
exists = any(
test_plugin_name in md.handler_module_path for md in star_handlers_registry
)
assert exists is False, f"插件 {test_plugin_name} handler 未成功清理"
assert exists is False, f"plugin {test_plugin_name} handlers were not cleaned"

finally:
# 清理测试插件
builder.cleanup(test_plugin_name)


Expand Down Expand Up @@ -230,41 +225,141 @@ async def test_commands_api(app: Quart, authenticated_header: dict):


@pytest.mark.asyncio
async def test_check_update(
async def test_check_update_success_no_new_version(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
"""测试检查更新 API,使用 Mock 避免真实网络调用。"""
async def mock_get_dashboard_version():
return "v-test-dashboard"

async def mock_check_update(*args, **kwargs): # noqa: ARG001
return None

monkeypatch.setattr(
"astrbot.dashboard.routes.update.get_dashboard_version",
mock_get_dashboard_version,
)
monkeypatch.setattr(
core_lifecycle_td.astrbot_updator,
"check_update",
mock_check_update,
)
test_client = app.test_client()

# Mock 更新检查和网络请求
async def mock_check_update(*args, **kwargs):
"""Mock 更新检查,返回无新版本。"""
return None # None 表示没有新版本
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"
assert {
"version",
"has_new_version",
"dashboard_version",
"dashboard_has_new_version",
}.issubset(data["data"])
assert data["data"]["has_new_version"] is False
assert data["data"]["dashboard_version"] == "v-test-dashboard"

async def mock_get_dashboard_version(*args, **kwargs):
"""Mock Dashboard 版本获取。"""
from astrbot.core.config.default import VERSION

return f"v{VERSION}" # 返回当前版本
@pytest.mark.asyncio
async def test_check_update_success_has_new_version(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
async def mock_get_dashboard_version():
return "v-test-dashboard"

async def mock_check_update(*args, **kwargs): # noqa: ARG001
return ReleaseInfo(
version="v999.0.0",
published_at="2026-01-01",
body="test release",
)

monkeypatch.setattr(
"astrbot.dashboard.routes.update.get_dashboard_version",
mock_get_dashboard_version,
)
monkeypatch.setattr(
core_lifecycle_td.astrbot_updator,
"check_update",
mock_check_update,
)

test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"
assert {
"version",
"has_new_version",
"dashboard_version",
"dashboard_has_new_version",
}.issubset(data["data"])
assert data["data"]["has_new_version"] is True
assert data["data"]["dashboard_version"] == "v-test-dashboard"


@pytest.mark.asyncio
async def test_check_update_error_when_updator_raises(
app: Quart,
authenticated_header: dict,
core_lifecycle_td: AstrBotCoreLifecycle,
monkeypatch,
):
async def mock_get_dashboard_version():
return "v-test-dashboard"

async def mock_check_update(*args, **kwargs): # noqa: ARG001
raise RuntimeError("mock update check failure")

monkeypatch.setattr(
"astrbot.dashboard.routes.update.get_dashboard_version",
mock_get_dashboard_version,
)
monkeypatch.setattr(
core_lifecycle_td.astrbot_updator,
"check_update",
mock_check_update,
)

test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "success"
assert data["data"]["has_new_version"] is False
assert data["status"] == "error"
assert isinstance(data["message"], str)
assert data["message"]


@pytest.mark.asyncio
@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.skipif(
not RUN_ONLINE_UPDATE_CHECK,
reason="Set ASTRBOT_RUN_ONLINE_UPDATE_CHECK=1 to run online update check test.",
)
async def test_check_update_online_optional(app: Quart, authenticated_header: dict):
"""Optional online smoke test for the real update-check request path."""
test_client = app.test_client()
response = await test_client.get("/api/update/check", headers=authenticated_header)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] in {"success", "error"}
assert "message" in data
assert "data" in data

if data["status"] == "success":
assert {
"version",
"has_new_version",
"dashboard_version",
"dashboard_has_new_version",
}.issubset(data["data"])


@pytest.mark.asyncio
Expand Down
13 changes: 7 additions & 6 deletions tests/test_kb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from astrbot.dashboard.server import AstrBotDashboard


@pytest_asyncio.fixture(scope="module")
@pytest_asyncio.fixture(scope="module", loop_scope="module")
async def core_lifecycle_td(tmp_path_factory):
"""Creates and initializes a core lifecycle instance with a temporary database."""
tmp_db_path = tmp_path_factory.mktemp("data") / "test_data_kb.db"
Expand All @@ -24,7 +24,8 @@ async def core_lifecycle_td(tmp_path_factory):

# Mock kb_manager and kb_helper
kb_manager = MagicMock()
kb_helper = AsyncMock(spec=KBHelper)
kb_helper = MagicMock(spec=KBHelper)
kb_helper.upload_document = AsyncMock()

# Configure get_kb to be an async mock that returns kb_helper
kb_manager.get_kb = AsyncMock(return_value=kb_helper)
Expand Down Expand Up @@ -64,7 +65,7 @@ def app(core_lifecycle_td: AstrBotCoreLifecycle):
return server.app


@pytest_asyncio.fixture(scope="module")
@pytest_asyncio.fixture(scope="module", loop_scope="module")
async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle):
"""Handles login and returns an authenticated header."""
test_client = app.test_client()
Expand Down Expand Up @@ -129,11 +130,11 @@ async def test_import_documents(
assert result["failed_count"] == 0

# Verify kb_helper.upload_document was called correctly
kb_helper = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
assert kb_helper.upload_document.call_count == 2
kb_helper_mock = await core_lifecycle_td.kb_manager.get_kb("test_kb_id")
assert kb_helper_mock.upload_document.call_count == 2

# Check first call arguments
call_args_list = kb_helper.upload_document.call_args_list
call_args_list = kb_helper_mock.upload_document.call_args_list

# First document
args1, kwargs1 = call_args_list[0]
Expand Down
Loading