From 5910274d8ba169d8cd8f76201fb201b2b8b628c2 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Mon, 23 Feb 2026 08:40:42 +0800 Subject: [PATCH 1/4] test: refactor and enhance P1 platform adapter tests - Refactor Telegram adapter tests to use shared mocks - Refactor Discord adapter tests to use shared mocks - Refactor Aiocqhttp adapter tests to use shared mocks - Fix sender name handling and command registration Co-Authored-By: Claude Sonnet 4.6 --- .../aiocqhttp/aiocqhttp_message_event.py | 20 +- .../aiocqhttp/aiocqhttp_platform_adapter.py | 90 +- .../discord/discord_platform_adapter.py | 118 +- .../platform/sources/telegram/tg_adapter.py | 105 +- .../platform/sources/telegram/tg_event.py | 14 +- tests/unit/test_aiocqhttp_adapter.py | 846 +++++++ tests/unit/test_discord_adapter.py | 1100 +++++++++ tests/unit/test_telegram_adapter.py | 2053 +++++++++++++++++ 8 files changed, 4214 insertions(+), 132 deletions(-) create mode 100644 tests/unit/test_aiocqhttp_adapter.py create mode 100644 tests/unit/test_discord_adapter.py create mode 100644 tests/unit/test_telegram_adapter.py diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 99ea727315..06b9f3ad72 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -2,10 +2,9 @@ import re from collections.abc import AsyncGenerator -from aiocqhttp import CQHttp, Event +import aiocqhttp -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import ( +from astrbot.core.message.components import ( BaseMessageComponent, File, Image, @@ -15,7 +14,8 @@ Record, Video, ) -from astrbot.api.platform import Group, MessageMember +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform import AstrMessageEvent, Group, MessageMember class AiocqhttpMessageEvent(AstrMessageEvent): @@ -25,7 +25,7 @@ def __init__( message_obj, platform_meta, session_id, - bot: CQHttp, + bot: aiocqhttp.CQHttp, ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -67,8 +67,8 @@ async def _parse_onebot_json(message_chain: MessageChain): @classmethod async def _dispatch_send( cls, - bot: CQHttp, - event: Event | None, + bot: aiocqhttp.CQHttp, + event: aiocqhttp.Event | None, is_group: bool, session_id: str | None, messages: list[dict], @@ -82,7 +82,7 @@ async def _dispatch_send( await bot.send_group_msg(group_id=session_id_int, message=messages) elif not is_group and isinstance(session_id_int, int): await bot.send_private_msg(user_id=session_id_int, message=messages) - elif isinstance(event, Event): # 最后兜底 + elif isinstance(event, aiocqhttp.Event): # 最后兜底 await bot.send(event=event, message=messages) else: raise ValueError( @@ -92,9 +92,9 @@ async def _dispatch_send( @classmethod async def send_message( cls, - bot: CQHttp, + bot: aiocqhttp.CQHttp, message_chain: MessageChain, - event: Event | None = None, + event: aiocqhttp.Event | None = None, is_group: bool = False, session_id: str | None = None, ) -> None: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index fb6c997848..c552faad0a 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -1,28 +1,37 @@ import asyncio +import importlib import itertools import logging import time import uuid -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable from typing import Any, cast -from aiocqhttp import CQHttp, Event +import aiocqhttp from aiocqhttp.exceptions import ActionFailed -from astrbot.api import logger -from astrbot.api.event import MessageChain -from astrbot.api.message_components import * -from astrbot.api.platform import ( +from astrbot import logger +from astrbot.core.message.components import ( + At, + ComponentTypes, + File, + Image, + Plain, + Poke, + Reply, +) +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform import ( AstrBotMessage, + Group, MessageMember, MessageType, Platform, PlatformMetadata, ) from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.register import register_platform_adapter -from ...register import register_platform_adapter -from .aiocqhttp_message_event import * from .aiocqhttp_message_event import AiocqhttpMessageEvent @@ -37,6 +46,7 @@ def __init__( platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue, + bot_factory: Callable[..., Any] | None = None, ) -> None: super().__init__(platform_config, event_queue) @@ -51,17 +61,10 @@ def __init__( support_streaming_message=False, ) - self.bot = CQHttp( - use_ws_reverse=True, - import_name="aiocqhttp", - api_timeout_sec=180, - access_token=platform_config.get( - "ws_reverse_token", - ), # 以防旧版本配置不存在 - ) + self.bot = self._create_bot(platform_config, bot_factory=bot_factory) @self.bot.on_request() - async def request(event: Event) -> None: + async def request(event: aiocqhttp.Event) -> None: try: abm = await self.convert_message(event) if not abm: @@ -72,7 +75,7 @@ async def request(event: Event) -> None: return @self.bot.on_notice() - async def notice(event: Event) -> None: + async def notice(event: aiocqhttp.Event) -> None: try: abm = await self.convert_message(event) if abm: @@ -82,7 +85,7 @@ async def notice(event: Event) -> None: return @self.bot.on_message("group") - async def group(event: Event) -> None: + async def group(event: aiocqhttp.Event) -> None: try: abm = await self.convert_message(event) if abm: @@ -92,7 +95,7 @@ async def group(event: Event) -> None: return @self.bot.on_message("private") - async def private(event: Event) -> None: + async def private(event: aiocqhttp.Event) -> None: try: abm = await self.convert_message(event) if abm: @@ -105,6 +108,29 @@ async def private(event: Event) -> None: def on_websocket_connection(_) -> None: logger.info("aiocqhttp(OneBot v11) 适配器已连接。") + @staticmethod + def _create_bot( + platform_config: dict, + bot_factory: Callable[..., Any] | None = None, + ) -> aiocqhttp.CQHttp: + if bot_factory is None: + # Resolve aiocqhttp at runtime so tests that swap sys.modules later + # still affect bot creation even if this module was imported earlier. + aiocqhttp_module = importlib.import_module("aiocqhttp") + bot_factory = aiocqhttp_module.CQHttp + + return cast( + aiocqhttp.CQHttp, + bot_factory( + use_ws_reverse=True, + import_name="aiocqhttp", + api_timeout_sec=180, + access_token=platform_config.get( + "ws_reverse_token", + ), # 以防旧版本配置不存在 + ), + ) + async def send_by_session( self, session: MessageSesion, @@ -124,7 +150,7 @@ async def send_by_session( ) await super().send_by_session(session, message_chain) - async def convert_message(self, event: Event) -> AstrBotMessage | None: + async def convert_message(self, event: aiocqhttp.Event) -> AstrBotMessage | None: logger.debug(f"[aiocqhttp] RawMessage {event}") if event["post_type"] == "message": @@ -139,7 +165,9 @@ async def convert_message(self, event: Event) -> AstrBotMessage | None: return abm - async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage: + async def _convert_handle_request_event( + self, event: aiocqhttp.Event + ) -> AstrBotMessage: """OneBot V11 请求类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) @@ -164,7 +192,9 @@ async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage: abm.raw_message = event return abm - async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage: + async def _convert_handle_notice_event( + self, event: aiocqhttp.Event + ) -> AstrBotMessage: """OneBot V11 通知类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) @@ -196,7 +226,7 @@ async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage: async def _convert_handle_message_event( self, - event: Event, + event: aiocqhttp.Event, get_reply=True, ) -> AstrBotMessage: """OneBot V11 消息类事件 @@ -309,7 +339,7 @@ async def _convert_handle_message_event( ) # 添加必要的 post_type 字段,防止 Event.from_payload 报错 reply_event_data["post_type"] = "message" - new_event = Event.from_payload(reply_event_data) + new_event = aiocqhttp.Event.from_payload(reply_event_data) if not new_event: logger.error( f"无法从回复消息数据构造 Event 对象: {reply_event_data}", @@ -401,6 +431,14 @@ async def _convert_handle_message_event( f"不支持的消息段类型,已忽略: {t}, data={m['data']}" ) continue + if ( + t == "image" + and not m["data"].get("file") + and m["data"].get("url") + ): + a = Image(file=m["data"]["url"], url=m["data"]["url"]) + abm.message.append(a) + continue a = ComponentTypes[t](**m["data"]) abm.message.append(a) except Exception as e: @@ -456,5 +494,5 @@ async def handle_msg(self, message: AstrBotMessage) -> None: self.commit_event(message_event) - def get_client(self) -> CQHttp: + def get_client(self) -> aiocqhttp.CQHttp: return self.bot diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 7657962a11..fa480fc1cd 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -366,6 +366,7 @@ async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] + registered_command_names: set[str] = set() for handler_md in star_handlers_registry: if not star_map[handler_md.handler_module_path].activated: @@ -373,35 +374,40 @@ async def _collect_and_register_commands(self) -> None: if not handler_md.enabled: continue for event_filter in handler_md.event_filters: - cmd_info = self._extract_command_info(event_filter, handler_md) - if not cmd_info: - continue - - cmd_name, description, cmd_filter_instance = cmd_info - - # 创建动态回调 - callback = self._create_dynamic_callback(cmd_name) - - # 创建一个通用的参数选项来接收所有文本输入 - options = [ - discord.Option( - name="params", - description="指令的所有参数", - type=discord.SlashCommandOptionType.string, - required=False, - ), - ] - - # 创建SlashCommand - slash_command = discord.SlashCommand( - name=cmd_name, - description=description, - func=callback, - options=options, - guild_ids=[self.guild_id] if self.guild_id else None, - ) - self.client.add_application_command(slash_command) - registered_commands.append(cmd_name) + cmd_infos = self._extract_command_infos(event_filter, handler_md) + for cmd_name, description in cmd_infos: + if cmd_name in registered_command_names: + logger.warning( + "[Discord] Duplicate slash command '%s' from %s ignored.", + cmd_name, + handler_md.handler_module_path, + ) + continue + + # 创建动态回调 + callback = self._create_dynamic_callback(cmd_name) + + # 创建一个通用的参数选项来接收所有文本输入 + options = [ + discord.Option( + name="params", + description="指令的所有参数", + type=discord.SlashCommandOptionType.string, + required=False, + ), + ] + + # 创建SlashCommand + slash_command = discord.SlashCommand( + name=cmd_name, + description=description, + func=callback, + options=options, + guild_ids=[self.guild_id] if self.guild_id else None, + ) + self.client.add_application_command(slash_command) + registered_command_names.add(cmd_name) + registered_commands.append(cmd_name) if registered_commands: logger.info( @@ -478,11 +484,23 @@ async def dynamic_callback( def _extract_command_info( event_filter: Any, handler_metadata: StarHandlerMetadata, - ) -> tuple[str, str, CommandFilter | None] | None: + ) -> tuple[str, str] | None: + infos = DiscordPlatformAdapter._extract_command_infos( + event_filter, + handler_metadata, + ) + if not infos: + return None + return infos[0] + + @staticmethod + def _extract_command_infos( + event_filter: Any, + handler_metadata: StarHandlerMetadata, + ) -> list[tuple[str, str]]: """从事件过滤器中提取指令信息""" - cmd_name = None - # is_group = False - cmd_filter_instance = None + primary_name = None + alias_names: list[str] = [] if isinstance(event_filter, CommandFilter): # 暂不支持子指令注册为斜杠指令 @@ -490,24 +508,32 @@ def _extract_command_info( event_filter.parent_command_names and event_filter.parent_command_names != [""] ): - return None - cmd_name = event_filter.command_name - cmd_filter_instance = event_filter + return [] + primary_name = event_filter.command_name + alias_names = sorted(getattr(event_filter, "alias", set()) or set()) elif isinstance(event_filter, CommandGroupFilter): # 暂不支持指令组直接注册为斜杠指令,因为它们没有 handle 方法 - return None - - if not cmd_name: - return None + return [] - # Discord 斜杠指令名称规范 - if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): - logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}") - return None + if not primary_name: + return [] - description = handler_metadata.desc or f"指令: {cmd_name}" + description = handler_metadata.desc or f"指令: {primary_name}" if len(description) > 100: description = f"{description[:97]}..." - return cmd_name, description, cmd_filter_instance + results: list[tuple[str, str]] = [] + seen: set[str] = set() + for cmd_name in [primary_name, *alias_names]: + if not cmd_name or cmd_name in seen: + continue + seen.add(cmd_name) + # Discord slash command names allow lowercase letters, numbers, underscores and hyphens. + if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): + logger.warning( + f"[Discord] Skipped invalid command name (must match ^[a-z0-9_-]{{1,32}}$): {cmd_name}" + ) + continue + results.append((cmd_name, description)) + return results diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 03ef26c1ec..e2000c13fb 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -2,25 +2,25 @@ import re import sys import uuid +from typing import Any -from apscheduler.schedulers.asyncio import AsyncIOScheduler +import apscheduler.schedulers.asyncio as _apscheduler_asyncio_import +import telegram.ext as _telegram_ext_import from telegram import BotCommand, Update from telegram.constants import ChatType -from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters -from telegram.ext import MessageHandler as TelegramMessageHandler -import astrbot.api.message_components as Comp -from astrbot.api import logger -from astrbot.api.event import MessageChain -from astrbot.api.platform import ( +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform import ( AstrBotMessage, MessageMember, MessageType, Platform, PlatformMetadata, - register_platform_adapter, ) from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.register import register_platform_adapter from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.star import star_map @@ -28,6 +28,12 @@ from .tg_event import TelegramPlatformEvent +telegram_ext = sys.modules.get("telegram.ext", _telegram_ext_import) +apscheduler_asyncio = sys.modules.get( + "apscheduler.schedulers.asyncio", + _apscheduler_asyncio_import, +) + if sys.version_info >= (3, 12): from typing import override else: @@ -73,21 +79,21 @@ def __init__( self.last_command_hash = None self.application = ( - ApplicationBuilder() + telegram_ext.ApplicationBuilder() .token(self.config["telegram_token"]) .base_url(base_url) .base_file_url(file_base_url) .build() ) - message_handler = TelegramMessageHandler( - filters=filters.ALL, # receive all messages + message_handler = telegram_ext.MessageHandler( + filters=telegram_ext.filters.ALL, # receive all messages callback=self.message_handler, ) self.application.add_handler(message_handler) self.client = self.application.bot logger.debug(f"Telegram base url: {self.client.base_url}") - self.scheduler = AsyncIOScheduler() + self.scheduler = apscheduler_asyncio.AsyncIOScheduler() # Media group handling # Cache structure: {media_group_id: {"created_at": datetime, "items": [(update, context), ...]}} @@ -149,15 +155,20 @@ async def register_commands(self) -> None: try: commands = self.collect_commands() - if commands: - current_hash = hash( - tuple((cmd.command, cmd.description) for cmd in commands), + current_hash = hash( + tuple((cmd.command, cmd.description) for cmd in commands), + ) + if current_hash == self.last_command_hash: + return + self.last_command_hash = current_hash + if not commands: + logger.info( + "[Telegram] No commands collected. Keep existing Telegram commands unchanged." ) - if current_hash == self.last_command_hash: - return - self.last_command_hash = current_hash - await self.client.delete_my_commands() - await self.client.set_my_commands(commands) + return + + await self.client.delete_my_commands() + await self.client.set_my_commands(commands) except Exception as e: logger.error(f"向 Telegram 注册指令时发生错误: {e!s}") @@ -179,14 +190,17 @@ def collect_commands(self) -> list[BotCommand]: handler_metadata, skip_commands, ) - if cmd_info_list: - for cmd_name, description in cmd_info_list: - if cmd_name in command_dict: - logger.warning( - f"命令名 '{cmd_name}' 重复注册,将使用首次注册的定义: " - f"'{command_dict[cmd_name]}'" - ) - command_dict.setdefault(cmd_name, description) + if not cmd_info_list: + continue + + for cmd_name, description in cmd_info_list: + if cmd_name in command_dict: + logger.warning( + "[Telegram] Duplicate command name '%s' will use first registered definition: '%s'", + cmd_name, + command_dict[cmd_name], + ) + command_dict.setdefault(cmd_name, description) commands_a = sorted(command_dict.keys()) return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a] @@ -197,8 +211,8 @@ def _extract_command_info( handler_metadata, skip_commands: set, ) -> list[tuple[str, str]] | None: - """从事件过滤器中提取指令信息,包括所有别名""" - cmd_names = [] + """从事件过滤器中提取指令信息""" + cmd_names: list[str] = [] is_group = False if isinstance(event_filter, CommandFilter) and event_filter.command_name: if ( @@ -206,9 +220,8 @@ def _extract_command_info( and event_filter.parent_command_names != [""] ): return None - # 收集主命令名和所有别名 cmd_names = [event_filter.command_name] - if event_filter.alias: + if getattr(event_filter, "alias", None): cmd_names.extend(event_filter.alias) elif isinstance(event_filter, CommandGroupFilter): if event_filter.parent_group: @@ -216,16 +229,20 @@ def _extract_command_info( cmd_names = [event_filter.group_name] is_group = True - result = [] + result: list[tuple[str, str]] = [] for cmd_name in cmd_names: if not cmd_name or cmd_name in skip_commands: continue - if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32: + # Telegram command names must start with a letter and contain only lowercase letters, numbers, and underscores + if not re.match(r"^[a-z][a-z0-9_]{0,31}$", cmd_name): + logger.warning( + f"[Telegram] Skipped invalid command name (must start with letter): {cmd_name}" + ) continue - - # Build description. description = handler_metadata.desc or ( - f"Command group: {cmd_name}" if is_group else f"Command: {cmd_name}" + f"指令组: {cmd_name} (包含多个子指令)" + if is_group + else f"指令: {cmd_name}" ) if len(description) > 30: description = description[:30] + "..." @@ -233,7 +250,7 @@ def _extract_command_info( return result if result else None - async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + async def start(self, update: Update, context: Any) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -244,9 +261,7 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non text=self.config["start_message"], ) - async def message_handler( - self, update: Update, context: ContextTypes.DEFAULT_TYPE - ) -> None: + async def message_handler(self, update: Update, context: Any) -> None: logger.debug(f"Telegram message: {update.message}") # Handle media group messages @@ -262,7 +277,7 @@ async def message_handler( async def convert_message( self, update: Update, - context: ContextTypes.DEFAULT_TYPE, + context: Any, get_reply=True, ) -> AstrBotMessage | None: """转换 Telegram 的消息对象为 AstrBotMessage 对象。 @@ -429,9 +444,7 @@ async def convert_message( return message - async def handle_media_group_message( - self, update: Update, context: ContextTypes.DEFAULT_TYPE - ): + async def handle_media_group_message(self, update: Update, context: Any): """Handle messages that are part of a media group (album). Caches incoming messages and schedules delayed processing to collect all @@ -546,7 +559,7 @@ async def handle_msg(self, message: AstrBotMessage) -> None: ) self.commit_event(message_event) - def get_client(self) -> ExtBot: + def get_client(self): return self.client async def terminate(self) -> None: diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index d7e3f16780..24f204ec29 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -10,8 +10,7 @@ from telegram.ext import ExtBot from astrbot import logger -from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import ( +from astrbot.core.message.components import ( At, File, Image, @@ -19,7 +18,13 @@ Record, Reply, ) -from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.platform import ( + AstrBotMessage, + AstrMessageEvent, + MessageType, + PlatformMetadata, +) class TelegramPlatformEvent(AstrMessageEvent): @@ -155,7 +160,8 @@ async def _send_voice_with_fallback( except BadRequest as e: # python-telegram-bot raises BadRequest for Voice_messages_forbidden; # distinguish the voice-privacy case via the API error message. - if "Voice_messages_forbidden" not in e.message: + err_msg = getattr(e, "message", str(e)) + if "Voice_messages_forbidden" not in err_msg: raise logger.warning( "User privacy settings prevent receiving voice messages, falling back to sending an audio file. " diff --git a/tests/unit/test_aiocqhttp_adapter.py b/tests/unit/test_aiocqhttp_adapter.py new file mode 100644 index 0000000000..8581365549 --- /dev/null +++ b/tests/unit/test_aiocqhttp_adapter.py @@ -0,0 +1,846 @@ +"""Unit tests for aiocqhttp platform adapter. + +Tests cover: +- AiocqhttpAdapter class initialization and methods +- AiocqhttpMessageEvent class and message handling +- Message conversion for different event types +- Group and private message processing + +Note: Uses shared mock fixtures from tests/fixtures/mocks/ +""" + +import asyncio +import importlib.util +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# 导入共享的辅助函数 +from tests.fixtures.helpers import NoopAwaitable, make_platform_config + +# 导入共享的 mock fixture +from tests.fixtures.mocks import mock_aiocqhttp_modules # noqa: F401 + + +def load_module_from_file(module_name: str, file_path: Path): + """Load a Python module directly from a file path.""" + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +# Get the path to the aiocqhttp source files +AIOCQHTTP_DIR = ( + Path(__file__).parent.parent.parent + / "astrbot" + / "core" + / "platform" + / "sources" + / "aiocqhttp" +) + + +# ============================================================================ +# Fixtures (使用 conftest.py 中的 event_queue 和 platform_settings) +# ============================================================================ + + +@pytest.fixture +def platform_config(): + """Create a platform configuration for testing.""" + return make_platform_config("aiocqhttp") + + +@pytest.fixture +def mock_bot(): + """Create a mock CQHttp bot instance.""" + bot = MagicMock() + bot.send = AsyncMock() + bot.call_action = AsyncMock() + bot.on_request = MagicMock() + bot.on_notice = MagicMock() + bot.on_message = MagicMock() + bot.on_websocket_connection = MagicMock() + bot.run_task = MagicMock(return_value=NoopAwaitable()) + return bot + + +@pytest.fixture +def mock_event_group(): + """Create a mock group message event.""" + event = MagicMock() + event.__getitem__ = lambda self, key: { + "post_type": "message", + "message_type": "group", + "message": [{"type": "text", "data": {"text": "Hello World"}}], + }.get(key) + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.message_id = "msg_123" + event.sender = {"user_id": 98765432, "nickname": "TestUser", "card": ""} + event.message = [{"type": "text", "data": {"text": "Hello World"}}] + event.get = lambda key, default=None: { + "group_name": "TestGroup", + }.get(key, default) + return event + + +@pytest.fixture +def mock_event_private(): + """Create a mock private message event.""" + event = MagicMock() + event.__getitem__ = lambda self, key: { + "post_type": "message", + "message_type": "private", + "message": [{"type": "text", "data": {"text": "Private Hello"}}], + }.get(key) + event.self_id = 12345678 + event.user_id = 98765432 + event.message_id = "msg_456" + event.sender = {"user_id": 98765432, "nickname": "TestUser"} + event.message = [{"type": "text", "data": {"text": "Private Hello"}}] + event.get = lambda key, default=None: None + return event + + +@pytest.fixture +def mock_event_notice(): + """Create a mock notice event.""" + event = MagicMock() + event.__getitem__ = lambda self, key: { + "post_type": "notice", + "sub_type": "poke", + "target_id": 12345678, + }.get(key) + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.get = lambda key, default=None: { + "group_id": 11111111, + "sub_type": "poke", + "target_id": 12345678, + }.get(key, default) + return event + + +@pytest.fixture +def mock_event_request(): + """Create a mock request event.""" + event = MagicMock() + event.__getitem__ = lambda self, key: {"post_type": "request"}.get(key) + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.get = lambda key, default=None: {"group_id": 11111111}.get(key, default) + return event + + +# ============================================================================ +# AiocqhttpAdapter Tests +# ============================================================================ + + +class TestAiocqhttpAdapterInit: + """Tests for AiocqhttpAdapter initialization.""" + + def test_init_basic(self, event_queue, platform_config, platform_settings): + """Test basic adapter initialization.""" + with patch("aiocqhttp.CQHttp"): + # Import after patching + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + assert adapter.config == platform_config + assert adapter.settings == platform_settings + assert adapter.host == platform_config["ws_reverse_host"] + assert adapter.port == platform_config["ws_reverse_port"] + assert adapter.metadata.name == "aiocqhttp" + assert adapter.metadata.id == "test_aiocqhttp" + + def test_init_metadata(self, event_queue, platform_config, platform_settings): + """Test adapter metadata is correctly set.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + assert adapter.metadata.name == "aiocqhttp" + assert "OneBot" in adapter.metadata.description + assert adapter.metadata.support_streaming_message is False + + +class TestAiocqhttpAdapterConvertMessage: + """Tests for message conversion.""" + + @pytest.mark.asyncio + async def test_convert_group_message( + self, + event_queue, + platform_config, + platform_settings, + mock_event_group, + ): + """Test converting a group message event.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + result = await adapter._convert_handle_message_event(mock_event_group) + + assert result is not None + assert result.self_id == "12345678" + assert result.sender.user_id == "98765432" + assert result.message_str == "Hello World" + assert len(result.message) == 1 + + @pytest.mark.asyncio + async def test_convert_private_message( + self, + event_queue, + platform_config, + platform_settings, + mock_event_private, + ): + """Test converting a private message event.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + result = await adapter._convert_handle_message_event(mock_event_private) + + assert result is not None + assert result.type == MessageType.FRIEND_MESSAGE + assert result.sender.user_id == "98765432" + + @pytest.mark.asyncio + async def test_convert_notice_event( + self, + event_queue, + platform_config, + platform_settings, + mock_event_notice, + ): + """Test converting a notice event.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + result = await adapter._convert_handle_notice_event(mock_event_notice) + + assert result is not None + assert result.raw_message == mock_event_notice + + @pytest.mark.asyncio + async def test_convert_request_event( + self, + event_queue, + platform_config, + platform_settings, + mock_event_request, + ): + """Test converting a request event.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + result = await adapter._convert_handle_request_event(mock_event_request) + + assert result is not None + assert result.raw_message == mock_event_request + + @pytest.mark.asyncio + async def test_convert_message_invalid_format( + self, event_queue, platform_config, platform_settings + ): + """Test converting a message with invalid format raises error.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + # Create event with non-list message + event = MagicMock() + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.message_id = "msg_123" + event.sender = {"user_id": 98765432, "nickname": "TestUser"} + event.message = "not a list" # Invalid format + event.__getitem__ = lambda self, key: { + "message_type": "group", + }.get(key) + event.get = lambda key, default=None: None + + with pytest.raises(ValueError) as exc_info: + await adapter._convert_handle_message_event(event) + + assert "无法识别的消息类型" in str(exc_info.value) + + +class TestAiocqhttpAdapterMessageComponents: + """Tests for different message component types.""" + + @pytest.mark.asyncio + async def test_convert_at_message( + self, event_queue, platform_config, platform_settings + ): + """Test converting a message with @ mention.""" + with patch("aiocqhttp.CQHttp") as mock_cqhttp: + mock_bot_instance = MagicMock() + mock_bot_instance.call_action = AsyncMock( + return_value={"card": "AtUser", "nickname": "AtUserNick"} + ) + mock_cqhttp.return_value = mock_bot_instance + + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + event = MagicMock() + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.message_id = "msg_123" + event.sender = {"user_id": 98765432, "nickname": "TestUser", "card": ""} + event.message = [ + {"type": "at", "data": {"qq": "88888888"}}, + {"type": "text", "data": {"text": "Hello"}}, + ] + event.__getitem__ = lambda self, key: { + "message_type": "group", + }.get(key) + event.get = lambda key, default=None: None + + result = await adapter._convert_handle_message_event(event) + + assert result is not None + # Should have At component and text + assert len(result.message) >= 1 + + @pytest.mark.asyncio + async def test_convert_image_message( + self, event_queue, platform_config, platform_settings + ): + """Test converting a message with image.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + event = MagicMock() + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.message_id = "msg_123" + event.sender = {"user_id": 98765432, "nickname": "TestUser"} + event.message = [ + {"type": "image", "data": {"url": "http://example.com/image.jpg"}}, + ] + event.__getitem__ = lambda self, key: { + "message_type": "group", + }.get(key) + event.get = lambda key, default=None: None + + result = await adapter._convert_handle_message_event(event) + + assert result is not None + assert len(result.message) == 1 + + @pytest.mark.asyncio + async def test_convert_empty_text_skipped( + self, event_queue, platform_config, platform_settings + ): + """Test that empty text segments are skipped.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + event = MagicMock() + event.self_id = 12345678 + event.user_id = 98765432 + event.group_id = 11111111 + event.message_id = "msg_123" + event.sender = {"user_id": 98765432, "nickname": "TestUser"} + event.message = [ + {"type": "text", "data": {"text": " "}}, # Empty/whitespace only + {"type": "text", "data": {"text": "Hello"}}, + ] + event.__getitem__ = lambda self, key: { + "message_type": "group", + }.get(key) + event.get = lambda key, default=None: None + + result = await adapter._convert_handle_message_event(event) + + assert result is not None + assert result.message_str == "Hello" + + +class TestAiocqhttpAdapterRun: + """Tests for run method.""" + + def test_run_with_config(self, event_queue, platform_config, platform_settings): + """Test run method with configured host and port.""" + mock_bot_instance = MagicMock() + mock_bot_instance.run_task = MagicMock(return_value=NoopAwaitable()) + + with patch("aiocqhttp.CQHttp", return_value=mock_bot_instance): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + result = adapter.run() + + assert result is not None + mock_bot_instance.run_task.assert_called_once() + + def test_run_with_default_values(self, event_queue, platform_settings): + """Test run method uses default values when not configured.""" + mock_bot_instance = MagicMock() + mock_bot_instance.run_task = MagicMock(return_value=NoopAwaitable()) + + with patch("aiocqhttp.CQHttp", return_value=mock_bot_instance): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + config = {"id": "test", "ws_reverse_host": None, "ws_reverse_port": None} + adapter = AiocqhttpAdapter(config, platform_settings, event_queue) + + adapter.run() + + assert adapter.host == "0.0.0.0" + assert adapter.port == 6199 + + +class TestAiocqhttpAdapterTerminate: + """Tests for terminate method.""" + + @pytest.mark.asyncio + async def test_terminate_sets_shutdown_event( + self, event_queue, platform_config, platform_settings + ): + """Test terminate method sets shutdown event.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + adapter.shutdown_event = asyncio.Event() + + await adapter.terminate() + + assert adapter.shutdown_event.is_set() + + +class TestAiocqhttpAdapterHandleMsg: + """Tests for handle_msg method.""" + + @pytest.mark.asyncio + async def test_handle_msg_creates_event( + self, event_queue, platform_config, platform_settings + ): + """Test handle_msg creates AiocqhttpMessageEvent and commits it.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.astrbot_message import AstrBotMessage + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + message = AstrBotMessage() + message.message_str = "Test message" + message.session_id = "test_session" + + await adapter.handle_msg(message) + + # Check that event was committed to queue + assert event_queue.qsize() == 1 + + +class TestAiocqhttpAdapterMeta: + """Tests for meta method.""" + + def test_meta_returns_metadata( + self, event_queue, platform_config, platform_settings + ): + """Test meta method returns PlatformMetadata.""" + with patch("aiocqhttp.CQHttp"): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + meta = adapter.meta() + + assert meta.name == "aiocqhttp" + assert meta.id == "test_aiocqhttp" + + +class TestAiocqhttpAdapterGetClient: + """Tests for get_client method.""" + + def test_get_client_returns_bot( + self, event_queue, platform_config, platform_settings + ): + """Test get_client returns the bot instance.""" + mock_bot_instance = MagicMock() + + with patch("aiocqhttp.CQHttp", return_value=mock_bot_instance): + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( + AiocqhttpAdapter, + ) + + adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) + + result = adapter.get_client() + + assert result == mock_bot_instance + + +# ============================================================================ +# AiocqhttpMessageEvent Tests +# ============================================================================ + + +class TestAiocqhttpMessageEventInit: + """Tests for AiocqhttpMessageEvent initialization.""" + + def test_init_basic(self): + """Test basic event initialization.""" + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + message_obj = MagicMock() + message_obj.raw_message = None + platform_meta = MagicMock() + bot = MagicMock() + + event = AiocqhttpMessageEvent( + message_str="Test message", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="test_session", + bot=bot, + ) + + assert event.message_str == "Test message" + assert event.bot == bot + assert event.session_id == "test_session" + + +class TestAiocqhttpMessageEventFromSegmentToDict: + """Tests for _from_segment_to_dict method.""" + + @pytest.mark.asyncio + async def test_from_segment_plain(self): + """Test converting Plain segment to dict.""" + from astrbot.core.message.components import Plain + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + plain = Plain(text="Hello") + result = await AiocqhttpMessageEvent._from_segment_to_dict(plain) + + # Plain component type is "text" in toDict() + assert result["type"] == "text" + assert result["data"]["text"] == "Hello" + + +class TestAiocqhttpMessageEventParseOnebotJson: + """Tests for _parse_onebot_json method.""" + + @pytest.mark.asyncio + async def test_parse_empty_chain(self): + """Test parsing empty message chain.""" + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + chain = MessageChain(chain=[]) + result = await AiocqhttpMessageEvent._parse_onebot_json(chain) + + assert result == [] + + @pytest.mark.asyncio + async def test_parse_plain_text(self): + """Test parsing plain text message chain.""" + from astrbot.core.message.components import Plain + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + chain = MessageChain(chain=[Plain(text="Hello World")]) + result = await AiocqhttpMessageEvent._parse_onebot_json(chain) + + assert len(result) == 1 + # Plain component type is "text" in toDict() + assert result[0]["type"] == "text" + + +class TestAiocqhttpMessageEventSend: + """Tests for send method.""" + + @pytest.mark.asyncio + async def test_send_group_message(self): + """Test sending group message.""" + from astrbot.core.message.components import Plain + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + bot.send_group_msg = AsyncMock() + + message_obj = MagicMock() + message_obj.raw_message = None + message_obj.group = MagicMock() + message_obj.group.group_id = "11111111" + + platform_meta = MagicMock() + + event = AiocqhttpMessageEvent( + message_str="Test", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="11111111", + bot=bot, + ) + + # Mock get_group_id to return group_id + event.get_group_id = MagicMock(return_value="11111111") + event.get_sender_id = MagicMock(return_value="98765432") + + with patch.object( + AiocqhttpMessageEvent, + "send_message", + new_callable=AsyncMock, + ) as mock_send: + with patch( + "astrbot.core.platform.astr_message_event.AstrMessageEvent.send", + new_callable=AsyncMock, + ): + chain = MessageChain(chain=[Plain(text="Hello")]) + await event.send(chain) + + mock_send.assert_called_once() + + +class TestAiocqhttpMessageEventDispatchSend: + """Tests for _dispatch_send method.""" + + @pytest.mark.asyncio + async def test_dispatch_send_group(self): + """Test dispatching send to group.""" + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + bot.send_group_msg = AsyncMock() + + await AiocqhttpMessageEvent._dispatch_send( + bot=bot, + event=None, + is_group=True, + session_id="11111111", + messages=[{"type": "text", "data": {"text": "Hello"}}], + ) + + bot.send_group_msg.assert_called_once() + + @pytest.mark.asyncio + async def test_dispatch_send_private(self): + """Test dispatching send to private chat.""" + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + bot.send_private_msg = AsyncMock() + + await AiocqhttpMessageEvent._dispatch_send( + bot=bot, + event=None, + is_group=False, + session_id="98765432", + messages=[{"type": "text", "data": {"text": "Hello"}}], + ) + + bot.send_private_msg.assert_called_once() + + @pytest.mark.asyncio + async def test_dispatch_send_invalid_session(self): + """Test dispatching send with invalid session raises error.""" + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + + with pytest.raises(ValueError) as exc_info: + await AiocqhttpMessageEvent._dispatch_send( + bot=bot, + event=None, + is_group=True, + session_id="invalid", + messages=[{"type": "text", "data": {"text": "Hello"}}], + ) + + assert "无法发送消息" in str(exc_info.value) + + +class TestAiocqhttpMessageEventGetGroup: + """Tests for get_group method.""" + + @pytest.mark.asyncio + async def test_get_group_success(self): + """Test getting group info successfully.""" + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + bot.call_action = AsyncMock( + side_effect=[ + {"group_name": "TestGroup"}, # get_group_info + [ # get_group_member_list + {"user_id": "111", "role": "owner", "nickname": "Owner"}, + {"user_id": "222", "role": "admin", "nickname": "Admin1"}, + {"user_id": "333", "role": "member", "nickname": "Member1"}, + ], + ] + ) + + message_obj = MagicMock() + message_obj.raw_message = None + platform_meta = MagicMock() + + event = AiocqhttpMessageEvent( + message_str="Test", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="11111111", + bot=bot, + ) + + group = await event.get_group(group_id="11111111") + + assert group is not None + assert group.group_id == "11111111" + assert group.group_name == "TestGroup" + assert group.group_owner == "111" + assert group.group_admins is not None + assert "222" in group.group_admins + + @pytest.mark.asyncio + async def test_get_group_no_group_id(self): + """Test get_group returns None when no group_id available.""" + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + bot.call_action = AsyncMock() + + message_obj = MagicMock() + message_obj.raw_message = None + platform_meta = MagicMock() + + event = AiocqhttpMessageEvent( + message_str="Test", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="private_session", + bot=bot, + ) + + # Mock get_group_id to return None + event.get_group_id = MagicMock(return_value=None) + + result = await event.get_group() + + assert result is None + + +class TestAiocqhttpMessageEventSendStreaming: + """Tests for send_streaming method.""" + + @pytest.mark.asyncio + async def test_send_streaming_without_fallback(self): + """Test streaming send without fallback mode.""" + from astrbot.core.message.components import Plain + from astrbot.core.message.message_event_result import MessageChain + from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import ( + AiocqhttpMessageEvent, + ) + + bot = MagicMock() + + message_obj = MagicMock() + message_obj.raw_message = None + platform_meta = MagicMock() + + event = AiocqhttpMessageEvent( + message_str="Test", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="test_session", + bot=bot, + ) + + async def mock_generator(): + yield MessageChain(chain=[Plain(text="Hello")]) + yield MessageChain(chain=[Plain(text=" World")]) + + with patch.object(event, "send", new_callable=AsyncMock) as mock_send: + with patch( + "astrbot.core.platform.astr_message_event.AstrMessageEvent.send_streaming", + new_callable=AsyncMock, + ): + await event.send_streaming(mock_generator(), use_fallback=False) + + # Should call send with combined message + mock_send.assert_called() diff --git a/tests/unit/test_discord_adapter.py b/tests/unit/test_discord_adapter.py new file mode 100644 index 0000000000..ca1ca02fd2 --- /dev/null +++ b/tests/unit/test_discord_adapter.py @@ -0,0 +1,1100 @@ +"""Unit tests for Discord platform adapter. + +Tests cover: +- DiscordPlatformAdapter class initialization and methods +- DiscordPlatformEvent class and message handling +- DiscordBotClient class +- Message conversion for different message types +- Slash command handling +- Component interactions + +Note: Uses shared mock fixtures from tests/fixtures/mocks/ +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# 导入共享的辅助函数 +from tests.fixtures.helpers import make_platform_config + +# 导入共享的 mock fixture +from tests.fixtures.mocks import mock_discord_modules # noqa: F401 + +# ============================================================================ +# Fixtures (使用 conftest.py 中的 event_queue 和 platform_settings) +# ============================================================================ + + +@pytest.fixture +def platform_config(): + """Create a platform configuration for testing.""" + return make_platform_config("discord") + + +@pytest.fixture +def mock_discord_client(): + """Create a mock Discord client instance.""" + client = MagicMock() + client.user = MagicMock() + client.user.id = 123456789 + client.user.display_name = "TestBot" + client.user.name = "TestBot" + client.get_channel = MagicMock() + client.fetch_channel = AsyncMock() + client.get_message = MagicMock() + client.start = AsyncMock() + client.close = AsyncMock() + client.is_closed = MagicMock(return_value=False) + client.add_application_command = MagicMock() + client.sync_commands = AsyncMock() + client.change_presence = AsyncMock() + return client + + +@pytest.fixture +def mock_discord_message(): + """Create a mock Discord message for testing.""" + + def _create_message( + content: str = "Hello World", + author_id: int = 987654321, + author_name: str = "TestUser", + channel_id: int = 111222333, + guild_id: int | None = 444555666, + mentions: list | None = None, + role_mentions: list | None = None, + attachments: list | None = None, + ): + message = MagicMock() + message.id = 12345678 + message.content = content + message.clean_content = content + + # Author mock + message.author = MagicMock() + message.author.id = author_id + message.author.display_name = author_name + message.author.name = author_name + message.author.bot = False + + # Channel mock + message.channel = MagicMock() + message.channel.id = channel_id + + # Guild mock + if guild_id: + message.guild = MagicMock() + message.guild.id = guild_id + message.guild.get_member = MagicMock(return_value=None) + else: + message.guild = None + + # Mentions + message.mentions = mentions or [] + message.role_mentions = role_mentions or [] + + # Attachments + message.attachments = attachments or [] + + return message + + return _create_message + + +@pytest.fixture +def mock_discord_channel(): + """Create a mock Discord channel for testing.""" + + def _create_channel( + channel_id: int = 111222333, + is_dm: bool = False, + is_messageable: bool = True, + ): + channel = MagicMock() + channel.id = channel_id + channel.send = AsyncMock() + + if is_dm: + # DMChannel mock + channel.guild = None + else: + # GuildChannel mock + channel.guild = MagicMock() + channel.guild.id = 444555666 + + return channel + + return _create_channel + + +@pytest.fixture +def mock_interaction(): + """Create a mock Discord interaction for testing.""" + + def _create_interaction( + interaction_type: int = 2, # application_command + command_name: str = "help", + custom_id: str | None = None, + user_id: int = 987654321, + channel_id: int = 111222333, + guild_id: int | None = 444555666, + ): + interaction = MagicMock() + interaction.id = 12345678 + interaction.type = interaction_type + interaction.user = MagicMock() + interaction.user.id = user_id + interaction.user.display_name = "TestUser" + interaction.channel_id = channel_id + interaction.guild_id = guild_id + + # Interaction data + interaction.data = {"name": command_name} + if custom_id: + interaction.data["custom_id"] = custom_id + interaction.data["component_type"] = 2 + + # Context mock + interaction.defer = AsyncMock() + interaction.followup = MagicMock() + interaction.followup.send = AsyncMock() + + return interaction + + return _create_interaction + + +def create_mock_discord_attachment( + url: str = "https://cdn.discord.com/test.png", + filename: str = "test.png", + content_type: str = "image/png", +): + """Create a mock Discord attachment.""" + attachment = MagicMock() + attachment.url = url + attachment.filename = filename + attachment.content_type = content_type + return attachment + + +# ============================================================================ +# DiscordPlatformAdapter Initialization Tests +# ============================================================================ + + +class TestDiscordAdapterInit: + """Tests for DiscordPlatformAdapter initialization.""" + + def test_init_basic(self, event_queue, platform_config, platform_settings): + """Test basic adapter initialization.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + assert adapter.config == platform_config + assert adapter.settings == platform_settings + assert adapter.enable_command_register is True + assert adapter.client_self_id is None + assert adapter.registered_handlers == [] + + def test_init_with_custom_settings( + self, event_queue, platform_config, platform_settings + ): + """Test adapter initialization with custom settings.""" + platform_config["discord_command_register"] = False + platform_config["discord_guild_id_for_debug"] = "123456789" + platform_config["discord_activity_name"] = "Custom Activity" + + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + assert adapter.enable_command_register is False + assert adapter.guild_id == "123456789" + assert adapter.activity_name == "Custom Activity" + + def test_init_shutdown_event(self, event_queue, platform_config, platform_settings): + """Test shutdown event is initialized.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + assert hasattr(adapter, "shutdown_event") + assert isinstance(adapter.shutdown_event, asyncio.Event) + assert not adapter.shutdown_event.is_set() + + +# ============================================================================ +# DiscordPlatformAdapter Metadata Tests +# ============================================================================ + + +class TestDiscordAdapterMetadata: + """Tests for DiscordPlatformAdapter metadata.""" + + def test_meta_returns_correct_metadata( + self, event_queue, platform_config, platform_settings + ): + """Test meta() returns correct PlatformMetadata.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + meta = adapter.meta() + + assert meta.name == "discord" + assert "discord" in meta.description.lower() + assert meta.id == "test_discord" + assert meta.support_streaming_message is False + + def test_meta_with_missing_id(self, event_queue, platform_settings): + """Test meta() handles missing id in config.""" + config = { + "discord_token": "test_token", + } + + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter(config, platform_settings, event_queue) + meta = adapter.meta() + + # Should use None or default when id is not configured + assert meta.name == "discord" + + +# ============================================================================ +# DiscordPlatformAdapter Message Type Tests +# ============================================================================ + + +class TestDiscordAdapterGetMessageType: + """Tests for _get_message_type method.""" + + def test_get_message_type_dm_channel( + self, event_queue, platform_config, platform_settings + ): + """Test message type detection for DM channel.""" + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + # Create DM channel mock - DMChannel has guild = None + dm_channel = MagicMock() + dm_channel.guild = None + + result = adapter._get_message_type(dm_channel) + + assert result == MessageType.FRIEND_MESSAGE + + def test_get_message_type_guild_channel( + self, event_queue, platform_config, platform_settings + ): + """Test message type detection for guild channel.""" + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + # Create guild channel mock - guild channel has guild with id + # Important: guild must not be None and must evaluate to True + # We need to create a real object, not MagicMock, for the guild attribute + # because the code checks `getattr(channel, "guild", None) is None` + class MockGuild: + def __init__(self): + self.id = 123456789 + + class MockGuildChannel: + def __init__(self): + self.guild = MockGuild() + + guild_channel = MockGuildChannel() + + result = adapter._get_message_type(guild_channel) + + assert result == MessageType.GROUP_MESSAGE + + def test_get_message_type_with_guild_id_override( + self, event_queue, platform_config, platform_settings + ): + """Test message type with guild_id override.""" + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + # Even with DM channel, guild_id should override to GROUP_MESSAGE + dm_channel = MagicMock() + dm_channel.guild = None + + result = adapter._get_message_type(dm_channel, guild_id=123456789) + + assert result == MessageType.GROUP_MESSAGE + + +# ============================================================================ +# DiscordPlatformAdapter Message Conversion Tests +# ============================================================================ + + +class TestDiscordAdapterConvertMessage: + """Tests for message conversion.""" + + @pytest.mark.asyncio + async def test_convert_text_message( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + mock_discord_message, + ): + """Test converting a text message.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter.client_self_id = "123456789" + + message = mock_discord_message( + content="Hello World", + author_id=987654321, + author_name="TestUser", + channel_id=111222333, + guild_id=444555666, + ) + + data = {"message": message, "bot_id": "123456789"} + + result = await adapter.convert_message(data) + + assert result is not None + assert result.message_str == "Hello World" + assert result.sender.user_id == "987654321" + assert result.sender.nickname == "TestUser" + assert result.session_id == "111222333" + # Note: type depends on channel.guild attribute + + @pytest.mark.asyncio + async def test_convert_message_with_mention( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + mock_discord_message, + ): + """Test converting a message with bot mention.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter.client_self_id = "123456789" + + # Create message with mention + bot_user = MagicMock() + bot_user.id = 123456789 + mock_discord_client.user = bot_user + + message = mock_discord_message( + content="<@123456789> Hello Bot", + author_id=987654321, + channel_id=111222333, + ) + + data = {"message": message, "bot_id": "123456789"} + + result = await adapter.convert_message(data) + + # Mention should be stripped + assert result.message_str == "Hello Bot" + + @pytest.mark.asyncio + async def test_convert_message_with_image_attachment( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + mock_discord_message, + ): + """Test converting a message with image attachment.""" + from astrbot.api.message_components import Image + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter.client_self_id = "123456789" + + attachment = create_mock_discord_attachment( + url="https://cdn.discord.com/test.png", + filename="test.png", + content_type="image/png", + ) + + message = mock_discord_message( + content="Check this image", + attachments=[attachment], + ) + + data = {"message": message, "bot_id": "123456789"} + + result = await adapter.convert_message(data) + + assert result.message_str == "Check this image" + # Should have Plain text and Image in message chain + assert len(result.message) == 2 + assert any(isinstance(comp, Image) for comp in result.message) + + @pytest.mark.asyncio + async def test_convert_message_with_file_attachment( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + mock_discord_message, + ): + """Test converting a message with file attachment.""" + from astrbot.api.message_components import File + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter.client_self_id = "123456789" + + attachment = create_mock_discord_attachment( + url="https://cdn.discord.com/document.pdf", + filename="document.pdf", + content_type="application/pdf", + ) + + message = mock_discord_message( + content="Here is a file", + attachments=[attachment], + ) + + data = {"message": message, "bot_id": "123456789"} + + result = await adapter.convert_message(data) + + assert result.message_str == "Here is a file" + # Should have Plain text and File in message chain + assert len(result.message) == 2 + assert any(isinstance(comp, File) for comp in result.message) + + +# ============================================================================ +# DiscordPlatformAdapter Send by Session Tests +# ============================================================================ + + +class TestDiscordAdapterSendBySession: + """Tests for send_by_session method.""" + + @pytest.mark.asyncio + async def test_send_by_session_client_not_ready( + self, + event_queue, + platform_config, + platform_settings, + ): + """Test send_by_session when client is not ready.""" + from astrbot.api.event import MessageChain + from astrbot.core.platform.astr_message_event import MessageSesion + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = MagicMock() + adapter.client.user = None # Client not ready + + session = MessageSesion( + platform_name="discord", + message_type=MessageType.GROUP_MESSAGE, + session_id="111222333", + ) + message_chain = MessageChain() + + # Should return early without error + await adapter.send_by_session(session, message_chain) + + @pytest.mark.asyncio + async def test_send_by_session_invalid_channel_id( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + ): + """Test send_by_session with invalid channel ID format.""" + from astrbot.api.event import MessageChain + from astrbot.api.message_components import Plain + from astrbot.core.platform.astr_message_event import MessageSesion + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter.client_self_id = "123456789" + + session = MessageSesion( + platform_name="discord", + message_type=MessageType.GROUP_MESSAGE, + session_id="invalid_id", + ) + message_chain = MessageChain([Plain(text="Test message")]) + + # Should handle invalid ID gracefully + await adapter.send_by_session(session, message_chain) + + +# ============================================================================ +# DiscordPlatformAdapter Run and Terminate Tests +# ============================================================================ + + +class TestDiscordAdapterRunTerminate: + """Tests for run and terminate methods.""" + + @pytest.mark.asyncio + async def test_run_without_token( + self, + event_queue, + platform_settings, + ): + """Test run method returns early without token.""" + config = { + "id": "test_discord", + "discord_token": "", # Empty token + } + + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter(config, platform_settings, event_queue) + + # Should return early without error + await adapter.run() + + @pytest.mark.asyncio + async def test_terminate_clears_commands( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + ): + """Test terminate method clears slash commands when enabled.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter._polling_task = None + + await adapter.terminate() + + # sync_commands should be called with empty list + mock_discord_client.sync_commands.assert_called_once() + + +# ============================================================================ +# DiscordPlatformAdapter Handle Message Tests +# ============================================================================ + + +class TestDiscordAdapterHandleMessage: + """Tests for handle_msg method.""" + + @pytest.mark.asyncio + async def test_handle_message_sets_wake_on_mention( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + mock_discord_message, + ): + """Test handle_msg sets is_wake when bot is mentioned.""" + from astrbot.api.message_components import Plain + from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + + # Create bot user for mention check + bot_user = MagicMock() + bot_user.id = 123456789 + mock_discord_client.user = bot_user + + # Create message with bot mention + message = mock_discord_message(content="Hello Bot") + message.mentions = [bot_user] + + abm = AstrBotMessage() + abm.type = MessageType.GROUP_MESSAGE + abm.message_str = "Hello Bot" + abm.message = [Plain(text="Hello Bot")] # Required attribute + abm.sender = MessageMember(user_id="987654321", nickname="TestUser") + abm.raw_message = message + abm.session_id = "111222333" + + await adapter.handle_msg(abm) + + # Event should be committed to queue + assert not event_queue.empty() + + @pytest.mark.asyncio + async def test_handle_slash_command( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + mock_interaction, + ): + """Test handle_msg processes slash command correctly.""" + from astrbot.api.message_components import Plain + from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + adapter.client_self_id = "123456789" + + interaction = mock_interaction(interaction_type=2, command_name="help") + + webhook = MagicMock() + + abm = AstrBotMessage() + abm.type = MessageType.GROUP_MESSAGE + abm.message_str = "/help" + abm.message = [Plain(text="/help")] # Required attribute + abm.sender = MessageMember(user_id="987654321", nickname="TestUser") + abm.raw_message = interaction + abm.session_id = "111222333" + + await adapter.handle_msg(abm, followup_webhook=webhook) + + # Event should be committed with is_wake=True for slash commands + assert not event_queue.empty() + + +# ============================================================================ +# DiscordPlatformAdapter Command Registration Tests +# ============================================================================ + + +class TestDiscordAdapterCommandRegistration: + """Tests for slash command collection and registration.""" + + def test_extract_command_infos_includes_aliases(self): + """Test _extract_command_infos expands command aliases.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + from astrbot.core.star.filter.command import CommandFilter + + handler_md = SimpleNamespace(desc="test command") + infos = DiscordPlatformAdapter._extract_command_infos( + CommandFilter("ping", alias={"p"}), + handler_md, + ) + + assert sorted(name for name, _ in infos) == ["p", "ping"] + + def test_extract_command_infos_allows_hyphenated_names(self): + """Test Discord slash command names may include hyphens.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + from astrbot.core.star.filter.command import CommandFilter + + handler_md = SimpleNamespace(desc="hyphen command") + infos = DiscordPlatformAdapter._extract_command_infos( + CommandFilter("user-info"), + handler_md, + ) + + assert infos == [("user-info", "hyphen command")] + + @pytest.mark.asyncio + async def test_collect_commands_warns_on_duplicates( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + ): + """Test duplicate slash commands are warned and ignored.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + from astrbot.core.star.filter.command import CommandFilter + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + + handler_a = SimpleNamespace( + handler_module_path="plugin.discord.a", + enabled=True, + desc="first", + event_filters=[CommandFilter("ping")], + ) + handler_b = SimpleNamespace( + handler_module_path="plugin.discord.b", + enabled=True, + desc="second", + event_filters=[CommandFilter("ping")], + ) + + with ( + pytest.MonkeyPatch.context() as monkeypatch, + patch( + "astrbot.core.platform.sources.discord.discord_platform_adapter.logger" + ) as mock_logger, + ): + monkeypatch.setattr( + "astrbot.core.platform.sources.discord.discord_platform_adapter.star_handlers_registry", + [handler_a, handler_b], + ) + monkeypatch.setattr( + "astrbot.core.platform.sources.discord.discord_platform_adapter.star_map", + { + "plugin.discord.a": SimpleNamespace(activated=True), + "plugin.discord.b": SimpleNamespace(activated=True), + }, + ) + await adapter._collect_and_register_commands() + + assert mock_discord_client.add_application_command.call_count == 1 + mock_logger.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_collect_commands_registers_aliases( + self, + event_queue, + platform_config, + platform_settings, + mock_discord_client, + ): + """Test slash command aliases are also registered.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + from astrbot.core.star.filter.command import CommandFilter + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_discord_client + + handler = SimpleNamespace( + handler_module_path="plugin.discord.alias", + enabled=True, + desc="alias command", + event_filters=[CommandFilter("hello", alias={"hi"})], + ) + + with ( + pytest.MonkeyPatch.context() as monkeypatch, + patch( + "astrbot.core.platform.sources.discord.discord_platform_adapter.discord.SlashCommand", + side_effect=lambda **kwargs: SimpleNamespace(name=kwargs["name"]), + ), + ): + monkeypatch.setattr( + "astrbot.core.platform.sources.discord.discord_platform_adapter.star_handlers_registry", + [handler], + ) + monkeypatch.setattr( + "astrbot.core.platform.sources.discord.discord_platform_adapter.star_map", + {"plugin.discord.alias": SimpleNamespace(activated=True)}, + ) + await adapter._collect_and_register_commands() + + assert mock_discord_client.add_application_command.call_count == 2 + called_names = sorted( + call.args[0].name + for call in mock_discord_client.add_application_command.call_args_list + ) + assert called_names == ["hello", "hi"] + + +# ============================================================================ +# Edge Cases and Error Handling Tests +# ============================================================================ + + +class TestDiscordAdapterEdgeCases: + """Tests for edge cases and error handling.""" + + def test_get_channel_id_returns_string( + self, event_queue, platform_config, platform_settings + ): + """Test _get_channel_id returns string representation.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + channel = MagicMock() + channel.id = 123456789 + + result = adapter._get_channel_id(channel) + + assert result == "123456789" + assert isinstance(result, str) + + +# ============================================================================ +# DiscordPlatformEvent Helper Method Tests (without full initialization) +# ============================================================================ + + +class TestDiscordPlatformEventHelpers: + """Tests for DiscordPlatformEvent helper methods that don't require full init.""" + + def test_is_slash_command_check_logic(self): + """Test the is_slash_command logic without full event initialization.""" + # This tests the logic pattern used in is_slash_command + interaction = MagicMock() + interaction.type = 2 # application_command + + # Simulate the check logic + result = hasattr(interaction, "type") and interaction.type == 2 + assert result is True + + # Test with non-slash command type + interaction.type = 3 # component + result = hasattr(interaction, "type") and interaction.type == 2 + assert result is False + + def test_is_button_interaction_check_logic(self): + """Test the is_button_interaction logic without full event initialization.""" + interaction = MagicMock() + interaction.type = 3 # component + + # Simulate the check logic + result = hasattr(interaction, "type") and interaction.type == 3 + assert result is True + + # Test with non-component type + interaction.type = 2 # application_command + result = hasattr(interaction, "type") and interaction.type == 3 + assert result is False + + +# ============================================================================ +# DiscordBotClient Method Tests +# ============================================================================ + + +class TestDiscordBotClientMethods: + """Tests for DiscordBotClient methods without full initialization.""" + + def test_extract_interaction_content_logic(self): + """Test the _extract_interaction_content logic pattern.""" + # Test slash command pattern + interaction_type = 2 # application_command + interaction_data = { + "name": "help", + "options": [{"name": "topic", "value": "commands"}], + } + + if interaction_type == 2: + command_name = interaction_data.get("name", "") + if options := interaction_data.get("options", []): + params = " ".join( + [f"{opt['name']}:{opt.get('value', '')}" for opt in options] + ) + result = f"/{command_name} {params}" + else: + result = f"/{command_name}" + + assert result == "/help topic:commands" + + # Test component pattern + interaction_type = 3 # component + interaction_data = { + "custom_id": "btn_confirm", + "component_type": 2, + } + + if interaction_type == 3: + custom_id = interaction_data.get("custom_id", "") + component_type = interaction_data.get("component_type", "") + result = f"component:{custom_id}:{component_type}" + + assert result == "component:btn_confirm:2" + + +# ============================================================================ +# Discord Components Data Structure Tests +# ============================================================================ + + +class TestDiscordComponentsData: + """Tests for Discord component data structures.""" + + def test_discord_embed_data_structure(self): + """Test DiscordEmbed data structure.""" + embed_data = { + "title": "Test Title", + "description": "Test Description", + "color": 0x3498DB, + "url": "https://example.com", + "thumbnail": "https://example.com/thumb.png", + "image": "https://example.com/image.png", + "footer": "Test Footer", + "fields": [{"name": "Field 1", "value": "Value 1", "inline": True}], + } + + assert embed_data["title"] == "Test Title" + assert embed_data["description"] == "Test Description" + assert embed_data["color"] == 0x3498DB + assert embed_data["url"] == "https://example.com" + assert embed_data["thumbnail"] == "https://example.com/thumb.png" + assert embed_data["image"] == "https://example.com/image.png" + assert embed_data["footer"] == "Test Footer" + assert len(embed_data["fields"]) == 1 + + def test_discord_button_data_structure(self): + """Test DiscordButton data structure.""" + button_data = { + "label": "Click Me", + "custom_id": "btn_click", + "style": "primary", + "emoji": "👋", + "disabled": False, + "url": None, + } + + assert button_data["label"] == "Click Me" + assert button_data["custom_id"] == "btn_click" + assert button_data["style"] == "primary" + assert button_data["emoji"] == "👋" + assert button_data["disabled"] is False + assert button_data["url"] is None + + def test_discord_button_url_data_structure(self): + """Test DiscordButton with URL data structure.""" + button_data = { + "label": "Visit Website", + "url": "https://example.com", + "style": "link", + "custom_id": None, + } + + assert button_data["url"] == "https://example.com" + assert button_data["custom_id"] is None + + def test_discord_reference_data_structure(self): + """Test DiscordReference data structure.""" + ref_data = { + "message_id": "123456789", + "channel_id": "987654321", + } + + assert ref_data["message_id"] == "123456789" + assert ref_data["channel_id"] == "987654321" + + +# ============================================================================ +# Register Handler Tests +# ============================================================================ + + +class TestDiscordAdapterRegisterHandler: + """Tests for register_handler method.""" + + def test_register_handler(self, event_queue, platform_config, platform_settings): + """Test register_handler adds handler to list.""" + from astrbot.core.platform.sources.discord.discord_platform_adapter import ( + DiscordPlatformAdapter, + ) + + adapter = DiscordPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + handler_info = {"command": "test", "handler": MagicMock()} + adapter.register_handler(handler_info) + + assert len(adapter.registered_handlers) == 1 + assert adapter.registered_handlers[0] == handler_info diff --git a/tests/unit/test_telegram_adapter.py b/tests/unit/test_telegram_adapter.py new file mode 100644 index 0000000000..648338e029 --- /dev/null +++ b/tests/unit/test_telegram_adapter.py @@ -0,0 +1,2053 @@ +"""Unit tests for Telegram platform adapter. + +Tests cover: +- TelegramPlatformAdapter class initialization and methods +- TelegramPlatformEvent class and message handling +- Message conversion for different message types +- Media group message handling +- Command registration + +Note: Uses shared mock fixtures from tests/fixtures/mocks/ +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# 导入共享的辅助函数 +from tests.fixtures.helpers import ( + NoopAwaitable, + create_mock_file, + create_mock_update, + make_platform_config, +) + +# 导入共享的 mock fixture +from tests.fixtures.mocks import mock_telegram_modules # noqa: F401 + +# ============================================================================ +# Fixtures (使用 conftest.py 中的 event_queue 和 platform_settings) +# ============================================================================ + + +@pytest.fixture +def platform_config(): + """Create a platform configuration for testing.""" + return make_platform_config("telegram") + + +@pytest.fixture +def mock_bot(): + """Create a mock Telegram bot instance.""" + bot = MagicMock() + bot.username = "test_bot" + bot.id = 12345678 + bot.base_url = "https://api.telegram.org/bottest_token_123/" + bot.send_message = AsyncMock() + bot.send_photo = AsyncMock() + bot.send_document = AsyncMock() + bot.send_voice = AsyncMock() + bot.send_chat_action = AsyncMock() + bot.delete_my_commands = AsyncMock() + bot.set_my_commands = AsyncMock() + bot.set_message_reaction = AsyncMock() + bot.edit_message_text = AsyncMock() + return bot + + +@pytest.fixture +def mock_application(): + """Create a mock Telegram Application instance.""" + app = MagicMock() + app.bot = MagicMock() + app.bot.username = "test_bot" + app.bot.base_url = "https://api.telegram.org/bottest_token_123/" + app.initialize = AsyncMock() + app.start = AsyncMock() + app.stop = AsyncMock() + app.add_handler = MagicMock() + app.updater = MagicMock() + app.updater.start_polling = MagicMock(return_value=NoopAwaitable()) + app.updater.stop = AsyncMock() + return app + + +@pytest.fixture +def mock_scheduler(): + """Create a mock APScheduler instance.""" + scheduler = MagicMock() + scheduler.add_job = MagicMock() + scheduler.start = MagicMock() + scheduler.running = True + scheduler.shutdown = MagicMock() + return scheduler + + +# ============================================================================ +# TelegramPlatformAdapter Initialization Tests +# ============================================================================ + + +class TestTelegramAdapterInit: + """Tests for TelegramPlatformAdapter initialization.""" + + def test_init_basic( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test basic adapter initialization.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + assert adapter.config == platform_config + assert adapter.settings == platform_settings + assert adapter.base_url == platform_config["telegram_api_base_url"] + assert adapter.enable_command_register is True + + def test_init_with_default_urls( + self, + event_queue, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test adapter uses default URLs when not configured.""" + config = { + "id": "test_telegram", + "telegram_token": "test_token", + "telegram_api_base_url": None, + "telegram_file_base_url": None, + } + + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter(config, platform_settings, event_queue) + + assert adapter.base_url == "https://api.telegram.org/bot" + + def test_init_media_group_settings( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test media group settings are correctly initialized.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + assert adapter.media_group_timeout == 2.5 + assert adapter.media_group_max_wait == 10.0 + assert adapter.media_group_cache == {} + + +# ============================================================================ +# TelegramPlatformAdapter Metadata Tests +# ============================================================================ + + +class TestTelegramAdapterMetadata: + """Tests for TelegramPlatformAdapter metadata.""" + + def test_meta_returns_correct_metadata( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test meta() returns correct PlatformMetadata.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + meta = adapter.meta() + + assert meta.name == "telegram" + assert "telegram" in meta.description.lower() + assert meta.id == "test_telegram" + + def test_meta_with_missing_id_uses_default( + self, + event_queue, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test meta() uses 'telegram' as default id when not configured.""" + config = { + "telegram_token": "test_token", + } + + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter(config, platform_settings, event_queue) + meta = adapter.meta() + + assert meta.id == "telegram" + + +# ============================================================================ +# TelegramPlatformAdapter Message Conversion Tests +# ============================================================================ + + +class TestTelegramAdapterConvertMessage: + """Tests for message conversion.""" + + @pytest.mark.asyncio + async def test_convert_text_message_private( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a private text message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update( + message_text="Hello World", + chat_type="private", + chat_id=123456789, + user_id=987654321, + username="test_user", + ) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert result.session_id == "123456789" + assert result.type == MessageType.FRIEND_MESSAGE + assert result.sender.user_id == "987654321" + assert result.sender.nickname == "test_user" + assert result.message_str == "Hello World" + + @pytest.mark.asyncio + async def test_convert_text_message_group( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a group text message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update( + message_text="Hello Group", + chat_type="group", + chat_id=111111111, + user_id=987654321, + username="test_user", + ) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert result.type == MessageType.GROUP_MESSAGE + assert result.group_id == "111111111" + + @pytest.mark.asyncio + async def test_convert_topic_group_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a topic (forum) group message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.message_type import MessageType + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update( + message_text="Hello Topic", + chat_type="supergroup", + chat_id=111111111, + user_id=987654321, + username="test_user", + message_thread_id=222, + is_topic_message=True, + ) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert result.type == MessageType.GROUP_MESSAGE + assert result.group_id == "111111111#222" + assert result.session_id == "111111111#222" + + @pytest.mark.asyncio + async def test_convert_photo_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a photo message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Create mock photo + mock_photo = MagicMock() + mock_photo.get_file = AsyncMock( + return_value=create_mock_file("https://example.com/photo.jpg") + ) + + update = create_mock_update( + message_text=None, + photo=[mock_photo], # Photo is a list, last one is largest + caption="Photo caption", + ) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert result.message_str == "Photo caption" + assert len(result.message) >= 1 # Should have at least Image component + + @pytest.mark.asyncio + async def test_convert_video_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a video message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Create mock video + mock_video = MagicMock() + mock_video.file_name = "test_video.mp4" + mock_video.get_file = AsyncMock( + return_value=create_mock_file("https://example.com/video.mp4") + ) + + update = create_mock_update(message_text=None, video=mock_video) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert len(result.message) >= 1 + + @pytest.mark.asyncio + async def test_convert_document_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a document message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Create mock document + mock_document = MagicMock() + mock_document.file_name = "test_document.pdf" + mock_document.get_file = AsyncMock( + return_value=create_mock_file("https://example.com/document.pdf") + ) + + update = create_mock_update(message_text=None, document=mock_document) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert len(result.message) >= 1 + + @pytest.mark.asyncio + async def test_convert_voice_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a voice message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Create mock voice + mock_voice = MagicMock() + mock_voice.get_file = AsyncMock( + return_value=create_mock_file("https://example.com/voice.ogg") + ) + + update = create_mock_update(message_text=None, voice=mock_voice) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert len(result.message) >= 1 + + @pytest.mark.asyncio + async def test_convert_sticker_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a sticker message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Create mock sticker + mock_sticker = MagicMock() + mock_sticker.emoji = "👍" + mock_sticker.get_file = AsyncMock( + return_value=create_mock_file("https://example.com/sticker.webp") + ) + + update = create_mock_update(message_text=None, sticker=mock_sticker) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + assert "Sticker: 👍" in result.message_str + + @pytest.mark.asyncio + async def test_convert_message_without_from_user( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a message without from_user returns None.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update() + update.message.from_user = None + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is None + + @pytest.mark.asyncio + async def test_convert_message_without_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting an update without message returns None.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = MagicMock() + update.message = None + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is None + + @pytest.mark.asyncio + async def test_convert_command_with_bot_mention( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a command with bot mention in group.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update( + message_text="/help@test_bot arg1", + chat_type="group", + ) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + # Should strip the bot mention from command + assert "@test_bot" not in result.message_str + + +# ============================================================================ +# TelegramPlatformAdapter Media Group Tests +# ============================================================================ + + +class TestTelegramAdapterMediaGroup: + """Tests for media group message handling.""" + + @pytest.mark.asyncio + async def test_handle_media_group_creates_cache( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test that media group message creates cache entry.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + + # Create a real scheduler mock that tracks add_job calls + scheduler = MagicMock() + scheduler.add_job = MagicMock() + scheduler.running = True + scheduler.shutdown = MagicMock() + mock_scheduler_class.return_value = scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + adapter.scheduler = scheduler + + update = create_mock_update( + message_text="Media item", + media_group_id="group_123", + ) + + context = MagicMock() + context.bot = mock_bot + + await adapter.handle_media_group_message(update, context) + + assert "group_123" in adapter.media_group_cache + assert len(adapter.media_group_cache["group_123"]["items"]) == 1 + scheduler.add_job.assert_called() + + @pytest.mark.asyncio + async def test_handle_media_group_accumulates_items( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test that multiple media group messages accumulate in cache.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + + scheduler = MagicMock() + scheduler.add_job = MagicMock() + scheduler.running = True + mock_scheduler_class.return_value = scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + adapter.scheduler = scheduler + + context = MagicMock() + context.bot = mock_bot + + # Send multiple messages with same media_group_id + for i in range(3): + update = create_mock_update( + message_text=f"Media item {i}", + media_group_id="group_456", + message_id=i + 1, + ) + await adapter.handle_media_group_message(update, context) + + assert len(adapter.media_group_cache["group_456"]["items"]) == 3 + + @pytest.mark.asyncio + async def test_handle_media_group_without_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test handling media group without message returns early.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = MagicMock() + update.message = None + + context = MagicMock() + + # Should not raise exception + await adapter.handle_media_group_message(update, context) + + assert len(adapter.media_group_cache) == 0 + + +# ============================================================================ +# TelegramPlatformAdapter Command Registration Tests +# ============================================================================ + + +class TestTelegramAdapterCommandRegistration: + """Tests for command registration.""" + + @pytest.mark.asyncio + async def test_register_commands_success( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test successful command registration.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + adapter.collect_commands = MagicMock( + return_value=[ + SimpleNamespace(command="help", description="help command"), + ] + ) + + await adapter.register_commands() + + mock_bot.delete_my_commands.assert_called_once() + mock_bot.set_my_commands.assert_called_once() + + @pytest.mark.asyncio + async def test_register_commands_empty_does_not_clear_existing( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test empty command list keeps existing Telegram commands.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + adapter.collect_commands = MagicMock(return_value=[]) + + await adapter.register_commands() + + mock_bot.delete_my_commands.assert_not_called() + mock_bot.set_my_commands.assert_not_called() + + def test_collect_commands_empty( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test collecting commands when no handlers are registered.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.star_handlers_registry", + [], + ), + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + + commands = adapter.collect_commands() + + assert commands == [] + + def test_collect_commands_includes_aliases( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test collecting commands includes command/group aliases.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.BotCommand", + side_effect=lambda cmd, desc: SimpleNamespace( + command=cmd, + description=desc, + ), + ), + ): + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + from astrbot.core.star.filter.command import CommandFilter + from astrbot.core.star.filter.command_group import CommandGroupFilter + + handler = SimpleNamespace( + handler_module_path="plugin.telegram.alias", + enabled=True, + desc="alias command", + event_filters=[ + CommandFilter("help", alias={"h"}), + CommandGroupFilter("admin", alias={"adm"}), + ], + ) + + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + with ( + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.star_handlers_registry", + [handler], + ), + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.star_map", + {"plugin.telegram.alias": SimpleNamespace(activated=True)}, + ), + ): + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + commands = adapter.collect_commands() + + names = sorted(cmd.command for cmd in commands) + assert names == ["admin", "h", "help"] + + def test_collect_commands_warns_on_duplicates( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test duplicate command names log warning and keep first one.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.BotCommand", + side_effect=lambda cmd, desc: SimpleNamespace( + command=cmd, + description=desc, + ), + ), + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.logger" + ) as mock_logger, + ): + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + from astrbot.core.star.filter.command import CommandFilter + + handler_a = SimpleNamespace( + handler_module_path="plugin.telegram.a", + enabled=True, + desc="first", + event_filters=[CommandFilter("help")], + ) + handler_b = SimpleNamespace( + handler_module_path="plugin.telegram.b", + enabled=True, + desc="second", + event_filters=[CommandFilter("help")], + ) + + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + with ( + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.star_handlers_registry", + [handler_a, handler_b], + ), + patch( + "astrbot.core.platform.sources.telegram.tg_adapter.star_map", + { + "plugin.telegram.a": SimpleNamespace(activated=True), + "plugin.telegram.b": SimpleNamespace(activated=True), + }, + ), + ): + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + commands = adapter.collect_commands() + + assert [cmd.command for cmd in commands] == ["help"] + mock_logger.warning.assert_called_once() + + +# ============================================================================ +# TelegramPlatformAdapter Run Tests +# ============================================================================ + + +class TestTelegramAdapterRun: + """Tests for run method.""" + + @pytest.mark.asyncio + async def test_run_initializes_application( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + ): + """Test run method initializes the application.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_application.bot + adapter.register_commands = AsyncMock() + + # Start run in background and cancel after short time + task = asyncio.create_task(adapter.run()) + + # Give it a moment to start + await asyncio.sleep(0.1) + + # Cancel the task + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + mock_application.initialize.assert_called_once() + mock_application.start.assert_called_once() + + +# ============================================================================ +# TelegramPlatformAdapter Terminate Tests +# ============================================================================ + + +class TestTelegramAdapterTerminate: + """Tests for terminate method.""" + + @pytest.mark.asyncio + async def test_terminate_stops_application( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test terminate method stops the application.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + await adapter.terminate() + + mock_application.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_terminate_shuts_down_scheduler( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test terminate method shuts down the scheduler.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + adapter.scheduler = mock_scheduler + + await adapter.terminate() + + mock_scheduler.shutdown.assert_called_once() + + +# ============================================================================ +# TelegramPlatformAdapter send_by_session Tests +# ============================================================================ + + +class TestTelegramAdapterSendBySession: + """Tests for send_by_session method.""" + + @pytest.mark.asyncio + async def test_send_by_session_calls_send_with_client( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test send_by_session calls send_with_client.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.astr_message_event import MessageSesion + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + session = MagicMock(spec=MessageSesion) + session.session_id = "123456789" + + message_chain = MagicMock() + message_chain.chain = [] + + with patch( + "astrbot.core.platform.sources.telegram.tg_adapter.TelegramPlatformEvent.send_with_client", + new_callable=AsyncMock, + ) as mock_send: + await adapter.send_by_session(session, message_chain) + + mock_send.assert_called_once_with(mock_bot, message_chain, "123456789") + + +# ============================================================================ +# TelegramPlatformEvent Tests +# ============================================================================ + + +class TestTelegramPlatformEvent: + """Tests for TelegramPlatformEvent class.""" + + def test_split_message_short_text(self): + """Test _split_message returns single chunk for short text.""" + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + text = "Short message" + result = TelegramPlatformEvent._split_message(text) + + assert len(result) == 1 + assert result[0] == text + + def test_split_message_long_text(self): + """Test _split_message splits long text into chunks.""" + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + # Create text longer than MAX_MESSAGE_LENGTH + text = "A" * 5000 + result = TelegramPlatformEvent._split_message(text) + + # Should be split into multiple chunks + assert len(result) > 1 + # Each chunk should be <= MAX_MESSAGE_LENGTH + for chunk in result: + assert len(chunk) <= TelegramPlatformEvent.MAX_MESSAGE_LENGTH + + def test_split_message_respects_paragraph_breaks(self): + """Test _split_message prefers paragraph breaks for splitting.""" + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + # Create text with paragraph breaks + para1 = "A" * 3000 + para2 = "B" * 3000 + text = f"{para1}\n\n{para2}" + + result = TelegramPlatformEvent._split_message(text) + + # Should split at paragraph break + assert len(result) >= 2 + + def test_get_chat_action_for_chain_voice(self): + """Test _get_chat_action_for_chain returns UPLOAD_VOICE for Record.""" + from astrbot.api.message_components import Record + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + chain = [Record(file="test.ogg", url="test.ogg")] + result = TelegramPlatformEvent._get_chat_action_for_chain(chain) + + assert result == "upload_voice" + + def test_get_chat_action_for_chain_image(self): + """Test _get_chat_action_for_chain returns UPLOAD_PHOTO for Image.""" + from astrbot.api.message_components import Image + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + chain = [Image(file="test.jpg", url="test.jpg")] + result = TelegramPlatformEvent._get_chat_action_for_chain(chain) + + assert result == "upload_photo" + + def test_get_chat_action_for_chain_file(self): + """Test _get_chat_action_for_chain returns UPLOAD_DOCUMENT for File.""" + from astrbot.api.message_components import File + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + chain = [File(file="test.pdf", name="test.pdf")] + result = TelegramPlatformEvent._get_chat_action_for_chain(chain) + + assert result == "upload_document" + + def test_get_chat_action_for_chain_plain(self): + """Test _get_chat_action_for_chain returns TYPING for Plain text.""" + from astrbot.api.message_components import Plain + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + chain = [Plain("Hello")] + result = TelegramPlatformEvent._get_chat_action_for_chain(chain) + + assert result == "typing" + + +class TestTelegramPlatformEventSend: + """Tests for TelegramPlatformEvent send methods.""" + + @pytest.fixture + def event_setup(self, mock_bot): + """Create a basic event setup for testing.""" + from astrbot.api.platform import AstrBotMessage, PlatformMetadata + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + message_obj = AstrBotMessage() + message_obj.session_id = "123456789" + message_obj.message_id = "1" + message_obj.group_id = None + + platform_meta = PlatformMetadata(name="telegram", description="test", id="test") + + event = TelegramPlatformEvent( + message_str="Test message", + message_obj=message_obj, + platform_meta=platform_meta, + session_id="123456789", + client=mock_bot, + ) + + return event, mock_bot + + @pytest.mark.asyncio + async def test_send_typing(self, event_setup): + """Test send_typing method.""" + event, mock_bot = event_setup + + await event.send_typing() + + mock_bot.send_chat_action.assert_called() + + @pytest.mark.asyncio + async def test_react_with_emoji(self, event_setup): + """Test react method with regular emoji.""" + event, mock_bot = event_setup + + await event.react("👍") + + mock_bot.set_message_reaction.assert_called_once() + + @pytest.mark.asyncio + async def test_react_with_custom_emoji(self, event_setup): + """Test react method with custom emoji ID.""" + event, mock_bot = event_setup + + await event.react("123456789") # Custom emoji ID + + mock_bot.set_message_reaction.assert_called_once() + + @pytest.mark.asyncio + async def test_react_clear(self, event_setup): + """Test react method clears reaction when None is passed.""" + event, mock_bot = event_setup + + await event.react(None) + + mock_bot.set_message_reaction.assert_called_once() + call_args = mock_bot.set_message_reaction.call_args + assert call_args[1]["reaction"] == [] + + +class TestTelegramPlatformEventSendWithClient: + """Tests for send_with_client class method.""" + + @pytest.mark.asyncio + async def test_send_with_client_plain_text(self, mock_bot): + """Test send_with_client sends plain text message.""" + from astrbot.api.event import MessageChain + from astrbot.api.message_components import Plain + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + message = MessageChain() + message.chain = [Plain("Hello World")] + + await TelegramPlatformEvent.send_with_client(mock_bot, message, "123456789") + + mock_bot.send_message.assert_called() + + @pytest.mark.asyncio + async def test_send_with_client_with_reply(self, mock_bot): + """Test send_with_client sends message with reply.""" + from astrbot.api.event import MessageChain + from astrbot.api.message_components import Plain, Reply + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + message = MessageChain() + reply = MagicMock() + reply.id = "123" + message.chain = [ + Reply( + id="123", + chain=[], + sender_id="1", + sender_nickname="test", + time=0, + message_str="", + text="", + qq="1", + ), + Plain("Reply text"), + ] + + await TelegramPlatformEvent.send_with_client(mock_bot, message, "123456789") + + mock_bot.send_message.assert_called() + + @pytest.mark.asyncio + async def test_send_with_client_to_topic_group(self, mock_bot): + """Test send_with_client handles topic group (with # in username).""" + from astrbot.api.event import MessageChain + from astrbot.api.message_components import Plain + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + message = MessageChain() + message.chain = [Plain("Topic message")] + + # Topic group format: chat_id#thread_id + await TelegramPlatformEvent.send_with_client(mock_bot, message, "123456789#222") + + mock_bot.send_chat_action.assert_called() + + +# ============================================================================ +# TelegramPlatformEvent Voice Fallback Tests +# ============================================================================ + + +class TestTelegramPlatformEventVoiceFallback: + """Tests for voice message fallback functionality.""" + + @pytest.mark.asyncio + async def test_send_voice_with_fallback_success(self, mock_bot): + """Test _send_voice_with_fallback sends voice normally.""" + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + payload = {"chat_id": "123456789"} + + await TelegramPlatformEvent._send_voice_with_fallback( + mock_bot, + "voice.ogg", + payload, + ) + + mock_bot.send_voice.assert_called_once() + + @pytest.mark.asyncio + async def test_send_voice_with_fallback_to_document(self, mock_bot): + """Test _send_voice_with_fallback falls back to document on Voice_messages_forbidden.""" + from telegram.error import BadRequest + + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + # Create a BadRequest with Voice_messages_forbidden message + error = BadRequest("Voice_messages_forbidden") + mock_bot.send_voice = AsyncMock(side_effect=error) + + payload = {"chat_id": "123456789"} + + await TelegramPlatformEvent._send_voice_with_fallback( + mock_bot, + "voice.ogg", + payload, + caption="Voice caption", + ) + + mock_bot.send_document.assert_called_once() + + @pytest.mark.asyncio + async def test_send_voice_with_fallback_reraises_other_errors(self, mock_bot): + """Test _send_voice_with_fallback re-raises non-voice-forbidden errors.""" + from telegram.error import BadRequest + + from astrbot.core.platform.sources.telegram.tg_event import ( + TelegramPlatformEvent, + ) + + # Create a BadRequest with different message + error = BadRequest("Some other error") + mock_bot.send_voice = AsyncMock(side_effect=error) + + payload = {"chat_id": "123456789"} + + with pytest.raises(BadRequest): + await TelegramPlatformEvent._send_voice_with_fallback( + mock_bot, + "voice.ogg", + payload, + ) + + +# ============================================================================ +# Integration-style Tests +# ============================================================================ + + +class TestTelegramAdapterIntegration: + """Integration-style tests for complete message flows.""" + + @pytest.mark.asyncio + async def test_message_handler_processes_text_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test message_handler processes a text message end-to-end.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update( + message_text="Hello bot!", + chat_type="private", + ) + + context = MagicMock() + context.bot = mock_bot + + await adapter.message_handler(update, context) + + # Check that an event was committed to the queue + assert not event_queue.empty() + + @pytest.mark.asyncio + async def test_start_command_sends_welcome_message( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test /start command sends welcome message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + update = create_mock_update( + message_text="/start", + chat_type="private", + ) + + context = MagicMock() + context.bot = mock_bot + + # convert_message should return None for /start + result = await adapter.convert_message(update, context) + + assert result is None + mock_bot.send_message.assert_called() + + +# ============================================================================ +# Edge Cases and Error Handling Tests +# ============================================================================ + + +class TestTelegramAdapterEdgeCases: + """Tests for edge cases and error handling.""" + + @pytest.mark.asyncio + async def test_convert_message_with_reply( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test converting a message that replies to another message.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Create a reply message + reply_message = MagicMock() + reply_message.message_id = 100 + reply_message.chat = MagicMock() + reply_message.chat.id = 123456789 + reply_message.chat.type = "private" + reply_message.from_user = MagicMock() + reply_message.from_user.id = 111111111 + reply_message.from_user.username = "reply_user" + reply_message.text = "Original message" + reply_message.message_thread_id = None + reply_message.is_topic_message = False + reply_message.media_group_id = None + reply_message.photo = None + reply_message.video = None + reply_message.document = None + reply_message.voice = None + reply_message.sticker = None + reply_message.reply_to_message = None + reply_message.caption = None + reply_message.entities = None + reply_message.caption_entities = None + + update = create_mock_update( + message_text="Reply text", + reply_to_message=reply_message, + ) + + context = MagicMock() + context.bot = mock_bot + + result = await adapter.convert_message(update, context) + + assert result is not None + # Should have Reply component in message + assert len(result.message) >= 1 + + @pytest.mark.asyncio + async def test_process_media_group_empty_cache( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test process_media_group handles missing cache entry gracefully.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Should not raise exception for non-existent media group + await adapter.process_media_group("non_existent_group") + + assert True # Just verify no exception + + @pytest.mark.asyncio + async def test_register_commands_handles_exception( + self, + event_queue, + platform_config, + platform_settings, + mock_application, + mock_scheduler, + mock_bot, + ): + """Test register_commands handles exceptions gracefully.""" + with ( + patch("telegram.ext.ApplicationBuilder") as mock_builder_class, + patch( + "apscheduler.schedulers.asyncio.AsyncIOScheduler" + ) as mock_scheduler_class, + ): + mock_builder = MagicMock() + mock_builder.token.return_value = mock_builder + mock_builder.base_url.return_value = mock_builder + mock_builder.base_file_url.return_value = mock_builder + mock_builder.build.return_value = mock_application + mock_builder_class.return_value = mock_builder + mock_scheduler_class.return_value = mock_scheduler + + from astrbot.core.platform.sources.telegram.tg_adapter import ( + TelegramPlatformAdapter, + ) + + adapter = TelegramPlatformAdapter( + platform_config, platform_settings, event_queue + ) + adapter.client = mock_bot + + # Make delete_my_commands raise an exception + mock_bot.delete_my_commands = AsyncMock( + side_effect=Exception("Network error") + ) + + # Should not raise exception + await adapter.register_commands() + + assert True # Just verify no exception From b50e335ff9e3bff81c57c1351d681a101e3f24c0 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 26 Feb 2026 11:45:13 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=20Discord=20?= =?UTF-8?q?=E5=92=8C=20Telegram=20=E9=80=82=E9=85=8D=E5=99=A8=E7=9A=84?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E6=B3=A8=E5=86=8C=E5=8F=8A=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=A4=96=E9=83=A8=E4=BE=9D=E8=B5=96=E7=9A=84=E6=8F=90=E5=89=8D?= =?UTF-8?q?=20mock=E6=9D=A5=E8=AE=A9test=E8=83=BD=E9=A1=BA=E5=88=A9pass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../discord/discord_platform_adapter.py | 6 +- .../platform/sources/telegram/tg_adapter.py | 13 +- tests/conftest.py | 136 +++++++++++++++++- tests/unit/test_aiocqhttp_adapter.py | 2 + tests/unit/test_telegram_adapter.py | 5 +- 5 files changed, 149 insertions(+), 13 deletions(-) diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index fa480fc1cd..be70f4e92f 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -365,7 +365,6 @@ def register_handler(self, handler_info) -> None: async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") - registered_commands = [] registered_command_names: set[str] = set() for handler_md in star_handlers_registry: @@ -407,11 +406,10 @@ async def _collect_and_register_commands(self) -> None: ) self.client.add_application_command(slash_command) registered_command_names.add(cmd_name) - registered_commands.append(cmd_name) - if registered_commands: + if registered_command_names: logger.info( - f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}", + f"[Discord] 准备同步 {len(registered_command_names)} 个指令: {', '.join(sorted(registered_command_names))}", ) else: logger.info("[Discord] 没有发现可注册的指令。") diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index e2000c13fb..d13f0bcf69 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -2,7 +2,6 @@ import re import sys import uuid -from typing import Any import apscheduler.schedulers.asyncio as _apscheduler_asyncio_import import telegram.ext as _telegram_ext_import @@ -227,6 +226,8 @@ def _extract_command_info( if event_filter.parent_group: return None cmd_names = [event_filter.group_name] + if getattr(event_filter, "alias", None): + cmd_names.extend(event_filter.alias) is_group = True result: list[tuple[str, str]] = [] @@ -250,7 +251,7 @@ def _extract_command_info( return result if result else None - async def start(self, update: Update, context: Any) -> None: + async def start(self, update: Update, context: telegram_ext.CallbackContext) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -261,7 +262,7 @@ async def start(self, update: Update, context: Any) -> None: text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: Any) -> None: + async def message_handler(self, update: Update, context: telegram_ext.CallbackContext) -> None: logger.debug(f"Telegram message: {update.message}") # Handle media group messages @@ -277,7 +278,7 @@ async def message_handler(self, update: Update, context: Any) -> None: async def convert_message( self, update: Update, - context: Any, + context: telegram_ext.CallbackContext, get_reply=True, ) -> AstrBotMessage | None: """转换 Telegram 的消息对象为 AstrBotMessage 对象。 @@ -444,7 +445,7 @@ async def convert_message( return message - async def handle_media_group_message(self, update: Update, context: Any): + async def handle_media_group_message(self, update: Update, context: telegram_ext.CallbackContext): """Handle messages that are part of a media group (album). Caches incoming messages and schedules delayed processing to collect all @@ -559,7 +560,7 @@ async def handle_msg(self, message: AstrBotMessage) -> None: ) self.commit_event(message_event) - def get_client(self): + def get_client(self) -> telegram_ext.ExtBot: return self.client async def terminate(self) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index b9807c1ded..3da25b9c9c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -90,7 +90,14 @@ def pytest_addoption(parser): def pytest_configure(config): - """注册自定义标记。""" + """注册自定义标记并提前 mock 外部依赖。 + + 在测试收集阶段提前 mock 外部依赖(telegram, discord, apscheduler 等), + 避免 import 时触发循环导入问题。 + """ + # 提前 mock 外部依赖,解决循环导入问题 + _mock_external_dependencies() + config.addinivalue_line("markers", "unit: 单元测试") config.addinivalue_line("markers", "integration: 集成测试") config.addinivalue_line("markers", "slow: 慢速测试") @@ -101,6 +108,133 @@ def pytest_configure(config): config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)") +def _mock_external_dependencies(): + """提前 mock 外部依赖模块,避免循环导入。 + + 某些外部依赖(如 telegram, discord, apscheduler)在导入时会触发 + astrbot 内部的循环导入链。通过提前 mock 这些模块,可以避免问题。 + """ + # 检查是否已经 mock 过 + if hasattr(sys, "_astrbot_external_deps_mocked"): + return + + # Mock astrbot.core.star 相关模块(解决循环导入) + _mock_star_modules() + + # Mock telegram + if "telegram" not in sys.modules: + mock_telegram = MagicMock() + mock_telegram.BotCommand = MagicMock + mock_telegram.Update = MagicMock + mock_telegram.constants = MagicMock() + mock_telegram.constants.ChatType = MagicMock() + mock_telegram.constants.ChatType.PRIVATE = "private" + mock_telegram.constants.ChatAction = MagicMock() + mock_telegram.error = MagicMock() + mock_telegram.error.BadRequest = Exception + mock_telegram.ReactionTypeCustomEmoji = MagicMock + mock_telegram.ReactionTypeEmoji = MagicMock + + mock_telegram_ext = MagicMock() + mock_telegram_ext.ApplicationBuilder = MagicMock + mock_telegram_ext.CallbackContext = MagicMock + mock_telegram_ext.ContextTypes = MagicMock + mock_telegram_ext.ExtBot = MagicMock + mock_telegram_ext.filters = MagicMock() + mock_telegram_ext.filters.ALL = MagicMock() + mock_telegram_ext.MessageHandler = MagicMock + + sys.modules["telegram"] = mock_telegram + sys.modules["telegram.constants"] = mock_telegram.constants + sys.modules["telegram.error"] = mock_telegram.error + sys.modules["telegram.ext"] = mock_telegram_ext + + # Mock telegramify_markdown + if "telegramify_markdown" not in sys.modules: + mock_telegramify = MagicMock() + mock_telegramify.markdownify = lambda text, **kwargs: text + sys.modules["telegramify_markdown"] = mock_telegramify + + # Mock apscheduler + if "apscheduler" not in sys.modules: + mock_apscheduler = MagicMock() + mock_apscheduler.schedulers = MagicMock() + mock_apscheduler.schedulers.asyncio = MagicMock() + mock_apscheduler.schedulers.asyncio.AsyncIOScheduler = MagicMock + mock_apscheduler.schedulers.background = MagicMock() + mock_apscheduler.schedulers.background.BackgroundScheduler = MagicMock + + sys.modules["apscheduler"] = mock_apscheduler + sys.modules["apscheduler.schedulers"] = mock_apscheduler.schedulers + sys.modules["apscheduler.schedulers.asyncio"] = mock_apscheduler.schedulers.asyncio + sys.modules["apscheduler.schedulers.background"] = mock_apscheduler.schedulers.background + + # Mock discord (py-cord) + if "discord" not in sys.modules: + mock_discord = MagicMock() + mock_discord.Client = MagicMock + mock_discord.Intents = MagicMock + mock_discord.Message = MagicMock + mock_discord.ApplicationContext = MagicMock + mock_discord.Option = MagicMock + mock_discord.SlashCommand = MagicMock + mock_discord.SlashCommandOptionType = MagicMock() + mock_discord.SlashCommandOptionType.string = 3 + + sys.modules["discord"] = mock_discord + + # 标记已 mock + sys._astrbot_external_deps_mocked = True + + +def _mock_star_modules(): + """Mock astrbot.core.star 相关模块,解决循环导入问题。""" + # 创建 mock star_map 和 star_handlers_registry + mock_star_map: dict = {} + mock_star_handlers_registry: list = [] + + # Mock astrbot.core.star.star + mock_star_module = MagicMock() + mock_star_module.star_map = mock_star_map + mock_star_module.star_registry = [] + mock_star_module.StarMetadata = MagicMock + + # Mock astrbot.core.star.star_handler + mock_star_handler_module = MagicMock() + mock_star_handler_module.star_handlers_registry = mock_star_handlers_registry + + # Mock astrbot.core.star.filter.command + # 创建真正的类,而不是函数,以便 isinstance 可以正常工作 + class MockCommandFilter: + """Mock CommandFilter 类。""" + + def __init__(self, command_name: str, alias: set | None = None, **kwargs): + self.command_name = command_name + self.alias = alias or set() + self.parent_command_names = kwargs.get("parent_command_names", []) + + mock_command_filter_module = MagicMock() + mock_command_filter_module.CommandFilter = MockCommandFilter + + # Mock astrbot.core.star.filter.command_group + class MockCommandGroupFilter: + """Mock CommandGroupFilter 类。""" + + def __init__(self, group_name: str, alias: set | None = None, **kwargs): + self.group_name = group_name + self.alias = alias or set() + self.parent_group = kwargs.get("parent_group", None) + + mock_command_group_filter_module = MagicMock() + mock_command_group_filter_module.CommandGroupFilter = MockCommandGroupFilter + + # 注册到 sys.modules + sys.modules["astrbot.core.star.star"] = mock_star_module + sys.modules["astrbot.core.star.star_handler"] = mock_star_handler_module + sys.modules["astrbot.core.star.filter.command"] = mock_command_filter_module + sys.modules["astrbot.core.star.filter.command_group"] = mock_command_group_filter_module + + # ============================================================ # 临时目录和文件 Fixtures # ============================================================ diff --git a/tests/unit/test_aiocqhttp_adapter.py b/tests/unit/test_aiocqhttp_adapter.py index 8581365549..9783ed30f4 100644 --- a/tests/unit/test_aiocqhttp_adapter.py +++ b/tests/unit/test_aiocqhttp_adapter.py @@ -481,10 +481,12 @@ async def test_handle_msg_creates_event( from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import ( AiocqhttpAdapter, ) + from astrbot.core.platform import MessageType adapter = AiocqhttpAdapter(platform_config, platform_settings, event_queue) message = AstrBotMessage() + message.type = MessageType.FRIEND_MESSAGE message.message_str = "Test message" message.session_id = "test_session" diff --git a/tests/unit/test_telegram_adapter.py b/tests/unit/test_telegram_adapter.py index 648338e029..056d0ccdd2 100644 --- a/tests/unit/test_telegram_adapter.py +++ b/tests/unit/test_telegram_adapter.py @@ -1201,7 +1201,7 @@ def test_collect_commands_includes_aliases( commands = adapter.collect_commands() names = sorted(cmd.command for cmd in commands) - assert names == ["admin", "h", "help"] + assert names == ["adm", "admin", "h", "help"] def test_collect_commands_warns_on_duplicates( self, @@ -1586,12 +1586,13 @@ class TestTelegramPlatformEventSend: @pytest.fixture def event_setup(self, mock_bot): """Create a basic event setup for testing.""" - from astrbot.api.platform import AstrBotMessage, PlatformMetadata + from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata from astrbot.core.platform.sources.telegram.tg_event import ( TelegramPlatformEvent, ) message_obj = AstrBotMessage() + message_obj.type = MessageType.FRIEND_MESSAGE message_obj.session_id = "123456789" message_obj.message_id = "1" message_obj.group_id = None From 38cc5973a16c11f7db4162af90847951d44011cd Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 26 Feb 2026 11:56:51 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=20Discord=20?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=99=A8=E4=B8=AD=E6=8C=87=E4=BB=A4=E6=B3=A8?= =?UTF-8?q?=E5=86=8C=E7=9A=84=E5=8F=98=E9=87=8F=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sources/discord/discord_platform_adapter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index be70f4e92f..9a153933e3 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -365,7 +365,7 @@ def register_handler(self, handler_info) -> None: async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") - registered_command_names: set[str] = set() + registered_commands: set[str] = set() for handler_md in star_handlers_registry: if not star_map[handler_md.handler_module_path].activated: @@ -375,7 +375,7 @@ async def _collect_and_register_commands(self) -> None: for event_filter in handler_md.event_filters: cmd_infos = self._extract_command_infos(event_filter, handler_md) for cmd_name, description in cmd_infos: - if cmd_name in registered_command_names: + if cmd_name in registered_commands: logger.warning( "[Discord] Duplicate slash command '%s' from %s ignored.", cmd_name, @@ -405,11 +405,11 @@ async def _collect_and_register_commands(self) -> None: guild_ids=[self.guild_id] if self.guild_id else None, ) self.client.add_application_command(slash_command) - registered_command_names.add(cmd_name) + registered_commands.add(cmd_name) - if registered_command_names: + if registered_commands: logger.info( - f"[Discord] 准备同步 {len(registered_command_names)} 个指令: {', '.join(sorted(registered_command_names))}", + f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(sorted(registered_commands))}", ) else: logger.info("[Discord] 没有发现可注册的指令。") From eae10c462c6be8e7f7ccc9a4bef8929cca289629 Mon Sep 17 00:00:00 2001 From: whatevertogo Date: Thu, 26 Feb 2026 12:18:23 +0800 Subject: [PATCH 4/4] =?UTF-8?q?refactor:=20=E6=A0=BC=E5=BC=8F=E5=8C=96=20T?= =?UTF-8?q?elegramPlatformAdapter=20=E4=B8=AD=E7=9A=84=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/platform/sources/telegram/tg_adapter.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 15d3402f4b..29816b948c 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -256,7 +256,9 @@ def _extract_command_info( return result if result else None - async def start(self, update: Update, context: telegram_ext.CallbackContext) -> None: + async def start( + self, update: Update, context: telegram_ext.CallbackContext + ) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -267,7 +269,9 @@ async def start(self, update: Update, context: telegram_ext.CallbackContext) -> text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: telegram_ext.CallbackContext) -> None: + async def message_handler( + self, update: Update, context: telegram_ext.CallbackContext + ) -> None: logger.debug(f"Telegram message: {update.message}") # Handle media group messages @@ -461,7 +465,9 @@ async def convert_message( return message - async def handle_media_group_message(self, update: Update, context: telegram_ext.CallbackContext): + async def handle_media_group_message( + self, update: Update, context: telegram_ext.CallbackContext + ): """Handle messages that are part of a media group (album). Caches incoming messages and schedules delayed processing to collect all