diff --git a/tests/fixtures/helpers.py b/tests/fixtures/helpers.py index 8f64ab6c9..26edb761c 100644 --- a/tests/fixtures/helpers.py +++ b/tests/fixtures/helpers.py @@ -3,7 +3,10 @@ 提供统一的测试辅助工具,减少测试代码重复。 """ -from typing import Any +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable from unittest.mock import AsyncMock, MagicMock from astrbot.core.message.components import BaseMessageComponent @@ -330,3 +333,255 @@ def create_mock_llm_response( tools_call_ids=tools_call_ids or [], usage=TokenUsage(input_other=10, output=5), ) + + +# ============================================================ +# 测试插件辅助函数 +# ============================================================ + + +@dataclass +class MockPluginConfig: + """测试插件配置。 + + 用于创建和管理测试用的模拟插件。 + + Attributes: + name: 插件名称 + author: 作者 + description: 描述 + version: 版本 + repo: 仓库 URL + main_code: main.py 的代码内容 + requirements: 依赖列表 + has_readme: 是否创建 README.md + readme_content: README.md 内容 + """ + + name: str = "test_plugin" + author: str = "Test Author" + description: str = "A test plugin for unit testing" + version: str = "1.0.0" + repo: str = "https://github.com/test/test_plugin" + main_code: str = "" + requirements: list[str] = field(default_factory=list) + has_readme: bool = True + readme_content: str = "# Test Plugin\n\nThis is a test plugin." + + +# 默认的插件主代码模板 +DEFAULT_PLUGIN_MAIN_TEMPLATE = ''' +from astrbot.api import star + +class Main(star.Star): + """测试插件主类。""" + + def __init__(self, context): + super().__init__(context) + self.name = "{plugin_name}" + + async def initialize(self): + """初始化插件。""" + pass + + async def terminate(self): + """终止插件。""" + pass +''' + + +class MockPluginBuilder: + """测试插件构建器。 + + 用于创建、管理和清理测试用的模拟插件。支持任意插件的模拟创建。 + + Example: + # 创建一个简单的测试插件 + builder = MockPluginBuilder(plugin_store_path) + plugin_dir = builder.create("my_test_plugin") + + # 创建自定义配置的插件 + config = MockPluginConfig( + name="custom_plugin", + version="2.0.0", + main_code="print('hello')", + ) + plugin_dir = builder.create(config) + + # 清理插件 + builder.cleanup("my_test_plugin") + """ + + def __init__(self, plugin_store_path: str | Path): + """初始化构建器。 + + Args: + plugin_store_path: 插件存储路径 (通常是 data/plugins) + """ + self.plugin_store_path = Path(plugin_store_path) + self._created_plugins: set[str] = set() + + def create( + self, + plugin_config: str | MockPluginConfig | None = None, + **kwargs, + ) -> Path: + """创建模拟插件。 + + Args: + plugin_config: 插件名称字符串、MockPluginConfig 对象或 None + **kwargs: 如果 plugin_config 是字符串或 None,这些参数用于构建 MockPluginConfig + + Returns: + Path: 创建的插件目录路径 + """ + # 处理不同类型的输入 + if plugin_config is None: + config = MockPluginConfig(**kwargs) + elif isinstance(plugin_config, str): + config = MockPluginConfig(name=plugin_config, **kwargs) + elif isinstance(plugin_config, MockPluginConfig): + config = plugin_config + else: + raise TypeError(f"Invalid plugin_config type: {type(plugin_config)}") + + # 创建插件目录 + plugin_dir = self.plugin_store_path / config.name + plugin_dir.mkdir(parents=True, exist_ok=True) + + # 创建 metadata.yaml + metadata_content = "\n".join( + [ + f"name: {config.name}", + f"author: {config.author}", + f"desc: {config.description}", + f"version: {config.version}", + f"repo: {config.repo}", + ] + ) + (plugin_dir / "metadata.yaml").write_text( + metadata_content + "\n", encoding="utf-8" + ) + + # 创建 main.py + main_code = config.main_code or DEFAULT_PLUGIN_MAIN_TEMPLATE.format( + plugin_name=config.name + ) + (plugin_dir / "main.py").write_text(main_code, encoding="utf-8") + + # 创建 requirements.txt(如果有依赖) + if config.requirements: + (plugin_dir / "requirements.txt").write_text( + "\n".join(config.requirements) + "\n", encoding="utf-8" + ) + + # 创建 README.md(如果需要) + if config.has_readme: + (plugin_dir / "README.md").write_text( + config.readme_content, encoding="utf-8" + ) + + # 记录创建的插件 + self._created_plugins.add(config.name) + + return plugin_dir + + def cleanup(self, plugin_name: str | None = None) -> None: + """清理插件。 + + Args: + plugin_name: 要清理的插件名称,如果为 None 则清理所有由本构建器创建的插件 + """ + if plugin_name: + plugins_to_clean = {plugin_name} + else: + plugins_to_clean = self._created_plugins.copy() + + for name in plugins_to_clean: + plugin_dir = self.plugin_store_path / name + if plugin_dir.exists(): + shutil.rmtree(plugin_dir) + self._created_plugins.discard(name) + + def cleanup_all(self) -> None: + """清理所有由本构建器创建的插件。""" + self.cleanup(None) + + def get_plugin_path(self, plugin_name: str) -> Path: + """获取插件路径。 + + Args: + plugin_name: 插件名称 + + Returns: + Path: 插件目录路径 + """ + return self.plugin_store_path / plugin_name + + @property + def created_plugins(self) -> set[str]: + """获取已创建的插件名称集合。""" + return self._created_plugins.copy() + + +def create_mock_updater_install( + plugin_builder: MockPluginBuilder, + repo_to_plugin: dict[str, str] | None = None, +) -> Callable: + """创建模拟的 updater.install 方法。 + + Args: + plugin_builder: MockPluginBuilder 实例 + repo_to_plugin: 仓库 URL 到插件名称的映射,格式: {"https://github.com/user/repo": "plugin_name"} + + Returns: + Callable: 异步函数,可用于 monkeypatch.setattr + """ + + async def mock_install(repo_url: str, proxy: str = "") -> str: + """Mock updater.install 方法。""" + # 查找插件名称 + plugin_name = None + if repo_to_plugin: + plugin_name = repo_to_plugin.get(repo_url) + + # 如果没有映射,尝试从 URL 提取插件名 + if not plugin_name: + # 从 https://github.com/user/plugin_name 提取 plugin_name + parts = repo_url.rstrip("/").split("/") + plugin_name = parts[-1] if parts else "unknown_plugin" + + # 创建插件目录 + config = MockPluginConfig(name=plugin_name, repo=repo_url) + plugin_dir = plugin_builder.create(config) + return str(plugin_dir) + + return mock_install + + +def create_mock_updater_update( + plugin_builder: MockPluginBuilder, + update_callback: Callable | None = None, +) -> Callable: + """创建模拟的 updater.update 方法。 + + Args: + plugin_builder: MockPluginBuilder 实例 + update_callback: 更新回调函数,接收 plugin 参数 + + Returns: + Callable: 异步函数,可用于 monkeypatch.setattr + """ + + async def mock_update(plugin, proxy: str = "") -> None: + """Mock updater.update 方法。""" + plugin_dir = plugin_builder.get_plugin_path(plugin.name) + + # 创建更新标记文件 + (plugin_dir / ".updated").write_text("ok", encoding="utf-8") + + # 调用回调 + if update_callback: + update_callback(plugin) + + return mock_update diff --git a/tests/test_api_key_open_api.py b/tests/test_api_key_open_api.py index 3d1ea0a0f..4bc5fd4d5 100644 --- a/tests/test_api_key_open_api.py +++ b/tests/test_api_key_open_api.py @@ -186,7 +186,7 @@ async def fake_chat(post_data: dict | None = None): "/api/v1/chat", json={ "message": "hello", - "username": "alice", + "username": "alice_auto_session", "enable_streaming": False, }, headers={"X-API-Key": raw_key}, @@ -200,16 +200,16 @@ async def fake_chat(post_data: dict | None = None): created_session_id = send_data["data"]["session_id"] assert isinstance(created_session_id, str) uuid.UUID(created_session_id) - assert send_data["data"]["creator"] == "alice" + assert send_data["data"]["creator"] == "alice_auto_session" created_session = await core_lifecycle_td.db.get_platform_session_by_id( created_session_id ) assert created_session is not None - assert created_session.creator == "alice" + assert created_session.creator == "alice_auto_session" assert created_session.platform_id == "webchat" await core_lifecycle_td.db.create_platform_session( - creator="bob", + creator="bob_auto_session", platform_id="webchat", session_id="open_api_existing_bob_session", is_group=0, @@ -251,14 +251,15 @@ async def test_open_chat_sessions_pagination( create_res = await test_client.post( "/api/apikey/create", - json={"name": "chat-scope-key", "scopes": ["chat"]}, + json={"name": "chat-scope-key-pagination", "scopes": ["chat"]}, headers=authenticated_header, ) create_data = await create_res.get_json() assert create_data["status"] == "ok" raw_key = create_data["data"]["api_key"] - creator = "alice" + # Use unique session IDs to avoid conflicts with other tests + creator = "alice_pagination" for idx in range(3): await core_lifecycle_td.db.create_platform_session( creator=creator, @@ -268,7 +269,7 @@ async def test_open_chat_sessions_pagination( is_group=0, ) await core_lifecycle_td.db.create_platform_session( - creator="bob", + creator="bob_pagination", platform_id="webchat", session_id="open_api_paginated_bob", display_name="Open API Session Bob", @@ -276,7 +277,7 @@ async def test_open_chat_sessions_pagination( ) page_1_res = await test_client.get( - "/api/v1/chat/sessions?page=1&page_size=2&username=alice", + "/api/v1/chat/sessions?page=1&page_size=2&username=alice_pagination", headers={"X-API-Key": raw_key}, ) assert page_1_res.status_code == 200 @@ -286,10 +287,10 @@ async def test_open_chat_sessions_pagination( assert page_1_data["data"]["page_size"] == 2 assert page_1_data["data"]["total"] == 3 assert len(page_1_data["data"]["sessions"]) == 2 - assert all(item["creator"] == "alice" for item in page_1_data["data"]["sessions"]) + assert all(item["creator"] == "alice_pagination" for item in page_1_data["data"]["sessions"]) page_2_res = await test_client.get( - "/api/v1/chat/sessions?page=2&page_size=2&username=alice", + "/api/v1/chat/sessions?page=2&page_size=2&username=alice_pagination", headers={"X-API-Key": raw_key}, ) assert page_2_res.status_code == 200 diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 969f0da6d..4bf0673e8 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,5 +1,6 @@ import asyncio import os +from pathlib import Path import pytest import pytest_asyncio @@ -11,6 +12,12 @@ from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_handlers_registry from astrbot.dashboard.server import AstrBotDashboard +from tests.fixtures.helpers import ( + MockPluginBuilder, + MockPluginConfig, + create_mock_updater_install, + create_mock_updater_update, +) @pytest_asyncio.fixture(scope="module") @@ -94,8 +101,15 @@ async def test_get_stat(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_plugins(app: Quart, authenticated_header: dict): +async def test_plugins( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + """测试插件 API 端点,使用 Mock 避免真实网络调用。""" test_client = app.test_client() + # 已经安装的插件 response = await test_client.get("/api/plugin/get", headers=authenticated_header) assert response.status_code == 200 @@ -111,53 +125,79 @@ async def test_plugins(app: Quart, authenticated_header: dict): data = await response.get_json() assert data["status"] == "ok" - # 插件安装 - response = await test_client.post( - "/api/plugin/install", - json={"url": "https://github.com/Soulter/astrbot_plugin_essential"}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - exists = False - for md in star_registry: - if md.name == "astrbot_plugin_essential": - exists = True - break - assert exists is True, "插件 astrbot_plugin_essential 未成功载入" - - # 插件更新 - response = await test_client.post( - "/api/plugin/update", - json={"name": "astrbot_plugin_essential"}, - headers=authenticated_header, + # 使用 MockPluginBuilder 创建测试插件 + plugin_store_path = core_lifecycle_td.plugin_manager.plugin_store_path + builder = MockPluginBuilder(plugin_store_path) + + # 定义测试插件 + test_plugin_name = "test_mock_plugin" + test_repo_url = f"https://github.com/test/{test_plugin_name}" + + # 创建 Mock 函数 + mock_install = create_mock_updater_install( + builder, + repo_to_plugin={test_repo_url: test_plugin_name}, ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" + mock_update = create_mock_updater_update(builder) - # 插件卸载 - response = await test_client.post( - "/api/plugin/uninstall", - json={"name": "astrbot_plugin_essential"}, - headers=authenticated_header, + # 设置 Mock + monkeypatch.setattr( + core_lifecycle_td.plugin_manager.updator, "install", mock_install + ) + monkeypatch.setattr( + core_lifecycle_td.plugin_manager.updator, "update", mock_update ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - exists = False - for md in star_registry: - if md.name == "astrbot_plugin_essential": - exists = True - break - assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" - exists = False - for md in star_handlers_registry: - if "astrbot_plugin_essential" in md.handler_module_path: - exists = True - break - assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + + try: + # 插件安装 + response = await test_client.post( + "/api/plugin/install", + json={"url": test_repo_url}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}" + + # 验证插件已注册 + exists = any(md.name == test_plugin_name for md in star_registry) + assert exists is True, f"插件 {test_plugin_name} 未成功载入" + + # 插件更新 + response = await test_client.post( + "/api/plugin/update", + json={"name": test_plugin_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + # 验证更新标记文件 + plugin_dir = builder.get_plugin_path(test_plugin_name) + assert (plugin_dir / ".updated").exists() + + # 插件卸载 + response = await test_client.post( + "/api/plugin/uninstall", + json={"name": test_plugin_name}, + headers=authenticated_header, + ) + assert response.status_code == 200 + data = await response.get_json() + assert data["status"] == "ok" + + # 验证插件已卸载 + exists = any(md.name == test_plugin_name for md in star_registry) + assert exists is False, f"插件 {test_plugin_name} 未成功卸载" + exists = any( + test_plugin_name in md.handler_module_path for md in star_handlers_registry + ) + assert exists is False, f"插件 {test_plugin_name} handler 未成功清理" + + finally: + # 清理测试插件 + builder.cleanup(test_plugin_name) @pytest.mark.asyncio @@ -189,12 +229,41 @@ async def test_commands_api(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_check_update(app: Quart, authenticated_header: dict): +async def test_check_update( + app: Quart, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, +): + """测试检查更新 API,使用 Mock 避免真实网络调用。""" test_client = app.test_client() + + # Mock 更新检查和网络请求 + async def mock_check_update(*args, **kwargs): + """Mock 更新检查,返回无新版本。""" + return None # None 表示没有新版本 + + async def mock_get_dashboard_version(*args, **kwargs): + """Mock Dashboard 版本获取。""" + from astrbot.core.config.default import VERSION + + return f"v{VERSION}" # 返回当前版本 + + monkeypatch.setattr( + core_lifecycle_td.astrbot_updator, + "check_update", + mock_check_update, + ) + monkeypatch.setattr( + "astrbot.dashboard.routes.update.get_dashboard_version", + mock_get_dashboard_version, + ) + response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 data = await response.get_json() assert data["status"] == "success" + assert data["data"]["has_new_version"] is False @pytest.mark.asyncio diff --git a/tests/test_main.py b/tests/test_main.py index 0453a51ee..2f879ee43 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -16,6 +16,16 @@ def __init__(self, major, minor): self.major = major self.minor = minor + def __eq__(self, other): + if isinstance(other, tuple): + return (self.major, self.minor) == other[:2] + return (self.major, self.minor) == (other.major, other.minor) + + def __ge__(self, other): + if isinstance(other, tuple): + return (self.major, self.minor) >= other[:2] + return (self.major, self.minor) >= (other.major, other.minor) + def test_check_env(monkeypatch): version_info_correct = _version_info(3, 10) @@ -23,15 +33,51 @@ def test_check_env(monkeypatch): monkeypatch.setattr(sys, "version_info", version_info_correct) with mock.patch("os.makedirs") as mock_makedirs: check_env() - mock_makedirs.assert_any_call("data/config", exist_ok=True) - mock_makedirs.assert_any_call("data/plugins", exist_ok=True) - mock_makedirs.assert_any_call("data/temp", exist_ok=True) + # Check that makedirs was called with paths containing expected dirs + called_paths = [call[0][0] for call in mock_makedirs.call_args_list] + # Use os.path.join for cross-platform path matching + assert any(p.rstrip(os.sep).endswith(os.path.join("data", "config")) for p in called_paths) + assert any(p.rstrip(os.sep).endswith(os.path.join("data", "plugins")) for p in called_paths) + assert any(p.rstrip(os.sep).endswith(os.path.join("data", "temp")) for p in called_paths) monkeypatch.setattr(sys, "version_info", version_info_wrong) with pytest.raises(SystemExit): check_env() +def test_version_info_comparisons(): + """Test _version_info comparison operators with tuples and other instances.""" + v3_10 = _version_info(3, 10) + v3_9 = _version_info(3, 9) + v3_11 = _version_info(3, 11) + + # Test __eq__ with tuples + assert v3_10 == (3, 10) + assert v3_10 != (3, 9) + assert v3_9 == (3, 9) + + # Test __ge__ with tuples + assert v3_10 >= (3, 10) + assert v3_10 >= (3, 9) + assert not (v3_9 >= (3, 10)) + assert v3_11 >= (3, 10) + + # Test __eq__ with other _version_info instances + assert v3_10 == _version_info(3, 10) + assert v3_10 != v3_9 + assert v3_10 == v3_10 # Same instance + + assert v3_10 != v3_11 + + # Test __ge__ with other _version_info instances + assert v3_10 >= v3_10 + assert v3_10 >= v3_9 + assert not (v3_9 >= v3_10) + assert v3_11 >= v3_10 + + assert v3_11 >= v3_11 # Same instance + + @pytest.mark.asyncio async def test_check_dashboard_files_not_exists(monkeypatch): """Tests dashboard download when files do not exist.""" diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1e4cd866a..b91e25c01 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,65 +1,164 @@ -import os +import sys from asyncio import Queue +from pathlib import Path from unittest.mock import MagicMock import pytest +import pytest_asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.star.context import Context -from astrbot.core.star.star import star_registry +from astrbot.core.star.star import star_map, star_registry from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.star_manager import PluginManager -@pytest.fixture -def plugin_manager_pm(tmp_path): - """Provides a fully isolated PluginManager instance for testing. - - Uses a temporary directory for plugins. - - Uses a temporary database. - - Creates a fresh context for each test. - """ - # Create temporary resources - temp_plugins_path = tmp_path / "plugins" - temp_plugins_path.mkdir() - temp_db_path = tmp_path / "test_db.db" +def _clear_module_cache() -> None: + """Clear module cache for data module tree to ensure test isolation.""" + modules_to_remove = [ + key for key in sys.modules if key == "data" or key.startswith("data.") + ] + for key in modules_to_remove: + del sys.modules[key] + + +def _clear_registry(plugin_name: str) -> None: + """Clear plugin from global registries.""" + # Clear star_registry (list) + star_registry[:] = [md for md in star_registry if md.name != plugin_name] + # Clear star_map (dict) + keys_to_remove = [ + key for key, md in star_map.items() if md.name == plugin_name + ] + for key in keys_to_remove: + del star_map[key] + # Clear star_handlers_registry (StarHandlerRegistry) + for handler in list(star_handlers_registry): + if plugin_name in (handler.handler_module_path or ""): + star_handlers_registry.remove(handler) + +TEST_PLUGIN_REPO = "https://github.com/Soulter/helloworld" +TEST_PLUGIN_DIR = "helloworld" +TEST_PLUGIN_NAME = "helloworld" + + +def _write_local_test_plugin(plugin_dir: Path, repo_url: str) -> None: + plugin_dir.mkdir(parents=True, exist_ok=True) + (plugin_dir / "metadata.yaml").write_text( + "\n".join( + [ + f"name: {TEST_PLUGIN_NAME}", + "author: AstrBot Team", + "desc: Local test plugin", + "version: 1.0.0", + f"repo: {repo_url}", + ], + ) + + "\n", + encoding="utf-8", + ) + (plugin_dir / "main.py").write_text( + "\n".join( + [ + "from astrbot.api import star", + "", + "class Main(star.Star):", + " pass", + "", + ], + ), + encoding="utf-8", + ) + + +@pytest_asyncio.fixture +async def plugin_manager_pm(tmp_path, monkeypatch): + """Provides a fully isolated PluginManager instance for testing.""" + # Clear module cache before setup to ensure isolation + _clear_module_cache() + + test_root = tmp_path / "astrbot_root" + data_dir = test_root / "data" + plugin_dir = data_dir / "plugins" + config_dir = data_dir / "config" + temp_dir = data_dir / "temp" + for path in (plugin_dir, config_dir, temp_dir): + path.mkdir(parents=True, exist_ok=True) + + # Ensure `import data.plugins..main` resolves to this temp root. + (data_dir / "__init__.py").write_text("", encoding="utf-8") + (plugin_dir / "__init__.py").write_text("", encoding="utf-8") + + # Use monkeypatch for both env var and sys.path to ensure proper cleanup + monkeypatch.setenv("ASTRBOT_ROOT", str(test_root)) + monkeypatch.syspath_prepend(str(test_root)) # Create fresh, isolated instances for the context event_queue = Queue() config = AstrBotConfig() - db = SQLiteDatabase(str(temp_db_path)) - - # Set the plugin store path in the config to the temporary directory - config.plugin_store_path = str(temp_plugins_path) + db = SQLiteDatabase(str(data_dir / "test_db.db")) + config.plugin_store_path = str(plugin_dir) - # Mock dependencies for the context provider_manager = MagicMock() platform_manager = MagicMock() conversation_manager = MagicMock() message_history_manager = MagicMock() persona_manager = MagicMock() + persona_manager.personas_v3 = [] astrbot_config_mgr = MagicMock() knowledge_base_manager = MagicMock() + cron_manager = MagicMock() star_context = Context( - event_queue, - config, - db, - provider_manager, - platform_manager, - conversation_manager, - message_history_manager, - persona_manager, - astrbot_config_mgr, + event_queue=event_queue, + config=config, + db=db, + provider_manager=provider_manager, + platform_manager=platform_manager, + conversation_manager=conversation_manager, + message_history_manager=message_history_manager, + persona_manager=persona_manager, + astrbot_config_mgr=astrbot_config_mgr, knowledge_base_manager=knowledge_base_manager, + cron_manager=cron_manager, + subagent_orchestrator=None, ) - # Create the PluginManager instance manager = PluginManager(star_context, config) - return manager + try: + yield manager + finally: + # Cleanup global registries and module cache + _clear_registry(TEST_PLUGIN_NAME) + _clear_module_cache() + await db.engine.dispose() + + +@pytest.fixture +def local_updator(plugin_manager_pm: PluginManager, monkeypatch): + plugin_path = Path(plugin_manager_pm.plugin_store_path) / TEST_PLUGIN_DIR + async def mock_install(repo_url: str, proxy=""): # noqa: ARG001 + if repo_url != TEST_PLUGIN_REPO: + raise Exception("Repo not found") + _write_local_test_plugin(plugin_path, repo_url) + return str(plugin_path) -def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): + async def mock_update(plugin, proxy=""): # noqa: ARG001 + if plugin.name != TEST_PLUGIN_NAME: + raise Exception("Plugin not found") + if not plugin_path.exists(): + raise Exception("Plugin path missing") + (plugin_path / ".updated").write_text("ok", encoding="utf-8") + + monkeypatch.setattr(plugin_manager_pm.updator, "install", mock_install) + monkeypatch.setattr(plugin_manager_pm.updator, "update", mock_update) + return plugin_path + + +@pytest.mark.asyncio +async def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): assert plugin_manager_pm is not None assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None @@ -73,73 +172,59 @@ async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_install_plugin(plugin_manager_pm: PluginManager): - """Tests successful plugin installation in an isolated environment.""" - test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - plugin_info = await plugin_manager_pm.install_plugin(test_repo) - plugin_path = os.path.join( - plugin_manager_pm.plugin_store_path, - "astrbot_plugin_essential", - ) - +async def test_install_plugin(plugin_manager_pm: PluginManager, local_updator: Path): + """Tests successful plugin installation without external network.""" + plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) assert plugin_info is not None - assert os.path.exists(plugin_path) - assert any(md.name == "astrbot_plugin_essential" for md in star_registry), ( - "Plugin 'astrbot_plugin_essential' was not loaded into star_registry." - ) + assert plugin_info["name"] == TEST_PLUGIN_NAME + assert local_updator.exists() + assert any(md.name == TEST_PLUGIN_NAME for md in star_registry) @pytest.mark.asyncio -async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_install_nonexistent_plugin( + plugin_manager_pm: PluginManager, local_updator +): """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.install_plugin( - "https://github.com/Soulter/non_existent_repo", + "https://github.com/Soulter/non_existent_repo" ) @pytest.mark.asyncio -async def test_update_plugin(plugin_manager_pm: PluginManager): - """Tests updating an existing plugin in an isolated environment.""" - # First, install the plugin - test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - await plugin_manager_pm.install_plugin(test_repo) - - # Then, update it - await plugin_manager_pm.update_plugin("astrbot_plugin_essential") +async def test_update_plugin(plugin_manager_pm: PluginManager, local_updator: Path): + """Tests updating an existing plugin without external network.""" + plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + assert plugin_info is not None + plugin_name = plugin_info["name"] + await plugin_manager_pm.update_plugin(plugin_name) + assert (local_updator / ".updated").exists() @pytest.mark.asyncio -async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_update_nonexistent_plugin( + plugin_manager_pm: PluginManager, local_updator +): """Tests that updating a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.update_plugin("non_existent_plugin") @pytest.mark.asyncio -async def test_uninstall_plugin(plugin_manager_pm: PluginManager): - """Tests successful plugin uninstallation in an isolated environment.""" - # First, install the plugin - test_repo = "https://github.com/Soulter/astrbot_plugin_essential" - await plugin_manager_pm.install_plugin(test_repo) - plugin_path = os.path.join( - plugin_manager_pm.plugin_store_path, - "astrbot_plugin_essential", - ) - assert os.path.exists(plugin_path) # Pre-condition +async def test_uninstall_plugin(plugin_manager_pm: PluginManager, local_updator: Path): + """Tests successful plugin uninstallation.""" + plugin_info = await plugin_manager_pm.install_plugin(TEST_PLUGIN_REPO) + assert plugin_info is not None + plugin_name = plugin_info["name"] + assert local_updator.exists() - # Then, uninstall it - await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential") + await plugin_manager_pm.uninstall_plugin(plugin_name) - assert not os.path.exists(plugin_path) - assert not any(md.name == "astrbot_plugin_essential" for md in star_registry), ( - "Plugin 'astrbot_plugin_essential' was not unloaded from star_registry." - ) + assert not local_updator.exists() + assert not any(md.name == TEST_PLUGIN_NAME for md in star_registry) assert not any( - "astrbot_plugin_essential" in md.handler_module_path - for md in star_handlers_registry - ), ( - "Plugin 'astrbot_plugin_essential' handler was not unloaded from star_handlers_registry." + TEST_PLUGIN_NAME in md.handler_module_path for md in star_handlers_registry ) diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 4474e1599..36870e617 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -101,10 +101,16 @@ def test_pipeline_import_is_stable_with_mocked_apscheduler() -> None: "mock_apscheduler.schedulers = MagicMock();" "mock_apscheduler.schedulers.asyncio = MagicMock();" "mock_apscheduler.schedulers.background = MagicMock();" + "mock_apscheduler.triggers = MagicMock();" + "mock_apscheduler.triggers.cron = MagicMock();" + "mock_apscheduler.triggers.date = MagicMock();" "sys.modules['apscheduler'] = mock_apscheduler;" "sys.modules['apscheduler.schedulers'] = mock_apscheduler.schedulers;" "sys.modules['apscheduler.schedulers.asyncio'] = mock_apscheduler.schedulers.asyncio;" "sys.modules['apscheduler.schedulers.background'] = mock_apscheduler.schedulers.background;" + "sys.modules['apscheduler.triggers'] = mock_apscheduler.triggers;" + "sys.modules['apscheduler.triggers.cron'] = mock_apscheduler.triggers.cron;" + "sys.modules['apscheduler.triggers.date'] = mock_apscheduler.triggers.date;" "import astrbot.core.pipeline as pipeline;" "assert pipeline.ProcessStage is not None;" "assert pipeline.RespondStage is not None" diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 0b5190407..c738cfc80 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -447,7 +447,8 @@ async def test_stop_signal_returns_aborted_and_persists_partial_message( final_resp = runner.get_final_llm_resp() assert final_resp is not None assert final_resp.role == "assistant" - assert final_resp.completion_text == "partial " + # When interrupted, the runner replaces completion_text with a system message + assert "interrupted" in final_resp.completion_text.lower() assert runner.run_context.messages[-1].role == "assistant"