feat: Add NapCat adapter plugin and enhance message handling
- Introduced a built-in NapCat adapter plugin for MVP message forwarding. - Implemented core functionalities for connecting to NapCat/OneBot v11 WebSocket service. - Added message serialization capabilities for WebUI chat routes. - Enhanced the RegisterPluginPayload to include optional adapter declarations. - Implemented methods for handling external messages and adapter declarations in the PluginRunner. - Improved the send_service to inherit platform IO route metadata for outgoing messages.
This commit is contained in:
@@ -1,31 +1,37 @@
|
||||
from rich.traceback import install
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.data_models.message_component_data_model import ReplyComponent
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.webui.routers.chat.serializers import serialize_message_sequence
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster():
|
||||
"""获取 WebUI 聊天室广播器"""
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
@@ -38,102 +44,36 @@ def get_webui_chat_broadcaster():
|
||||
|
||||
|
||||
def is_webui_virtual_group(group_id: str) -> bool:
|
||||
"""检查是否是 WebUI 虚拟群"""
|
||||
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
def parse_message_segments(segment) -> list:
|
||||
"""解析消息段,转换为 WebUI 可用的格式
|
||||
|
||||
参考 NapCat 适配器的消息解析逻辑
|
||||
"""检查是否是 WebUI 虚拟群。
|
||||
|
||||
Args:
|
||||
segment: Seg 消息段对象
|
||||
group_id: 待判断的群 ID。
|
||||
|
||||
Returns:
|
||||
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||
bool: 若群 ID 属于 WebUI 虚拟群则返回 ``True``。
|
||||
"""
|
||||
|
||||
result = []
|
||||
|
||||
if segment is None:
|
||||
return result
|
||||
|
||||
if segment.type == "seglist":
|
||||
# 处理消息段列表
|
||||
if segment.data:
|
||||
for seg in segment.data:
|
||||
result.extend(parse_message_segments(seg))
|
||||
elif segment.type == "text":
|
||||
# 文本消息
|
||||
if segment.data:
|
||||
result.append({"type": "text", "data": segment.data})
|
||||
elif segment.type == "image":
|
||||
# 图片消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
||||
elif segment.type == "emoji":
|
||||
# 表情包消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
||||
elif segment.type == "imageurl":
|
||||
# 图片链接消息
|
||||
if segment.data:
|
||||
result.append({"type": "image", "data": segment.data})
|
||||
elif segment.type == "face":
|
||||
# 原生表情
|
||||
result.append({"type": "face", "data": segment.data})
|
||||
elif segment.type == "voice":
|
||||
# 语音消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
||||
elif segment.type == "voiceurl":
|
||||
# 语音链接
|
||||
if segment.data:
|
||||
result.append({"type": "voice", "data": segment.data})
|
||||
elif segment.type == "video":
|
||||
# 视频消息(base64)
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
||||
elif segment.type == "videourl":
|
||||
# 视频链接
|
||||
if segment.data:
|
||||
result.append({"type": "video", "data": segment.data})
|
||||
elif segment.type == "music":
|
||||
# 音乐消息
|
||||
result.append({"type": "music", "data": segment.data})
|
||||
elif segment.type == "file":
|
||||
# 文件消息
|
||||
result.append({"type": "file", "data": segment.data})
|
||||
elif segment.type == "reply":
|
||||
# 回复消息
|
||||
result.append({"type": "reply", "data": segment.data})
|
||||
elif segment.type == "forward":
|
||||
# 转发消息
|
||||
forward_items = []
|
||||
if segment.data:
|
||||
for item in segment.data:
|
||||
forward_items.append(
|
||||
{
|
||||
"content": parse_message_segments(item.get("message_segment", {}))
|
||||
if isinstance(item, dict)
|
||||
else []
|
||||
}
|
||||
)
|
||||
result.append({"type": "forward", "data": forward_items})
|
||||
else:
|
||||
# 未知类型,尝试作为文本处理
|
||||
if segment.data:
|
||||
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||
|
||||
return result
|
||||
return bool(group_id) and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||
|
||||
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"""执行统一的消息发送流程。
|
||||
|
||||
发送顺序为:
|
||||
1. WebUI 特殊链路
|
||||
2. Platform IO 适配器链路
|
||||
3. 旧版 ``maim_message`` / API Server 链路
|
||||
|
||||
Args:
|
||||
message: 待发送的内部会话消息。
|
||||
show_log: 是否输出发送成功日志。
|
||||
|
||||
Returns:
|
||||
bool: 是否最终发送成功。
|
||||
"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
platform = message.platform
|
||||
group_id = message.session.group_id
|
||||
group_info = message.message_info.group_info
|
||||
group_id = group_info.group_id if group_info is not None else ""
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
@@ -146,7 +86,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
from src.config.config import global_config
|
||||
|
||||
# 解析消息段,获取富文本内容
|
||||
message_segments = parse_message_segments(message.message_segment)
|
||||
message_segments = serialize_message_sequence(message.raw_message)
|
||||
|
||||
# 判断消息类型
|
||||
# 如果只有一个文本段,使用简单的 text 类型
|
||||
@@ -184,8 +124,38 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
||||
return True
|
||||
|
||||
try:
|
||||
from src.platform_io import DeliveryStatus
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
receipt = await get_plugin_runtime_manager().try_send_message_via_platform_io(message)
|
||||
if receipt is not None:
|
||||
if receipt.status == DeliveryStatus.SENT:
|
||||
if show_log:
|
||||
logger.info(
|
||||
f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' "
|
||||
f"(driver: {receipt.driver_id or 'unknown'})"
|
||||
)
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
f"Platform IO 发送失败: platform={platform} driver={receipt.driver_id} "
|
||||
f"status={receipt.status} error={receipt.error}"
|
||||
)
|
||||
return False
|
||||
except Exception as exc:
|
||||
logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}")
|
||||
|
||||
# Fallback 逻辑: 尝试通过 API Server 发送
|
||||
async def send_with_new_api(legacy_exception=None):
|
||||
async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool:
|
||||
"""通过 API Server 回退链路发送消息。
|
||||
|
||||
Args:
|
||||
legacy_exception: 旧发送链已经抛出的异常;若回退也失败,则重新抛出。
|
||||
|
||||
Returns:
|
||||
bool: 回退链路是否发送成功。
|
||||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -289,7 +259,8 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
class UniversalMessageSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化统一消息发送器。"""
|
||||
pass
|
||||
|
||||
async def send_message(
|
||||
@@ -300,7 +271,7 @@ class UniversalMessageSender:
|
||||
reply_message_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
处理、发送并存储一条消息。
|
||||
|
||||
|
||||
@@ -1129,7 +1129,10 @@ class DefaultReplyer:
|
||||
user_id=bot_user_id,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
additional_config={},
|
||||
additional_config={
|
||||
"platform_io_target_group_id": self.chat_stream.group_id,
|
||||
"platform_io_target_user_id": self.chat_stream.user_id,
|
||||
},
|
||||
),
|
||||
message_segment=message_segment,
|
||||
)
|
||||
|
||||
@@ -970,7 +970,9 @@ class PrivateReplyer:
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
group_info=None,
|
||||
additional_config={},
|
||||
additional_config={
|
||||
"platform_io_target_user_id": self.chat_stream.user_id,
|
||||
},
|
||||
),
|
||||
message_segment=message_segment,
|
||||
)
|
||||
|
||||
@@ -1,34 +1,51 @@
|
||||
"""提供 Platform IO 的 plugin 传输驱动骨架。"""
|
||||
"""提供 Platform IO 的插件适配器驱动实现。"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol
|
||||
|
||||
from src.platform_io.drivers.base import PlatformIODriver
|
||||
from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey
|
||||
from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
|
||||
|
||||
class _AdapterSupervisorProtocol(Protocol):
|
||||
"""适配器驱动依赖的 Supervisor 最小协议。"""
|
||||
|
||||
async def invoke_adapter(
|
||||
self,
|
||||
plugin_id: str,
|
||||
method_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""调用适配器插件专用方法。"""
|
||||
|
||||
|
||||
class PluginPlatformDriver(PlatformIODriver):
|
||||
"""面向 ``MessageGateway`` 插件链路的 Platform IO 驱动骨架。"""
|
||||
"""面向适配器插件链路的 Platform IO 驱动。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
driver_id: str,
|
||||
platform: str,
|
||||
supervisor: _AdapterSupervisorProtocol,
|
||||
send_method: str = "send_to_platform",
|
||||
account_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
plugin_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""初始化一个 plugin 驱动描述对象。
|
||||
"""初始化一个插件适配器驱动。
|
||||
|
||||
Args:
|
||||
driver_id: Broker 内的唯一驱动 ID。
|
||||
platform: 该 plugin 适配器链路负责的平台。
|
||||
platform: 该适配器负责的平台名称。
|
||||
supervisor: 持有该适配器插件的 Supervisor。
|
||||
send_method: 出站发送时要调用的插件方法名。
|
||||
account_id: 可选的账号 ID 或 self ID。
|
||||
scope: 可选的额外路由作用域。
|
||||
plugin_id: 拥有该适配器实现的插件 ID,可为空。
|
||||
plugin_id: 拥有该适配器实现的插件 ID。
|
||||
metadata: 可选的额外驱动元数据。
|
||||
"""
|
||||
descriptor = DriverDescriptor(
|
||||
@@ -41,6 +58,8 @@ class PluginPlatformDriver(PlatformIODriver):
|
||||
metadata=metadata or {},
|
||||
)
|
||||
super().__init__(descriptor)
|
||||
self._supervisor = supervisor
|
||||
self._send_method = send_method
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -48,7 +67,7 @@ class PluginPlatformDriver(PlatformIODriver):
|
||||
route_key: RouteKey,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> DeliveryReceipt:
|
||||
"""通过 plugin 传输路径发送消息。
|
||||
"""通过适配器插件发送消息。
|
||||
|
||||
Args:
|
||||
message: 要投递的内部会话消息。
|
||||
@@ -57,8 +76,119 @@ class PluginPlatformDriver(PlatformIODriver):
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 由驱动返回的规范化回执。
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 当前仍处于骨架阶段,尚未真正接入 MessageGateway。
|
||||
"""
|
||||
raise NotImplementedError("PluginPlatformDriver 仅完成地基实现,尚未接入 MessageGateway")
|
||||
from src.plugin_runtime.host.message_utils import PluginMessageUtils
|
||||
|
||||
plugin_id = self.descriptor.plugin_id or ""
|
||||
if not plugin_id:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error="插件适配器驱动缺少 plugin_id",
|
||||
)
|
||||
|
||||
try:
|
||||
message_dict = PluginMessageUtils._session_message_to_dict(message)
|
||||
response = await self._supervisor.invoke_adapter(
|
||||
plugin_id=plugin_id,
|
||||
method_name=self._send_method,
|
||||
args={
|
||||
"message": message_dict,
|
||||
"route": {
|
||||
"platform": route_key.platform,
|
||||
"account_id": route_key.account_id,
|
||||
"scope": route_key.scope,
|
||||
},
|
||||
"metadata": metadata or {},
|
||||
},
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=message.message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
return self._build_receipt(message.message_id, route_key, response)
|
||||
|
||||
def _build_receipt(self, internal_message_id: str, route_key: RouteKey, response: Any) -> DeliveryReceipt:
|
||||
"""将适配器调用响应归一化为出站回执。
|
||||
|
||||
Args:
|
||||
internal_message_id: 内部消息 ID。
|
||||
route_key: 本次投递的路由键。
|
||||
response: Supervisor 返回的 RPC 响应对象。
|
||||
|
||||
Returns:
|
||||
DeliveryReceipt: 标准化后的出站回执。
|
||||
"""
|
||||
if getattr(response, "error", None):
|
||||
error = response.error.get("message", "适配器发送失败")
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=error,
|
||||
)
|
||||
|
||||
payload = getattr(response, "payload", {})
|
||||
invoke_success = bool(payload.get("success", False)) if isinstance(payload, dict) else False
|
||||
if not invoke_success:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(payload.get("result", "适配器发送失败")) if isinstance(payload, dict) else "适配器发送失败",
|
||||
)
|
||||
|
||||
result = payload.get("result") if isinstance(payload, dict) else None
|
||||
if isinstance(result, dict):
|
||||
if result.get("success") is False:
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.FAILED,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
error=str(result.get("error", "适配器发送失败")),
|
||||
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
|
||||
)
|
||||
external_message_id = str(result.get("external_message_id") or result.get("message_id") or "") or None
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
external_message_id=external_message_id,
|
||||
metadata=result.get("metadata", {}) if isinstance(result.get("metadata"), dict) else {},
|
||||
)
|
||||
|
||||
if isinstance(result, str) and result.strip():
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
external_message_id=result.strip(),
|
||||
)
|
||||
|
||||
return DeliveryReceipt(
|
||||
internal_message_id=internal_message_id,
|
||||
route_key=route_key,
|
||||
status=DeliveryStatus.SENT,
|
||||
driver_id=self.driver_id,
|
||||
driver_kind=self.descriptor.kind,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,11 @@ Message Gateway 模块
|
||||
适配器专用,用于将其他平台的消息转换为系统内部的消息格式,并将系统消息转换为其他平台的格式。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.platform_io import DeliveryStatus, get_platform_io_manager
|
||||
|
||||
from .message_utils import PluginMessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,25 +19,53 @@ logger = get_logger("plugin_runtime.host.message_gateway")
|
||||
|
||||
|
||||
class MessageGateway:
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
self._component_registry = component_registry
|
||||
"""Host 侧消息网关包装器。"""
|
||||
|
||||
async def receive_external_message(self, external_message: Dict[str, Any]):
|
||||
"""
|
||||
接收外部消息,转换为系统内部格式,并返回转换结果
|
||||
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||
"""初始化消息网关。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据
|
||||
component_registry: 组件注册表。
|
||||
"""
|
||||
self._component_registry = component_registry
|
||||
|
||||
def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
|
||||
"""将标准消息字典转换为 ``SessionMessage``。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据。
|
||||
|
||||
Returns:
|
||||
转换后的 SessionMessage 对象
|
||||
SessionMessage: 转换后的内部消息对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息字典不合法时抛出。
|
||||
"""
|
||||
return PluginMessageUtils._build_session_message_from_dict(external_message)
|
||||
|
||||
def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
|
||||
"""将 ``SessionMessage`` 转换为标准消息字典。
|
||||
|
||||
Args:
|
||||
internal_message: 内部消息对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 供适配器插件消费的标准消息字典。
|
||||
"""
|
||||
return dict(PluginMessageUtils._session_message_to_dict(internal_message))
|
||||
|
||||
async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
|
||||
"""接收外部消息并送入主消息链。
|
||||
|
||||
Args:
|
||||
external_message: 外部消息的字典格式数据。
|
||||
"""
|
||||
# 使用递归函数将外部消息字典转换为 SessionMessage
|
||||
try:
|
||||
session_message = PluginMessageUtils._build_session_message_from_dict(external_message)
|
||||
session_message = self.build_session_message(external_message)
|
||||
except Exception as e:
|
||||
logger.error(f"转换外部消息失败: {e}")
|
||||
return
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
await chat_bot.receive_message(session_message)
|
||||
@@ -48,46 +78,32 @@ class MessageGateway:
|
||||
enabled_only: bool = True,
|
||||
save_to_db: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
接收系统内部消息,转换为外部格式,并返回转换结果
|
||||
"""将内部消息通过 Platform IO 发送到外部平台。
|
||||
|
||||
Args:
|
||||
internal_message: 系统内部的 SessionMessage 对象
|
||||
internal_message: 系统内部的 ``SessionMessage`` 对象。
|
||||
supervisor: 当前持有该消息网关的 Supervisor。
|
||||
enabled_only: 兼容旧签名的保留参数,当前由 Platform IO 统一裁决。
|
||||
save_to_db: 发送成功后是否写入数据库。
|
||||
|
||||
Returns:
|
||||
转换是否成功
|
||||
bool: 是否发送成功。
|
||||
"""
|
||||
try:
|
||||
# 将 SessionMessage 转换为字典格式
|
||||
message_dict = PluginMessageUtils._session_message_to_dict(internal_message)
|
||||
except Exception as e:
|
||||
logger.error(f"转换内部消息失败:{e}")
|
||||
return False
|
||||
gateway_entry = self._component_registry.get_message_gateways(
|
||||
internal_message.platform,
|
||||
enabled_only=enabled_only,
|
||||
session_id=internal_message.session_id,
|
||||
)
|
||||
if not gateway_entry:
|
||||
logger.warning(f"未找到适配平台 {internal_message.platform} 的消息网关组件,无法发送消息到外部平台")
|
||||
return False
|
||||
args = {"platform": internal_message.platform, "message": message_dict}
|
||||
try:
|
||||
resp_envelope = await supervisor.invoke_plugin(
|
||||
"plugin.emit_event", gateway_entry.plugin_id, gateway_entry.name, args
|
||||
)
|
||||
logger.debug("信息发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"调用消息网关组件失败:{e}")
|
||||
del enabled_only
|
||||
del supervisor
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
if not platform_io_manager.is_started:
|
||||
logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
|
||||
return False
|
||||
|
||||
# 更新为实际id(如果组件返回了新的id)
|
||||
actual_message_id = resp_envelope.payload.get("message_id")
|
||||
try:
|
||||
actual_message_id = str(actual_message_id)
|
||||
except Exception:
|
||||
actual_message_id = None
|
||||
internal_message.message_id = actual_message_id or internal_message.message_id
|
||||
route_key = platform_io_manager.build_route_key_from_message(internal_message)
|
||||
receipt = await platform_io_manager.send_message(internal_message, route_key)
|
||||
if receipt.status != DeliveryStatus.SENT:
|
||||
logger.warning(f"通过适配器链路发送消息失败: {receipt.error or receipt.status}")
|
||||
return False
|
||||
|
||||
internal_message.message_id = receipt.external_message_id or internal_message.message_id
|
||||
if save_to_db:
|
||||
try:
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
@@ -209,6 +209,9 @@ class PluginMessageUtils:
|
||||
session_message.is_notify = message_dict.get("is_notify", False)
|
||||
if not isinstance(session_message.is_notify, bool):
|
||||
session_message.is_notify = False
|
||||
session_message.session_id = message_dict.get("session_id", "")
|
||||
if not isinstance(session_message.session_id, str):
|
||||
session_message.session_id = ""
|
||||
session_message.reply_to = message_dict.get("reply_to")
|
||||
if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
|
||||
session_message.reply_to = None
|
||||
|
||||
@@ -8,13 +8,20 @@ import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||
from src.platform_io.drivers import PluginPlatformDriver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.platform_io.routing import RouteBindingConflictError
|
||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
AdapterDeclarationPayload,
|
||||
BootstrapPluginPayload,
|
||||
ConfigUpdatedPayload,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
PROTOCOL_VERSION,
|
||||
ReceiveExternalMessagePayload,
|
||||
ReceiveExternalMessageResultPayload,
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
RunnerReadyPayload,
|
||||
@@ -86,6 +93,7 @@ class PluginRunnerSupervisor:
|
||||
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._registered_plugins: Dict[str, RegisterPluginPayload] = {}
|
||||
self._registered_adapters: Dict[str, AdapterDeclarationPayload] = {}
|
||||
self._runner_ready_events: asyncio.Event = asyncio.Event()
|
||||
self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload()
|
||||
self._health_task: Optional[asyncio.Task[None]] = None
|
||||
@@ -257,6 +265,32 @@ class PluginRunnerSupervisor:
|
||||
timeout_ms,
|
||||
)
|
||||
|
||||
async def invoke_adapter(
|
||||
self,
|
||||
plugin_id: str,
|
||||
method_name: str,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""调用适配器插件的专用方法。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标适配器插件 ID。
|
||||
method_name: 要调用的插件方法名,例如 ``send_to_platform``。
|
||||
args: 传递给插件方法的关键字参数。
|
||||
timeout_ms: RPC 超时时间,单位毫秒。
|
||||
|
||||
Returns:
|
||||
Envelope: Runner 返回的响应信封。
|
||||
"""
|
||||
return await self.invoke_plugin(
|
||||
method="plugin.invoke_adapter",
|
||||
plugin_id=plugin_id,
|
||||
component_name=method_name,
|
||||
args=args,
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool:
|
||||
"""按插件 ID 触发精确重载。
|
||||
|
||||
@@ -384,6 +418,7 @@ class PluginRunnerSupervisor:
|
||||
def _register_internal_methods(self) -> None:
|
||||
"""注册 Host 侧内部 RPC 方法。"""
|
||||
self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request)
|
||||
self._rpc_server.register_method("host.receive_external_message", self._handle_receive_external_message)
|
||||
self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin)
|
||||
self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin)
|
||||
self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin)
|
||||
@@ -427,6 +462,17 @@ class PluginRunnerSupervisor:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
if payload.plugin_id in self._registered_adapters:
|
||||
await self._unregister_adapter_driver(payload.plugin_id)
|
||||
|
||||
try:
|
||||
if payload.adapter is not None:
|
||||
await self._register_adapter_driver(payload.plugin_id, payload.adapter)
|
||||
except RouteBindingConflictError as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, str(exc))
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
|
||||
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
[component.model_dump() for component in payload.components],
|
||||
@@ -438,6 +484,7 @@ class PluginRunnerSupervisor:
|
||||
"accepted": True,
|
||||
"plugin_id": payload.plugin_id,
|
||||
"registered_components": registered_count,
|
||||
"adapter_registered": payload.adapter is not None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -458,6 +505,7 @@ class PluginRunnerSupervisor:
|
||||
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
self._authorization.revoke_permission_token(payload.plugin_id)
|
||||
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
|
||||
await self._unregister_adapter_driver(payload.plugin_id)
|
||||
|
||||
return envelope.make_response(
|
||||
payload={
|
||||
@@ -469,6 +517,221 @@ class PluginRunnerSupervisor:
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_adapter_driver_id(plugin_id: str) -> str:
|
||||
"""构造适配器驱动 ID。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
|
||||
Returns:
|
||||
str: 对应 Platform IO 中的驱动 ID。
|
||||
"""
|
||||
return f"adapter:{plugin_id}"
|
||||
|
||||
async def _register_adapter_driver(self, plugin_id: str, adapter: AdapterDeclarationPayload) -> None:
|
||||
"""将适配器插件注册到 Platform IO。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
adapter: 经过校验的适配器声明。
|
||||
|
||||
Raises:
|
||||
ValueError: 适配器路由冲突或驱动注册失败时抛出。
|
||||
"""
|
||||
await self._unregister_adapter_driver(plugin_id)
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
driver = PluginPlatformDriver(
|
||||
driver_id=self._build_adapter_driver_id(plugin_id),
|
||||
platform=adapter.platform,
|
||||
account_id=adapter.account_id or None,
|
||||
scope=adapter.scope or None,
|
||||
plugin_id=plugin_id,
|
||||
send_method=adapter.send_method,
|
||||
supervisor=self,
|
||||
metadata={
|
||||
"protocol": adapter.protocol,
|
||||
**adapter.metadata,
|
||||
},
|
||||
)
|
||||
binding = RouteBinding(
|
||||
route_key=driver.descriptor.route_key,
|
||||
driver_id=driver.driver_id,
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
metadata={
|
||||
"plugin_id": plugin_id,
|
||||
"protocol": adapter.protocol,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
if platform_io_manager.is_started:
|
||||
await platform_io_manager.add_driver(driver)
|
||||
else:
|
||||
platform_io_manager.register_driver(driver)
|
||||
platform_io_manager.bind_route(binding)
|
||||
except Exception:
|
||||
with contextlib.suppress(Exception):
|
||||
if platform_io_manager.is_started:
|
||||
await platform_io_manager.remove_driver(driver.driver_id)
|
||||
else:
|
||||
platform_io_manager.unregister_driver(driver.driver_id)
|
||||
raise
|
||||
|
||||
self._registered_adapters[plugin_id] = adapter
|
||||
|
||||
async def _unregister_adapter_driver(self, plugin_id: str) -> None:
|
||||
"""从 Platform IO 注销一个适配器驱动。
|
||||
|
||||
Args:
|
||||
plugin_id: 适配器插件 ID。
|
||||
"""
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
driver_id = self._build_adapter_driver_id(plugin_id)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
if platform_io_manager.is_started:
|
||||
await platform_io_manager.remove_driver(driver_id)
|
||||
else:
|
||||
platform_io_manager.unregister_driver(driver_id)
|
||||
|
||||
self._registered_adapters.pop(plugin_id, None)
|
||||
|
||||
async def _unregister_all_adapter_drivers(self) -> None:
|
||||
"""注销当前 Supervisor 管理的全部适配器驱动。"""
|
||||
plugin_ids = list(self._registered_adapters.keys())
|
||||
for plugin_id in plugin_ids:
|
||||
await self._unregister_adapter_driver(plugin_id)
|
||||
|
||||
@staticmethod
|
||||
def _attach_inbound_route_metadata(
|
||||
session_message: "SessionMessage",
|
||||
route_key: RouteKey,
|
||||
route_metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
"""将入站路由信息写回消息的 ``additional_config``。
|
||||
|
||||
Args:
|
||||
session_message: 已构造好的内部消息对象。
|
||||
route_key: Host 为该消息解析出的标准路由键。
|
||||
route_metadata: 适配器通过 RPC 补充的原始路由辅助元数据。
|
||||
"""
|
||||
additional_config = session_message.message_info.additional_config
|
||||
if not isinstance(additional_config, dict):
|
||||
additional_config = {}
|
||||
session_message.message_info.additional_config = additional_config
|
||||
|
||||
for key, value in route_metadata.items():
|
||||
if value is None:
|
||||
continue
|
||||
normalized_value = str(value).strip()
|
||||
if normalized_value:
|
||||
additional_config[key] = value
|
||||
|
||||
if route_key.account_id:
|
||||
additional_config.setdefault("platform_io_account_id", route_key.account_id)
|
||||
if route_key.scope:
|
||||
additional_config.setdefault("platform_io_scope", route_key.scope)
|
||||
|
||||
def _build_inbound_route_key(
|
||||
self,
|
||||
adapter: AdapterDeclarationPayload,
|
||||
message: Dict[str, Any],
|
||||
route_metadata: Dict[str, Any],
|
||||
) -> RouteKey:
|
||||
"""为适配器入站消息构造归一路由键。
|
||||
|
||||
Args:
|
||||
adapter: 当前适配器声明。
|
||||
message: 标准消息字典。
|
||||
route_metadata: 插件补充的路由辅助元数据。
|
||||
|
||||
Returns:
|
||||
RouteKey: 供 Platform IO 使用的规范化路由键。
|
||||
|
||||
Raises:
|
||||
ValueError: 消息平台字段与适配器平台声明不一致时抛出。
|
||||
"""
|
||||
message_platform = str(message.get("platform") or adapter.platform).strip()
|
||||
if message_platform != adapter.platform:
|
||||
raise ValueError(
|
||||
f"外部消息平台 {message_platform} 与适配器 {adapter.platform} 不一致"
|
||||
)
|
||||
|
||||
try:
|
||||
route_key = RouteKeyFactory.from_message_dict(message)
|
||||
except Exception:
|
||||
route_key = RouteKey(platform=message_platform)
|
||||
|
||||
route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
|
||||
account_id = route_key.account_id or route_account_id or adapter.account_id or None
|
||||
scope = route_key.scope or route_scope or adapter.scope or None
|
||||
return RouteKey(
|
||||
platform=message_platform,
|
||||
account_id=account_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
async def _handle_receive_external_message(self, envelope: Envelope) -> Envelope:
|
||||
"""处理适配器插件上报的外部入站消息。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: 注入结果响应。
|
||||
"""
|
||||
try:
|
||||
payload = ReceiveExternalMessagePayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
adapter = self._registered_adapters.get(envelope.plugin_id)
|
||||
if adapter is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {envelope.plugin_id} 未声明为适配器,不能注入外部消息",
|
||||
)
|
||||
|
||||
try:
|
||||
route_key = self._build_inbound_route_key(
|
||||
adapter=adapter,
|
||||
message=payload.message,
|
||||
route_metadata=payload.route_metadata,
|
||||
)
|
||||
session_message = self._message_gateway.build_session_message(payload.message)
|
||||
self._attach_inbound_route_metadata(session_message, route_key, payload.route_metadata)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
accepted = await platform_io_manager.accept_inbound(
|
||||
InboundMessageEnvelope(
|
||||
route_key=route_key,
|
||||
driver_id=self._build_adapter_driver_id(envelope.plugin_id),
|
||||
driver_kind=DriverKind.PLUGIN,
|
||||
external_message_id=payload.external_message_id or str(payload.message.get("message_id") or "") or None,
|
||||
dedupe_key=payload.dedupe_key or None,
|
||||
session_message=session_message,
|
||||
payload=payload.message,
|
||||
metadata={
|
||||
"plugin_id": envelope.plugin_id,
|
||||
"protocol": adapter.protocol,
|
||||
**payload.route_metadata,
|
||||
},
|
||||
)
|
||||
)
|
||||
response = ReceiveExternalMessageResultPayload(
|
||||
accepted=accepted,
|
||||
route_key={
|
||||
"platform": route_key.platform,
|
||||
"account_id": route_key.account_id,
|
||||
"scope": route_key.scope,
|
||||
},
|
||||
)
|
||||
return envelope.make_response(payload=response.model_dump())
|
||||
|
||||
async def _handle_runner_ready(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 Runner 就绪通知。
|
||||
|
||||
@@ -595,6 +858,7 @@ class PluginRunnerSupervisor:
|
||||
await self._stderr_drain_task
|
||||
self._stderr_drain_task = None
|
||||
|
||||
await self._unregister_all_adapter_drivers()
|
||||
self._clear_runner_state()
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
@@ -671,6 +935,7 @@ class PluginRunnerSupervisor:
|
||||
self._authorization.clear()
|
||||
self._component_registry.clear()
|
||||
self._registered_plugins.clear()
|
||||
self._registered_adapters.clear()
|
||||
self._runner_ready_events = asyncio.Event()
|
||||
self._runner_ready_payloads = RunnerReadyPayload()
|
||||
|
||||
|
||||
@@ -12,11 +12,13 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Ite
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.file_watcher import FileChange, FileWatcher
|
||||
from src.platform_io import DeliveryReceipt, InboundMessageEnvelope, get_platform_io_manager
|
||||
from src.plugin_runtime.capabilities import (
|
||||
RuntimeComponentCapabilityMixin,
|
||||
RuntimeCoreCapabilityMixin,
|
||||
@@ -57,6 +59,7 @@ class PluginRuntimeManager(
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化插件运行时管理器。"""
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
self._builtin_supervisor: Optional[PluginSupervisor] = None
|
||||
@@ -66,6 +69,22 @@ class PluginRuntimeManager(
|
||||
self._plugin_source_watcher_subscription_id: Optional[str] = None
|
||||
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
|
||||
|
||||
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
|
||||
"""接收 Platform IO 审核后的入站消息并送入主消息链。
|
||||
|
||||
Args:
|
||||
envelope: Platform IO 产出的入站封装。
|
||||
"""
|
||||
session_message = envelope.session_message
|
||||
if session_message is None and envelope.payload is not None:
|
||||
session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
|
||||
if session_message is None:
|
||||
raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
await chat_bot.receive_message(session_message)
|
||||
|
||||
# ─── 插件目录 ─────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
@@ -110,6 +129,8 @@ class PluginRuntimeManager(
|
||||
logger.info("未找到任何插件目录,跳过插件运行时启动")
|
||||
return
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
|
||||
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
|
||||
socket_path_base = _cfg.ipc_socket_path or None
|
||||
|
||||
@@ -134,6 +155,9 @@ class PluginRuntimeManager(
|
||||
|
||||
started_supervisors: List[PluginSupervisor] = []
|
||||
try:
|
||||
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
|
||||
await platform_io_manager.start()
|
||||
|
||||
if self._builtin_supervisor:
|
||||
await self._builtin_supervisor.start()
|
||||
started_supervisors.append(self._builtin_supervisor)
|
||||
@@ -147,6 +171,11 @@ class PluginRuntimeManager(
|
||||
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
|
||||
await self._stop_plugin_file_watcher()
|
||||
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
|
||||
platform_io_manager.clear_inbound_dispatcher()
|
||||
try:
|
||||
await platform_io_manager.stop()
|
||||
except Exception as platform_io_exc:
|
||||
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
self._third_party_supervisor = None
|
||||
@@ -156,6 +185,7 @@ class PluginRuntimeManager(
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
await self._stop_plugin_file_watcher()
|
||||
|
||||
coroutines: List[Coroutine[Any, Any, None]] = []
|
||||
@@ -164,11 +194,23 @@ class PluginRuntimeManager(
|
||||
if self._third_party_supervisor:
|
||||
coroutines.append(self._third_party_supervisor.stop())
|
||||
|
||||
stop_errors: List[str] = []
|
||||
try:
|
||||
await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
logger.info("插件运行时已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"插件运行时停止失败: {e}", exc_info=True)
|
||||
results = await asyncio.gather(*coroutines, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
stop_errors.append(str(result))
|
||||
|
||||
platform_io_manager.clear_inbound_dispatcher()
|
||||
try:
|
||||
await platform_io_manager.stop()
|
||||
except Exception as exc:
|
||||
stop_errors.append(f"Platform IO: {exc}")
|
||||
|
||||
if stop_errors:
|
||||
logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
|
||||
else:
|
||||
logger.info("插件运行时已停止")
|
||||
finally:
|
||||
self._started = False
|
||||
self._builtin_supervisor = None
|
||||
@@ -176,6 +218,7 @@ class PluginRuntimeManager(
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""返回插件运行时是否处于启动状态。"""
|
||||
return self._started
|
||||
|
||||
@property
|
||||
@@ -303,6 +346,37 @@ class PluginRuntimeManager(
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
async def try_send_message_via_platform_io(
|
||||
self,
|
||||
message: "SessionMessage",
|
||||
) -> Optional[DeliveryReceipt]:
|
||||
"""尝试通过 Platform IO 中间层发送消息。
|
||||
|
||||
Args:
|
||||
message: 待发送的内部会话消息。
|
||||
|
||||
Returns:
|
||||
Optional[DeliveryReceipt]: 若当前消息存在 active 路由,则返回实际发送
|
||||
结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
|
||||
"""
|
||||
if not self._started:
|
||||
return None
|
||||
|
||||
platform_io_manager = get_platform_io_manager()
|
||||
if not platform_io_manager.is_started:
|
||||
return None
|
||||
|
||||
try:
|
||||
route_key = platform_io_manager.build_route_key_from_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
|
||||
return None
|
||||
|
||||
if platform_io_manager.resolve_driver(route_key) is None:
|
||||
return None
|
||||
|
||||
return await platform_io_manager.send_message(message, route_key)
|
||||
|
||||
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
|
||||
"""返回当前持有指定插件的所有 Supervisor。
|
||||
|
||||
@@ -426,6 +500,11 @@ class PluginRuntimeManager(
|
||||
"""为指定插件生成配置文件变更回调。"""
|
||||
|
||||
async def _callback(changes: Sequence[FileChange]) -> None:
|
||||
"""将 watcher 事件转发到指定插件的配置处理逻辑。
|
||||
|
||||
Args:
|
||||
changes: 当前批次收集到的文件变更列表。
|
||||
"""
|
||||
await self._handle_plugin_config_changes(plugin_id, changes)
|
||||
|
||||
return _callback
|
||||
@@ -542,6 +621,11 @@ class PluginRuntimeManager(
|
||||
# ─── 能力实现注册 ──────────────────────────────────────────
|
||||
|
||||
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
|
||||
"""向指定 Supervisor 注册主程序能力实现。
|
||||
|
||||
Args:
|
||||
supervisor: 需要注册能力实现的目标 Supervisor。
|
||||
"""
|
||||
register_capability_impls(self, supervisor)
|
||||
|
||||
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import logging as stdlib_logging
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ====== 协议常量 ======
|
||||
PROTOCOL_VERSION = "1.0.0"
|
||||
@@ -156,6 +156,8 @@ class RegisterPluginPayload(BaseModel):
|
||||
"""插件版本"""
|
||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
"""组件列表"""
|
||||
adapter: Optional["AdapterDeclarationPayload"] = Field(default=None, description="可选的适配器声明")
|
||||
"""可选的适配器声明"""
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
"""所需能力列表"""
|
||||
|
||||
@@ -285,6 +287,48 @@ class ReloadPluginResultPayload(BaseModel):
|
||||
"""重载失败的插件及原因"""
|
||||
|
||||
|
||||
class AdapterDeclarationPayload(BaseModel):
|
||||
"""适配器插件声明载荷。"""
|
||||
|
||||
platform: str = Field(description="适配器负责的平台名称,例如 qq")
|
||||
"""适配器负责的平台名称,例如 qq"""
|
||||
protocol: str = Field(default="", description="接入协议或实现名称,例如 napcat")
|
||||
"""接入协议或实现名称,例如 napcat"""
|
||||
account_id: str = Field(default="", description="可选的账号 ID 或 self_id")
|
||||
"""可选的账号 ID 或 self_id"""
|
||||
scope: str = Field(default="", description="可选的路由作用域")
|
||||
"""可选的路由作用域"""
|
||||
send_method: str = Field(default="send_to_platform", description="Host 出站调用的插件方法名")
|
||||
"""Host 出站调用的插件方法名"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="适配器附加元数据")
|
||||
"""适配器附加元数据"""
|
||||
|
||||
|
||||
class ReceiveExternalMessagePayload(BaseModel):
|
||||
"""适配器插件向 Host 注入外部消息的请求载荷。"""
|
||||
|
||||
message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
|
||||
"""符合 MessageDict 结构的标准消息字典"""
|
||||
route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
|
||||
"""可选的路由辅助元数据"""
|
||||
external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
|
||||
"""可选的外部平台消息 ID"""
|
||||
dedupe_key: str = Field(default="", description="可选的显式去重键")
|
||||
"""可选的显式去重键"""
|
||||
|
||||
|
||||
class ReceiveExternalMessageResultPayload(BaseModel):
|
||||
"""外部消息注入结果载荷。"""
|
||||
|
||||
accepted: bool = Field(description="Host 是否接受了本次消息注入")
|
||||
"""Host 是否接受了本次消息注入"""
|
||||
route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
|
||||
"""本次消息使用的归一路由键"""
|
||||
|
||||
|
||||
RegisterPluginPayload.model_rebuild()
|
||||
|
||||
|
||||
# ====== 日志传输 ======
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -26,6 +25,7 @@ import tomllib
|
||||
from src.common.logger import get_console_handler, get_logger, initialize_logging
|
||||
from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
AdapterDeclarationPayload,
|
||||
BootstrapPluginPayload,
|
||||
ComponentDeclaration,
|
||||
Envelope,
|
||||
@@ -227,7 +227,7 @@ class PluginRunner:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
"""桥接 PluginContext.call_capability → RPCClient.send_request。
|
||||
"""桥接 PluginContext 的原始 RPC 调用到 Host。
|
||||
|
||||
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
|
||||
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
|
||||
@@ -237,17 +237,13 @@ class PluginRunner:
|
||||
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
|
||||
)
|
||||
resp = await rpc_client.send_request(
|
||||
method="cap.call",
|
||||
method=method,
|
||||
plugin_id=bound_plugin_id,
|
||||
payload={
|
||||
"capability": method,
|
||||
"args": payload or {},
|
||||
},
|
||||
payload=payload or {},
|
||||
)
|
||||
# 从响应信封中提取业务结果
|
||||
if resp.error:
|
||||
raise RuntimeError(resp.error.get("message", "能力调用失败"))
|
||||
return resp.payload.get("result")
|
||||
return resp.payload
|
||||
|
||||
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
|
||||
cast(_ContextAwarePlugin, instance)._set_context(ctx)
|
||||
@@ -286,6 +282,7 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_adapter", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
|
||||
@@ -306,12 +303,14 @@ class PluginRunner:
|
||||
)
|
||||
|
||||
try:
|
||||
await self._rpc_client.send_request(
|
||||
response = await self._rpc_client.send_request(
|
||||
"plugin.bootstrap",
|
||||
plugin_id=meta.plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
if response.error:
|
||||
raise RuntimeError(response.error.get("message", "插件 bootstrap 失败"))
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {meta.plugin_id} bootstrap 失败: {e}")
|
||||
@@ -321,6 +320,29 @@ class PluginRunner:
|
||||
"""撤销 bootstrap 期间为插件签发的能力令牌。"""
|
||||
await self._bootstrap_plugin(meta, capabilities_required=[])
|
||||
|
||||
def _collect_adapter_declaration(self, meta: PluginMeta) -> Optional[AdapterDeclarationPayload]:
|
||||
"""从插件实例中提取适配器声明。
|
||||
|
||||
Args:
|
||||
meta: 待提取声明的插件元数据。
|
||||
|
||||
Returns:
|
||||
Optional[AdapterDeclarationPayload]: 若插件声明了适配器角色,则返回
|
||||
经过校验的适配器声明;否则返回 ``None``。
|
||||
|
||||
Raises:
|
||||
ValueError: 插件导出的适配器声明结构非法时抛出。
|
||||
"""
|
||||
instance = meta.instance
|
||||
if not hasattr(instance, "get_adapter_info"):
|
||||
return None
|
||||
|
||||
adapter_info = instance.get_adapter_info()
|
||||
if adapter_info is None:
|
||||
return None
|
||||
|
||||
return AdapterDeclarationPayload.model_validate(adapter_info)
|
||||
|
||||
async def _register_plugin(self, meta: PluginMeta) -> bool:
|
||||
"""向 Host 注册单个插件。
|
||||
|
||||
@@ -346,20 +368,29 @@ class PluginRunner:
|
||||
for comp_info in instance.get_components()
|
||||
)
|
||||
|
||||
try:
|
||||
adapter = self._collect_adapter_declaration(meta)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {meta.plugin_id} 适配器声明非法: {exc}", exc_info=True)
|
||||
return False
|
||||
|
||||
reg_payload = RegisterPluginPayload(
|
||||
plugin_id=meta.plugin_id,
|
||||
plugin_version=meta.version,
|
||||
components=components,
|
||||
adapter=adapter,
|
||||
capabilities_required=meta.capabilities_required,
|
||||
)
|
||||
|
||||
try:
|
||||
_resp = await self._rpc_client.send_request(
|
||||
response = await self._rpc_client.send_request(
|
||||
"plugin.register_components",
|
||||
plugin_id=meta.plugin_id,
|
||||
payload=reg_payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
if response.error:
|
||||
raise RuntimeError(response.error.get("message", "插件注册失败"))
|
||||
logger.info(f"插件 {meta.plugin_id} 注册完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
30
src/plugins/built_in/napcat_adapter/_manifest.json
Normal file
30
src/plugins/built_in/napcat_adapter/_manifest.json
Normal file
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "napcat_adapter_builtin",
|
||||
"version": "0.1.0",
|
||||
"description": "Built-in NapCat adapter plugin for MVP message forwarding.",
|
||||
"author": {
|
||||
"name": "OpenAI Codex"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "1.0.0"
|
||||
},
|
||||
"keywords": [
|
||||
"adapter",
|
||||
"built-in",
|
||||
"napcat",
|
||||
"onebot",
|
||||
"qq"
|
||||
],
|
||||
"categories": [
|
||||
"Adapter",
|
||||
"Built-in"
|
||||
],
|
||||
"default_locale": "en-US",
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "adapter"
|
||||
},
|
||||
"capabilities": []
|
||||
}
|
||||
690
src/plugins/built_in/napcat_adapter/plugin.py
Normal file
690
src/plugins/built_in/napcat_adapter/plugin.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""内置 NapCat 适配器插件。
|
||||
|
||||
当前实现是一个 MVP 版本,目标仅限于跑通基础消息收发链路:
|
||||
1. 作为客户端连接 NapCat / OneBot v11 WebSocket 服务。
|
||||
2. 将入站消息事件转换为 Host 侧的 ``MessageDict``。
|
||||
3. 将 Host 出站消息转换为 OneBot 动作并发送。
|
||||
|
||||
当前范围刻意收敛为:
|
||||
- 单连接
|
||||
- 文本、@、reply 基础转发
|
||||
- 暂不处理 ``notice`` / ``meta_event``
|
||||
- 暂不支持图片、语音、文件等复杂媒体
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
|
||||
from maibot_sdk import Adapter, MaiBotPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientWebSocketResponse as AiohttpClientWebSocketResponse
|
||||
|
||||
try:
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientWebSocketResponse, WSMsgType
|
||||
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
ClientSession = cast(Any, None)
|
||||
ClientTimeout = cast(Any, None)
|
||||
ClientWebSocketResponse = cast(Any, None)
|
||||
WSMsgType = cast(Any, None)
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
AiohttpClientWebSocketResponse = Any
|
||||
|
||||
|
||||
@Adapter(platform="qq", protocol="napcat", send_method="send_to_platform")
|
||||
class NapCatAdapterPlugin(MaiBotPlugin):
|
||||
"""NapCat 适配器 MVP 实现。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化 NapCat 适配器插件实例。"""
|
||||
super().__init__()
|
||||
self._plugin_config: Dict[str, Any] = {}
|
||||
self._connection_task: Optional[asyncio.Task[None]] = None
|
||||
self._pending_actions: Dict[str, asyncio.Future[Dict[str, Any]]] = {}
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._ws: Optional[AiohttpClientWebSocketResponse] = None
|
||||
|
||||
def set_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""设置插件配置内容。
|
||||
|
||||
Args:
|
||||
config: Runner 注入的 ``config.toml`` 解析结果。
|
||||
"""
|
||||
self._plugin_config = config if isinstance(config, dict) else {}
|
||||
|
||||
async def on_load(self) -> None:
|
||||
"""在插件加载时根据配置决定是否启动连接。"""
|
||||
await self._restart_connection_if_needed()
|
||||
|
||||
async def on_unload(self) -> None:
|
||||
"""在插件卸载时关闭连接并清理后台任务。"""
|
||||
await self._stop_connection()
|
||||
await self._cancel_background_tasks()
|
||||
|
||||
async def on_config_update(self, new_config: Dict[str, Any], version: str) -> None:
|
||||
"""在配置更新后重载连接状态。
|
||||
|
||||
Args:
|
||||
new_config: 最新的插件配置。
|
||||
version: 配置版本号。
|
||||
"""
|
||||
del version
|
||||
self.set_plugin_config(new_config)
|
||||
await self._restart_connection_if_needed()
|
||||
|
||||
async def send_to_platform(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
route: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""将 Host 出站消息发送到 NapCat。
|
||||
|
||||
Args:
|
||||
message: Host 侧标准 ``MessageDict``。
|
||||
route: Platform IO 生成的路由信息。
|
||||
metadata: Platform IO 附带的投递元数据。
|
||||
**kwargs: 预留的扩展参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 标准化后的发送结果。
|
||||
"""
|
||||
del metadata
|
||||
del kwargs
|
||||
|
||||
ws = self._ws
|
||||
if ws is None or ws.closed:
|
||||
return {"success": False, "error": "NapCat is not connected"}
|
||||
|
||||
try:
|
||||
action_name, params = self._build_outbound_action(message, route or {})
|
||||
response = await self._call_action(action_name, params)
|
||||
except Exception as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if str(response.get("status", "")).lower() != "ok":
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(response.get("wording") or response.get("message") or "NapCat send failed"),
|
||||
"metadata": {"retcode": response.get("retcode")},
|
||||
}
|
||||
|
||||
response_data = response.get("data", {})
|
||||
external_message_id = ""
|
||||
if isinstance(response_data, dict):
|
||||
external_message_id = str(response_data.get("message_id") or "")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"external_message_id": external_message_id or None,
|
||||
"metadata": {"action": action_name},
|
||||
}
|
||||
|
||||
async def _restart_connection_if_needed(self) -> None:
|
||||
"""根据当前配置重启连接循环。"""
|
||||
await self._stop_connection()
|
||||
if not self._should_connect():
|
||||
self.ctx.logger.info("NapCat 适配器保持空闲状态,因为插件或配置未启用")
|
||||
return
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
self.ctx.logger.error("NapCat 适配器依赖 aiohttp,但当前环境未安装该依赖")
|
||||
return
|
||||
self._connection_task = asyncio.create_task(self._connection_loop(), name="napcat_adapter.connection")
|
||||
|
||||
async def _stop_connection(self) -> None:
|
||||
"""停止当前连接并让所有等待中的动作失败返回。"""
|
||||
connection_task = self._connection_task
|
||||
self._connection_task = None
|
||||
|
||||
ws = self._ws
|
||||
if ws is not None and not ws.closed:
|
||||
with contextlib.suppress(Exception):
|
||||
await ws.close()
|
||||
self._ws = None
|
||||
|
||||
if connection_task is not None:
|
||||
connection_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await connection_task
|
||||
|
||||
self._fail_pending_actions("NapCat connection closed")
|
||||
|
||||
async def _cancel_background_tasks(self) -> None:
|
||||
"""取消所有仍在运行的入站后台任务。"""
|
||||
background_tasks = list(self._background_tasks)
|
||||
for task in background_tasks:
|
||||
task.cancel()
|
||||
if background_tasks:
|
||||
with contextlib.suppress(Exception):
|
||||
await asyncio.gather(*background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def _connection_loop(self) -> None:
|
||||
"""维护单个 WebSocket 连接,并在断开后按配置重连。"""
|
||||
assert ClientSession is not None
|
||||
assert ClientTimeout is not None
|
||||
|
||||
while self._should_connect():
|
||||
ws_url = self._get_string(self._connection_config(), "ws_url")
|
||||
if not ws_url:
|
||||
self.ctx.logger.warning("NapCat 适配器已启用,但 connection.ws_url 为空")
|
||||
return
|
||||
|
||||
headers = self._build_headers()
|
||||
timeout = ClientTimeout(total=None, connect=10)
|
||||
heartbeat = self._get_positive_float(self._connection_config(), "heartbeat_sec", 30.0)
|
||||
|
||||
try:
|
||||
async with ClientSession(headers=headers, timeout=timeout) as session:
|
||||
async with session.ws_connect(ws_url, heartbeat=heartbeat or None) as ws:
|
||||
self._ws = ws
|
||||
self.ctx.logger.info(f"NapCat 适配器已连接: {ws_url}")
|
||||
await self._receive_loop(ws)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self.ctx.logger.warning(f"NapCat 适配器连接失败: {exc}")
|
||||
finally:
|
||||
self._ws = None
|
||||
self._fail_pending_actions("NapCat connection interrupted")
|
||||
|
||||
if not self._should_connect():
|
||||
break
|
||||
|
||||
await asyncio.sleep(self._get_positive_float(self._connection_config(), "reconnect_delay_sec", 5.0))
|
||||
|
||||
async def _receive_loop(self, ws: AiohttpClientWebSocketResponse) -> None:
|
||||
"""持续消费 WebSocket 消息并分发处理。
|
||||
|
||||
Args:
|
||||
ws: 当前活跃的 WebSocket 连接对象。
|
||||
"""
|
||||
assert WSMsgType is not None
|
||||
|
||||
async for ws_message in ws:
|
||||
if ws_message.type != WSMsgType.TEXT:
|
||||
if ws_message.type in {WSMsgType.CLOSE, WSMsgType.CLOSED, WSMsgType.ERROR}:
|
||||
break
|
||||
continue
|
||||
|
||||
payload = self._parse_json_message(ws_message.data)
|
||||
if payload is None:
|
||||
continue
|
||||
|
||||
if echo_id := str(payload.get("echo") or "").strip():
|
||||
self._resolve_pending_action(echo_id, payload)
|
||||
continue
|
||||
|
||||
if str(payload.get("post_type") or "").strip() != "message":
|
||||
continue
|
||||
|
||||
task = asyncio.create_task(self._handle_inbound_message(payload), name="napcat_adapter.inbound")
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _handle_inbound_message(self, payload: Dict[str, Any]) -> None:
|
||||
"""处理单条 NapCat 入站消息并注入 Host。
|
||||
|
||||
Args:
|
||||
payload: NapCat / OneBot 推送的原始事件数据。
|
||||
"""
|
||||
self_id = str(payload.get("self_id") or "").strip()
|
||||
sender = payload.get("sender", {})
|
||||
if not isinstance(sender, dict):
|
||||
sender = {}
|
||||
|
||||
sender_user_id = str(payload.get("user_id") or sender.get("user_id") or "").strip()
|
||||
if not sender_user_id:
|
||||
return
|
||||
|
||||
if self_id and sender_user_id == self_id and self._get_bool(self._filters_config(), "ignore_self_message", True):
|
||||
return
|
||||
|
||||
message_dict = self._build_inbound_message_dict(payload, self_id, sender_user_id, sender)
|
||||
route_metadata: Dict[str, Any] = {}
|
||||
if self_id:
|
||||
route_metadata["self_id"] = self_id
|
||||
if connection_id := self._get_string(self._connection_config(), "connection_id"):
|
||||
route_metadata["connection_id"] = connection_id
|
||||
|
||||
external_message_id = str(payload.get("message_id") or "").strip()
|
||||
accepted = await self.ctx.adapter.receive_external_message(
|
||||
message_dict,
|
||||
route_metadata=route_metadata,
|
||||
external_message_id=external_message_id,
|
||||
dedupe_key=external_message_id,
|
||||
)
|
||||
if not accepted:
|
||||
self.ctx.logger.debug(f"Host 丢弃了 NapCat 入站消息: {external_message_id or '无消息 ID'}")
|
||||
|
||||
def _build_inbound_message_dict(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
self_id: str,
|
||||
sender_user_id: str,
|
||||
sender: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""构造 Host 侧可接受的 ``MessageDict``。
|
||||
|
||||
Args:
|
||||
payload: NapCat 原始消息事件。
|
||||
self_id: 当前机器人账号 ID。
|
||||
sender_user_id: 发送者用户 ID。
|
||||
sender: 发送者信息字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 规范化后的 ``MessageDict``。
|
||||
"""
|
||||
message_type = str(payload.get("message_type") or "").strip() or "private"
|
||||
group_id = str(payload.get("group_id") or "").strip()
|
||||
group_name = str(payload.get("group_name") or "").strip() or (f"group_{group_id}" if group_id else "")
|
||||
user_nickname = str(sender.get("nickname") or sender.get("card") or sender_user_id).strip() or sender_user_id
|
||||
user_cardname = str(sender.get("card") or "").strip() or None
|
||||
|
||||
raw_message, is_at = self._convert_inbound_segments(payload.get("message"), self_id)
|
||||
raw_message_text = str(payload.get("raw_message") or "").strip()
|
||||
if not raw_message:
|
||||
raw_message = [{"type": "text", "data": raw_message_text or "[unsupported]"}]
|
||||
|
||||
plain_text = self._build_plain_text(raw_message, raw_message_text)
|
||||
timestamp_seconds = payload.get("time")
|
||||
if not isinstance(timestamp_seconds, (int, float)):
|
||||
timestamp_seconds = time.time()
|
||||
|
||||
additional_config: Dict[str, Any] = {"self_id": self_id, "napcat_message_type": message_type}
|
||||
if group_id:
|
||||
additional_config["platform_io_target_group_id"] = group_id
|
||||
else:
|
||||
additional_config["platform_io_target_user_id"] = sender_user_id
|
||||
|
||||
message_info: Dict[str, Any] = {
|
||||
"user_info": {
|
||||
"user_id": sender_user_id,
|
||||
"user_nickname": user_nickname,
|
||||
"user_cardname": user_cardname,
|
||||
},
|
||||
"additional_config": additional_config,
|
||||
}
|
||||
if group_id:
|
||||
message_info["group_info"] = {"group_id": group_id, "group_name": group_name}
|
||||
|
||||
message_id = str(payload.get("message_id") or f"napcat-{uuid4().hex}").strip()
|
||||
return {
|
||||
"message_id": message_id,
|
||||
"timestamp": str(float(timestamp_seconds)),
|
||||
"platform": "qq",
|
||||
"message_info": message_info,
|
||||
"raw_message": raw_message,
|
||||
"is_mentioned": is_at,
|
||||
"is_at": is_at,
|
||||
"is_emoji": False,
|
||||
"is_picture": False,
|
||||
"is_command": plain_text.startswith("/"),
|
||||
"is_notify": False,
|
||||
"session_id": "",
|
||||
"processed_plain_text": plain_text,
|
||||
"display_message": plain_text,
|
||||
}
|
||||
|
||||
def _convert_inbound_segments(self, message_payload: Any, self_id: str) -> tuple[List[Dict[str, Any]], bool]:
|
||||
"""将 OneBot 消息段转换为 Host 消息段结构。
|
||||
|
||||
Args:
|
||||
message_payload: OneBot 原始 ``message`` 字段。
|
||||
self_id: 当前机器人账号 ID。
|
||||
|
||||
Returns:
|
||||
tuple[List[Dict[str, Any]], bool]: 转换后的消息段列表,以及是否 @ 到当前机器人。
|
||||
"""
|
||||
if isinstance(message_payload, str):
|
||||
normalized_text = message_payload.strip()
|
||||
return ([{"type": "text", "data": normalized_text}] if normalized_text else []), False
|
||||
|
||||
if not isinstance(message_payload, list):
|
||||
return [], False
|
||||
|
||||
converted_segments: List[Dict[str, Any]] = []
|
||||
is_at = False
|
||||
placeholder_texts = {
|
||||
"face": "[face]",
|
||||
"file": "[file]",
|
||||
"image": "[image]",
|
||||
"json": "[json]",
|
||||
"record": "[voice]",
|
||||
"video": "[video]",
|
||||
"xml": "[xml]",
|
||||
}
|
||||
|
||||
for segment in message_payload:
|
||||
if not isinstance(segment, dict):
|
||||
continue
|
||||
|
||||
segment_type = str(segment.get("type") or "").strip()
|
||||
segment_data = segment.get("data", {})
|
||||
if not isinstance(segment_data, dict):
|
||||
segment_data = {}
|
||||
|
||||
if segment_type == "text":
|
||||
if text_value := str(segment_data.get("text") or ""):
|
||||
converted_segments.append({"type": "text", "data": text_value})
|
||||
continue
|
||||
|
||||
if segment_type == "at":
|
||||
if target_user_id := str(segment_data.get("qq") or "").strip():
|
||||
converted_segments.append(
|
||||
{
|
||||
"type": "at",
|
||||
"data": {
|
||||
"target_user_id": target_user_id,
|
||||
"target_user_nickname": None,
|
||||
"target_user_cardname": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
if self_id and target_user_id == self_id:
|
||||
is_at = True
|
||||
continue
|
||||
|
||||
if segment_type == "reply":
|
||||
if target_message_id := str(segment_data.get("id") or "").strip():
|
||||
converted_segments.append({"type": "reply", "data": target_message_id})
|
||||
continue
|
||||
|
||||
if placeholder := placeholder_texts.get(segment_type):
|
||||
converted_segments.append({"type": "text", "data": placeholder})
|
||||
|
||||
return converted_segments, is_at
|
||||
|
||||
def _build_outbound_action(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
route: Dict[str, Any],
|
||||
) -> tuple[str, Dict[str, Any]]:
|
||||
"""为 Host 出站消息构造 OneBot 动作。
|
||||
|
||||
Args:
|
||||
message: Host 侧标准 ``MessageDict``。
|
||||
route: Platform IO 路由信息。
|
||||
|
||||
Returns:
|
||||
tuple[str, Dict[str, Any]]: 动作名称与参数字典。
|
||||
"""
|
||||
message_info = message.get("message_info", {})
|
||||
if not isinstance(message_info, dict):
|
||||
message_info = {}
|
||||
|
||||
group_info = message_info.get("group_info", {})
|
||||
if not isinstance(group_info, dict):
|
||||
group_info = {}
|
||||
|
||||
additional_config = message_info.get("additional_config", {})
|
||||
if not isinstance(additional_config, dict):
|
||||
additional_config = {}
|
||||
|
||||
raw_message = message.get("raw_message", [])
|
||||
segments = self._convert_outbound_segments(raw_message)
|
||||
|
||||
if target_group_id := str(
|
||||
group_info.get("group_id") or additional_config.get("platform_io_target_group_id") or ""
|
||||
).strip():
|
||||
return "send_group_msg", {"group_id": target_group_id, "message": segments}
|
||||
|
||||
if not (
|
||||
target_user_id := str(
|
||||
additional_config.get("platform_io_target_user_id")
|
||||
or additional_config.get("target_user_id")
|
||||
or route.get("target_user_id")
|
||||
or ""
|
||||
).strip()
|
||||
):
|
||||
raise ValueError("Outbound private message is missing target_user_id")
|
||||
|
||||
return "send_private_msg", {"message": segments, "user_id": target_user_id}
|
||||
|
||||
def _convert_outbound_segments(self, raw_message: Any) -> List[Dict[str, Any]]:
|
||||
"""将 Host 消息段转换为 OneBot 消息段。
|
||||
|
||||
Args:
|
||||
raw_message: Host 侧 ``raw_message`` 字段。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: OneBot 消息段列表。
|
||||
"""
|
||||
if not isinstance(raw_message, list):
|
||||
return [{"type": "text", "data": {"text": ""}}]
|
||||
|
||||
outbound_segments: List[Dict[str, Any]] = []
|
||||
for item in raw_message:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = str(item.get("type") or "").strip()
|
||||
item_data = item.get("data")
|
||||
|
||||
if item_type == "text":
|
||||
text_value = str(item_data or "")
|
||||
outbound_segments.append({"type": "text", "data": {"text": text_value}})
|
||||
continue
|
||||
|
||||
if item_type == "at" and isinstance(item_data, dict):
|
||||
if target_user_id := str(item_data.get("target_user_id") or "").strip():
|
||||
outbound_segments.append({"type": "at", "data": {"qq": target_user_id}})
|
||||
continue
|
||||
|
||||
if item_type == "reply":
|
||||
if target_message_id := str(item_data or "").strip():
|
||||
outbound_segments.append({"type": "reply", "data": {"id": target_message_id}})
|
||||
continue
|
||||
|
||||
fallback_text = f"[unsupported:{item_type or 'unknown'}]"
|
||||
outbound_segments.append({"type": "text", "data": {"text": fallback_text}})
|
||||
|
||||
if not outbound_segments:
|
||||
outbound_segments.append({"type": "text", "data": {"text": ""}})
|
||||
return outbound_segments
|
||||
|
||||
async def _call_action(self, action_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送 OneBot 动作并等待对应的 echo 响应。
|
||||
|
||||
Args:
|
||||
action_name: OneBot 动作名称。
|
||||
params: 动作参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: NapCat 返回的原始响应字典。
|
||||
"""
|
||||
ws = self._ws
|
||||
if ws is None or ws.closed:
|
||||
raise RuntimeError("NapCat is not connected")
|
||||
|
||||
echo_id = uuid4().hex
|
||||
loop = asyncio.get_running_loop()
|
||||
response_future: asyncio.Future[Dict[str, Any]] = loop.create_future()
|
||||
self._pending_actions[echo_id] = response_future
|
||||
|
||||
request_payload = {"action": action_name, "params": params, "echo": echo_id}
|
||||
try:
|
||||
async with self._send_lock:
|
||||
await ws.send_str(json.dumps(request_payload, ensure_ascii=False))
|
||||
timeout_seconds = self._get_positive_float(self._connection_config(), "action_timeout_sec", 15.0)
|
||||
return await asyncio.wait_for(response_future, timeout=timeout_seconds)
|
||||
finally:
|
||||
self._pending_actions.pop(echo_id, None)
|
||||
|
||||
def _resolve_pending_action(self, echo_id: str, payload: Dict[str, Any]) -> None:
|
||||
"""解析等待中的动作响应。
|
||||
|
||||
Args:
|
||||
echo_id: 动作请求对应的 echo 标识。
|
||||
payload: NapCat 返回的响应载荷。
|
||||
"""
|
||||
response_future = self._pending_actions.get(echo_id)
|
||||
if response_future is None or response_future.done():
|
||||
return
|
||||
response_future.set_result(payload)
|
||||
|
||||
def _fail_pending_actions(self, error_message: str) -> None:
|
||||
"""让所有等待中的动作以异常方式结束。
|
||||
|
||||
Args:
|
||||
error_message: 写入异常中的错误信息。
|
||||
"""
|
||||
for response_future in self._pending_actions.values():
|
||||
if not response_future.done():
|
||||
response_future.set_exception(RuntimeError(error_message))
|
||||
self._pending_actions.clear()
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构造连接 NapCat 所需的请求头。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: WebSocket 握手请求头。
|
||||
"""
|
||||
access_token = self._get_string(self._connection_config(), "access_token")
|
||||
return {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
|
||||
def _parse_json_message(self, data: Any) -> Optional[Dict[str, Any]]:
|
||||
"""解析 WebSocket 文本消息中的 JSON 数据。
|
||||
|
||||
Args:
|
||||
data: WebSocket 收到的原始文本数据。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 成功时返回字典,失败时返回 ``None``。
|
||||
"""
|
||||
try:
|
||||
payload = json.loads(str(data))
|
||||
except Exception as exc:
|
||||
self.ctx.logger.warning(f"NapCat 适配器解析 JSON 载荷失败: {exc}")
|
||||
return None
|
||||
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
def _build_plain_text(self, raw_message: List[Dict[str, Any]], fallback_text: str) -> str:
|
||||
"""从标准消息段中提取可展示的纯文本。
|
||||
|
||||
Args:
|
||||
raw_message: 标准化后的消息段列表。
|
||||
fallback_text: 当无法拼出文本时使用的回退文本。
|
||||
|
||||
Returns:
|
||||
str: 用于 Host 展示和命令判断的纯文本内容。
|
||||
"""
|
||||
plain_text_parts: List[str] = []
|
||||
for item in raw_message:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = str(item.get("type") or "").strip()
|
||||
item_data = item.get("data")
|
||||
if item_type == "text":
|
||||
plain_text_parts.append(str(item_data or ""))
|
||||
elif item_type == "at" and isinstance(item_data, dict):
|
||||
plain_text_parts.append(f"@{item_data.get('target_user_id') or ''}")
|
||||
elif item_type == "reply":
|
||||
plain_text_parts.append("[reply]")
|
||||
|
||||
plain_text = "".join(part for part in plain_text_parts if part).strip()
|
||||
return plain_text or fallback_text or "[unsupported]"
|
||||
|
||||
def _plugin_section(self) -> Dict[str, Any]:
|
||||
"""读取插件配置中的 ``plugin`` 段。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: ``plugin`` 配置字典。
|
||||
"""
|
||||
plugin_section = self._plugin_config.get("plugin", {})
|
||||
return plugin_section if isinstance(plugin_section, dict) else {}
|
||||
|
||||
def _connection_config(self) -> Dict[str, Any]:
|
||||
"""读取插件配置中的 ``connection`` 段。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: ``connection`` 配置字典。
|
||||
"""
|
||||
connection_config = self._plugin_config.get("connection", {})
|
||||
return connection_config if isinstance(connection_config, dict) else {}
|
||||
|
||||
def _filters_config(self) -> Dict[str, Any]:
|
||||
"""读取插件配置中的 ``filters`` 段。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: ``filters`` 配置字典。
|
||||
"""
|
||||
filters_config = self._plugin_config.get("filters", {})
|
||||
return filters_config if isinstance(filters_config, dict) else {}
|
||||
|
||||
def _should_connect(self) -> bool:
|
||||
"""判断当前配置下是否应当启动连接。
|
||||
|
||||
Returns:
|
||||
bool: 若启用了插件连接则返回 ``True``。
|
||||
"""
|
||||
return self._get_bool(self._plugin_section(), "enabled", False)
|
||||
|
||||
@staticmethod
|
||||
def _get_bool(mapping: Dict[str, Any], key: str, default: bool) -> bool:
|
||||
"""安全读取布尔配置值。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
key: 目标键名。
|
||||
default: 读取失败时的默认值。
|
||||
|
||||
Returns:
|
||||
bool: 解析后的布尔值。
|
||||
"""
|
||||
value = mapping.get(key, default)
|
||||
return value if isinstance(value, bool) else default
|
||||
|
||||
@staticmethod
|
||||
def _get_positive_float(mapping: Dict[str, Any], key: str, default: float) -> float:
|
||||
"""安全读取正浮点数配置值。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
key: 目标键名。
|
||||
default: 读取失败时的默认值。
|
||||
|
||||
Returns:
|
||||
float: 合法的正浮点数;否则返回默认值。
|
||||
"""
|
||||
value = mapping.get(key, default)
|
||||
if isinstance(value, (int, float)) and float(value) > 0:
|
||||
return float(value)
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _get_string(mapping: Dict[str, Any], key: str) -> str:
|
||||
"""安全读取字符串配置值。
|
||||
|
||||
Args:
|
||||
mapping: 待读取的配置字典。
|
||||
key: 目标键名。
|
||||
|
||||
Returns:
|
||||
str: 去除首尾空白后的字符串值。
|
||||
"""
|
||||
value = mapping.get(key)
|
||||
return "" if value is None else str(value).strip()
|
||||
|
||||
|
||||
def create_plugin() -> NapCatAdapterPlugin:
|
||||
"""创建插件实例。
|
||||
|
||||
Returns:
|
||||
NapCatAdapterPlugin: NapCat 内置适配器插件实例。
|
||||
"""
|
||||
return NapCatAdapterPlugin()
|
||||
@@ -4,7 +4,7 @@
|
||||
提供发送各种类型消息的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import time
|
||||
import traceback
|
||||
@@ -19,6 +19,7 @@ from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
@@ -31,6 +32,50 @@ logger = get_logger("send_service")
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]:
|
||||
"""从目标会话上下文继承 Platform IO 路由元数据。
|
||||
|
||||
Args:
|
||||
target_stream: 当前消息要发送到的会话对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, object]: 可安全透传到出站消息 ``additional_config`` 中的
|
||||
路由辅助字段。
|
||||
"""
|
||||
inherited_metadata: Dict[str, object] = {}
|
||||
|
||||
context = getattr(target_stream, "context", None)
|
||||
context_message = getattr(context, "message", None)
|
||||
if context_message is None:
|
||||
return inherited_metadata
|
||||
|
||||
additional_config = getattr(context_message.message_info, "additional_config", {})
|
||||
if not isinstance(additional_config, dict):
|
||||
return inherited_metadata
|
||||
|
||||
for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
|
||||
value = additional_config.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
normalized_value = str(value).strip()
|
||||
if normalized_value:
|
||||
inherited_metadata[key] = value
|
||||
|
||||
target_group_id = getattr(target_stream, "group_id", None)
|
||||
if target_group_id is not None:
|
||||
normalized_group_id = str(target_group_id).strip()
|
||||
if normalized_group_id:
|
||||
inherited_metadata["platform_io_target_group_id"] = normalized_group_id
|
||||
|
||||
target_user_id = getattr(target_stream, "user_id", None)
|
||||
if target_user_id is not None:
|
||||
normalized_user_id = str(target_user_id).strip()
|
||||
if normalized_user_id:
|
||||
inherited_metadata["platform_io_target_user_id"] = normalized_user_id
|
||||
|
||||
return inherited_metadata
|
||||
|
||||
|
||||
async def _send_to_target(
|
||||
message_segment: Seg,
|
||||
stream_id: str,
|
||||
@@ -42,7 +87,22 @@ async def _send_to_target(
|
||||
show_log: bool = True,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> bool:
|
||||
"""向指定目标发送消息的内部实现"""
|
||||
"""向指定目标发送消息。
|
||||
|
||||
Args:
|
||||
message_segment: 待发送的消息段。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于界面展示的文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
set_reply: 是否在发送时附带引用回复。
|
||||
reply_message: 被回复的消息对象。
|
||||
storage_message: 是否将发送结果写入消息存储。
|
||||
show_log: 是否输出发送日志。
|
||||
selected_expressions: 可选的表情候选索引列表。
|
||||
|
||||
Returns:
|
||||
bool: 发送成功返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
try:
|
||||
if set_reply and not reply_message:
|
||||
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
|
||||
@@ -80,7 +140,7 @@ async def _send_to_target(
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
|
||||
additional_config: dict[str, object] = {}
|
||||
additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream)
|
||||
if selected_expressions is not None:
|
||||
additional_config["selected_expressions"] = selected_expressions
|
||||
bot_user_id = get_bot_account(target_stream.platform)
|
||||
|
||||
175
src/webui/routers/chat/serializers.py
Normal file
175
src/webui/routers/chat/serializers.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""提供 WebUI 聊天路由使用的消息序列化能力。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import base64
|
||||
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
AtComponent,
|
||||
DictComponent,
|
||||
EmojiComponent,
|
||||
ForwardComponent,
|
||||
ForwardNodeComponent,
|
||||
ImageComponent,
|
||||
MessageSequence,
|
||||
ReplyComponent,
|
||||
StandardMessageComponents,
|
||||
TextComponent,
|
||||
VoiceComponent,
|
||||
)
|
||||
|
||||
|
||||
def serialize_message_sequence(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
|
||||
"""将内部统一消息组件序列转换为 WebUI 富文本消息段。
|
||||
|
||||
Args:
|
||||
message_sequence: 内部统一消息组件序列。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 可直接广播给 WebUI 前端的消息段列表。
|
||||
"""
|
||||
serialized_segments: List[Dict[str, Any]] = []
|
||||
for component in message_sequence.components:
|
||||
serialized_segment = serialize_message_component(component)
|
||||
if serialized_segment is not None:
|
||||
serialized_segments.append(serialized_segment)
|
||||
return serialized_segments
|
||||
|
||||
|
||||
def serialize_message_component(component: StandardMessageComponents) -> Optional[Dict[str, Any]]:
|
||||
"""将单个内部消息组件转换为 WebUI 消息段。
|
||||
|
||||
Args:
|
||||
component: 待序列化的内部消息组件。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 序列化后的 WebUI 消息段;若组件不应展示则返回 ``None``。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return {"type": "text", "data": component.text}
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
return _serialize_binary_component(
|
||||
segment_type="image",
|
||||
mime_type="image/png",
|
||||
binary_data=component.binary_data,
|
||||
fallback_text=component.content,
|
||||
)
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
return _serialize_binary_component(
|
||||
segment_type="emoji",
|
||||
mime_type="image/gif",
|
||||
binary_data=component.binary_data,
|
||||
fallback_text=component.content,
|
||||
)
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
return _serialize_binary_component(
|
||||
segment_type="voice",
|
||||
mime_type="audio/wav",
|
||||
binary_data=component.binary_data,
|
||||
fallback_text=component.content,
|
||||
)
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
return {
|
||||
"type": "at",
|
||||
"data": {
|
||||
"target_user_id": component.target_user_id,
|
||||
"target_user_nickname": component.target_user_nickname,
|
||||
"target_user_cardname": component.target_user_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
return {
|
||||
"type": "reply",
|
||||
"data": {
|
||||
"target_message_id": component.target_message_id,
|
||||
"target_message_content": component.target_message_content,
|
||||
"target_message_sender_id": component.target_message_sender_id,
|
||||
"target_message_sender_nickname": component.target_message_sender_nickname,
|
||||
"target_message_sender_cardname": component.target_message_sender_cardname,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return {
|
||||
"type": "forward",
|
||||
"data": [_serialize_forward_component(item) for item in component.forward_components],
|
||||
}
|
||||
|
||||
if isinstance(component, DictComponent):
|
||||
return _serialize_dict_component(component.data)
|
||||
|
||||
return {"type": "unknown", "data": str(component)}
|
||||
|
||||
|
||||
def _serialize_binary_component(
|
||||
segment_type: str,
|
||||
mime_type: str,
|
||||
binary_data: bytes,
|
||||
fallback_text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""序列化带二进制负载的消息组件。
|
||||
|
||||
Args:
|
||||
segment_type: WebUI 消息段类型。
|
||||
mime_type: 对应的数据 MIME 类型。
|
||||
binary_data: 组件二进制数据。
|
||||
fallback_text: 二进制缺失时可退化展示的文本。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的 WebUI 消息段。
|
||||
"""
|
||||
if binary_data:
|
||||
encoded_payload = base64.b64encode(binary_data).decode()
|
||||
return {"type": segment_type, "data": f"data:{mime_type};base64,{encoded_payload}"}
|
||||
|
||||
if fallback_text:
|
||||
return {"type": "text", "data": fallback_text}
|
||||
|
||||
return {"type": "unknown", "original_type": segment_type, "data": ""}
|
||||
|
||||
|
||||
def _serialize_forward_component(component: ForwardComponent) -> Dict[str, Any]:
|
||||
"""序列化单个转发节点。
|
||||
|
||||
Args:
|
||||
component: 待序列化的转发节点组件。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: WebUI 可消费的转发节点字典。
|
||||
"""
|
||||
return {
|
||||
"message_id": component.message_id,
|
||||
"user_id": component.user_id,
|
||||
"user_nickname": component.user_nickname,
|
||||
"user_cardname": component.user_cardname,
|
||||
"content": serialize_message_sequence(MessageSequence(component.content)),
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dict_component(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""最佳努力地序列化非标准字典组件。
|
||||
|
||||
Args:
|
||||
data: 原始字典组件内容。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的 WebUI 消息段。
|
||||
"""
|
||||
raw_type = str(data.get("type") or "dict").strip()
|
||||
raw_payload = data.get("data", data)
|
||||
|
||||
if raw_type in {"text", "image", "emoji", "voice", "video", "file", "music", "face"}:
|
||||
return {"type": raw_type, "data": raw_payload}
|
||||
|
||||
if raw_type == "reply":
|
||||
return {"type": "reply", "data": raw_payload}
|
||||
|
||||
if raw_type == "forward" and isinstance(raw_payload, list):
|
||||
return {"type": "forward", "data": raw_payload}
|
||||
|
||||
return {"type": "unknown", "original_type": raw_type, "data": raw_payload}
|
||||
Reference in New Issue
Block a user