From d07915eea04b6b60b6f15b462ac93528e3338869 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 11:38:46 +0800 Subject: [PATCH] Refactor message sending architecture and implement legacy driver support - Removed UniversalMessageSender from group_generator.py and private_generator.py. - Updated PlatformIOManager to manage legacy send drivers and ensure send pipeline readiness. - Enhanced LegacyPlatformDriver to utilize prepared messages for sending. - Refactored send_service to unify message sending logic and integrate with Platform IO. - Added regression tests for Platform IO legacy driver and send service functionality. --- pytests/test_platform_io_legacy_driver.py | 124 ++++ pytests/test_send_service.py | 141 ++++ src/chat/brain_chat/PFC/message_sender.py | 84 +-- .../message_receive/uni_message_sender.py | 63 +- src/chat/replyer/group_generator.py | 8 +- src/chat/replyer/private_generator.py | 8 +- src/common/message_server/server.py | 2 +- src/platform_io/drivers/legacy_driver.py | 51 +- src/platform_io/manager.py | 77 ++- src/plugin_runtime/integration.py | 2 +- src/services/send_service.py | 636 ++++++++++++++---- 11 files changed, 967 insertions(+), 229 deletions(-) create mode 100644 pytests/test_platform_io_legacy_driver.py create mode 100644 pytests/test_send_service.py diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py new file mode 100644 index 00000000..2e94c1fc --- /dev/null +++ b/pytests/test_platform_io_legacy_driver.py @@ -0,0 +1,124 @@ +"""Platform IO legacy driver 回归测试。""" + +from typing import Any, Dict, Optional + +import pytest + +from src.chat.utils import utils as chat_utils +from src.chat.message_receive import uni_message_sender +from src.platform_io.drivers.base import PlatformIODriver +from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver +from src.platform_io.manager import PlatformIOManager +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteBinding, RouteKey + + +class _PluginDriver(PlatformIODriver): + """测试用插件发送驱动。""" + + def __init__(self, driver_id: str, platform: str) -> None: + """初始化测试驱动。 + + Args: + driver_id: 驱动 ID。 + platform: 负责的平台名称。 + """ + super().__init__( + DriverDescriptor( + driver_id=driver_id, + kind=DriverKind.PLUGIN, + platform=platform, + plugin_id="test.plugin", + ) + ) + + async def send_message( + self, + message: Any, + route_key: RouteKey, + metadata: Optional[Dict[str, Any]] = None, + ) -> DeliveryReceipt: + """返回一个固定成功回执。 + + Args: + message: 待发送消息。 + route_key: 当前路由键。 + metadata: 发送元数据。 + + Returns: + DeliveryReceipt: 固定成功回执。 + """ + del metadata + return DeliveryReceipt( + internal_message_id=str(message.message_id), + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) + + +@pytest.mark.asyncio +async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """没有显式发送路由时,应由 Platform IO 回退到 legacy driver。""" + manager = PlatformIOManager() + monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) + + try: + await manager.ensure_send_pipeline_ready() + + fallback_drivers = manager.resolve_drivers(RouteKey(platform="qq")) + assert [driver.driver_id for driver in fallback_drivers] == ["legacy.send.qq"] + + plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") + await manager.add_driver(plugin_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=plugin_driver.driver_id, + driver_kind=plugin_driver.descriptor.kind, + ) + ) + + explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) + assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"] + finally: + await manager.stop() + + +@pytest.mark.asyncio +async def test_legacy_platform_driver_uses_prepared_universal_sender( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """legacy driver 应复用已预处理消息的旧链发送函数。""" + calls: list[dict[str, Any]] = [] + + async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: + """记录 legacy driver 调用。""" + calls.append({"message": message, "show_log": show_log}) + return True + + monkeypatch.setattr( + uni_message_sender, + "send_prepared_message_to_platform", + _fake_send_prepared_message_to_platform, + ) + + driver = LegacyPlatformDriver( + driver_id="legacy.send.qq", + platform="qq", + account_id="bot-qq", + ) + message = type("FakeMessage", (), {"message_id": "message-1"})() + receipt = await driver.send_message( + message=message, + route_key=RouteKey(platform="qq"), + metadata={"show_log": False}, + ) + + assert len(calls) == 1 + assert calls[0]["message"] is message + assert calls[0]["show_log"] is False + assert receipt.status == DeliveryStatus.SENT + assert receipt.driver_id == "legacy.send.qq" diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py new file mode 100644 index 00000000..4ddd4fa1 --- /dev/null +++ b/pytests/test_send_service.py @@ -0,0 +1,141 @@ +"""发送服务回归测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from src.chat.message_receive.chat_manager import BotChatSession +from src.services import send_service + + +class _FakePlatformIOManager: + """用于测试的 Platform IO 管理器假对象。""" + + def __init__(self, delivery_batch: Any) -> None: + """初始化假 Platform IO 管理器。 + + Args: + delivery_batch: 发送时返回的批量回执。 + """ + self._delivery_batch = delivery_batch + self.ensure_calls = 0 + self.sent_messages: List[Dict[str, Any]] = [] + + async def ensure_send_pipeline_ready(self) -> None: + """记录发送管线准备调用次数。""" + self.ensure_calls += 1 + + def build_route_key_from_message(self, message: Any) -> Any: + """根据消息构造假的路由键。 + + Args: + message: 待发送的内部消息对象。 + + Returns: + Any: 简化后的路由键对象。 + """ + del message + return SimpleNamespace(platform="qq") + + async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: + """记录发送请求并返回预设回执。 + + Args: + message: 待发送的内部消息对象。 + route_key: 本次发送使用的路由键。 + metadata: 发送元数据。 + + Returns: + Any: 预设的批量发送回执。 + """ + self.sent_messages.append( + { + "message": message, + "route_key": route_key, + "metadata": metadata, + } + ) + return self._delivery_batch + + +def _build_target_stream() -> BotChatSession: + """构造一个最小可用的目标会话对象。 + + Returns: + BotChatSession: 测试用会话对象。 + """ + return BotChatSession( + session_id="test-session", + platform="qq", + user_id="target-user", + group_id=None, + ) + + +@pytest.mark.asyncio +async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: + """send service 应将发送职责统一交给 Platform IO。""" + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=True, + sent_receipts=[SimpleNamespace(driver_id="plugin.qq.sender")], + failed_receipts=[], + route_key=SimpleNamespace(platform="qq"), + ) + ) + stored_messages: List[Any] = [] + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + ) + monkeypatch.setattr( + send_service.MessageUtils, + "store_message_to_db", + lambda message: stored_messages.append(message), + ) + + result = await send_service.text_to_stream(text="你好", stream_id="test-session") + + assert result is True + assert fake_manager.ensure_calls == 1 + assert len(fake_manager.sent_messages) == 1 + assert fake_manager.sent_messages[0]["metadata"] == {"show_log": False} + assert len(stored_messages) == 1 + + +@pytest.mark.asyncio +async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: + """Platform IO 批量发送全部失败时,应直接向上返回失败。""" + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=False, + sent_receipts=[], + failed_receipts=[ + SimpleNamespace( + driver_id="plugin.qq.sender", + status="failed", + error="network error", + ) + ], + route_key=SimpleNamespace(platform="qq"), + ) + ) + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + ) + + result = await send_service.text_to_stream(text="发送失败", stream_id="test-session") + + assert result is False + assert fake_manager.ensure_calls == 1 + assert len(fake_manager.sent_messages) == 1 diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py index ec5fb5ba..b9da905c 100644 --- a/src/chat/brain_chat/PFC/message_sender.py +++ b/src/chat/brain_chat/PFC/message_sender.py @@ -1,27 +1,28 @@ -import time +"""PFC 侧消息发送封装。""" + from typing import Optional -from maim_message import Seg from rich.traceback import install -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.message import MessageSending -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.utils.utils import get_bot_account +from src.common.data_models.mai_message_data_model import MaiMessage from src.common.logger import get_logger -from src.config.config import global_config +from src.services import send_service as send_api install(extra_lines=3) - logger = get_logger("message_sender") class DirectMessageSender: - """直接消息发送器""" + """直接消息发送器。""" - def __init__(self, private_name: str): + def __init__(self, private_name: str) -> None: + """初始化直接消息发送器。 + + Args: + private_name: 当前私聊实例的名称。 + """ self.private_name = private_name async def send_message( @@ -30,58 +31,31 @@ class DirectMessageSender: content: str, reply_to_message: Optional[MaiMessage] = None, ) -> None: - """发送消息到聊天流 + """发送文本消息到聊天流。 Args: - chat_stream: 聊天会话 - content: 消息内容 - reply_to_message: 要回复的消息(可选) + chat_stream: 目标聊天会话。 + content: 待发送的文本内容。 + reply_to_message: 可选的引用回复锚点消息。 + + Raises: + RuntimeError: 当消息发送失败时抛出。 """ try: - # 创建消息内容 - segments = Seg(type="seglist", data=[Seg(type="text", data=content)]) - - # 获取麦麦的信息 - bot_user_id = get_bot_account(chat_stream.platform) - if not bot_user_id: - logger.error(f"[私聊][{self.private_name}]平台 {chat_stream.platform} 未配置机器人账号,无法发送消息") - raise RuntimeError(f"平台 {chat_stream.platform} 未配置机器人账号") - bot_user_info = UserInfo( - user_id=bot_user_id, - user_nickname=global_config.bot.nickname, + sent = await send_api.text_to_stream( + text=content, + stream_id=chat_stream.session_id, + set_reply=reply_to_message is not None, + reply_message=reply_to_message, + storage_message=True, ) - # 用当前时间作为message_id,和之前那套sender一样 - message_id = f"dm{round(time.time(), 2)}" - - # 构建发送者信息(私聊时为接收者) - sender_info = None - if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info: - sender_info = reply_to_message.message_info.user_info - - # 构建消息对象 - message = MessageSending( - message_id=message_id, - session=chat_stream, - bot_user_info=bot_user_info, - sender_info=sender_info, - message_segment=segments, - reply=reply_to_message, - is_head=True, - is_emoji=False, - thinking_start_time=time.time(), - ) - - # 发送消息 - message_sender = UniversalMessageSender() - sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True) - if sent: logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}") - else: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") - raise RuntimeError("消息发送失败") + return - except Exception as e: - logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}") + logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") + raise RuntimeError("消息发送失败") + except Exception as exc: + logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {exc}") raise diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index df74e459..cf42e092 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -60,8 +60,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: 发送顺序为: 1. WebUI 特殊链路 - 2. Platform IO 适配器链路 - 3. 旧版 ``maim_message`` / API Server 链路 + 2. 旧版 ``maim_message`` / API Server 链路 Args: message: 待发送的内部会话消息。 @@ -124,32 +123,6 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") return True - try: - from src.plugin_runtime.integration import get_plugin_runtime_manager - - delivery_batch = await get_plugin_runtime_manager().try_send_message_via_platform_io(message) - if delivery_batch is not None: - if delivery_batch.has_success: - successful_driver_ids = [ - receipt.driver_id or "unknown" - for receipt in delivery_batch.sent_receipts - ] - if show_log: - logger.info( - f"已通过 Platform IO 将消息 '{message_preview}' 发往平台'{platform}' " - f"(drivers: {', '.join(successful_driver_ids)})" - ) - return True - - failed_details = "; ".join( - f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" - for receipt in delivery_batch.failed_receipts - ) or "未命中任何发送路由" - logger.warning(f"Platform IO 发送失败: platform={platform} {failed_details}") - return False - except Exception as exc: - logger.warning(f"检查 Platform IO 出站链路时出现异常,将回退旧发送链: {exc}") - # Fallback 逻辑: 尝试通过 API Server 发送 async def send_with_new_api(legacy_exception: Optional[Exception] = None) -> bool: """通过 API Server 回退链路发送消息。 @@ -260,8 +233,21 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool: raise e # 重新抛出其他异常 +async def send_prepared_message_to_platform(message: SessionMessage, show_log: bool = True) -> bool: + """发送一条已完成预处理的消息到底层平台。 + + Args: + message: 已经完成回复组件注入、文本处理等预处理的消息对象。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + return await _send_message(message, show_log=show_log) + + class UniversalMessageSender: - """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" + """旧链与 WebUI 的底层发送器。""" def __init__(self) -> None: """初始化统一消息发送器。""" @@ -276,17 +262,18 @@ class UniversalMessageSender: storage_message: bool = True, show_log: bool = True, ) -> bool: - """ - 处理、发送并存储一条消息。 + """通过旧链或 WebUI 发送并存储一条消息。 - 参数: - message: MessageSession 对象,待发送的消息。 + Args: + message: 待发送的内部消息对象。 typing: 是否模拟打字等待。 - set_reply: 是否构建回复引用消息。 + set_reply: 是否构建引用回复消息。 + reply_message_id: 被引用消息的 ID。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 - - 用法: - - typing=True 时,发送前会有打字等待。 + Returns: + bool: 发送成功时返回 ``True``。 """ if not message.message_id: logger.error("消息缺少 message_id,无法发送") @@ -339,7 +326,7 @@ class UniversalMessageSender: ) await asyncio.sleep(typing_time) - sent_msg = await _send_message(message, show_log=show_log) + sent_msg = await send_prepared_message_to_platform(message, show_log=show_log) if not sent_msg: return False diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 75563df7..4ffa14a7 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -17,7 +17,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager @@ -51,10 +50,15 @@ class DefaultReplyer: chat_stream: BotChatSession, request_type: str = "replyer", ): + """初始化群聊回复器。 + + Args: + chat_stream: 当前绑定的聊天会话。 + request_type: LLM 请求类型标识。 + """ self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - self.heart_fc_sender = UniversalMessageSender() from src.chat.tool_executor import ToolExecutor diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index f642dd69..c125a42f 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -16,7 +16,6 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo as MaimUser from src.common.data_models.mai_message_data_model import MaiMessage from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager @@ -47,10 +46,15 @@ class PrivateReplyer: chat_stream: BotChatSession, request_type: str = "replyer", ): + """初始化私聊回复器。 + + Args: + chat_stream: 当前绑定的聊天会话。 + request_type: LLM 请求类型标识。 + """ self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) - self.heart_fc_sender = UniversalMessageSender() # self.memory_activator = MemoryActivator() from src.chat.tool_executor import ToolExecutor diff --git a/src/common/message_server/server.py b/src/common/message_server/server.py index 77a931e5..e75da4e7 100644 --- a/src/common/message_server/server.py +++ b/src/common/message_server/server.py @@ -21,7 +21,7 @@ class Server: self._server: Optional[UvicornServer] = None self.set_address(host, port) - def register_router(self, router: APIRouter, prefix: str = ""): + def register_router(self, router: APIRouter, prefix: str = ""): """注册路由 APIRouter 用于对相关的路由端点进行分组和模块化管理: diff --git a/src/platform_io/drivers/legacy_driver.py b/src/platform_io/drivers/legacy_driver.py index bd74d8c7..ef90c772 100644 --- a/src/platform_io/drivers/legacy_driver.py +++ b/src/platform_io/drivers/legacy_driver.py @@ -1,16 +1,16 @@ -"""提供 Platform IO 的 legacy 传输驱动骨架。""" +"""提供 Platform IO 的 legacy 传输驱动实现。""" from typing import TYPE_CHECKING, Any, Dict, Optional from src.platform_io.drivers.base import PlatformIODriver -from src.platform_io.types import DeliveryReceipt, DriverDescriptor, DriverKind, RouteKey +from src.platform_io.types import DeliveryReceipt, DeliveryStatus, DriverDescriptor, DriverKind, RouteKey if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage class LegacyPlatformDriver(PlatformIODriver): - """面向 ``maim_message`` 旧链路的 Platform IO 驱动骨架。""" + """面向 ``UniversalMessageSender`` 旧链的 Platform IO 驱动。""" def __init__( self, @@ -25,7 +25,7 @@ class LegacyPlatformDriver(PlatformIODriver): Args: driver_id: Broker 内的唯一驱动 ID。 platform: 该 legacy 适配器链路负责的平台。 - account_id: 可选的账号 ID 或 self ID。 + account_id: 可选的账号 ID。 scope: 可选的额外路由作用域。 metadata: 可选的额外驱动元数据。 """ @@ -45,7 +45,7 @@ class LegacyPlatformDriver(PlatformIODriver): route_key: RouteKey, metadata: Optional[Dict[str, Any]] = None, ) -> DeliveryReceipt: - """通过 legacy 传输路径发送消息。 + """通过旧链发送一条已经过预处理的消息。 Args: message: 要投递的内部会话消息。 @@ -53,9 +53,40 @@ class LegacyPlatformDriver(PlatformIODriver): metadata: 本次出站投递可选的 Broker 侧元数据。 Returns: - DeliveryReceipt: 由驱动返回的规范化回执。 - - Raises: - NotImplementedError: 当前仍处于骨架阶段,尚未真正接入旧发送链。 + DeliveryReceipt: 规范化后的发送回执。 """ - raise NotImplementedError("LegacyPlatformDriver 仅完成地基实现,尚未接入旧发送链") + from src.chat.message_receive.uni_message_sender import send_prepared_message_to_platform + + show_log = False + if isinstance(metadata, dict): + show_log = bool(metadata.get("show_log", False)) + + try: + sent = await send_prepared_message_to_platform(message, show_log=show_log) + except Exception as exc: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error=str(exc), + ) + + if not sent: + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.FAILED, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + error="旧链发送失败", + ) + + return DeliveryReceipt( + internal_message_id=message.message_id, + route_key=route_key, + status=DeliveryStatus.SENT, + driver_id=self.driver_id, + driver_kind=self.descriptor.kind, + ) diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index c96a9ddd..cb5996b4 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -36,6 +36,7 @@ class PlatformIOManager: self._driver_registry = DriverRegistry() self._send_route_table = RouteTable() self._receive_route_table = RouteTable() + self._legacy_send_drivers: Dict[str, PlatformIODriver] = {} self._deduplicator = MessageDeduplicator() self._outbound_tracker = OutboundTracker() self._inbound_dispatcher: Optional[InboundDispatcher] = None @@ -75,6 +76,16 @@ class PlatformIOManager: self._started = True + async def ensure_send_pipeline_ready(self) -> None: + """确保出站发送管线已准备就绪。 + + 该方法会先同步 legacy fallback driver,再在需要时启动 Broker。 + send service 应只调用这一层准备入口,而不是自行判断旧链或插件链。 + """ + await self._sync_legacy_send_drivers() + if not self._started: + await self.start() + async def stop(self) -> None: """停止 Broker,并按逆序停止全部已注册驱动。 @@ -272,8 +283,60 @@ class PlatformIOManager: removed_driver.clear_inbound_handler() self._send_route_table.remove_bindings_by_driver(driver_id) self._receive_route_table.remove_bindings_by_driver(driver_id) + self._legacy_send_drivers = { + platform: driver + for platform, driver in self._legacy_send_drivers.items() + if driver.driver_id != driver_id + } return removed_driver + async def _sync_legacy_send_drivers(self) -> None: + """根据当前配置同步 legacy fallback driver。""" + from src.chat.utils.utils import get_all_bot_accounts + from src.platform_io.drivers.legacy_driver import LegacyPlatformDriver + + desired_accounts = get_all_bot_accounts() + desired_platforms = set(desired_accounts.keys()) + current_platforms = set(self._legacy_send_drivers.keys()) + + for platform in sorted(current_platforms - desired_platforms): + await self._remove_legacy_send_driver(platform) + + for platform, account_id in desired_accounts.items(): + existing_driver = self._legacy_send_drivers.get(platform) + if existing_driver is not None and existing_driver.descriptor.account_id == account_id: + continue + + if existing_driver is not None: + await self._remove_legacy_send_driver(platform) + + driver = LegacyPlatformDriver( + driver_id=f"legacy.send.{platform}", + platform=platform, + account_id=account_id, + ) + if self._started: + await self.add_driver(driver) + else: + self.register_driver(driver) + self._legacy_send_drivers[platform] = driver + + async def _remove_legacy_send_driver(self, platform: str) -> None: + """移除指定平台的 legacy fallback driver。 + + Args: + platform: 要移除的目标平台。 + """ + driver = self._legacy_send_drivers.get(platform) + if driver is None: + return + + if self._started: + await self.remove_driver(driver.driver_id) + else: + self.unregister_driver(driver.driver_id) + self._legacy_send_drivers.pop(platform, None) + def bind_send_route(self, binding: RouteBinding) -> None: """为某个路由键绑定发送驱动。 @@ -353,7 +416,19 @@ class PlatformIOManager: driver = self._driver_registry.get(binding.driver_id) if driver is not None: drivers.append(driver) - return drivers + if drivers: + return drivers + + fallback_driver = self._legacy_send_drivers.get(route_key.platform) + if fallback_driver is None: + return [] + + descriptor = fallback_driver.descriptor + if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id): + return [] + if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope): + return [] + return [fallback_driver] def resolve_driver(self, route_key: RouteKey) -> Optional[PlatformIODriver]: """兼容旧接口,返回首个命中的发送驱动。""" diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index b74b2d46..ff51f419 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -157,7 +157,7 @@ class PluginRuntimeManager( started_supervisors: List[PluginSupervisor] = [] try: platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound) - await platform_io_manager.start() + await platform_io_manager.ensure_send_pipeline_ready() if self._builtin_supervisor: await self._builtin_supervisor.start() diff --git a/src/services/send_service.py b/src/services/send_service.py index 6ca7d005..7903cdeb 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -1,39 +1,51 @@ """ -发送服务模块 +发送服务模块。 -提供发送各种类型消息的核心功能。 +统一封装内部模块的出站消息发送逻辑: + +1. 内部模块统一调用本模块。 +2. send service 只负责构造和预处理消息。 +3. 具体走插件链还是 legacy 旧链,由 Platform IO 内部统一决策。 """ -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import Any, Dict, List, Optional +from maim_message import Seg + +import asyncio +import base64 +import hashlib import time import traceback +from datetime import datetime -from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo - +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage -from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.utils.utils import get_bot_account -from src.common.data_models.mai_message_data_model import MaiMessage -from src.common.data_models.message_component_data_model import DictComponent, MessageSequence +from src.chat.utils.utils import calculate_typing_time, get_bot_account +from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo +from src.common.data_models.message_component_data_model import ( + AtComponent, + DictComponent, + EmojiComponent, + ImageComponent, + MessageSequence, + ReplyComponent, + StandardMessageComponents, + TextComponent, + VoiceComponent, +) from src.common.logger import get_logger +from src.common.utils.utils_message import MessageUtils from src.config.config import global_config +from src.platform_io import DeliveryBatch, get_platform_io_manager from src.platform_io.route_key_factory import RouteKeyFactory -if TYPE_CHECKING: - from src.chat.message_receive.message import SessionMessage - logger = get_logger("send_service") -# ============================================================================= -# 内部实现函数 -# ============================================================================= - - -def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object]: - """从目标会话上下文继承 Platform IO 路由元数据。 +def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[str, object]: + """从目标会话继承 Platform IO 路由元数据。 Args: target_stream: 当前消息要发送到的会话对象。 @@ -44,12 +56,11 @@ def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object] """ inherited_metadata: Dict[str, object] = {} - context = getattr(target_stream, "context", None) - context_message = getattr(context, "message", None) + context_message = target_stream.context.message if target_stream.context else None if context_message is None: return inherited_metadata - additional_config = getattr(context_message.message_info, "additional_config", {}) + additional_config = context_message.message_info.additional_config if not isinstance(additional_config, dict): return inherited_metadata @@ -61,33 +72,412 @@ def _inherit_platform_io_route_metadata(target_stream: Any) -> Dict[str, object] if normalized_value: inherited_metadata[key] = value - target_group_id = getattr(target_stream, "group_id", None) - if target_group_id is not None: - normalized_group_id = str(target_group_id).strip() + if target_stream.group_id: + normalized_group_id = str(target_stream.group_id).strip() if normalized_group_id: inherited_metadata["platform_io_target_group_id"] = normalized_group_id - target_user_id = getattr(target_stream, "user_id", None) - if target_user_id is not None: - normalized_user_id = str(target_user_id).strip() + if target_stream.user_id: + normalized_user_id = str(target_stream.user_id).strip() if normalized_user_id: inherited_metadata["platform_io_target_user_id"] = normalized_user_id return inherited_metadata +def _build_component_from_seg(message_segment: Seg) -> StandardMessageComponents: + """将单个消息段转换为内部消息组件。 + + Args: + message_segment: 待转换的消息段。 + + Returns: + StandardMessageComponents: 转换后的内部消息组件。 + """ + segment_type = str(message_segment.type or "").strip().lower() + segment_data = message_segment.data + + if segment_type == "text": + return TextComponent(text=str(segment_data or "")) + + if segment_type == "image": + image_binary = base64.b64decode(str(segment_data or "")) + return ImageComponent( + binary_hash=hashlib.sha256(image_binary).hexdigest(), + binary_data=image_binary, + ) + + if segment_type == "emoji": + emoji_binary = base64.b64decode(str(segment_data or "")) + return EmojiComponent( + binary_hash=hashlib.sha256(emoji_binary).hexdigest(), + binary_data=emoji_binary, + ) + + if segment_type == "voice": + voice_binary = base64.b64decode(str(segment_data or "")) + return VoiceComponent( + binary_hash=hashlib.sha256(voice_binary).hexdigest(), + binary_data=voice_binary, + ) + + if segment_type == "at": + return AtComponent(target_user_id=str(segment_data or "")) + + if segment_type == "reply": + return ReplyComponent(target_message_id=str(segment_data or "")) + + if segment_type == "dict" and isinstance(segment_data, dict): + return DictComponent(data=segment_data) + + return DictComponent(data={"type": segment_type, "data": segment_data}) + + +def _build_message_sequence_from_seg(message_segment: Seg) -> MessageSequence: + """将消息段转换为内部消息组件序列。 + + Args: + message_segment: 待转换的消息段。 + + Returns: + MessageSequence: 转换后的消息组件序列。 + """ + if str(message_segment.type or "").strip().lower() == "seglist": + raw_segments = message_segment.data + if not isinstance(raw_segments, list): + raise ValueError("seglist 类型的消息段数据必须是列表") + components = [ + _build_component_from_seg(item) + for item in raw_segments + if isinstance(item, Seg) + ] + return MessageSequence(components=components) + + return MessageSequence(components=[_build_component_from_seg(message_segment)]) + + +def _build_processed_plain_text(message: SessionMessage) -> str: + """为出站消息构造轻量纯文本摘要。 + + Args: + message: 待发送的内部消息对象。 + + Returns: + str: 适用于日志与打字时长估算的纯文本摘要。 + """ + processed_parts: List[str] = [] + for component in message.raw_message.components: + if isinstance(component, TextComponent): + processed_parts.append(component.text) + continue + + if isinstance(component, ImageComponent): + processed_parts.append(component.content or "[图片]") + continue + + if isinstance(component, EmojiComponent): + processed_parts.append(component.content or "[表情]") + continue + + if isinstance(component, VoiceComponent): + processed_parts.append(component.content or "[语音]") + continue + + if isinstance(component, AtComponent): + at_target = component.target_user_cardname or component.target_user_nickname or component.target_user_id + processed_parts.append(f"@{at_target}") + continue + + if isinstance(component, ReplyComponent): + processed_parts.append(component.target_message_content or "[回复消息]") + continue + + if isinstance(component, DictComponent): + raw_type = component.data.get("type") if isinstance(component.data, dict) else None + if isinstance(raw_type, str) and raw_type.strip(): + processed_parts.append(f"[{raw_type.strip()}消息]") + else: + processed_parts.append("[自定义消息]") + continue + + return " ".join(part for part in processed_parts if part) + + +def _build_outbound_session_message( + message_segment: Seg, + stream_id: str, + display_message: str = "", + reply_message: Optional[MaiMessage] = None, + selected_expressions: Optional[List[int]] = None, +) -> Optional[SessionMessage]: + """根据目标会话构建待发送的内部消息对象。 + + Args: + message_segment: 待发送的消息段。 + stream_id: 目标会话 ID。 + display_message: 用于界面展示的文本内容。 + reply_message: 被回复的锚点消息。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + Optional[SessionMessage]: 构建成功时返回内部消息对象;若目标会话或 + 机器人账号不存在,则返回 ``None``。 + """ + target_stream = _chat_manager.get_session_by_session_id(stream_id) + if target_stream is None: + logger.error(f"[SendService] 未找到聊天流: {stream_id}") + return None + + bot_user_id = get_bot_account(target_stream.platform) + if not bot_user_id: + logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息") + return None + + current_time = time.time() + message_id = f"send_api_{int(current_time * 1000)}" + anchor_message = reply_message.deepcopy() if reply_message is not None else None + + group_info: Optional[GroupInfo] = None + if target_stream.group_id: + group_name = "" + if ( + target_stream.context + and target_stream.context.message + and target_stream.context.message.message_info.group_info + ): + group_name = target_stream.context.message.message_info.group_info.group_name + group_info = GroupInfo( + group_id=target_stream.group_id, + group_name=group_name, + ) + + additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) + if selected_expressions is not None: + additional_config["selected_expressions"] = selected_expressions + + outbound_message = SessionMessage( + message_id=message_id, + timestamp=datetime.fromtimestamp(current_time), + platform=target_stream.platform, + ) + outbound_message.message_info = MessageInfo( + user_info=UserInfo( + user_id=bot_user_id, + user_nickname=global_config.bot.nickname, + ), + group_info=group_info, + additional_config=additional_config, + ) + outbound_message.raw_message = _build_message_sequence_from_seg(message_segment) + outbound_message.session_id = target_stream.session_id + outbound_message.display_message = display_message + outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None + outbound_message.is_emoji = message_segment.type == "emoji" + outbound_message.is_picture = message_segment.type == "image" + outbound_message.is_command = message_segment.type == "command" + outbound_message.initialized = True + return outbound_message + + +def _ensure_reply_component(message: SessionMessage, reply_message_id: str) -> None: + """为消息补充回复组件。 + + Args: + message: 待发送的内部消息对象。 + reply_message_id: 被引用消息的 ID。 + """ + if message.raw_message.components: + first_component = message.raw_message.components[0] + if isinstance(first_component, ReplyComponent) and first_component.target_message_id == reply_message_id: + return + + message.raw_message.components.insert(0, ReplyComponent(target_message_id=reply_message_id)) + + +async def _prepare_message_for_platform_io( + message: SessionMessage, + *, + typing: bool, + set_reply: bool, + reply_message_id: Optional[str], +) -> None: + """为 Platform IO 发送链预处理消息。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否构建引用回复组件。 + reply_message_id: 被引用消息的 ID。 + + Raises: + ValueError: 当要求设置引用回复但缺少 ``reply_message_id`` 时抛出。 + """ + if set_reply: + if not reply_message_id: + raise ValueError("set_reply=True 时必须提供 reply_message_id") + _ensure_reply_component(message, reply_message_id) + + message.processed_plain_text = _build_processed_plain_text(message) + if typing: + typing_time = calculate_typing_time( + input_string=message.processed_plain_text or "", + is_emoji=message.is_emoji, + ) + await asyncio.sleep(typing_time) + + +def _store_sent_message(message: SessionMessage) -> None: + """将已成功发送的消息写入数据库。 + + Args: + message: 已成功发送的内部消息对象。 + """ + MessageUtils.store_message_to_db(message) + + +def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None: + """输出 Platform IO 批量发送失败详情。 + + Args: + delivery_batch: Platform IO 返回的批量回执。 + """ + failed_details = "; ".join( + f"driver={receipt.driver_id} status={receipt.status} error={receipt.error}" + for receipt in delivery_batch.failed_receipts + ) or "未命中任何发送路由" + logger.warning( + "[SendService] Platform IO 发送失败: platform=%s %s", + delivery_batch.route_key.platform, + failed_details, + ) + + +async def _send_via_platform_io( + message: SessionMessage, + *, + typing: bool, + set_reply: bool, + reply_message_id: Optional[str], + storage_message: bool, + show_log: bool, +) -> bool: + """通过 Platform IO 发送消息。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否设置引用回复。 + reply_message_id: 被引用消息的 ID。 + storage_message: 发送成功后是否写入数据库。 + show_log: 是否输出发送成功日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ + platform_io_manager = get_platform_io_manager() + try: + await platform_io_manager.ensure_send_pipeline_ready() + except Exception as exc: + logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}") + logger.debug(traceback.format_exc()) + return False + + try: + route_key = platform_io_manager.build_route_key_from_message(message) + except Exception as exc: + logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}") + return False + + try: + await _prepare_message_for_platform_io( + message, + typing=typing, + set_reply=set_reply, + reply_message_id=reply_message_id, + ) + delivery_batch = await platform_io_manager.send_message( + message, + route_key, + metadata={"show_log": False}, + ) + except Exception as exc: + logger.error(f"[SendService] Platform IO 发送异常: {exc}") + logger.debug(traceback.format_exc()) + return False + + if delivery_batch.has_success: + if storage_message: + _store_sent_message(message) + if show_log: + successful_driver_ids = [ + receipt.driver_id or "unknown" + for receipt in delivery_batch.sent_receipts + ] + logger.info( + "[SendService] 已通过 Platform IO 将消息发往平台 '%s' (drivers: %s)", + route_key.platform, + ", ".join(successful_driver_ids), + ) + return True + + _log_platform_io_failures(delivery_batch) + return False + + +async def send_session_message( + message: SessionMessage, + *, + typing: bool = False, + set_reply: bool = False, + reply_message_id: Optional[str] = None, + storage_message: bool = True, + show_log: bool = True, +) -> bool: + """统一发送一条内部消息。 + + 该方法是内部模块的统一发送入口: + + 1. 构造并维护内部消息对象。 + 2. 由 Platform IO 统一决定走插件链还是 legacy 旧链。 + 3. send service 不再自行判断底层发送路径。 + + Args: + message: 待发送的内部消息对象。 + typing: 是否模拟打字等待。 + set_reply: 是否设置引用回复。 + reply_message_id: 被引用消息的 ID。 + storage_message: 发送成功后是否写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``,否则返回 ``False``。 + """ + if not message.message_id: + logger.error("[SendService] 消息缺少 message_id,无法发送") + raise ValueError("消息缺少 message_id,无法发送") + + return await _send_via_platform_io( + message, + typing=typing, + set_reply=set_reply, + reply_message_id=reply_message_id, + storage_message=storage_message, + show_log=show_log, + ) + + async def _send_to_target( message_segment: Seg, stream_id: str, display_message: str = "", typing: bool = False, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, storage_message: bool = True, show_log: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定目标发送消息。 + """向指定目标构建并发送消息。 Args: message_segment: 待发送的消息段。 @@ -104,110 +494,66 @@ async def _send_to_target( bool: 发送成功返回 ``True``,否则返回 ``False``。 """ try: - if set_reply and not reply_message: + if set_reply and reply_message is None: logger.warning("[SendService] 使用引用回复,但未提供回复消息") return False if show_log: logger.debug(f"[SendService] 发送{message_segment.type}消息到 {stream_id}") - target_stream = _chat_manager.get_session_by_session_id(stream_id) - if not target_stream: - logger.error(f"[SendService] 未找到聊天流: {stream_id}") - return False - - message_sender = UniversalMessageSender() - - current_time = time.time() - message_id = f"send_api_{int(current_time * 1000)}" - - anchor_message: Optional[MaiMessage] = None - if reply_message: - anchor_message = reply_message.deepcopy() - if anchor_message: - logger.debug( - f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}" - ) - - group_info = None - if target_stream.group_id: - group_name = "" - if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info: - group_name = target_stream.context.message.message_info.group_info.group_name - group_info = MaimGroupInfo( - group_id=target_stream.group_id, - group_name=group_name, - platform=target_stream.platform, - ) - - additional_config: Dict[str, object] = _inherit_platform_io_route_metadata(target_stream) - if selected_expressions is not None: - additional_config["selected_expressions"] = selected_expressions - bot_user_id = get_bot_account(target_stream.platform) - if not bot_user_id: - logger.error(f"[SendService] 平台 {target_stream.platform} 未配置机器人账号,无法发送消息") - return False - - maim_message = MessageBase( - message_info=BaseMessageInfo( - platform=target_stream.platform, - message_id=message_id, - time=current_time, - user_info=MaimUserInfo( - user_id=bot_user_id, - user_nickname=global_config.bot.nickname, - platform=target_stream.platform, - ), - group_info=group_info, - additional_config=additional_config, - ), + outbound_message = _build_outbound_session_message( message_segment=message_segment, + stream_id=stream_id, + display_message=display_message, + reply_message=reply_message, + selected_expressions=selected_expressions, ) - bot_message = SessionMessage.from_maim_message(maim_message) - bot_message.session_id = target_stream.session_id - bot_message.display_message = display_message - bot_message.reply_to = anchor_message.message_id if anchor_message else None - bot_message.is_emoji = message_segment.type == "emoji" - bot_message.is_picture = message_segment.type == "image" - bot_message.is_command = message_segment.type == "command" + if outbound_message is None: + return False - sent_msg = await message_sender.send_message( - bot_message, + sent = await send_session_message( + outbound_message, typing=typing, set_reply=set_reply, - reply_message_id=anchor_message.message_id if anchor_message else None, + reply_message_id=reply_message.message_id if reply_message is not None else None, storage_message=storage_message, show_log=show_log, ) - - if sent_msg: + if sent: logger.debug(f"[SendService] 成功发送消息到 {stream_id}") return True - else: - logger.error("[SendService] 发送消息失败") - return False - except Exception as e: - logger.error(f"[SendService] 发送消息时出错: {e}") + logger.error("[SendService] 发送消息失败") + return False + except Exception as exc: + logger.error(f"[SendService] 发送消息时出错: {exc}") traceback.print_exc() return False -# ============================================================================= -# 公共函数 - 预定义类型的发送函数 -# ============================================================================= - - async def text_to_stream( text: str, stream_id: str, typing: bool = False, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, storage_message: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: - """向指定流发送文本消息""" + """向指定流发送文本消息。 + + Args: + text: 要发送的文本内容。 + stream_id: 目标会话 ID。 + typing: 是否显示输入中状态。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + storage_message: 是否在发送成功后写入数据库。 + selected_expressions: 可选的表情候选索引列表。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( message_segment=Seg(type="text", data=text), stream_id=stream_id, @@ -225,9 +571,20 @@ async def emoji_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, ) -> bool: - """向指定流发送表情包""" + """向指定流发送表情消息。 + + Args: + emoji_base64: 表情图片的 Base64 内容。 + stream_id: 目标会话 ID。 + storage_message: 是否在发送成功后写入数据库。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( message_segment=Seg(type="emoji", data=emoji_base64), stream_id=stream_id, @@ -244,9 +601,20 @@ async def image_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, ) -> bool: - """向指定流发送图片""" + """向指定流发送图片消息。 + + Args: + image_base64: 图片的 Base64 内容。 + stream_id: 目标会话 ID。 + storage_message: 是否在发送成功后写入数据库。 + set_reply: 是否附带引用回复。 + reply_message: 被回复的消息对象。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( message_segment=Seg(type="image", data=image_base64), stream_id=stream_id, @@ -260,18 +628,33 @@ async def image_to_stream( async def custom_to_stream( message_type: str, - content: str | Dict, + content: str | Dict[str, Any], stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送自定义类型消息""" + """向指定流发送自定义类型消息。 + + Args: + message_type: 自定义消息类型。 + content: 自定义消息内容。 + stream_id: 目标会话 ID。 + display_message: 用于展示的文本内容。 + typing: 是否显示输入中状态。 + reply_message: 被回复的消息对象。 + set_reply: 是否附带引用回复。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 发送成功时返回 ``True``。 + """ return await _send_to_target( - message_segment=Seg(type=message_type, data=content), # type: ignore + message_segment=Seg(type=message_type, data=content), # type: ignore[arg-type] stream_id=stream_id, display_message=display_message, typing=typing, @@ -287,18 +670,33 @@ async def custom_reply_set_to_stream( stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["SessionMessage"] = None, + reply_message: Optional[MaiMessage] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送消息组件序列。""" - flag: bool = True + """向指定流发送消息组件序列。 + + Args: + reply_set: 待发送的消息组件序列。 + stream_id: 目标会话 ID。 + display_message: 用于展示的文本内容。 + typing: 是否显示输入中状态。 + reply_message: 被回复的消息对象。 + set_reply: 是否附带引用回复。 + storage_message: 是否在发送成功后写入数据库。 + show_log: 是否输出发送日志。 + + Returns: + bool: 全部组件发送成功时返回 ``True``。 + """ + success = True for component in reply_set.components: if isinstance(component, DictComponent): - message_seg = Seg(type="dict", data=component.data) # type: ignore + message_seg = Seg(type="dict", data=component.data) # type: ignore[arg-type] else: message_seg = await component.to_seg() + status = await _send_to_target( message_segment=message_seg, stream_id=stream_id, @@ -310,8 +708,8 @@ async def custom_reply_set_to_stream( show_log=show_log, ) if not status: - flag = False + success = False logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}") set_reply = False - return flag + return success