Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 256 additions & 1 deletion tests/fixtures/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
提供统一的测试辅助工具,减少测试代码重复。
"""

from typing import Any
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
from unittest.mock import AsyncMock, MagicMock

from astrbot.core.message.components import BaseMessageComponent
Expand Down Expand Up @@ -330,3 +333,255 @@ def create_mock_llm_response(
tools_call_ids=tools_call_ids or [],
usage=TokenUsage(input_other=10, output=5),
)


# ============================================================
# 测试插件辅助函数
# ============================================================


@dataclass
class MockPluginConfig:
"""测试插件配置。
用于创建和管理测试用的模拟插件。
Attributes:
name: 插件名称
author: 作者
description: 描述
version: 版本
repo: 仓库 URL
main_code: main.py 的代码内容
requirements: 依赖列表
has_readme: 是否创建 README.md
readme_content: README.md 内容
"""

name: str = "test_plugin"
author: str = "Test Author"
description: str = "A test plugin for unit testing"
version: str = "1.0.0"
repo: str = "https://github.com/test/test_plugin"
main_code: str = ""
requirements: list[str] = field(default_factory=list)
has_readme: bool = True
readme_content: str = "# Test Plugin\n\nThis is a test plugin."


# 默认的插件主代码模板
DEFAULT_PLUGIN_MAIN_TEMPLATE = '''
from astrbot.api import star
class Main(star.Star):
"""测试插件主类。"""
def __init__(self, context):
super().__init__(context)
self.name = "{plugin_name}"
async def initialize(self):
"""初始化插件。"""
pass
async def terminate(self):
"""终止插件。"""
pass
'''


class MockPluginBuilder:
"""测试插件构建器。
用于创建、管理和清理测试用的模拟插件。支持任意插件的模拟创建。
Example:
# 创建一个简单的测试插件
builder = MockPluginBuilder(plugin_store_path)
plugin_dir = builder.create("my_test_plugin")
# 创建自定义配置的插件
config = MockPluginConfig(
name="custom_plugin",
version="2.0.0",
main_code="print('hello')",
)
plugin_dir = builder.create(config)
# 清理插件
builder.cleanup("my_test_plugin")
"""

def __init__(self, plugin_store_path: str | Path):
"""初始化构建器。
Args:
plugin_store_path: 插件存储路径 (通常是 data/plugins)
"""
self.plugin_store_path = Path(plugin_store_path)
self._created_plugins: set[str] = set()

def create(
self,
plugin_config: str | MockPluginConfig | None = None,
**kwargs,
) -> Path:
"""创建模拟插件。
Args:
plugin_config: 插件名称字符串、MockPluginConfig 对象或 None
**kwargs: 如果 plugin_config 是字符串或 None,这些参数用于构建 MockPluginConfig
Returns:
Path: 创建的插件目录路径
"""
# 处理不同类型的输入
if plugin_config is None:
config = MockPluginConfig(**kwargs)
elif isinstance(plugin_config, str):
config = MockPluginConfig(name=plugin_config, **kwargs)
elif isinstance(plugin_config, MockPluginConfig):
config = plugin_config
else:
raise TypeError(f"Invalid plugin_config type: {type(plugin_config)}")

# 创建插件目录
plugin_dir = self.plugin_store_path / config.name
plugin_dir.mkdir(parents=True, exist_ok=True)

# 创建 metadata.yaml
metadata_content = "\n".join(
[
f"name: {config.name}",
f"author: {config.author}",
f"desc: {config.description}",
f"version: {config.version}",
f"repo: {config.repo}",
]
)
(plugin_dir / "metadata.yaml").write_text(
metadata_content + "\n", encoding="utf-8"
)

# 创建 main.py
main_code = config.main_code or DEFAULT_PLUGIN_MAIN_TEMPLATE.format(
plugin_name=config.name
)
(plugin_dir / "main.py").write_text(main_code, encoding="utf-8")

# 创建 requirements.txt(如果有依赖)
if config.requirements:
(plugin_dir / "requirements.txt").write_text(
"\n".join(config.requirements) + "\n", encoding="utf-8"
)

# 创建 README.md(如果需要)
if config.has_readme:
(plugin_dir / "README.md").write_text(
config.readme_content, encoding="utf-8"
)

# 记录创建的插件
self._created_plugins.add(config.name)

return plugin_dir

def cleanup(self, plugin_name: str | None = None) -> None:
"""清理插件。
Args:
plugin_name: 要清理的插件名称,如果为 None 则清理所有由本构建器创建的插件
"""
if plugin_name:
plugins_to_clean = {plugin_name}
else:
plugins_to_clean = self._created_plugins.copy()

for name in plugins_to_clean:
plugin_dir = self.plugin_store_path / name
if plugin_dir.exists():
shutil.rmtree(plugin_dir)
self._created_plugins.discard(name)

def cleanup_all(self) -> None:
"""清理所有由本构建器创建的插件。"""
self.cleanup(None)

def get_plugin_path(self, plugin_name: str) -> Path:
"""获取插件路径。
Args:
plugin_name: 插件名称
Returns:
Path: 插件目录路径
"""
return self.plugin_store_path / plugin_name

@property
def created_plugins(self) -> set[str]:
"""获取已创建的插件名称集合。"""
return self._created_plugins.copy()


def create_mock_updater_install(
plugin_builder: MockPluginBuilder,
repo_to_plugin: dict[str, str] | None = None,
) -> Callable:
"""创建模拟的 updater.install 方法。
Args:
plugin_builder: MockPluginBuilder 实例
repo_to_plugin: 仓库 URL 到插件名称的映射,格式: {"https://github.com/user/repo": "plugin_name"}
Returns:
Callable: 异步函数,可用于 monkeypatch.setattr
"""

async def mock_install(repo_url: str, proxy: str = "") -> str:
"""Mock updater.install 方法。"""
# 查找插件名称
plugin_name = None
if repo_to_plugin:
plugin_name = repo_to_plugin.get(repo_url)

# 如果没有映射,尝试从 URL 提取插件名
if not plugin_name:
# 从 https://github.com/user/plugin_name 提取 plugin_name
parts = repo_url.rstrip("/").split("/")
plugin_name = parts[-1] if parts else "unknown_plugin"

# 创建插件目录
config = MockPluginConfig(name=plugin_name, repo=repo_url)
plugin_dir = plugin_builder.create(config)
return str(plugin_dir)

return mock_install


def create_mock_updater_update(
plugin_builder: MockPluginBuilder,
update_callback: Callable | None = None,
) -> Callable:
"""创建模拟的 updater.update 方法。
Args:
plugin_builder: MockPluginBuilder 实例
update_callback: 更新回调函数,接收 plugin 参数
Returns:
Callable: 异步函数,可用于 monkeypatch.setattr
"""

async def mock_update(plugin, proxy: str = "") -> None:
"""Mock updater.update 方法。"""
plugin_dir = plugin_builder.get_plugin_path(plugin.name)

# 创建更新标记文件
(plugin_dir / ".updated").write_text("ok", encoding="utf-8")

# 调用回调
if update_callback:
update_callback(plugin)

return mock_update
21 changes: 11 additions & 10 deletions tests/test_api_key_open_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def fake_chat(post_data: dict | None = None):
"/api/v1/chat",
json={
"message": "hello",
"username": "alice",
"username": "alice_auto_session",
"enable_streaming": False,
},
headers={"X-API-Key": raw_key},
Expand All @@ -200,16 +200,16 @@ async def fake_chat(post_data: dict | None = None):
created_session_id = send_data["data"]["session_id"]
assert isinstance(created_session_id, str)
uuid.UUID(created_session_id)
assert send_data["data"]["creator"] == "alice"
assert send_data["data"]["creator"] == "alice_auto_session"
created_session = await core_lifecycle_td.db.get_platform_session_by_id(
created_session_id
)
assert created_session is not None
assert created_session.creator == "alice"
assert created_session.creator == "alice_auto_session"
assert created_session.platform_id == "webchat"

await core_lifecycle_td.db.create_platform_session(
creator="bob",
creator="bob_auto_session",
platform_id="webchat",
session_id="open_api_existing_bob_session",
is_group=0,
Expand Down Expand Up @@ -251,14 +251,15 @@ async def test_open_chat_sessions_pagination(

create_res = await test_client.post(
"/api/apikey/create",
json={"name": "chat-scope-key", "scopes": ["chat"]},
json={"name": "chat-scope-key-pagination", "scopes": ["chat"]},
headers=authenticated_header,
)
create_data = await create_res.get_json()
assert create_data["status"] == "ok"
raw_key = create_data["data"]["api_key"]

creator = "alice"
# Use unique session IDs to avoid conflicts with other tests
creator = "alice_pagination"
for idx in range(3):
await core_lifecycle_td.db.create_platform_session(
creator=creator,
Expand All @@ -268,15 +269,15 @@ async def test_open_chat_sessions_pagination(
is_group=0,
)
await core_lifecycle_td.db.create_platform_session(
creator="bob",
creator="bob_pagination",
platform_id="webchat",
session_id="open_api_paginated_bob",
display_name="Open API Session Bob",
is_group=0,
)

page_1_res = await test_client.get(
"/api/v1/chat/sessions?page=1&page_size=2&username=alice",
"/api/v1/chat/sessions?page=1&page_size=2&username=alice_pagination",
headers={"X-API-Key": raw_key},
)
assert page_1_res.status_code == 200
Expand All @@ -286,10 +287,10 @@ async def test_open_chat_sessions_pagination(
assert page_1_data["data"]["page_size"] == 2
assert page_1_data["data"]["total"] == 3
assert len(page_1_data["data"]["sessions"]) == 2
assert all(item["creator"] == "alice" for item in page_1_data["data"]["sessions"])
assert all(item["creator"] == "alice_pagination" for item in page_1_data["data"]["sessions"])

page_2_res = await test_client.get(
"/api/v1/chat/sessions?page=2&page_size=2&username=alice",
"/api/v1/chat/sessions?page=2&page_size=2&username=alice_pagination",
headers={"X-API-Key": raw_key},
)
assert page_2_res.status_code == 200
Expand Down
Loading