diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 894af238..17d5d6d5 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,31 +1,37 @@ -from rich.traceback import install -from typing import Optional +from typing import Any, Optional, Tuple import asyncio +import traceback +from rich.traceback import install -from src.common.message_server.api import get_global_api -from src.common.logger import get_logger -from src.common.database.database import get_db_session from src.chat.message_receive.message import SessionMessage +from src.chat.utils.utils import calculate_typing_time, truncate_message from src.common.data_models.message_component_data_model import ReplyComponent -from src.chat.utils.utils import truncate_message -from src.chat.utils.utils import calculate_typing_time +from src.common.database.database import get_db_session +from src.common.logger import get_logger +from src.common.message_server.api import get_global_api +from src.webui.routers.chat.serializers import serialize_message_sequence install(extra_lines=3) logger = get_logger("sender") # WebUI 聊天室的消息广播器(延迟导入避免循环依赖) -_webui_chat_broadcaster = None +_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None # 虚拟群 ID 前缀(与 chat_routes.py 保持一致) VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" # TODO: 重构完成后完成webui相关 -def get_webui_chat_broadcaster(): - """获取 WebUI 聊天室广播器""" +def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]: + """获取 WebUI 聊天室广播器。 + + Returns: + Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组; + 若 WebUI 相关模块不可用,则元素会退化为 ``None``。 + """ global _webui_chat_broadcaster if _webui_chat_broadcaster is None: try: @@ -38,102 +44,36 @@ def get_webui_chat_broadcaster(): def is_webui_virtual_group(group_id: str) -> bool: - """检查是否是 WebUI 虚拟群""" - return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) - - -def parse_message_segments(segment) -> list: - """解析消息段,转换为 WebUI 可用的格式 - - 参考 NapCat 适配器的消息解析逻辑 + """检查是否是 WebUI 虚拟群。 Args: - segment: Seg 消息段对象 + group_id: 待判断的群 ID。 Returns: - list: 消息段列表,每个元素为 {"type": "...", "data": ...} + bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。 """ - - result = [] - - if segment is None: - return result - - if segment.type == "seglist": - # 处理消息段列表 - if segment.data: - for seg in segment.data: - result.extend(parse_message_segments(seg)) - elif segment.type == "text": - # 文本消息 - if segment.data: - result.append({"type": "text", "data": segment.data}) - elif segment.type == "image": - # 图片消息(base64) - if segment.data: - result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"}) - elif segment.type == "emoji": - # 表情包消息(base64) - if segment.data: - result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"}) - elif segment.type == "imageurl": - # 图片链接消息 - if segment.data: - result.append({"type": "image", "data": segment.data}) - elif segment.type == "face": - # 原生表情 - result.append({"type": "face", "data": segment.data}) - elif segment.type == "voice": - # 语音消息(base64) - if segment.data: - result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"}) - elif segment.type == "voiceurl": - # 语音链接 - if segment.data: - result.append({"type": "voice", "data": segment.data}) - elif segment.type == "video": - # 视频消息(base64) - if segment.data: - result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"}) - elif segment.type == "videourl": - # 视频链接 - if segment.data: - result.append({"type": "video", "data": segment.data}) - elif segment.type == "music": - # 音乐消息 - result.append({"type": "music", "data": segment.data}) - elif segment.type == "file": - # 文件消息 - result.append({"type": "file", "data": segment.data}) - elif segment.type == "reply": - # 回复消息 - result.append({"type": "reply", "data": segment.data}) - elif segment.type == "forward": - # 转发消息 - forward_items = [] - if segment.data: - for item in segment.data: - forward_items.append( - { - "content": parse_message_segments(item.get("message_segment", {})) - if isinstance(item, dict) - else [] - } - ) - result.append({"type": "forward", "data": forward_items}) - else: - # 未知类型,尝试作为文本处理 - if segment.data: - result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) - - return result + return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) -async def _send_message(message: MessageSending, show_log=True) -> bool: - """合并后的消息发送函数,包含WS发送和日志记录""" +async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: + """执行统一的消息发送流程。 + + 发送顺序为: + 1. WebUI 特殊链路 + 2. Platform IO 适配器链路 + 3. 旧版 ``maim_message`` / API Server 链路 + + Args: + message: 待发送的内部会话消息。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 是否最终发送成功。 + """ message_preview = truncate_message(message.processed_plain_text, max_length=200) platform = message.platform - group_id = message.session.group_id + group_info = message.message_info.group_info + group_id = group_info.group_id if group_info is not None else "" try: # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 @@ -146,7 +86,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: from src.config.config import global_config # 解析消息段,获取富文本内容 - message_segments = parse_message_segments(message.message_segment) + message_segments = serialize_message_sequence(message.raw_message) # 判断消息类型 # 如果只有一个文本段,使用简单的 text 类型 @@ -184,8 +124,38 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") return True + try: + from src.platform_io import DeliveryStatus + from src.plugin_runtime.integration import get_plugin_runtime_manager + + receipt = await get_plugin_runtime_manager().try_send_message_via_platform_io(message) + if receipt is not None: + if receipt.status == DeliveryStatus.SENT: + if show_log: + logger.info( + f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' " + f"(driver: {receipt.driver_id or 'unknown'})" + ) + return True + + logger.warning( + f"Platform IO 发送失败: platform={platform} driver={receipt.driver_id} " + f"status={receipt.status} error={receipt.error}" + ) + return False + except Exception as exc: + logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}") + # Fallback 逻辑: 尝试通过 API Server 发送 - async def send_with_new_api(legacy_exception=None): + async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool: + """通过 API Server 回退链路发送消息。 + + Args: + legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。 + + Returns: + bool: 回退链路是否发送成功。 + """ try: from src.config.config import global_config @@ -289,7 +259,8 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: class UniversalMessageSender: """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" - def __init__(self): + def __init__(self) -> None: + """初始化统一消息发送器。""" pass async def send_message( @@ -300,7 +271,7 @@ class UniversalMessageSender: reply_message_id: Optional[str] = None, storage_message: bool = True, show_log: bool = True, - ): + ) -> bool: """ 处理、发送并存储一条消息。 diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index e10aa147..75563df7 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1129,7 +1129,10 @@ class DefaultReplyer: user_id=bot_user_id, user_nickname=global_config.bot.nickname, ), - additional_config={}, + additional_config={ + "platform_io_target_group_id": self.chat_stream.group_id, + "platform_io_target_user_id": self.chat_stream.user_id, + }, ), message_segment=message_segment, ) diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index ccd8e1e4..f642dd69 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -970,7 +970,9 @@ class PrivateReplyer: user_nickname=global_config.bot.nickname, ), group_info=None, - additional_config={}, + additional_config={ + "platform_io_target_user_id": self.chat_stream.user_id, + }, ), message_segment=message_segment, ) diff --git a/src/platform_io/drivers/plugin_driver.py b/src/platform_io/drivers/plugin_driver.py index 9c139309..dff980f8 100644 --- a/src/platform_io/drivers/plugin_driver.py +++ b/src/platform_io/drivers/plugin_driver.py @@ -1,34 +1,51 @@ -"""提供 Platform IO 的 plugin 传输驱动骨架。""" +"""提供 Platform IO 的插件适配器驱动实现。""" -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol from src.platform_io.drivers.base import PlatformIODriver -from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage +class _AdapterSupervisorProtocol(Protocol): + """适配器驱动依赖的 Supervisor 最小协议。""" + + async def invoke_adapter( + self, + plugin_id: str, + method_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Any: + """调用适配器插件专用方法。""" + + class PluginPlatformDriver(PlatformIODriver): - """面向 ``MessageGateway`` 插件链路的 Platform IO 驱动骨架。""" + """面向适配器插件链路的 Platform IO 驱动。""" def __init__( self, driver_id: str, platform: str, + supervisor: _AdapterSupervisorProtocol, + send_method: str = "send_to_platform", account_id: Optional[str] = None, scope: Optional[str] = None, plugin_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: - """初始化一个 plugin 驱动描述对象。 + """初始化一个插件适配器驱动。 Args: driver_id: Broker 内的唯一驱动 ID。 - platform: 该 plugin 适配器链路负责的平台。 + platform: 该适配器负责的平台名称。 + supervisor: 持有该适配器插件的 Supervisor。 + send_method: 出站发送时要调用的插件方法名。 account_id: 可选的账号 ID 或 self ID。 scope: 可选的额外路由作用域。 - plugin_id: 拥有该适配器实现的插件 ID,可为空。 + plugin_id: 拥有该适配器实现的插件 ID。 metadata: 可选的额外驱动元数据。 """ descriptor = DriverDescriptor( @@ -41,6 +58,8 @@ class PluginPlatformDriver(PlatformIODriver): metadata=metadata or {}, ) super().__init__(descriptor) + self._supervisor = supervisor + self._send_method = send_method async def send_message( self, @@ -48,7 +67,7 @@ class PluginPlatformDriver(PlatformIODriver): route_key: RouteKey, metadata: Optional[Dict[str, Any]] = None, ) -> DeliveryReceipt: - """通过 plugin 传输路径发送消息。 + """通过适配器插件发送消息。 Args: message: 要投递的内部会话消息。 @@ -57,8 +76,119 @@ class PluginPlatformDriver(PlatformIODriver): Returns: DeliveryReceipt: 由驱动返回的规范化回执。 - - Raises: - NotImplementedError: 当前仍处于骨架阶段,尚未真正接入 MessageGateway。 """ - raise NotImplementedError("PluginPlatformDriver 仅完成地基实现,尚未接入 MessageGateway") + from src.plugin_runtime.host.message_utils import PluginMessageUtils + + plugin_id = self.descriptor.plugin_id or "" + if not plugin_id: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="插件适配器驱动缺少 plugin_id", + ) + + try: + message_dict = PluginMessageUtils._session_message_to_dict(message) + response = await self._supervisor.invoke_adapter( + plugin_id=plugin_id, + method_name=self._send_method, + args={ + "message": message_dict, + "route": { + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + }, + "metadata": metadata or {}, + }, + timeout_ms=30000, + ) + except Exception as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(exc), + ) + + return self._build_receipt(message.message_id, route_key, response) + + def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt: + """将适配器调用响应归一化为出站回执。 + + Args: + internal_message_id: 内部消息 ID。 + route_key: 本次投递的路由键。 + response: Supervisor 返回的 RPC 响应对象。 + + Returns: + DeliveryReceipt: 标准化后的出站回执。 + """ + if getattr(response, "error", None): + error = response.error.get("message", "适配器发送失败") + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=error, + ) + + payload = getattr(response, "payload", {}) + invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False + if not invoke_success: + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(payload.get("result", "适配器发送失败")) if isinstance(payload, dict) else "适配器发送失败", + ) + + result = payload.get("result") if isinstance(payload, dict) else None + if isinstance(result, dict): + if result.get("success") is False: + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(result.get("error", "适配器发送失败")), + metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, + ) + external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + external_message_id=external_message_id, + metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {}, + ) + + if isinstance(result, str) and result.strip(): + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + external_message_id=result.strip(), + ) + + return DeliveryReceipt( + internal_message_id=internal_message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) diff --git a/src/plugin_runtime/host/message_gateway.py b/src/plugin_runtime/host/message_gateway.py index 43777286..9e8e9be6 100644 --- a/src/plugin_runtime/host/message_gateway.py +++ b/src/plugin_runtime/host/message_gateway.py @@ -3,9 +3,11 @@ Message Gateway 模块 适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。 """ -from typing import Dict, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict from src.common.logger import get_logger +from src.platform_io import DeliveryStatus, get_platform_io_manager + from .message_utils import PluginMessageUtils if TYPE_CHECKING: @@ -17,25 +19,53 @@ logger = get_logger("plugin_runtime.host.message_gateway") class MessageGateway: - def __init__(self, component_registry: "ComponentRegistry") -> None: - self._component_registry = component_registry + """Host 侧消息网关包装器。""" - async def receive_external_message(self, external_message: Dict[str, Any]): - """ - 接收外部消息,转换为系统内部格式,并返回转换结果 + def __init__(self, component_registry: "ComponentRegistry") -> None: + """初始化消息网关。 Args: - external_message: 外部消息的字典格式数据 + component_registry: 组件注册表。 + """ + self._component_registry = component_registry + + def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage": + """将标准消息字典转换为 ``SessionMessage``。 + + Args: + external_message: 外部消息的字典格式数据。 Returns: - 转换后的 SessionMessage 对象 + SessionMessage: 转换后的内部消息对象。 + + Raises: + ValueError: 消息字典不合法时抛出。 + """ + return PluginMessageUtils._build_session_message_from_dict(external_message) + + def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]: + """将 ``SessionMessage`` 转换为标准消息字典。 + + Args: + internal_message: 内部消息对象。 + + Returns: + Dict[str, Any]: 供适配器插件消费的标准消息字典。 + """ + return dict(PluginMessageUtils._session_message_to_dict(internal_message)) + + async def receive_external_message(self, external_message: Dict[str, Any]) -> None: + """接收外部消息并送入主消息链。 + + Args: + external_message: 外部消息的字典格式数据。 """ - # 使用递归函数将外部消息字典转换为 SessionMessage try: - session_message = PluginMessageUtils._build_session_message_from_dict(external_message) + session_message = self.build_session_message(external_message) except Exception as e: logger.error(f"转换外部消息失败: {e}") return + from src.chat.message_receive.bot import chat_bot await chat_bot.receive_message(session_message) @@ -48,46 +78,32 @@ class MessageGateway: enabled_only: bool = True, save_to_db: bool = True, ) -> bool: - """ - 接收系统内部消息,转换为外部格式,并返回转换结果 + """将内部消息通过 Platform IO 发送到外部平台。 Args: - internal_message: 系统内部的 SessionMessage 对象 + internal_message: 系统内部的 ``SessionMessage`` 对象。 + supervisor: 当前持有该消息网关的 Supervisor。 + enabled_only: 兼容旧签名的保留参数,当前由 Platform IO 统一裁决。 + save_to_db: 发送成功后是否写入数据库。 Returns: - 转换是否成功 + bool: 是否发送成功。 """ - try: - # 将 SessionMessage 转换为字典格式 - message_dict = PluginMessageUtils._session_message_to_dict(internal_message) - except Exception as e: - logger.error(f"转换内部消息失败:{e}") - return False - gateway_entry = self._component_registry.get_message_gateways( - internal_message.platform, - enabled_only=enabled_only, - session_id=internal_message.session_id, - ) - if not gateway_entry: - logger.warning(f"未找到适配平台 {internal_message.platform} 的消息网关组件,无法发送消息到外部平台") - return False - args = {"platform": internal_message.platform, "message": message_dict} - try: - resp_envelope = await supervisor.invoke_plugin( - "plugin.emit_event", gateway_entry.plugin_id, gateway_entry.name, args - ) - logger.debug("信息发送成功") - except Exception as e: - logger.error(f"调用消息网关组件失败:{e}") + del enabled_only + del supervisor + + platform_io_manager = get_platform_io_manager() + if not platform_io_manager.is_started: + logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息") return False - # 更新为实际id(如果组件返回了新的id) - actual_message_id = resp_envelope.payload.get("message_id") - try: - actual_message_id = str(actual_message_id) - except Exception: - actual_message_id = None - internal_message.message_id = actual_message_id or internal_message.message_id + route_key = platform_io_manager.build_route_key_from_message(internal_message) + receipt = await platform_io_manager.send_message(internal_message, route_key) + if receipt.status != DeliveryStatus.SENT: + logger.warning(f"通过适配器链路发送消息失败: {receipt.error or receipt.status}") + return False + + internal_message.message_id = receipt.external_message_id or internal_message.message_id if save_to_db: try: from src.common.utils.utils_message import MessageUtils diff --git a/src/plugin_runtime/host/message_utils.py b/src/plugin_runtime/host/message_utils.py index 428e3c48..aaebb529 100644 --- a/src/plugin_runtime/host/message_utils.py +++ b/src/plugin_runtime/host/message_utils.py @@ -209,6 +209,9 @@ class PluginMessageUtils: session_message.is_notify = message_dict.get("is_notify", False) if not isinstance(session_message.is_notify, bool): session_message.is_notify = False + session_message.session_id = message_dict.get("session_id", "") + if not isinstance(session_message.session_id, str): + session_message.session_id = "" session_message.reply_to = message_dict.get("reply_to") if session_message.reply_to is not None and not isinstance(session_message.reply_to, str): session_message.reply_to = None diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 5ae3bdee..cdf3d4ee 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -8,13 +8,20 @@ import sys from src.common.logger import get_logger from src.config.config import global_config +from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager +from src.platform_io.drivers import PluginPlatformDriver +from src.platform_io.route_key_factory import RouteKeyFactory +from src.platform_io.routing import RouteBindingConflictError from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( + AdapterDeclarationPayload, BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, HealthPayload, PROTOCOL_VERSION, + ReceiveExternalMessagePayload, + ReceiveExternalMessageResultPayload, RegisterPluginPayload, ReloadPluginResultPayload, RunnerReadyPayload, @@ -86,6 +93,7 @@ class PluginRunnerSupervisor: self._runner_process: Optional[asyncio.subprocess.Process] = None self._registered_plugins: Dict[str, RegisterPluginPayload] = {} + self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() self._health_task: Optional[asyncio.Task[None]] = None @@ -257,6 +265,32 @@ class PluginRunnerSupervisor: timeout_ms, ) + async def invoke_adapter( + self, + plugin_id: str, + method_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用适配器插件的专用方法。 + + Args: + plugin_id: 目标适配器插件 ID。 + method_name: 要调用的插件方法名,例如 ``send_to_platform``。 + args: 传递给插件方法的关键字参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 + """ + return await self.invoke_plugin( + method="plugin.invoke_adapter", + plugin_id=plugin_id, + component_name=method_name, + args=args, + timeout_ms=timeout_ms, + ) + async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: """按插件 ID 触发精确重载。 @@ -384,6 +418,7 @@ class PluginRunnerSupervisor: def _register_internal_methods(self) -> None: """注册 Host 侧内部 RPC 方法。""" self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) + self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message) self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) @@ -427,6 +462,17 @@ class PluginRunnerSupervisor: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) self._component_registry.remove_components_by_plugin(payload.plugin_id) + if payload.plugin_id in self._registered_adapters: + await self._unregister_adapter_driver(payload.plugin_id) + + try: + if payload.adapter is not None: + await self._register_adapter_driver(payload.plugin_id, payload.adapter) + except RouteBindingConflictError as exc: + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc)) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc)) + registered_count = self._component_registry.register_plugin_components( payload.plugin_id, [component.model_dump() for component in payload.components], @@ -438,6 +484,7 @@ class PluginRunnerSupervisor: "accepted": True, "plugin_id": payload.plugin_id, "registered_components": registered_count, + "adapter_registered": payload.adapter is not None, } ) @@ -458,6 +505,7 @@ class PluginRunnerSupervisor: removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None + await self._unregister_adapter_driver(payload.plugin_id) return envelope.make_response( payload={ @@ -469,6 +517,221 @@ class PluginRunnerSupervisor: } ) + @staticmethod + def _build_adapter_driver_id(plugin_id: str) -> str: + """构造适配器驱动 ID。 + + Args: + plugin_id: 适配器插件 ID。 + + Returns: + str: 对应 Platform IO 中的驱动 ID。 + """ + return f"adapter:{plugin_id}" + + async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None: + """将适配器插件注册到 Platform IO。 + + Args: + plugin_id: 适配器插件 ID。 + adapter: 经过校验的适配器声明。 + + Raises: + ValueError: 适配器路由冲突或驱动注册失败时抛出。 + """ + await self._unregister_adapter_driver(plugin_id) + + platform_io_manager = get_platform_io_manager() + driver = PluginPlatformDriver( + driver_id=self._build_adapter_driver_id(plugin_id), + platform=adapter.platform, + account_id=adapter.account_id or None, + scope=adapter.scope or None, + plugin_id=plugin_id, + send_method=adapter.send_method, + supervisor=self, + metadata={ + "protocol": adapter.protocol, + **adapter.metadata, + }, + ) + binding = RouteBinding( + route_key=driver.descriptor.route_key, + driver_id=driver.driver_id, + driver_kind=DriverKind.PLUGIN, + metadata={ + "plugin_id": plugin_id, + "protocol": adapter.protocol, + }, + ) + + try: + if platform_io_manager.is_started: + await platform_io_manager.add_driver(driver) + else: + platform_io_manager.register_driver(driver) + platform_io_manager.bind_route(binding) + except Exception: + with contextlib.suppress(Exception): + if platform_io_manager.is_started: + await platform_io_manager.remove_driver(driver.driver_id) + else: + platform_io_manager.unregister_driver(driver.driver_id) + raise + + self._registered_adapters[plugin_id] = adapter + + async def _unregister_adapter_driver(self, plugin_id: str) -> None: + """从 Platform IO 注销一个适配器驱动。 + + Args: + plugin_id: 适配器插件 ID。 + """ + platform_io_manager = get_platform_io_manager() + driver_id = self._build_adapter_driver_id(plugin_id) + + with contextlib.suppress(Exception): + if platform_io_manager.is_started: + await platform_io_manager.remove_driver(driver_id) + else: + platform_io_manager.unregister_driver(driver_id) + + self._registered_adapters.pop(plugin_id, None) + + async def _unregister_all_adapter_drivers(self) -> None: + """注销当前 Supervisor 管理的全部适配器驱动。""" + plugin_ids = list(self._registered_adapters.keys()) + for plugin_id in plugin_ids: + await self._unregister_adapter_driver(plugin_id) + + @staticmethod + def _attach_inbound_route_metadata( + session_message: "SessionMessage", + route_key: RouteKey, + route_metadata: Dict[str, Any], + ) -> None: + """将入站路由信息写回消息的 ``additional_config``。 + + Args: + session_message: 已构造好的内部消息对象。 + route_key: Host 为该消息解析出的标准路由键。 + route_metadata: 适配器通过 RPC 补充的原始路由辅助元数据。 + """ + additional_config = session_message.message_info.additional_config + if not isinstance(additional_config, dict): + additional_config = {} + session_message.message_info.additional_config = additional_config + + for key, value in route_metadata.items(): + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + additional_config[key] = value + + if route_key.account_id: + additional_config.setdefault("platform_io_account_id", route_key.account_id) + if route_key.scope: + additional_config.setdefault("platform_io_scope", route_key.scope) + + def _build_inbound_route_key( + self, + adapter: AdapterDeclarationPayload, + message: Dict[str, Any], + route_metadata: Dict[str, Any], + ) -> RouteKey: + """为适配器入站消息构造归一路由键。 + + Args: + adapter: 当前适配器声明。 + message: 标准消息字典。 + route_metadata: 插件补充的路由辅助元数据。 + + Returns: + RouteKey: 供 Platform IO 使用的规范化路由键。 + + Raises: + ValueError: 消息平台字段与适配器平台声明不一致时抛出。 + """ + message_platform = str(message.get("platform") or adapter.platform).strip() + if message_platform != adapter.platform: + raise ValueError( + f"外部消息平台 {message_platform} 与适配器 {adapter.platform} 不一致" + ) + + try: + route_key = RouteKeyFactory.from_message_dict(message) + except Exception: + route_key = RouteKey(platform=message_platform) + + route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata) + account_id = route_key.account_id or route_account_id or adapter.account_id or None + scope = route_key.scope or route_scope or adapter.scope or None + return RouteKey( + platform=message_platform, + account_id=account_id, + scope=scope, + ) + + async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope: + """处理适配器插件上报的外部入站消息。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 注入结果响应。 + """ + try: + payload = ReceiveExternalMessagePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + adapter = self._registered_adapters.get(envelope.plugin_id) + if adapter is None: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {envelope.plugin_id} 未声明为适配器,不能注入外部消息", + ) + + try: + route_key = self._build_inbound_route_key( + adapter=adapter, + message=payload.message, + route_metadata=payload.route_metadata, + ) + session_message = self._message_gateway.build_session_message(payload.message) + self._attach_inbound_route_metadata(session_message, route_key, payload.route_metadata) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + platform_io_manager = get_platform_io_manager() + accepted = await platform_io_manager.accept_inbound( + InboundMessageEnvelope( + route_key=route_key, + driver_id=self._build_adapter_driver_id(envelope.plugin_id), + driver_kind=DriverKind.PLUGIN, + external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None, + dedupe_key=payload.dedupe_key or None, + session_message=session_message, + payload=payload.message, + metadata={ + "plugin_id": envelope.plugin_id, + "protocol": adapter.protocol, + **payload.route_metadata, + }, + ) + ) + response = ReceiveExternalMessageResultPayload( + accepted=accepted, + route_key={ + "platform": route_key.platform, + "account_id": route_key.account_id, + "scope": route_key.scope, + }, + ) + return envelope.make_response(payload=response.model_dump()) + async def _handle_runner_ready(self, envelope: Envelope) -> Envelope: """处理 Runner 就绪通知。 @@ -595,6 +858,7 @@ class PluginRunnerSupervisor: await self._stderr_drain_task self._stderr_drain_task = None + await self._unregister_all_adapter_drivers() self._clear_runner_state() async def _health_check_loop(self) -> None: @@ -671,6 +935,7 @@ class PluginRunnerSupervisor: self._authorization.clear() self._component_registry.clear() self._registered_plugins.clear() + self._registered_adapters.clear() self._runner_ready_events = asyncio.Event() self._runner_ready_payloads = RunnerReadyPayload() diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 730da3e1..30a3c150 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -12,11 +12,13 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Ite import asyncio import json + import tomlkit from src.common.logger import get_logger from src.config.config import global_config from src.config.file_watcher import FileChange, FileWatcher +from src.platform_io import DeliveryReceipt, InboundMessageEnvelope, get_platform_io_manager from src.plugin_runtime.capabilities import ( RuntimeComponentCapabilityMixin, RuntimeCoreCapabilityMixin, @@ -57,6 +59,7 @@ class PluginRuntimeManager( """ def __init__(self) -> None: + """初始化插件运行时管理器。""" from src.plugin_runtime.host.supervisor import PluginSupervisor self._builtin_supervisor: Optional[PluginSupervisor] = None @@ -66,6 +69,22 @@ class PluginRuntimeManager( self._plugin_source_watcher_subscription_id: Optional[str] = None self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {} + async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: + """接收 Platform IO 审核后的入站消息并送入主消息链。 + + Args: + envelope: Platform IO 产出的入站封装。 + """ + session_message = envelope.session_message + if session_message is None and envelope.payload is not None: + session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload)) + if session_message is None: + raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload") + + from src.chat.message_receive.bot import chat_bot + + await chat_bot.receive_message(session_message) + # ─── 插件目录 ───────────────────────────────────────────── @staticmethod @@ -110,6 +129,8 @@ class PluginRuntimeManager( logger.info("未找到任何插件目录,跳过插件运行时启动") return + platform_io_manager = get_platform_io_manager() + # 从配置读取自定义 IPC socket 路径(留空则自动生成) socket_path_base = _cfg.ipc_socket_path or None @@ -134,6 +155,9 @@ class PluginRuntimeManager( started_supervisors: List[PluginSupervisor] = [] try: + platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) + await platform_io_manager.start() + if self._builtin_supervisor: await self._builtin_supervisor.start() started_supervisors.append(self._builtin_supervisor) @@ -147,6 +171,11 @@ class PluginRuntimeManager( logger.error(f"插件运行时启动失败: {e}", exc_info=True) await self._stop_plugin_file_watcher() await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True) + platform_io_manager.clear_inbound_dispatcher() + try: + await platform_io_manager.stop() + except Exception as platform_io_exc: + logger.warning(f"Platform IO 停止失败: {platform_io_exc}") self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -156,6 +185,7 @@ class PluginRuntimeManager( if not self._started: return + platform_io_manager = get_platform_io_manager() await self._stop_plugin_file_watcher() coroutines: List[Coroutine[Any, Any, None]] = [] @@ -164,11 +194,23 @@ class PluginRuntimeManager( if self._third_party_supervisor: coroutines.append(self._third_party_supervisor.stop()) + stop_errors: List[str] = [] try: - await asyncio.gather(*coroutines, return_exceptions=True) - logger.info("插件运行时已停止") - except Exception as e: - logger.error(f"插件运行时停止失败: {e}", exc_info=True) + results = await asyncio.gather(*coroutines, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + stop_errors.append(str(result)) + + platform_io_manager.clear_inbound_dispatcher() + try: + await platform_io_manager.stop() + except Exception as exc: + stop_errors.append(f"Platform IO: {exc}") + + if stop_errors: + logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}") + else: + logger.info("插件运行时已停止") finally: self._started = False self._builtin_supervisor = None @@ -176,6 +218,7 @@ class PluginRuntimeManager( @property def is_running(self) -> bool: + """返回插件运行时是否处于启动状态。""" return self._started @property @@ -303,6 +346,37 @@ class PluginRuntimeManager( timeout_ms=timeout_ms, ) + async def try_send_message_via_platform_io( + self, + message: "SessionMessage", + ) -> Optional[DeliveryReceipt]: + """尝试通过 Platform IO 中间层发送消息。 + + Args: + message: 待发送的内部会话消息。 + + Returns: + Optional[DeliveryReceipt]: 若当前消息存在 active 路由,则返回实际发送 + 结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。 + """ + if not self._started: + return None + + platform_io_manager = get_platform_io_manager() + if not platform_io_manager.is_started: + return None + + try: + route_key = platform_io_manager.build_route_key_from_message(message) + except Exception as exc: + logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}") + return None + + if platform_io_manager.resolve_driver(route_key) is None: + return None + + return await platform_io_manager.send_message(message, route_key) + def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]: """返回当前持有指定插件的所有 Supervisor。 @@ -426,6 +500,11 @@ class PluginRuntimeManager( """为指定插件生成配置文件变更回调。""" async def _callback(changes: Sequence[FileChange]) -> None: + """将 watcher 事件转发到指定插件的配置处理逻辑。 + + Args: + changes: 当前批次收集到的文件变更列表。 + """ await self._handle_plugin_config_changes(plugin_id, changes) return _callback @@ -542,6 +621,11 @@ class PluginRuntimeManager( # ─── 能力实现注册 ────────────────────────────────────────── def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None: + """向指定 Supervisor 注册主程序能力实现。 + + Args: + supervisor: 需要注册能力实现的目标 Supervisor。 + """ register_capability_impls(self, supervisor) diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 6f95f97f..0dfc6656 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -7,11 +7,11 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field - import logging as stdlib_logging import time +from pydantic import BaseModel, Field + # ====== 协议常量 ====== PROTOCOL_VERSION = "1.0.0" @@ -156,6 +156,8 @@ class RegisterPluginPayload(BaseModel): """插件版本""" components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") """组件列表""" + adapter: Optional["AdapterDeclarationPayload"] = Field(default=None, description="可选的适配器声明") + """可选的适配器声明""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" @@ -285,6 +287,48 @@ class ReloadPluginResultPayload(BaseModel): """重载失败的插件及原因""" +class AdapterDeclarationPayload(BaseModel): + """适配器插件声明载荷。""" + + platform: str = Field(description="适配器负责的平台名称,例如 qq") + """适配器负责的平台名称,例如 qq""" + protocol: str = Field(default="", description="接入协议或实现名称,例如 napcat") + """接入协议或实现名称,例如 napcat""" + account_id: str = Field(default="", description="可选的账号 ID 或 self_id") + """可选的账号 ID 或 self_id""" + scope: str = Field(default="", description="可选的路由作用域") + """可选的路由作用域""" + send_method: str = Field(default="send_to_platform", description="Host 出站调用的插件方法名") + """Host 出站调用的插件方法名""" + metadata: Dict[str, Any] = Field(default_factory=dict, description="适配器附加元数据") + """适配器附加元数据""" + + +class ReceiveExternalMessagePayload(BaseModel): + """适配器插件向 Host 注入外部消息的请求载荷。""" + + message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典") + """符合 MessageDict 结构的标准消息字典""" + route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据") + """可选的路由辅助元数据""" + external_message_id: str = Field(default="", description="可选的外部平台消息 ID") + """可选的外部平台消息 ID""" + dedupe_key: str = Field(default="", description="可选的显式去重键") + """可选的显式去重键""" + + +class ReceiveExternalMessageResultPayload(BaseModel): + """外部消息注入结果载荷。""" + + accepted: bool = Field(description="Host 是否接受了本次消息注入") + """Host 是否接受了本次消息注入""" + route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键") + """本次消息使用的归一路由键""" + + +RegisterPluginPayload.model_rebuild() + + # ====== 日志传输 ====== diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index bf36a05c..3ffb6b4b 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,9 +9,8 @@ 6. 转发插件的能力调用到 Host """ -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast - from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast import asyncio import contextlib @@ -26,6 +25,7 @@ import tomllib from src.common.logger import get_console_handler, get_logger, initialize_logging from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( + AdapterDeclarationPayload, BootstrapPluginPayload, ComponentDeclaration, Envelope, @@ -227,7 +227,7 @@ class PluginRunner: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, ) -> Any: - """桥接 PluginContext.call_capability → RPCClient.send_request。 + """桥接 PluginContext 的原始 RPC 调用到 Host。 无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id 始终绑定为当前插件实例,避免伪造其他插件身份申请能力。 @@ -237,17 +237,13 @@ class PluginRunner: f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" ) resp = await rpc_client.send_request( - method="cap.call", + method=method, plugin_id=bound_plugin_id, - payload={ - "capability": method, - "args": payload or {}, - }, + payload=payload or {}, ) - # 从响应信封中提取业务结果 if resp.error: raise RuntimeError(resp.error.get("message", "能力调用失败")) - return resp.payload.get("result") + return resp.payload ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call) cast(_ContextAwarePlugin, instance)._set_context(ctx) @@ -286,6 +282,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_adapter", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) @@ -306,12 +303,14 @@ class PluginRunner: ) try: - await self._rpc_client.send_request( + response = await self._rpc_client.send_request( "plugin.bootstrap", plugin_id=meta.plugin_id, payload=payload.model_dump(), timeout_ms=10000, ) + if response.error: + raise RuntimeError(response.error.get("message", "插件 bootstrap 失败")) return True except Exception as e: logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}") @@ -321,6 +320,29 @@ class PluginRunner: """撤销 bootstrap 期间为插件签发的能力令牌。""" await self._bootstrap_plugin(meta, capabilities_required=[]) + def _collect_adapter_declaration(self, meta: PluginMeta) -> Optional[AdapterDeclarationPayload]: + """从插件实例中提取适配器声明。 + + Args: + meta: 待提取声明的插件元数据。 + + Returns: + Optional[AdapterDeclarationPayload]: 若插件声明了适配器角色,则返回 + 经过校验的适配器声明;否则返回 ``None``。 + + Raises: + ValueError: 插件导出的适配器声明结构非法时抛出。 + """ + instance = meta.instance + if not hasattr(instance, "get_adapter_info"): + return None + + adapter_info = instance.get_adapter_info() + if adapter_info is None: + return None + + return AdapterDeclarationPayload.model_validate(adapter_info) + async def _register_plugin(self, meta: PluginMeta) -> bool: """向 Host 注册单个插件。 @@ -346,20 +368,29 @@ class PluginRunner: for comp_info in instance.get_components() ) + try: + adapter = self._collect_adapter_declaration(meta) + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} 适配器声明非法: {exc}", exc_info=True) + return False + reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, + adapter=adapter, capabilities_required=meta.capabilities_required, ) try: - _resp = await self._rpc_client.send_request( + response = await self._rpc_client.send_request( "plugin.register_components", plugin_id=meta.plugin_id, payload=reg_payload.model_dump(), timeout_ms=10000, ) + if response.error: + raise RuntimeError(response.error.get("message", "插件注册失败")) logger.info(f"插件 {meta.plugin_id} 注册完成") return True except Exception as e: diff --git a/src/plugins/built_in/napcat_adapter/_manifest.json b/src/plugins/built_in/napcat_adapter/_manifest.json new file mode 100644 index 00000000..6f7e68fd --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/_manifest.json @@ -0,0 +1,30 @@ +{ + "manifest_version": 1, + "name": "napcat_adapter_builtin", + "version": "0.1.0", + "description": "Built-in NapCat adapter plugin for MVP message forwarding.", + "author": { + "name": "OpenAI Codex" + }, + "license": "GPL-v3.0-or-later", + "host_application": { + "min_version": "1.0.0" + }, + "keywords": [ + "adapter", + "built-in", + "napcat", + "onebot", + "qq" + ], + "categories": [ + "Adapter", + "Built-in" + ], + "default_locale": "en-US", + "plugin_info": { + "is_built_in": true, + "plugin_type": "adapter" + }, + "capabilities": [] +} diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py new file mode 100644 index 00000000..3eff518d --- /dev/null +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -0,0 +1,690 @@ +"""内置 NapCat 适配器插件。 + +当前实现是一个 MVP 版本,目标仅限于跑通基础消息收发链路: +1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。 +2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。 +3. 将 Host 出站消息转换为 OneBot 动作并发送。 + +当前范围刻意收敛为: +- 单连接 +- 文本、@、reply 基础转发 +- 暂不处理 ``notice`` / ``meta_event`` +- 暂不支持图片、语音、文件等复杂媒体 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from uuid import uuid4 + +import asyncio +import contextlib +import json +import time + +from maibot_sdk import Adapter, MaiBotPlugin + +if TYPE_CHECKING: + from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse + +try: + from aiohttp import ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType + + AIOHTTP_AVAILABLE = True +except ImportError: + ClientSession = cast(Any, None) + ClientTimeout = cast(Any, None) + ClientWebSocketResponse = cast(Any, None) + WSMsgType = cast(Any, None) + AIOHTTP_AVAILABLE = False + +if not TYPE_CHECKING: + AiohttpClientWebSocketResponse = Any + + +@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform") +class NapCatAdapterPlugin(MaiBotPlugin): + """NapCat 适配器 MVP 实现。""" + + def __init__(self) -> None: + """初始化 NapCat 适配器插件实例。""" + super().__init__() + self._plugin_config: Dict[str, Any] = {} + self._connection_task: Optional[asyncio.Task[None]] = None + self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {} + self._background_tasks: set[asyncio.Task[Any]] = set() + self._send_lock = asyncio.Lock() + self._ws: Optional[AiohttpClientWebSocketResponse] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + """设置插件配置内容。 + + Args: + config: Runner 注入的 ``config.toml`` 解析结果。 + """ + self._plugin_config = config if isinstance(config, dict) else {} + + async def on_load(self) -> None: + """在插件加载时根据配置决定是否启动连接。""" + await self._restart_connection_if_needed() + + async def on_unload(self) -> None: + """在插件卸载时关闭连接并清理后台任务。""" + await self._stop_connection() + await self._cancel_background_tasks() + + async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None: + """在配置更新后重载连接状态。 + + Args: + new_config: 最新的插件配置。 + version: 配置版本号。 + """ + del version + self.set_plugin_config(new_config) + await self._restart_connection_if_needed() + + async def send_to_platform( + self, + message: Dict[str, Any], + route: Optional[Dict[str, Any]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """将 Host 出站消息发送到 NapCat。 + + Args: + message: Host 侧标准 ``MessageDict``。 + route: Platform IO 生成的路由信息。 + metadata: Platform IO 附带的投递元数据。 + **kwargs: 预留的扩展参数。 + + Returns: + Dict[str, Any]: 标准化后的发送结果。 + """ + del metadata + del kwargs + + ws = self._ws + if ws is None or ws.closed: + return {"success": False, "error": "NapCat is not connected"} + + try: + action_name, params = self._build_outbound_action(message, route or {}) + response = await self._call_action(action_name, params) + except Exception as exc: + return {"success": False, "error": str(exc)} + + if str(response.get("status", "")).lower() != "ok": + return { + "success": False, + "error": str(response.get("wording") or response.get("message") or "NapCat send failed"), + "metadata": {"retcode": response.get("retcode")}, + } + + response_data = response.get("data", {}) + external_message_id = "" + if isinstance(response_data, dict): + external_message_id = str(response_data.get("message_id") or "") + + return { + "success": True, + "external_message_id": external_message_id or None, + "metadata": {"action": action_name}, + } + + async def _restart_connection_if_needed(self) -> None: + """根据当前配置重启连接循环。""" + await self._stop_connection() + if not self._should_connect(): + self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用") + return + if not AIOHTTP_AVAILABLE: + self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖") + return + self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection") + + async def _stop_connection(self) -> None: + """停止当前连接并让所有等待中的动作失败返回。""" + connection_task = self._connection_task + self._connection_task = None + + ws = self._ws + if ws is not None and not ws.closed: + with contextlib.suppress(Exception): + await ws.close() + self._ws = None + + if connection_task is not None: + connection_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await connection_task + + self._fail_pending_actions("NapCat connection closed") + + async def _cancel_background_tasks(self) -> None: + """取消所有仍在运行的入站后台任务。""" + background_tasks = list(self._background_tasks) + for task in background_tasks: + task.cancel() + if background_tasks: + with contextlib.suppress(Exception): + await asyncio.gather(*background_tasks, return_exceptions=True) + self._background_tasks.clear() + + async def _connection_loop(self) -> None: + """维护单个 WebSocket 连接,并在断开后按配置重连。""" + assert ClientSession is not None + assert ClientTimeout is not None + + while self._should_connect(): + ws_url = self._get_string(self._connection_config(), "ws_url") + if not ws_url: + self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空") + return + + headers = self._build_headers() + timeout = ClientTimeout(total=None, connect=10) + heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0) + + try: + async with ClientSession(headers=headers, timeout=timeout) as session: + async with session.ws_connect(ws_url, heartbeat=heartbeat or None) as ws: + self._ws = ws + self.ctx.logger.info(f"NapCat 适配器已连接: {ws_url}") + await self._receive_loop(ws) + except asyncio.CancelledError: + raise + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}") + finally: + self._ws = None + self._fail_pending_actions("NapCat connection interrupted") + + if not self._should_connect(): + break + + await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0)) + + async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None: + """持续消费 WebSocket 消息并分发处理。 + + Args: + ws: 当前活跃的 WebSocket 连接对象。 + """ + assert WSMsgType is not None + + async for ws_message in ws: + if ws_message.type != WSMsgType.TEXT: + if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}: + break + continue + + payload = self._parse_json_message(ws_message.data) + if payload is None: + continue + + if echo_id := str(payload.get("echo") or "").strip(): + self._resolve_pending_action(echo_id, payload) + continue + + if str(payload.get("post_type") or "").strip() != "message": + continue + + task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound") + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None: + """处理单条 NapCat 入站消息并注入 Host。 + + Args: + payload: NapCat / OneBot 推送的原始事件数据。 + """ + self_id = str(payload.get("self_id") or "").strip() + sender = payload.get("sender", {}) + if not isinstance(sender, dict): + sender = {} + + sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip() + if not sender_user_id: + return + + if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True): + return + + message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender) + route_metadata: Dict[str, Any] = {} + if self_id: + route_metadata["self_id"] = self_id + if connection_id := self._get_string(self._connection_config(), "connection_id"): + route_metadata["connection_id"] = connection_id + + external_message_id = str(payload.get("message_id") or "").strip() + accepted = await self.ctx.adapter.receive_external_message( + message_dict, + route_metadata=route_metadata, + external_message_id=external_message_id, + dedupe_key=external_message_id, + ) + if not accepted: + self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}") + + def _build_inbound_message_dict( + self, + payload: Dict[str, Any], + self_id: str, + sender_user_id: str, + sender: Dict[str, Any], + ) -> Dict[str, Any]: + """构造 Host 侧可接受的 ``MessageDict``。 + + Args: + payload: NapCat 原始消息事件。 + self_id: 当前机器人账号 ID。 + sender_user_id: 发送者用户 ID。 + sender: 发送者信息字典。 + + Returns: + Dict[str, Any]: 规范化后的 ``MessageDict``。 + """ + message_type = str(payload.get("message_type") or "").strip() or "private" + group_id = str(payload.get("group_id") or "").strip() + group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "") + user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id + user_cardname = str(sender.get("card") or "").strip() or None + + raw_message, is_at = self._convert_inbound_segments(payload.get("message"), self_id) + raw_message_text = str(payload.get("raw_message") or "").strip() + if not raw_message: + raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}] + + plain_text = self._build_plain_text(raw_message, raw_message_text) + timestamp_seconds = payload.get("time") + if not isinstance(timestamp_seconds, (int, float)): + timestamp_seconds = time.time() + + additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type} + if group_id: + additional_config["platform_io_target_group_id"] = group_id + else: + additional_config["platform_io_target_user_id"] = sender_user_id + + message_info: Dict[str, Any] = { + "user_info": { + "user_id": sender_user_id, + "user_nickname": user_nickname, + "user_cardname": user_cardname, + }, + "additional_config": additional_config, + } + if group_id: + message_info["group_info"] = {"group_id": group_id, "group_name": group_name} + + message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip() + return { + "message_id": message_id, + "timestamp": str(float(timestamp_seconds)), + "platform": "qq", + "message_info": message_info, + "raw_message": raw_message, + "is_mentioned": is_at, + "is_at": is_at, + "is_emoji": False, + "is_picture": False, + "is_command": plain_text.startswith("/"), + "is_notify": False, + "session_id": "", + "processed_plain_text": plain_text, + "display_message": plain_text, + } + + def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]: + """将 OneBot 消息段转换为 Host 消息段结构。 + + Args: + message_payload: OneBot 原始 ``message`` 字段。 + self_id: 当前机器人账号 ID。 + + Returns: + tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。 + """ + if isinstance(message_payload, str): + normalized_text = message_payload.strip() + return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False + + if not isinstance(message_payload, list): + return [], False + + converted_segments: List[Dict[str, Any]] = [] + is_at = False + placeholder_texts = { + "face": "[face]", + "file": "[file]", + "image": "[image]", + "json": "[json]", + "record": "[voice]", + "video": "[video]", + "xml": "[xml]", + } + + for segment in message_payload: + if not isinstance(segment, dict): + continue + + segment_type = str(segment.get("type") or "").strip() + segment_data = segment.get("data", {}) + if not isinstance(segment_data, dict): + segment_data = {} + + if segment_type == "text": + if text_value := str(segment_data.get("text") or ""): + converted_segments.append({"type": "text", "data": text_value}) + continue + + if segment_type == "at": + if target_user_id := str(segment_data.get("qq") or "").strip(): + converted_segments.append( + { + "type": "at", + "data": { + "target_user_id": target_user_id, + "target_user_nickname": None, + "target_user_cardname": None, + }, + } + ) + if self_id and target_user_id == self_id: + is_at = True + continue + + if segment_type == "reply": + if target_message_id := str(segment_data.get("id") or "").strip(): + converted_segments.append({"type": "reply", "data": target_message_id}) + continue + + if placeholder := placeholder_texts.get(segment_type): + converted_segments.append({"type": "text", "data": placeholder}) + + return converted_segments, is_at + + def _build_outbound_action( + self, + message: Dict[str, Any], + route: Dict[str, Any], + ) -> tuple[str, Dict[str, Any]]: + """为 Host 出站消息构造 OneBot 动作。 + + Args: + message: Host 侧标准 ``MessageDict``。 + route: Platform IO 路由信息。 + + Returns: + tuple[str, Dict[str, Any]]: 动作名称与参数字典。 + """ + message_info = message.get("message_info", {}) + if not isinstance(message_info, dict): + message_info = {} + + group_info = message_info.get("group_info", {}) + if not isinstance(group_info, dict): + group_info = {} + + additional_config = message_info.get("additional_config", {}) + if not isinstance(additional_config, dict): + additional_config = {} + + raw_message = message.get("raw_message", []) + segments = self._convert_outbound_segments(raw_message) + + if target_group_id := str( + group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or "" + ).strip(): + return "send_group_msg", {"group_id": target_group_id, "message": segments} + + if not ( + target_user_id := str( + additional_config.get("platform_io_target_user_id") + or additional_config.get("target_user_id") + or route.get("target_user_id") + or "" + ).strip() + ): + raise ValueError("Outbound private message is missing target_user_id") + + return "send_private_msg", {"message": segments, "user_id": target_user_id} + + def _convert_outbound_segments(self, raw_message: Any) -> List[Dict[str, Any]]: + """将 Host 消息段转换为 OneBot 消息段。 + + Args: + raw_message: Host 侧 ``raw_message`` 字段。 + + Returns: + List[Dict[str, Any]]: OneBot 消息段列表。 + """ + if not isinstance(raw_message, list): + return [{"type": "text", "data": {"text": ""}}] + + outbound_segments: List[Dict[str, Any]] = [] + for item in raw_message: + if not isinstance(item, dict): + continue + + item_type = str(item.get("type") or "").strip() + item_data = item.get("data") + + if item_type == "text": + text_value = str(item_data or "") + outbound_segments.append({"type": "text", "data": {"text": text_value}}) + continue + + if item_type == "at" and isinstance(item_data, dict): + if target_user_id := str(item_data.get("target_user_id") or "").strip(): + outbound_segments.append({"type": "at", "data": {"qq": target_user_id}}) + continue + + if item_type == "reply": + if target_message_id := str(item_data or "").strip(): + outbound_segments.append({"type": "reply", "data": {"id": target_message_id}}) + continue + + fallback_text = f"[unsupported:{item_type or 'unknown'}]" + outbound_segments.append({"type": "text", "data": {"text": fallback_text}}) + + if not outbound_segments: + outbound_segments.append({"type": "text", "data": {"text": ""}}) + return outbound_segments + + async def _call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]: + """发送 OneBot 动作并等待对应的 echo 响应。 + + Args: + action_name: OneBot 动作名称。 + params: 动作参数。 + + Returns: + Dict[str, Any]: NapCat 返回的原始响应字典。 + """ + ws = self._ws + if ws is None or ws.closed: + raise RuntimeError("NapCat is not connected") + + echo_id = uuid4().hex + loop = asyncio.get_running_loop() + response_future: asyncio.Future[Dict[str, Any]] = loop.create_future() + self._pending_actions[echo_id] = response_future + + request_payload = {"action": action_name, "params": params, "echo": echo_id} + try: + async with self._send_lock: + await ws.send_str(json.dumps(request_payload, ensure_ascii=False)) + timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0) + return await asyncio.wait_for(response_future, timeout=timeout_seconds) + finally: + self._pending_actions.pop(echo_id, None) + + def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None: + """解析等待中的动作响应。 + + Args: + echo_id: 动作请求对应的 echo 标识。 + payload: NapCat 返回的响应载荷。 + """ + response_future = self._pending_actions.get(echo_id) + if response_future is None or response_future.done(): + return + response_future.set_result(payload) + + def _fail_pending_actions(self, error_message: str) -> None: + """让所有等待中的动作以异常方式结束。 + + Args: + error_message: 写入异常中的错误信息。 + """ + for response_future in self._pending_actions.values(): + if not response_future.done(): + response_future.set_exception(RuntimeError(error_message)) + self._pending_actions.clear() + + def _build_headers(self) -> Dict[str, str]: + """构造连接 NapCat 所需的请求头。 + + Returns: + Dict[str, str]: WebSocket 握手请求头。 + """ + access_token = self._get_string(self._connection_config(), "access_token") + return {"Authorization": f"Bearer {access_token}"} if access_token else {} + + def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]: + """解析 WebSocket 文本消息中的 JSON 数据。 + + Args: + data: WebSocket 收到的原始文本数据。 + + Returns: + Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。 + """ + try: + payload = json.loads(str(data)) + except Exception as exc: + self.ctx.logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}") + return None + + return payload if isinstance(payload, dict) else None + + def _build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str: + """从标准消息段中提取可展示的纯文本。 + + Args: + raw_message: 标准化后的消息段列表。 + fallback_text: 当无法拼出文本时使用的回退文本。 + + Returns: + str: 用于 Host 展示和命令判断的纯文本内容。 + """ + plain_text_parts: List[str] = [] + for item in raw_message: + if not isinstance(item, dict): + continue + item_type = str(item.get("type") or "").strip() + item_data = item.get("data") + if item_type == "text": + plain_text_parts.append(str(item_data or "")) + elif item_type == "at" and isinstance(item_data, dict): + plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}") + elif item_type == "reply": + plain_text_parts.append("[reply]") + + plain_text = "".join(part for part in plain_text_parts if part).strip() + return plain_text or fallback_text or "[unsupported]" + + def _plugin_section(self) -> Dict[str, Any]: + """读取插件配置中的 ``plugin`` 段。 + + Returns: + Dict[str, Any]: ``plugin`` 配置字典。 + """ + plugin_section = self._plugin_config.get("plugin", {}) + return plugin_section if isinstance(plugin_section, dict) else {} + + def _connection_config(self) -> Dict[str, Any]: + """读取插件配置中的 ``connection`` 段。 + + Returns: + Dict[str, Any]: ``connection`` 配置字典。 + """ + connection_config = self._plugin_config.get("connection", {}) + return connection_config if isinstance(connection_config, dict) else {} + + def _filters_config(self) -> Dict[str, Any]: + """读取插件配置中的 ``filters`` 段。 + + Returns: + Dict[str, Any]: ``filters`` 配置字典。 + """ + filters_config = self._plugin_config.get("filters", {}) + return filters_config if isinstance(filters_config, dict) else {} + + def _should_connect(self) -> bool: + """判断当前配置下是否应当启动连接。 + + Returns: + bool: 若启用了插件连接则返回 ``True``。 + """ + return self._get_bool(self._plugin_section(), "enabled", False) + + @staticmethod + def _get_bool(mapping: Dict[str, Any], key: str, default: bool) -> bool: + """安全读取布尔配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + bool: 解析后的布尔值。 + """ + value = mapping.get(key, default) + return value if isinstance(value, bool) else default + + @staticmethod + def _get_positive_float(mapping: Dict[str, Any], key: str, default: float) -> float: + """安全读取正浮点数配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + default: 读取失败时的默认值。 + + Returns: + float: 合法的正浮点数;否则返回默认值。 + """ + value = mapping.get(key, default) + if isinstance(value, (int, float)) and float(value) > 0: + return float(value) + return default + + @staticmethod + def _get_string(mapping: Dict[str, Any], key: str) -> str: + """安全读取字符串配置值。 + + Args: + mapping: 待读取的配置字典。 + key: 目标键名。 + + Returns: + str: 去除首尾空白后的字符串值。 + """ + value = mapping.get(key) + return "" if value is None else str(value).strip() + + +def create_plugin() -> NapCatAdapterPlugin: + """创建插件实例。 + + Returns: + NapCatAdapterPlugin: NapCat 内置适配器插件实例。 + """ + return NapCatAdapterPlugin() diff --git a/src/services/send_service.py b/src/services/send_service.py index 7af55716..6ca7d005 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -4,7 +4,7 @@ 提供发送各种类型消息的核心功能。 """ -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional import time import traceback @@ -19,6 +19,7 @@ from src.common.data_models.mai_message_data_model import MaiMessage from src.common.data_models.message_component_data_model import DictComponent, MessageSequence from src.common.logger import get_logger from src.config.config import global_config +from src.platform_io.route_key_factory import RouteKeyFactory if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -31,6 +32,50 @@ logger = get_logger("send_service") # ============================================================================= +def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]: + """从目标会话上下文继承 Platform IO 路由元数据。 + + Args: + target_stream: 当前消息要发送到的会话对象。 + + Returns: + Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的 + 路由辅助字段。 + """ + inherited_metadata: Dict[str, object] = {} + + context = getattr(target_stream, "context", None) + context_message = getattr(context, "message", None) + if context_message is None: + return inherited_metadata + + additional_config = getattr(context_message.message_info, "additional_config", {}) + if not isinstance(additional_config, dict): + return inherited_metadata + + for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): + value = additional_config.get(key) + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + inherited_metadata[key] = value + + target_group_id = getattr(target_stream, "group_id", None) + if target_group_id is not None: + normalized_group_id = str(target_group_id).strip() + if normalized_group_id: + inherited_metadata["platform_io_target_group_id"] = normalized_group_id + + target_user_id = getattr(target_stream, "user_id", None) + if target_user_id is not None: + normalized_user_id = str(target_user_id).strip() + if normalized_user_id: + inherited_metadata["platform_io_target_user_id"] = normalized_user_id + + return inherited_metadata + + async def _send_to_target( message_segment: Seg, stream_id: str, @@ -42,7 +87,22 @@ async def _send_to_target( show_log: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定目标发送消息的内部实现""" + """向指定目标发送消息。 + + Args: + message_segment: 待发送的消息段。 + stream_id: 目标会话 ID。 + display_message: 用于界面展示的文本内容。 + typing: 是否显示输入中状态。 + set_reply: 是否在发送时附带引用回复。 + reply_message: 被回复的消息对象。 + storage_message: 是否将发送结果写入消息存储。 + show_log: 是否输出发送日志。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + bool: 发送成功返回 ``True``,否则返回 ``False``。 + """ try: if set_reply and not reply_message: logger.warning("[SendService] 使用引用回复,但未提供回复消息") @@ -80,7 +140,7 @@ async def _send_to_target( platform=target_stream.platform, ) - additional_config: dict[str, object] = {} + additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) if selected_expressions is not None: additional_config["selected_expressions"] = selected_expressions bot_user_id = get_bot_account(target_stream.platform) diff --git a/src/webui/routers/chat/serializers.py b/src/webui/routers/chat/serializers.py new file mode 100644 index 00000000..32104f88 --- /dev/null +++ b/src/webui/routers/chat/serializers.py @@ -0,0 +1,175 @@ +"""提供 WebUI 聊天路由使用的消息序列化能力。""" + +from typing import Any, Dict, List, Optional + +import base64 + +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ForwardComponent, + ForwardNodeComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) + + +def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]: + """将内部统一消息组件序列转换为 WebUI 富文本消息段。 + + Args: + message_sequence: 内部统一消息组件序列。 + + Returns: + List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。 + """ + serialized_segments: List[Dict[str, Any]] = [] + for component in message_sequence.components: + serialized_segment = serialize_message_component(component) + if serialized_segment is not None: + serialized_segments.append(serialized_segment) + return serialized_segments + + +def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]: + """将单个内部消息组件转换为 WebUI 消息段。 + + Args: + component: 待序列化的内部消息组件。 + + Returns: + Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。 + """ + if isinstance(component, TextComponent): + return {"type": "text", "data": component.text} + + if isinstance(component, ImageComponent): + return _serialize_binary_component( + segment_type="image", + mime_type="image/png", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, EmojiComponent): + return _serialize_binary_component( + segment_type="emoji", + mime_type="image/gif", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, VoiceComponent): + return _serialize_binary_component( + segment_type="voice", + mime_type="audio/wav", + binary_data=component.binary_data, + fallback_text=component.content, + ) + + if isinstance(component, AtComponent): + return { + "type": "at", + "data": { + "target_user_id": component.target_user_id, + "target_user_nickname": component.target_user_nickname, + "target_user_cardname": component.target_user_cardname, + }, + } + + if isinstance(component, ReplyComponent): + return { + "type": "reply", + "data": { + "target_message_id": component.target_message_id, + "target_message_content": component.target_message_content, + "target_message_sender_id": component.target_message_sender_id, + "target_message_sender_nickname": component.target_message_sender_nickname, + "target_message_sender_cardname": component.target_message_sender_cardname, + }, + } + + if isinstance(component, ForwardNodeComponent): + return { + "type": "forward", + "data": [_serialize_forward_component(item) for item in component.forward_components], + } + + if isinstance(component, DictComponent): + return _serialize_dict_component(component.data) + + return {"type": "unknown", "data": str(component)} + + +def _serialize_binary_component( + segment_type: str, + mime_type: str, + binary_data: bytes, + fallback_text: str, +) -> Dict[str, Any]: + """序列化带二进制负载的消息组件。 + + Args: + segment_type: WebUI 消息段类型。 + mime_type: 对应的数据 MIME 类型。 + binary_data: 组件二进制数据。 + fallback_text: 二进制缺失时可退化展示的文本。 + + Returns: + Dict[str, Any]: 序列化后的 WebUI 消息段。 + """ + if binary_data: + encoded_payload = base64.b64encode(binary_data).decode() + return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"} + + if fallback_text: + return {"type": "text", "data": fallback_text} + + return {"type": "unknown", "original_type": segment_type, "data": ""} + + +def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]: + """序列化单个转发节点。 + + Args: + component: 待序列化的转发节点组件。 + + Returns: + Dict[str, Any]: WebUI 可消费的转发节点字典。 + """ + return { + "message_id": component.message_id, + "user_id": component.user_id, + "user_nickname": component.user_nickname, + "user_cardname": component.user_cardname, + "content": serialize_message_sequence(MessageSequence(component.content)), + } + + +def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]: + """最佳努力地序列化非标准字典组件。 + + Args: + data: 原始字典组件内容。 + + Returns: + Dict[str, Any]: 序列化后的 WebUI 消息段。 + """ + raw_type = str(data.get("type") or "dict").strip() + raw_payload = data.get("data", data) + + if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}: + return {"type": raw_type, "data": raw_payload} + + if raw_type == "reply": + return {"type": "reply", "data": raw_payload} + + if raw_type == "forward" and isinstance(raw_payload, list): + return {"type": "forward", "data": raw_payload} + + return {"type": "unknown", "original_type": raw_type, "data": raw_payload}