From b5408b455069ae00157564c321066f0ab0188bf8 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 29 Mar 2026 19:57:34 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E4=BF=AE=E5=A4=8D=E7=A7=81?= =?UTF-8?q?=E8=81=8A=E5=9B=9E=E5=A4=8D=E9=97=AE=E9=A2=98=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E5=A4=8Dwait=E5=8A=A8=E4=BD=9C=EF=BC=8C=E8=A1=A5=E4=B8=8A?= =?UTF-8?q?=E5=9B=9E=E5=A4=8D=E5=90=8E=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_send_service.py | 111 +++++++++++------ .../data_models/mai_message_data_model.py | 91 ++++++++------ src/maisaka/reasoning_engine.py | 112 ++++++++++++++---- src/maisaka/runtime.py | 34 +++--- 4 files changed, 241 insertions(+), 107 deletions(-) diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py index 16aad080..44f77090 100644 --- a/pytests/test_send_service.py +++ b/pytests/test_send_service.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List import pytest from src.chat.message_receive.chat_manager import BotChatSession +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.services import send_service @@ -13,42 +14,18 @@ class _FakePlatformIOManager: """用于测试的 Platform IO 管理器假对象。""" def __init__(self, delivery_batch: Any) -> None: - """初始化假 Platform IO 管理器。 - - Args: - delivery_batch: 发送时返回的批量回执。 - """ self._delivery_batch = delivery_batch self.ensure_calls = 0 self.sent_messages: List[Dict[str, Any]] = [] async def ensure_send_pipeline_ready(self) -> None: - """记录发送管线准备调用次数。""" self.ensure_calls += 1 def build_route_key_from_message(self, message: Any) -> Any: - """根据消息构造假的路由键。 - - Args: - message: 待发送的内部消息对象。 - - Returns: - Any: 简化后的路由键对象。 - """ del message return SimpleNamespace(platform="qq") async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: - """记录发送请求并返回预设回执。 - - Args: - message: 待发送的内部消息对象。 - route_key: 本次发送使用的路由键。 - metadata: 发送元数据。 - - Returns: - Any: 预设的批量发送回执。 - """ self.sent_messages.append( { "message": message, @@ -59,12 +36,7 @@ class _FakePlatformIOManager: return self._delivery_batch -def _build_target_stream() -> BotChatSession: - """构造一个最小可用的目标会话对象。 - - Returns: - BotChatSession: 测试用会话对象。 - """ +def _build_private_stream() -> BotChatSession: return BotChatSession( session_id="test-session", platform="qq", @@ -73,14 +45,21 @@ def _build_target_stream() -> BotChatSession: ) +def _build_group_stream() -> BotChatSession: + return BotChatSession( + session_id="group-session", + platform="qq", + user_id="target-user", + group_id="target-group", + ) + + def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( monkeypatch: pytest.MonkeyPatch, ) -> None: - """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。""" - monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "") - metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream()) + metadata = send_service._inherit_platform_io_route_metadata(_build_private_stream()) assert metadata["platform_io_account_id"] == "bot-qq" assert metadata["platform_io_target_user_id"] == "target-user" @@ -88,7 +67,6 @@ def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( @pytest.mark.asyncio async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: - """send service 应将发送职责统一交给 Platform IO。""" fake_manager = _FakePlatformIOManager( delivery_batch=SimpleNamespace( has_success=True, @@ -104,7 +82,7 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke monkeypatch.setattr( send_service._chat_manager, "get_session_by_session_id", - lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, ) monkeypatch.setattr( send_service.MessageUtils, @@ -123,7 +101,6 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke @pytest.mark.asyncio async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: - """Platform IO 批量发送全部失败时,应直接向上返回失败。""" fake_manager = _FakePlatformIOManager( delivery_batch=SimpleNamespace( has_success=False, @@ -144,7 +121,7 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: monkeypatch.setattr( send_service._chat_manager, "get_session_by_session_id", - lambda stream_id: _build_target_stream() if stream_id == "test-session" else None, + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, ) result = await send_service.text_to_stream(text="发送失败", stream_id="test-session") @@ -152,3 +129,63 @@ async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: assert result is False assert fake_manager.ensure_calls == 1 assert len(fake_manager.sent_messages) == 1 + + +@pytest.mark.asyncio +async def test_private_outbound_message_preserves_bot_sender_and_receiver_user( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, + ) + + outbound_message = send_service._build_outbound_session_message( + message_sequence=MessageSequence(components=[TextComponent(text="你好")]), + stream_id="test-session", + display_message="你好", + ) + + assert outbound_message is not None + maim_message = await outbound_message.to_maim_message() + + assert maim_message.message_info.user_info is not None + assert maim_message.message_info.user_info.user_id == "bot-qq" + assert maim_message.message_info.group_info is None + assert maim_message.message_info.sender_info is not None + assert maim_message.message_info.sender_info.user_info is not None + assert maim_message.message_info.sender_info.user_info.user_id == "bot-qq" + assert maim_message.message_info.receiver_info is not None + assert maim_message.message_info.receiver_info.user_info is not None + assert maim_message.message_info.receiver_info.user_info.user_id == "target-user" + + +@pytest.mark.asyncio +async def test_group_outbound_message_preserves_bot_sender_and_target_group( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_group_stream() if stream_id == "group-session" else None, + ) + + outbound_message = send_service._build_outbound_session_message( + message_sequence=MessageSequence(components=[TextComponent(text="大家好")]), + stream_id="group-session", + display_message="大家好", + ) + + assert outbound_message is not None + maim_message = await outbound_message.to_maim_message() + + assert maim_message.message_info.user_info is not None + assert maim_message.message_info.user_info.user_id == "bot-qq" + assert maim_message.message_info.group_info is not None + assert maim_message.message_info.group_info.group_id == "target-group" + assert maim_message.message_info.receiver_info is not None + assert maim_message.message_info.receiver_info.group_info is not None + assert maim_message.message_info.receiver_info.group_info.group_id == "target-group" diff --git a/src/common/data_models/mai_message_data_model.py b/src/common/data_models/mai_message_data_model.py index 4396201a..814f642b 100644 --- a/src/common/data_models/mai_message_data_model.py +++ b/src/common/data_models/mai_message_data_model.py @@ -1,15 +1,17 @@ +import json from dataclasses import dataclass, field -from maim_message import ( - MessageBase, - UserInfo as MaimUserInfo, - GroupInfo as MaimGroupInfo, - BaseMessageInfo as MaimBaseMessageInfo, - Seg, -) +from datetime import datetime from typing import Optional -import json -from datetime import datetime +from maim_message import ( + BaseMessageInfo as MaimBaseMessageInfo, + GroupInfo as MaimGroupInfo, + MessageBase, + ReceiverInfo as MaimReceiverInfo, + Seg, + SenderInfo as MaimSenderInfo, + UserInfo as MaimUserInfo, +) from src.common.database.database_model import Messages from src.common.data_models.message_component_data_model import MessageSequence @@ -41,34 +43,24 @@ class MessageInfo: class MaiMessage(BaseDatabaseDataModel[Messages]): def __init__(self, message_id: str, timestamp: datetime, platform: str): self.message_id: str = message_id - self.timestamp: datetime = timestamp # 时间戳 - self.initialized = False # 用于标记是否已初始化其他属性 + self.timestamp: datetime = timestamp + self.initialized = False self.platform: str = platform - # 定义其他属性 - self.message_info: MessageInfo # 初始化后赋值 + self.message_info: MessageInfo self.is_mentioned: bool = False - """机器人被提及标记,若被at,则提及也被标记""" self.is_at: bool = False - """机器人被at标记""" self.is_emoji: bool = False - """消息为纯表情包,在计算打字时长时候会被特殊处理""" self.is_picture: bool = False - """消息为纯图片,在计算打字时长时候会被特殊处理""" self.is_command: bool = False - """消息为命令消息,打字时长必定为0""" self.is_notify: bool = False - """消息为通知消息""" self.session_id: str self.reply_to: Optional[str] = None self.processed_plain_text: Optional[str] = None - """处理过后的纯文本内容""" self.display_message: Optional[str] = None - """最后显示给大模型的消息内容""" self.raw_message: MessageSequence - """原始消息数据""" @classmethod def from_db_instance(cls, db_record: "Messages"): @@ -79,12 +71,12 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): group_info = GroupInfo(db_record.group_id, db_record.group_name) else: group_info = None + obj.message_info = MessageInfo( user_info=user_info, group_info=group_info, additional_config=json.loads(db_record.additional_config) if db_record.additional_config else {}, ) - obj.is_mentioned = db_record.is_mentioned obj.is_at = db_record.is_at obj.is_emoji = db_record.is_emoji @@ -127,18 +119,22 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): @classmethod def from_maim_message(cls, message: MessageBase): - """从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息""" + """从 maim_message.MessageBase 创建 MaiMessage。""" msg_info = message.message_info assert msg_info, "MessageBase 的 message_info 不能为空" + platform = msg_info.platform assert isinstance(platform, str) + msg_id = str(msg_info.message_id) timestamp = msg_info.time assert isinstance(msg_id, str) assert msg_id assert timestamp + obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp), platform=platform) obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message) + usr_info = msg_info.user_info assert usr_info assert isinstance(usr_info.user_id, str) @@ -148,40 +144,69 @@ class MaiMessage(BaseDatabaseDataModel[Messages]): user_nickname=usr_info.user_nickname, user_cardname=usr_info.user_cardname, ) - if grp_info := msg_info.group_info: + + if msg_info.group_info: + grp_info = msg_info.group_info assert isinstance(grp_info.group_id, str) assert isinstance(grp_info.group_name, str) group_info = GroupInfo(group_id=grp_info.group_id, group_name=grp_info.group_name) else: group_info = None + add_cfg = msg_info.additional_config or {} obj.message_info = MessageInfo(user_info=user_info, group_info=group_info, additional_config=add_cfg) return obj async def to_maim_message(self) -> MessageBase: - """ - 从 MaiMessage 实例转换为 maim_message.MessageBase,构建消息内容并设置相关信息 - """ - maim_user_info = MaimUserInfo( + """将 MaiMessage 转换为 maim_message.MessageBase。""" + sender_user_info = MaimUserInfo( user_id=self.message_info.user_info.user_id, user_nickname=self.message_info.user_info.user_nickname, user_cardname=self.message_info.user_info.user_cardname, platform=self.platform, ) - maim_group_info = None + + sender_group_info = None if self.message_info.group_info: - maim_group_info = MaimGroupInfo( + sender_group_info = MaimGroupInfo( group_id=self.message_info.group_info.group_id, group_name=self.message_info.group_info.group_name, platform=self.platform, ) + + sender_info = MaimSenderInfo( + group_info=sender_group_info, + user_info=sender_user_info, + ) + + receiver_group_info = sender_group_info + receiver_user_info = None + additional_config = self.message_info.additional_config or {} + target_user_id = str(additional_config.get("platform_io_target_user_id") or "").strip() + if receiver_group_info is None and target_user_id: + receiver_user_info = MaimUserInfo( + user_id=target_user_id, + user_nickname=None, + user_cardname=None, + platform=self.platform, + ) + + receiver_info = None + if receiver_group_info or receiver_user_info: + receiver_info = MaimReceiverInfo( + group_info=receiver_group_info, + user_info=receiver_user_info, + ) + maim_msg_info = MaimBaseMessageInfo( platform=self.platform, message_id=self.message_id, time=self.timestamp.timestamp(), - group_info=maim_group_info, - user_info=maim_user_info, + group_info=receiver_group_info, + user_info=sender_user_info, additional_config=self.message_info.additional_config, + sender_info=sender_info, + receiver_info=receiver_info, ) msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message) return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments)) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index ff0f0e2c..bef58b59 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -14,7 +14,7 @@ from sqlmodel import select from src.chat.heart_flow.heartFC_utils import CycleDetail from src.chat.message_receive.message import SessionMessage from src.chat.replyer.replyer_manager import replyer_manager -from src.chat.utils.utils import get_bot_account +from src.chat.utils.utils import get_bot_account, process_llm_response from src.common.database.database import get_db_session from src.common.database.database_model import Jargon from src.common.data_models.mai_message_data_model import UserInfo @@ -58,14 +58,26 @@ class MaisakaReasoningEngine: try: while self._runtime._running: cached_messages = await self._runtime._internal_turn_queue.get() - if not cached_messages: + timeout_triggered = cached_messages is None + if not timeout_triggered and not cached_messages: self._runtime._internal_turn_queue.task_done() continue self._runtime._agent_state = self._runtime._STATE_RUNNING - await self._ingest_messages(cached_messages) - - anchor_message = cached_messages[-1] + if cached_messages: + await self._ingest_messages(cached_messages) + anchor_message = cached_messages[-1] + else: + anchor_message = self._get_timeout_anchor_message() + if anchor_message is None: + logger.warning( + f"{self._runtime.log_prefix} wait 超时后缺少可复用的锚点消息,跳过本轮继续思考" + ) + self._runtime._internal_turn_queue.task_done() + continue + logger.info(f"{self._runtime.log_prefix} wait 超时后开始新一轮思考") + self._runtime._chat_history.append(self._build_wait_timeout_message(anchor_message)) + self._trim_chat_history() try: for round_index in range(self._runtime._max_internal_rounds): cycle_detail = self._start_cycle() @@ -126,6 +138,24 @@ class MaisakaReasoningEngine: logger.exception("%s Maisaka internal loop crashed", self._runtime.log_prefix) raise + def _get_timeout_anchor_message(self) -> Optional[SessionMessage]: + """在 wait 超时后复用最近一条真实用户消息作为锚点。""" + if self._runtime.message_cache: + return self._runtime.message_cache[-1] + return None + + def _build_wait_timeout_message(self, anchor_message: SessionMessage) -> SessionMessage: + """构造 wait 超时后的工具结果消息,用于触发下一轮思考。""" + return build_message( + role="tool", + content="wait 已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。", + source="tool", + platform=anchor_message.platform, + session_id=self._runtime.session_id, + group_info=self._runtime._build_group_info(anchor_message), + user_info=UserInfo(user_id="maisaka_tool", user_nickname="tool", user_cardname=None), + ) + async def _ingest_messages(self, messages: list[SessionMessage]) -> None: """处理传入消息列表,将其转换为历史消息并加入聊天历史缓存。""" for message in messages: @@ -417,6 +447,19 @@ class MaisakaReasoningEngine: logger.info(f"{self._runtime.log_prefix} reasoning similarity: {similarity:.2f}") return similarity > 0.9 + @staticmethod + def _post_process_reply_text(reply_text: str) -> list[str]: + """沿用旧回复链的文本后处理,执行分段与错别字注入。""" + processed_segments: list[str] = [] + for segment in process_llm_response(reply_text): + normalized_segment = segment.strip() + if normalized_segment: + processed_segments.append(normalized_segment) + + if processed_segments: + return processed_segments + return [reply_text.strip()] + async def _handle_tool_calls( self, tool_calls: list[ToolCall], @@ -426,8 +469,10 @@ class MaisakaReasoningEngine: for tool_call in tool_calls: if tool_call.func_name == "reply": reply_sent = await self._handle_reply(tool_call, latest_thought, anchor_message) - if reply_sent: - return True + if not reply_sent: + logger.warning( + f"{self._runtime.log_prefix} reply tool did not produce a visible message, continuing loop" + ) continue if tool_call.func_name == "no_reply": @@ -634,19 +679,31 @@ class MaisakaReasoningEngine: ) return False + reply_segments = self._post_process_reply_text(reply_text) + combined_reply_text = "".join(reply_segments) + logger.info( + f"{self._runtime.log_prefix} reply post process finished: " + f"target_msg_id={target_message_id} segment_count={len(reply_segments)} " + f"segments={reply_segments!r}" + ) + logger.info( f"{self._runtime.log_prefix} sending guided reply: " - f"target_msg_id={target_message_id} quote={quote_reply} reply_text={reply_text!r}" + f"target_msg_id={target_message_id} quote={quote_reply} reply_segments={reply_segments!r}" ) try: - sent = await send_service.text_to_stream( - text=reply_text, - stream_id=self._runtime.session_id, - set_reply=quote_reply, - reply_message=target_message if quote_reply else None, - selected_expressions=reply_result.selected_expression_ids or None, - typing=False, - ) + sent = False + for index, segment in enumerate(reply_segments): + sent = await send_service.text_to_stream( + text=segment, + stream_id=self._runtime.session_id, + set_reply=quote_reply if index == 0 else False, + reply_message=target_message if quote_reply and index == 0 else None, + selected_expressions=reply_result.selected_expression_ids or None, + typing=index > 0, + ) + if not sent: + break except Exception: logger.exception( f"{self._runtime.log_prefix} send_service.text_to_stream crashed " @@ -675,11 +732,12 @@ class MaisakaReasoningEngine: if self._runtime.chat_stream is not None: await database_api.store_tool_info( chat_stream=self._runtime.chat_stream, - display_prompt=f"你对{target_user_name}进行了回复:{reply_text}", + display_prompt=f"你对{target_user_name}进行了回复:{combined_reply_text}", tool_data={ "msg_id": target_message_id, "quote": quote_reply, - "reply_text": reply_text, + "reply_text": combined_reply_text, + "reply_segments": reply_segments, }, tool_name="reply", tool_reasoning=latest_thought, @@ -693,17 +751,25 @@ class MaisakaReasoningEngine: user_cardname=None, ) history_message = build_message( - role="assistant", - content=reply_text, + role="user", + content="", source="guided_reply", platform=target_platform, session_id=self._runtime.session_id, group_info=self._runtime._build_group_info(target_message), user_info=bot_user_info, ) - structured_visible_text = f"{self._build_planner_user_prefix(history_message)}{reply_text}" - history_message.display_message = structured_visible_text - history_message.processed_plain_text = structured_visible_text + history_message.raw_message = MessageSequence( + [TextComponent(f"{self._build_planner_user_prefix(history_message)}{combined_reply_text}")] + ) + visible_reply_text = format_speaker_content( + bot_name, + combined_reply_text, + history_message.timestamp, + history_message.message_id, + ) + history_message.display_message = visible_reply_text + history_message.processed_plain_text = visible_reply_text self._runtime._chat_history.append(history_message) return True diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index c8db017e..90b4b961 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -46,7 +46,7 @@ class MaisakaHeartFlowChatting: # Keep all original messages for batching and later learning. self.message_cache: list[SessionMessage] = [] self._last_processed_index = 0 - self._internal_turn_queue: asyncio.Queue[list[SessionMessage]] = asyncio.Queue() + self._internal_turn_queue: asyncio.Queue[Optional[list[SessionMessage]]] = asyncio.Queue() self._mcp_manager: Optional[MCPManager] = None self._current_cycle_detail: Optional[CycleDetail] = None @@ -139,22 +139,28 @@ class MaisakaHeartFlowChatting: while self._running: if not self._has_pending_messages(): if self._agent_state == self._STATE_WAIT: - message_arrived = await self._wait_for_trigger() + trigger_reason = await self._wait_for_trigger() else: self._new_message_event.clear() await self._new_message_event.wait() - message_arrived = self._running + trigger_reason: Literal["message", "timeout", "stop"] = "message" if self._running else "stop" else: - message_arrived = True + trigger_reason = "message" if not self._running: return - if not message_arrived: + if trigger_reason == "stop": self._agent_state = self._STATE_STOP continue self._new_message_event.clear() + if trigger_reason == "timeout": + # wait 超时后继续下一轮内部思考,但不要重复注入旧消息。 + logger.info(f"{self.log_prefix} wait 超时后投递继续思考触发") + await self._internal_turn_queue.put(None) + continue + while self._has_pending_messages(): cached_messages = self._collect_pending_messages() if not cached_messages: @@ -190,31 +196,31 @@ class MaisakaHeartFlowChatting: ) return unique_messages - async def _wait_for_trigger(self) -> bool: - """Return True on new message, False on timeout.""" + async def _wait_for_trigger(self) -> Literal["message", "timeout", "stop"]: + """等待 wait 状态的触发结果。""" if self._agent_state != self._STATE_WAIT: await self._new_message_event.wait() - return True + return "message" if self._wait_until is None: await self._new_message_event.wait() - return True + return "message" timeout = self._wait_until - time.time() if timeout <= 0: logger.info(f"{self.log_prefix} Maisaka wait timed out") - self._enter_stop_state() + self._agent_state = self._STATE_RUNNING self._wait_until = None - return False + return "timeout" try: await asyncio.wait_for(self._new_message_event.wait(), timeout=timeout) - return True + return "message" except asyncio.TimeoutError: logger.info(f"{self.log_prefix} Maisaka wait timed out") - self._enter_stop_state() + self._agent_state = self._STATE_RUNNING self._wait_until = None - return False + return "timeout" def _enter_wait_state(self, seconds: Optional[float] = None) -> None: """Enter wait state."""