Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import re
from collections.abc import AsyncGenerator

from aiocqhttp import CQHttp, Event
import aiocqhttp

from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
from astrbot.core.message.components import (
BaseMessageComponent,
File,
Image,
Expand All @@ -15,7 +14,8 @@
Record,
Video,
)
from astrbot.api.platform import Group, MessageMember
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import AstrMessageEvent, Group, MessageMember


class AiocqhttpMessageEvent(AstrMessageEvent):
Expand All @@ -25,7 +25,7 @@ def __init__(
message_obj,
platform_meta,
session_id,
bot: CQHttp,
bot: aiocqhttp.CQHttp,
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
Expand Down Expand Up @@ -80,8 +80,8 @@ async def _parse_onebot_json(message_chain: MessageChain):
@classmethod
async def _dispatch_send(
cls,
bot: CQHttp,
event: Event | None,
bot: aiocqhttp.CQHttp,
event: aiocqhttp.Event | None,
is_group: bool,
session_id: str | None,
messages: list[dict],
Expand All @@ -95,7 +95,7 @@ async def _dispatch_send(
await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底
elif isinstance(event, aiocqhttp.Event): # 最后兜底
await bot.send(event=event, message=messages)
else:
raise ValueError(
Expand All @@ -105,9 +105,9 @@ async def _dispatch_send(
@classmethod
async def send_message(
cls,
bot: CQHttp,
bot: aiocqhttp.CQHttp,
message_chain: MessageChain,
event: Event | None = None,
event: aiocqhttp.Event | None = None,
is_group: bool = False,
session_id: str | None = None,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,38 @@
import asyncio
import importlib
import inspect
import itertools
import logging
import time
import uuid
from collections.abc import Awaitable
from collections.abc import Awaitable, Callable
from typing import Any, cast

from aiocqhttp import CQHttp, Event
import aiocqhttp
from aiocqhttp.exceptions import ActionFailed

from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import *
from astrbot.api.platform import (
from astrbot import logger
from astrbot.core.message.components import (
At,
ComponentTypes,
File,
Image,
Plain,
Poke,
Reply,
)
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
Group,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.register import register_platform_adapter

from ...register import register_platform_adapter
from .aiocqhttp_message_event import *
from .aiocqhttp_message_event import AiocqhttpMessageEvent


Expand All @@ -38,6 +47,7 @@ def __init__(
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
bot_factory: Callable[..., Any] | None = None,
) -> None:
super().__init__(platform_config, event_queue)

Expand All @@ -52,17 +62,10 @@ def __init__(
support_streaming_message=False,
)

self.bot = CQHttp(
use_ws_reverse=True,
import_name="aiocqhttp",
api_timeout_sec=180,
access_token=platform_config.get(
"ws_reverse_token",
), # 以防旧版本配置不存在
)
self.bot = self._create_bot(platform_config, bot_factory=bot_factory)

@self.bot.on_request()
async def request(event: Event) -> None:
async def request(event: aiocqhttp.Event) -> None:
try:
abm = await self.convert_message(event)
if not abm:
Expand All @@ -73,7 +76,7 @@ async def request(event: Event) -> None:
return

@self.bot.on_notice()
async def notice(event: Event) -> None:
async def notice(event: aiocqhttp.Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
Expand All @@ -83,7 +86,7 @@ async def notice(event: Event) -> None:
return

@self.bot.on_message("group")
async def group(event: Event) -> None:
async def group(event: aiocqhttp.Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
Expand All @@ -93,7 +96,7 @@ async def group(event: Event) -> None:
return

@self.bot.on_message("private")
async def private(event: Event) -> None:
async def private(event: aiocqhttp.Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
Expand All @@ -106,6 +109,29 @@ async def private(event: Event) -> None:
def on_websocket_connection(_) -> None:
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")

@staticmethod
def _create_bot(
platform_config: dict,
bot_factory: Callable[..., Any] | None = None,
) -> aiocqhttp.CQHttp:
if bot_factory is None:
# Resolve aiocqhttp at runtime so tests that swap sys.modules later
# still affect bot creation even if this module was imported earlier.
aiocqhttp_module = importlib.import_module("aiocqhttp")
bot_factory = aiocqhttp_module.CQHttp

return cast(
aiocqhttp.CQHttp,
bot_factory(
use_ws_reverse=True,
import_name="aiocqhttp",
api_timeout_sec=180,
access_token=platform_config.get(
"ws_reverse_token",
), # 以防旧版本配置不存在
),
)

async def send_by_session(
self,
session: MessageSesion,
Expand All @@ -125,7 +151,7 @@ async def send_by_session(
)
await super().send_by_session(session, message_chain)

async def convert_message(self, event: Event) -> AstrBotMessage | None:
async def convert_message(self, event: aiocqhttp.Event) -> AstrBotMessage | None:
logger.debug(f"[aiocqhttp] RawMessage {event}")

if event["post_type"] == "message":
Expand All @@ -140,7 +166,9 @@ async def convert_message(self, event: Event) -> AstrBotMessage | None:

return abm

async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
async def _convert_handle_request_event(
self, event: aiocqhttp.Event
) -> AstrBotMessage:
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
Expand All @@ -165,7 +193,9 @@ async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
abm.raw_message = event
return abm

async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
async def _convert_handle_notice_event(
self, event: aiocqhttp.Event
) -> AstrBotMessage:
"""OneBot V11 通知类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
Expand Down Expand Up @@ -197,7 +227,7 @@ async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:

async def _convert_handle_message_event(
self,
event: Event,
event: aiocqhttp.Event,
get_reply=True,
) -> AstrBotMessage:
"""OneBot V11 消息类事件
Expand Down Expand Up @@ -310,7 +340,7 @@ async def _convert_handle_message_event(
)
# 添加必要的 post_type 字段,防止 Event.from_payload 报错
reply_event_data["post_type"] = "message"
new_event = Event.from_payload(reply_event_data)
new_event = aiocqhttp.Event.from_payload(reply_event_data)
if not new_event:
logger.error(
f"无法从回复消息数据构造 Event 对象: {reply_event_data}",
Expand Down Expand Up @@ -402,6 +432,14 @@ async def _convert_handle_message_event(
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
)
continue
if (
t == "image"
and not m["data"].get("file")
and m["data"].get("url")
):
a = Image(file=m["data"]["url"], url=m["data"]["url"])
abm.message.append(a)
continue
a = ComponentTypes[t](**m["data"])
abm.message.append(a)
except Exception as e:
Expand Down Expand Up @@ -492,5 +530,5 @@ async def handle_msg(self, message: AstrBotMessage) -> None:

self.commit_event(message_event)

def get_client(self) -> CQHttp:
def get_client(self) -> aiocqhttp.CQHttp:
return self.bot
Loading