From 6968879a04624fcfe05888b62ed051fd3766ae02 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 7 Apr 2026 16:21:42 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E4=BF=AE=E5=A4=8D=E5=AD=A4?= =?UTF-8?q?=E5=84=BF=E5=B7=A5=E5=85=B7=E6=8A=A5=E9=94=99=EF=BC=8C=E4=B8=BA?= =?UTF-8?q?replyer=E7=AD=89tool=E6=B7=BB=E5=8A=A0=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E7=9A=84=E6=8E=A7=E5=88=B6=E5=8F=B0=E5=B1=95=E7=A4=BA=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_maisaka_monitor_protocol.py | 192 +++ src/chat/replyer/maisaka_generator.py | 77 +- src/chat/replyer/maisaka_generator_multi.py | 97 +- .../reply_generation_data_models.py | 106 +- src/maisaka/builtin_tool/reply.py | 16 +- src/maisaka/chat_loop_service.py | 19 +- src/maisaka/history_utils.py | 27 + src/maisaka/monitor_events.py | 435 ++----- src/maisaka/reasoning_engine.py | 101 +- src/memory_system/chat_history_summarizer.py | 1123 +++++++++++++++++ tests/test_maisaka_orphan_tool_results.py | 49 + 11 files changed, 1803 insertions(+), 439 deletions(-) create mode 100644 pytests/test_maisaka_monitor_protocol.py create mode 100644 src/memory_system/chat_history_summarizer.py create mode 100644 tests/test_maisaka_orphan_tool_results.py diff --git a/pytests/test_maisaka_monitor_protocol.py b/pytests/test_maisaka_monitor_protocol.py new file mode 100644 index 00000000..591c2062 --- /dev/null +++ b/pytests/test_maisaka_monitor_protocol.py @@ -0,0 +1,192 @@ +from types import SimpleNamespace +from typing import Any, Callable + +import pytest + +from src.chat.replyer import maisaka_generator as legacy_replyer_module +from src.chat.replyer import maisaka_generator_multi as multimodal_replyer_module +from src.common.data_models.reply_generation_data_models import ( + GenerationMetrics, + LLMCompletionResult, + ReplyGenerationResult, +) +from src.core.tooling import ToolExecutionResult, ToolInvocation +from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext +from src.maisaka.builtin_tool import reply as reply_tool_module +from src.maisaka.monitor_events import emit_planner_finalized +from src.maisaka.reasoning_engine import MaisakaReasoningEngine + + +class _FakeLLMResult: + def __init__(self) -> None: + self.response = "测试回复" + self.reasoning = "先理解上下文,再给出自然回复。" + self.model_name = "fake-model" + self.tool_calls = [] + self.prompt_tokens = 12 + self.completion_tokens = 7 + self.total_tokens = 19 + + +class _FakeLegacyLLMServiceClient: + def __init__(self, *args: Any, **kwargs: Any) -> None: + del args + del kwargs + + async def generate_response(self, prompt: str) -> _FakeLLMResult: + assert prompt + return _FakeLLMResult() + + +class _FakeMultimodalLLMServiceClient: + def __init__(self, *args: Any, **kwargs: Any) -> None: + del args + del kwargs + + async def generate_response_with_messages(self, *, message_factory: Callable[[object], list[Any]]) -> _FakeLLMResult: + assert message_factory(object()) + return _FakeLLMResult() + + +@pytest.mark.asyncio +async def test_legacy_and_multimodal_replyer_monitor_detail_have_same_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(legacy_replyer_module, "LLMServiceClient", _FakeLegacyLLMServiceClient) + monkeypatch.setattr(multimodal_replyer_module, "LLMServiceClient", _FakeMultimodalLLMServiceClient) + monkeypatch.setattr(legacy_replyer_module, "load_prompt", lambda *args, **kwargs: "legacy prompt") + monkeypatch.setattr(multimodal_replyer_module, "load_prompt", lambda *args, **kwargs: "multi prompt") + + legacy_generator = legacy_replyer_module.MaisakaReplyGenerator(chat_stream=None, request_type="test_legacy") + multimodal_generator = multimodal_replyer_module.MaisakaReplyGenerator(chat_stream=None, request_type="test_multi") + + legacy_success, legacy_result = await legacy_generator.generate_reply_with_context( + stream_id="session-legacy", + chat_history=[], + reply_reason="测试原因", + ) + multimodal_success, multimodal_result = await multimodal_generator.generate_reply_with_context( + stream_id="session-multi", + chat_history=[], + reply_reason="测试原因", + ) + + assert legacy_success is True + assert multimodal_success is True + assert legacy_result.monitor_detail is not None + assert multimodal_result.monitor_detail is not None + assert set(legacy_result.monitor_detail.keys()) == set(multimodal_result.monitor_detail.keys()) + assert set(legacy_result.monitor_detail["metrics"].keys()) == set(multimodal_result.monitor_detail["metrics"].keys()) + assert legacy_result.monitor_detail["metrics"]["prompt_tokens"] == 12 + assert legacy_result.monitor_detail["metrics"]["completion_tokens"] == 7 + assert legacy_result.monitor_detail["metrics"]["total_tokens"] == 19 + + +@pytest.mark.asyncio +async def test_reply_tool_puts_monitor_detail_into_metadata(monkeypatch: pytest.MonkeyPatch) -> None: + fake_monitor_detail = { + "prompt_text": "reply prompt", + "reasoning_text": "reply reasoning", + "output_text": "reply output", + "metrics": {"model_name": "fake-model", "total_tokens": 10}, + } + fake_reply_result = ReplyGenerationResult( + success=True, + completion=LLMCompletionResult(response_text="测试回复"), + metrics=GenerationMetrics(overall_ms=11.5), + monitor_detail=fake_monitor_detail, + ) + + class _FakeReplyer: + async def generate_reply_with_context(self, **kwargs: Any) -> tuple[bool, ReplyGenerationResult]: + del kwargs + return True, fake_reply_result + + monkeypatch.setattr(reply_tool_module.replyer_manager, "get_replyer", lambda **kwargs: _FakeReplyer()) + monkeypatch.setattr(reply_tool_module, "render_cli_message", lambda text: text) + + target_message = SimpleNamespace( + message_id="msg-1", + message_info=SimpleNamespace( + user_info=SimpleNamespace( + user_cardname="测试用户", + user_nickname="测试用户", + user_id="user-1", + ) + ), + ) + runtime = SimpleNamespace( + _source_messages_by_id={"msg-1": target_message}, + log_prefix="[test]", + chat_stream=SimpleNamespace(platform=reply_tool_module.CLI_PLATFORM_NAME), + session_id="session-1", + _chat_history=[], + _clear_force_continue_until_reply=lambda: None, + run_sub_agent=None, + ) + engine = SimpleNamespace(_get_runtime_manager=lambda: None) + tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime) + invocation = ToolInvocation(tool_name="reply", arguments={"msg_id": "msg-1", "set_quote": True}) + + result = await reply_tool_module.handle_tool(tool_ctx, invocation) + + assert result.success is True + assert result.metadata["monitor_detail"] == fake_monitor_detail + + +@pytest.mark.asyncio +async def test_emit_planner_finalized_broadcasts_new_protocol(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, Any] = {} + + async def _fake_broadcast(event: str, data: dict[str, Any]) -> None: + captured["event"] = event + captured["data"] = data + + monkeypatch.setattr("src.maisaka.monitor_events._broadcast", _fake_broadcast) + + await emit_planner_finalized( + session_id="session-1", + cycle_id=3, + request_messages=[{"role": "user", "content": "你好"}], + selected_history_count=5, + tool_count=2, + planner_content="先查询再回复", + planner_tool_calls=[SimpleNamespace(call_id="call-1", func_name="reply", args={"msg_id": "m1"})], + prompt_tokens=100, + completion_tokens=30, + total_tokens=130, + duration_ms=88.5, + tools=[ + { + "tool_call_id": "call-1", + "tool_name": "reply", + "tool_args": {"msg_id": "m1"}, + "success": True, + "duration_ms": 22.0, + "summary": "- reply [成功]: 已回复", + "detail": {"output_text": "测试回复"}, + } + ], + time_records={"planner": 0.1, "tool_calls": 0.2}, + agent_state="stop", + ) + + assert captured["event"] == "planner.finalized" + payload = captured["data"] + assert payload["request"]["messages"][0]["content"] == "你好" + assert payload["request"]["tool_count"] == 2 + assert payload["planner"]["tool_calls"][0]["id"] == "call-1" + assert payload["tools"][0]["detail"]["output_text"] == "测试回复" + assert payload["final_state"]["agent_state"] == "stop" + + +def test_reasoning_engine_build_tool_monitor_result_keeps_non_reply_tool_without_detail() -> None: + engine = object.__new__(MaisakaReasoningEngine) + tool_call = SimpleNamespace(call_id="call-2", func_name="query_memory") + invocation = ToolInvocation(tool_name="query_memory", arguments={"query": "Alice"}) + result = ToolExecutionResult(tool_name="query_memory", success=True, content="查询成功") + + tool_result = engine._build_tool_monitor_result(tool_call, invocation, result, duration_ms=18.6) + + assert tool_result["tool_call_id"] == "call-2" + assert tool_result["tool_name"] == "query_memory" + assert tool_result["tool_args"] == {"query": "Alice"} + assert tool_result["detail"] is None diff --git a/src/chat/replyer/maisaka_generator.py b/src/chat/replyer/maisaka_generator.py index 41154ab0..c7ef7717 100644 --- a/src/chat/replyer/maisaka_generator.py +++ b/src/chat/replyer/maisaka_generator.py @@ -11,6 +11,7 @@ from src.common.data_models.reply_generation_data_models import ( GenerationMetrics, LLMCompletionResult, ReplyGenerationResult, + build_reply_monitor_detail, ) from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt @@ -18,10 +19,17 @@ from src.config.config import global_config from src.core.types import ActionInfo from src.services.llm_service import LLMServiceClient -from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage -from .maisaka_expression_selector import maisaka_expression_selector +from src.maisaka.context_messages import ( + AssistantMessage, + LLMContextMessage, + ReferenceMessage, + SessionBackedMessage, + ToolResultMessage, +) from src.maisaka.message_adapter import parse_speaker_content +from .maisaka_expression_selector import maisaka_expression_selector + logger = get_logger("replyer") @@ -50,7 +58,7 @@ class MaisakaReplyGenerator: self._personality_prompt = self._build_personality_prompt() def _build_personality_prompt(self) -> str: - """构建 replyer 使用的人设描述。""" + """构建 replyer 使用的人设提示。""" try: bot_name = global_config.bot.nickname alias_names = global_config.bot.alias_names @@ -268,6 +276,11 @@ class MaisakaReplyGenerator: sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None, ) -> Tuple[bool, ReplyGenerationResult]: """结合上下文生成 Maisaka 的最终可见回复。""" + + def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]: + result.monitor_detail = build_reply_monitor_detail(result) + return success_value, result + del available_actions del chosen_actions del extra_info @@ -278,14 +291,14 @@ class MaisakaReplyGenerator: del unknown_words result = ReplyGenerationResult() + overall_started_at = time.perf_counter() if chat_history is None: result.error_message = "聊天历史为空" - return False, result + return finalize(False) logger.info( f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} " - f"历史消息数={len(chat_history)} 目标消息编号=" - f"{reply_message.message_id if reply_message else None}" + f"历史消息数={len(chat_history)} 目标消息编号={reply_message.message_id if reply_message else None}" ) filtered_history = [ @@ -293,14 +306,12 @@ class MaisakaReplyGenerator: for message in chat_history if not isinstance(message, (ReferenceMessage, ToolResultMessage)) ] - logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}") - # Validate that express_model is properly initialized if self.express_model is None: logger.error("Maisaka 回复器的回复模型未初始化") result.error_message = "回复模型尚未初始化" - return False, result + return finalize(False) try: reply_context = await self._build_reply_context( @@ -312,9 +323,13 @@ class MaisakaReplyGenerator: ) except Exception as exc: import traceback + logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}") result.error_message = f"构建回复上下文失败: {exc}" - return False, result + result.metrics = GenerationMetrics( + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), + ) + return finalize(False) merged_expression_habits = expression_habits.strip() or reply_context.expression_habits result.selected_expression_ids = ( @@ -328,6 +343,7 @@ class MaisakaReplyGenerator: f"已选表达编号={result.selected_expression_ids!r}" ) + prompt_started_at = time.perf_counter() try: prompt = self._build_prompt( chat_history=filtered_history, @@ -337,26 +353,36 @@ class MaisakaReplyGenerator: ) except Exception as exc: import traceback + logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}") result.error_message = f"构建提示词失败: {exc}" - return False, result + result.metrics = GenerationMetrics( + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), + ) + return finalize(False) + prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2) result.completion.request_prompt = prompt + show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False)) + show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False)) - if global_config.debug.show_replyer_prompt: - logger.info(f"\nMaisaka 回复器提示词:\n{prompt}\n") + if show_replyer_prompt: + logger.info(f"\nMaisaka 回复器提示词:\n{prompt}\n") - started_at = time.perf_counter() + llm_started_at = time.perf_counter() try: generation_result = await self.express_model.generate_response(prompt) except Exception as exc: logger.exception("Maisaka 回复器调用失败") result.error_message = str(exc) result.metrics = GenerationMetrics( - overall_ms=round((time.perf_counter() - started_at) * 1000, 2), + prompt_ms=prompt_ms, + llm_ms=round((time.perf_counter() - llm_started_at) * 1000, 2), + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), ) - return False, result + return finalize(False) + llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2) response_text = (generation_result.response or "").strip() result.success = bool(response_text) result.completion = LLMCompletionResult( @@ -365,18 +391,27 @@ class MaisakaReplyGenerator: reasoning_text=generation_result.reasoning or "", model_name=generation_result.model_name or "", tool_calls=generation_result.tool_calls or [], + prompt_tokens=generation_result.prompt_tokens, + completion_tokens=generation_result.completion_tokens, + total_tokens=generation_result.total_tokens, ) result.metrics = GenerationMetrics( - overall_ms=round((time.perf_counter() - started_at) * 1000, 2), + prompt_ms=prompt_ms, + llm_ms=llm_ms, + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), + stage_logs=[ + f"prompt: {prompt_ms} ms", + f"llm: {llm_ms} ms", + ], ) - if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text: - logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}") + if show_replyer_reasoning and result.completion.reasoning_text: + logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}") if not result.success: result.error_message = "回复器返回了空内容" logger.warning("Maisaka 回复器返回了空内容") - return False, result + return finalize(False) logger.info( f"Maisaka 回复器生成成功: 回复文本={response_text!r} " @@ -384,4 +419,4 @@ class MaisakaReplyGenerator: f"已选表达编号={result.selected_expression_ids!r}" ) result.text_fragments = [response_text] - return True, result + return finalize(True) diff --git a/src/chat/replyer/maisaka_generator_multi.py b/src/chat/replyer/maisaka_generator_multi.py index a5978648..3681b2b3 100644 --- a/src/chat/replyer/maisaka_generator_multi.py +++ b/src/chat/replyer/maisaka_generator_multi.py @@ -16,13 +16,19 @@ from src.common.data_models.reply_generation_data_models import ( GenerationMetrics, LLMCompletionResult, ReplyGenerationResult, + build_reply_monitor_detail, ) from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt from src.config.config import global_config from src.core.types import ActionInfo -from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart -from src.maisaka.monitor_events import emit_replier_request, emit_replier_response +from src.llm_models.payload_content.message import ( + ImageMessagePart, + Message, + MessageBuilder, + RoleType, + TextMessagePart, +) from src.services.llm_service import LLMServiceClient from src.maisaka.context_messages import ( @@ -32,10 +38,11 @@ from src.maisaka.context_messages import ( SessionBackedMessage, ToolResultMessage, ) -from .maisaka_expression_selector import maisaka_expression_selector from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content from src.maisaka.prompt_cli_renderer import PromptCLIVisualizer +from .maisaka_expression_selector import maisaka_expression_selector + logger = get_logger("replyer") @@ -177,7 +184,7 @@ class MaisakaReplyGenerator: return f"{system_prompt}\n\n" + "\n\n".join(sections) def _build_reply_instruction(self) -> str: - return "请自然地回复。请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。" + return "请自然地回复。不要输出多余说明、括号、at 或额外标记,只输出实际要发送的内容。" def _build_multimodal_user_message( self, @@ -342,6 +349,11 @@ class MaisakaReplyGenerator: selected_expression_ids: Optional[List[int]] = None, sub_agent_runner: Optional[Callable[[str], Awaitable[str]]] = None, ) -> Tuple[bool, ReplyGenerationResult]: + + def finalize(success_value: bool) -> Tuple[bool, ReplyGenerationResult]: + result.monitor_detail = build_reply_monitor_detail(result) + return success_value, result + del available_actions del chosen_actions del extra_info @@ -352,9 +364,10 @@ class MaisakaReplyGenerator: del unknown_words result = ReplyGenerationResult() + overall_started_at = time.perf_counter() if chat_history is None: result.error_message = "聊天历史为空" - return False, result + return finalize(False) logger.info( f"Maisaka 回复器开始生成: 流={stream_id} 原因={reply_reason!r} " @@ -370,7 +383,7 @@ class MaisakaReplyGenerator: if self.express_model is None: logger.error("回复模型未初始化") result.error_message = "回复模型尚未初始化" - return False, result + return finalize(False) try: reply_context = await self._build_reply_context( @@ -382,9 +395,13 @@ class MaisakaReplyGenerator: ) except Exception as exc: import traceback + logger.error(f"构建回复上下文失败: {exc}\n{traceback.format_exc()}") result.error_message = f"构建回复上下文失败: {exc}" - return False, result + result.metrics = GenerationMetrics( + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), + ) + return finalize(False) merged_expression_habits = expression_habits.strip() or reply_context.expression_habits result.selected_expression_ids = ( @@ -397,6 +414,7 @@ class MaisakaReplyGenerator: f"回复上下文完成: 流={stream_id} 已选表达={result.selected_expression_ids!r}" ) + prompt_started_at = time.perf_counter() try: request_messages = self._build_request_messages( chat_history=filtered_history, @@ -406,11 +424,18 @@ class MaisakaReplyGenerator: ) except Exception as exc: import traceback + logger.error(f"构建提示词失败: {exc}\n{traceback.format_exc()}") result.error_message = f"构建提示词失败: {exc}" - return False, result + result.metrics = GenerationMetrics( + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), + ) + return finalize(False) + prompt_ms = round((time.perf_counter() - prompt_started_at) * 1000, 2) prompt_preview = self._build_request_prompt_preview(request_messages) + show_replyer_prompt = bool(getattr(global_config.debug, "show_replyer_prompt", False)) + show_replyer_reasoning = bool(getattr(global_config.debug, "show_replyer_reasoning", False)) def message_factory(_client: object) -> List[Message]: return request_messages @@ -418,7 +443,7 @@ class MaisakaReplyGenerator: result.completion.request_prompt = prompt_preview preview_chat_id = self._resolve_session_id(stream_id) replyer_prompt_section: RenderableType | None = None - if global_config.debug.show_replyer_prompt: + if show_replyer_prompt: replyer_prompt_section = PromptCLIVisualizer.build_text_section( prompt_preview, category="replyer", @@ -428,15 +453,7 @@ class MaisakaReplyGenerator: folded=global_config.debug.fold_maisaka_thinking, ) - started_at = time.perf_counter() - - # 向监控前端广播回复器请求事件 - await emit_replier_request( - session_id=preview_chat_id, - messages=request_messages, - model_name=getattr(self.express_model, "model_name", ""), - ) - + llm_started_at = time.perf_counter() try: generation_result = await self.express_model.generate_response_with_messages( message_factory=message_factory @@ -445,10 +462,13 @@ class MaisakaReplyGenerator: logger.exception("Maisaka 回复器调用失败") result.error_message = str(exc) result.metrics = GenerationMetrics( - overall_ms=round((time.perf_counter() - started_at) * 1000, 2), + prompt_ms=prompt_ms, + llm_ms=round((time.perf_counter() - llm_started_at) * 1000, 2), + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), ) - return False, result + return finalize(False) + llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2) response_text = (generation_result.response or "").strip() result.success = bool(response_text) result.completion = LLMCompletionResult( @@ -457,36 +477,33 @@ class MaisakaReplyGenerator: reasoning_text=generation_result.reasoning or "", model_name=generation_result.model_name or "", tool_calls=generation_result.tool_calls or [], - ) - result.metrics = GenerationMetrics( - overall_ms=round((time.perf_counter() - started_at) * 1000, 2), - ) - - # 向监控前端广播回复器响应事件 - await emit_replier_response( - session_id=preview_chat_id, - content=response_text, - reasoning=generation_result.reasoning or "", - model_name=generation_result.model_name or "", prompt_tokens=generation_result.prompt_tokens, completion_tokens=generation_result.completion_tokens, total_tokens=generation_result.total_tokens, - duration_ms=result.metrics.overall_ms or 0.0, - success=result.success, + ) + result.metrics = GenerationMetrics( + prompt_ms=prompt_ms, + llm_ms=llm_ms, + overall_ms=round((time.perf_counter() - overall_started_at) * 1000, 2), + stage_logs=[ + f"prompt: {prompt_ms} ms", + f"llm: {llm_ms} ms", + ], ) - if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text: - logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}") + if show_replyer_reasoning and result.completion.reasoning_text: + logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}") if not result.success: result.error_message = "回复器返回了空内容" logger.warning("Maisaka 回复器返回了空内容") - return False, result + return finalize(False) logger.info( - f"Maisaka 回复器生成成功: 文本={response_text!r} 总耗时ms={result.metrics.overall_ms} 已选表达={result.selected_expression_ids!r}" + f"Maisaka 回复器生成成功: 文本={response_text!r} " + f"总耗时ms={result.metrics.overall_ms} 已选表达={result.selected_expression_ids!r}" ) - if global_config.debug.show_replyer_prompt or global_config.debug.show_replyer_reasoning: + if show_replyer_prompt or show_replyer_reasoning: summary_lines = [ f"流ID: {preview_chat_id or 'unknown'}", f"耗时: {result.metrics.overall_ms} ms", @@ -497,7 +514,7 @@ class MaisakaReplyGenerator: renderables: List[RenderableType] = [Text("\n".join(summary_lines))] if replyer_prompt_section is not None: renderables.append(replyer_prompt_section) - if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text: + if show_replyer_reasoning and result.completion.reasoning_text: renderables.append( Panel( Text(result.completion.reasoning_text), @@ -523,4 +540,4 @@ class MaisakaReplyGenerator: ) ) result.text_fragments = [response_text] - return True, result + return finalize(True) diff --git a/src/common/data_models/reply_generation_data_models.py b/src/common/data_models/reply_generation_data_models.py index 0f394094..152ea687 100644 --- a/src/common/data_models/reply_generation_data_models.py +++ b/src/common/data_models/reply_generation_data_models.py @@ -1,6 +1,6 @@ """回复生成结果相关数据模型。 -该模块用于描述新版本回复链中的三个层次: +该模块用于描述新版回复链中的三个层次: 1. LLM 原始完成结果。 2. 生成过程中的耗时与调试信息。 @@ -23,13 +23,6 @@ class LLMCompletionResult(BaseDataModel): 该模型只描述模型调用本身的输入与输出,不承载回复切分、 消息序列拼装或表达方式选择等后处理结果。 - - Attributes: - request_prompt: 实际发送给模型的 Prompt 文本。 - response_text: 模型返回的主文本内容。 - reasoning_text: 模型返回的推理内容。 - model_name: 本次请求实际使用的模型名称。 - tool_calls: 模型返回的工具调用列表。 """ request_prompt: str = field( @@ -52,19 +45,23 @@ class LLMCompletionResult(BaseDataModel): default_factory=list, metadata={"description": "模型返回的工具调用列表。"}, ) + prompt_tokens: int = field( + default=0, + metadata={"description": "本次请求的输入 Token 数。"}, + ) + completion_tokens: int = field( + default=0, + metadata={"description": "本次请求的输出 Token 数。"}, + ) + total_tokens: int = field( + default=0, + metadata={"description": "本次请求的总 Token 数。"}, + ) @dataclass class GenerationMetrics(BaseDataModel): - """一次生成流程的耗时与调试指标。 - - Attributes: - prompt_ms: Prompt 构建耗时,单位为毫秒。 - llm_ms: LLM 调用耗时,单位为毫秒。 - overall_ms: 整个生成流程总耗时,单位为毫秒。 - stage_logs: 各阶段的简短耗时日志列表。 - extra: 额外指标字典,用于承载不适合单独升格为字段的监控信息。 - """ + """一次生成流程的耗时与调试指标。""" prompt_ms: Optional[float] = field( default=None, @@ -90,20 +87,7 @@ class GenerationMetrics(BaseDataModel): @dataclass class ReplyGenerationResult(BaseDataModel): - """回复链的最终结构化结果。 - - 该模型用于承接回复器和生成服务合并后的最终产物,供 HFC、 - BrainChat、发送服务和日志系统继续消费。 - - Attributes: - success: 本次回复生成是否成功。 - completion: LLM 原始完成结果。 - metrics: 本次生成的耗时与调试指标。 - selected_expression_ids: 本次选中的表达方式 ID 列表。 - text_fragments: 对模型输出进行切分、规范化后的文本片段列表。 - message_sequence: 最终可直接发送的消息序列。 - error_message: 失败时的错误描述;成功时为空。 - """ + """回复链的最终结构化结果。""" success: bool = field( default=False, @@ -133,10 +117,70 @@ class ReplyGenerationResult(BaseDataModel): default_factory=str, metadata={"description": "失败时的错误描述;成功时通常为空字符串。"}, ) + monitor_detail: Optional[Dict[str, Any]] = field( + default=None, + metadata={"description": "供监控层直接消费的通用 tool 展示详情。"}, + ) + + +def build_reply_monitor_detail(result: ReplyGenerationResult) -> Dict[str, Any]: + """构建 reply 工具统一监控详情结构。""" + + detail: Dict[str, Any] = {} + prompt_text = result.completion.request_prompt.strip() + reasoning_text = result.completion.reasoning_text.strip() + output_text = result.completion.response_text.strip() + + if prompt_text: + detail["prompt_text"] = prompt_text + if reasoning_text: + detail["reasoning_text"] = reasoning_text + if output_text: + detail["output_text"] = output_text + + metrics: Dict[str, Any] = {} + if result.completion.model_name.strip(): + metrics["model_name"] = result.completion.model_name.strip() + if result.completion.prompt_tokens > 0: + metrics["prompt_tokens"] = result.completion.prompt_tokens + if result.completion.completion_tokens > 0: + metrics["completion_tokens"] = result.completion.completion_tokens + if result.completion.total_tokens > 0: + metrics["total_tokens"] = result.completion.total_tokens + if result.metrics.prompt_ms is not None: + metrics["prompt_ms"] = result.metrics.prompt_ms + if result.metrics.llm_ms is not None: + metrics["llm_ms"] = result.metrics.llm_ms + if result.metrics.overall_ms is not None: + metrics["overall_ms"] = result.metrics.overall_ms + if metrics: + detail["metrics"] = metrics + + extra_sections: List[Dict[str, str]] = [] + if result.selected_expression_ids: + extra_sections.append({ + "title": "已选表达方式", + "content": ", ".join(str(item) for item in result.selected_expression_ids), + }) + if result.metrics.stage_logs: + extra_sections.append({ + "title": "阶段日志", + "content": "\n".join(result.metrics.stage_logs), + }) + if result.error_message.strip(): + extra_sections.append({ + "title": "错误信息", + "content": result.error_message.strip(), + }) + if extra_sections: + detail["extra_sections"] = extra_sections + + return detail __all__ = [ "GenerationMetrics", "LLMCompletionResult", "ReplyGenerationResult", + "build_reply_monitor_detail", ] diff --git a/src/maisaka/builtin_tool/reply.py b/src/maisaka/builtin_tool/reply.py index d8c2dd24..73eec23c 100644 --- a/src/maisaka/builtin_tool/reply.py +++ b/src/maisaka/builtin_tool/reply.py @@ -57,6 +57,15 @@ def get_tool_spec() -> ToolSpec: ) +def _build_monitor_metadata(reply_result: object) -> dict[str, object]: + """从 reply 结果中提取统一监控详情。""" + + monitor_detail = getattr(reply_result, "monitor_detail", None) + if isinstance(monitor_detail, dict): + return {"monitor_detail": monitor_detail} + return {} + + async def handle_tool( tool_ctx: BuiltinToolRuntimeContext, invocation: ToolInvocation, @@ -71,7 +80,7 @@ async def handle_tool( if not target_message_id: return tool_ctx.build_failure_result( invocation.tool_name, - "回复工具需要提供有效的 `msg_id` 参数。", + "reply 工具需要提供有效的 `msg_id` 参数。", ) target_message = tool_ctx.runtime._source_messages_by_id.get(target_message_id) @@ -129,6 +138,7 @@ async def handle_tool( "生成可见回复时发生异常。", ) + reply_metadata = _build_monitor_metadata(reply_result) reply_text = reply_result.completion.response_text.strip() if success else "" if not reply_text: logger.warning( @@ -138,6 +148,7 @@ async def handle_tool( return tool_ctx.build_failure_result( invocation.tool_name, "生成可见回复失败。", + metadata=reply_metadata, ) reply_segments = tool_ctx.post_process_reply_text(reply_text) @@ -170,6 +181,7 @@ async def handle_tool( return tool_ctx.build_failure_result( invocation.tool_name, "发送可见回复时发生异常。", + metadata=reply_metadata, ) if not sent: @@ -181,6 +193,7 @@ async def handle_tool( "set_quote": set_quote, "reply_segments": reply_segments, }, + metadata=reply_metadata, ) target_user_info = target_message.message_info.user_info @@ -202,4 +215,5 @@ async def handle_tool( "reply_segments": reply_segments, "target_user_name": target_user_name, }, + metadata=reply_metadata, ) diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index 5cbee7b2..35e9f195 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from datetime import datetime -from time import perf_counter from typing import Any, List, Optional, Sequence import asyncio @@ -11,7 +10,6 @@ import random from pydantic import BaseModel, Field as PydanticField from rich.console import RenderableType -from rich.panel import Panel from src.common.data_models.llm_service_data_models import LLMGenerationOptions from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt @@ -35,7 +33,7 @@ from src.services.llm_service import LLMServiceClient from .builtin_tool import get_builtin_tools from .context_messages import AssistantMessage, LLMContextMessage -from .history_utils import drop_leading_orphan_tool_results +from .history_utils import drop_orphan_tool_results from .prompt_cli_renderer import PromptCLIVisualizer @@ -45,8 +43,10 @@ class ChatResponse: content: Optional[str] tool_calls: List[ToolCall] + request_messages: List[Message] raw_message: AssistantMessage selected_history_count: int + tool_count: int prompt_tokens: int built_message_count: int completion_tokens: int @@ -742,7 +742,6 @@ class MaisakaChatLoopService: folded=global_config.debug.fold_maisaka_thinking, ) - request_started_at = perf_counter() logger.info( "规划器请求开始: " f"已选上下文消息数={len(selected_history)} " @@ -808,8 +807,10 @@ class MaisakaChatLoopService: return ChatResponse( content=final_response or None, tool_calls=final_tool_calls, + request_messages=list(built_messages), raw_message=raw_message, selected_history_count=len(selected_history), + tool_count=len(all_tools), prompt_tokens=prompt_tokens, built_message_count=len(built_messages), completion_tokens=completion_tokens, @@ -846,7 +847,7 @@ class MaisakaChatLoopService: selected_indices.reverse() selected_history = [chat_history[index] for index in selected_indices] selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history) - selected_history, _ = drop_leading_orphan_tool_results(selected_history) + selected_history, _ = drop_orphan_tool_results(selected_history) selection_reason = ( f"上下文裁剪:最近 {effective_context_size} 条 user/assistant 消息," f"实际发送 {len(selected_history)} 条" @@ -890,7 +891,7 @@ class MaisakaChatLoopService: selected_indices.reverse() selected_history = [chat_history[index] for index in selected_indices] selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history) - selected_history, _ = drop_leading_orphan_tool_results(selected_history) + selected_history, _ = drop_orphan_tool_results(selected_history) return ( selected_history, ( @@ -935,10 +936,10 @@ class MaisakaChatLoopService: return filtered_history, hidden_assistant_count @staticmethod - def _drop_leading_orphan_tool_results( + def _drop_orphan_tool_results( selected_history: List[LLMContextMessage], ) -> List[LLMContextMessage]: - """移除窗口前缀中缺少对应 tool_call 的工具结果消息。""" + """移除窗口中缺少对应 tool_call 的工具结果消息。""" - normalized_history, _ = drop_leading_orphan_tool_results(selected_history) + normalized_history, _ = drop_orphan_tool_results(selected_history) return normalized_history diff --git a/src/maisaka/history_utils.py b/src/maisaka/history_utils.py index 664a4211..6b2d086c 100644 --- a/src/maisaka/history_utils.py +++ b/src/maisaka/history_utils.py @@ -78,3 +78,30 @@ def drop_leading_orphan_tool_results( if first_valid_index == 0: return chat_history, 0 return chat_history[first_valid_index:], first_valid_index + + +def drop_orphan_tool_results( + chat_history: list[LLMContextMessage], +) -> tuple[list[LLMContextMessage], int]: + """移除窗口任意位置中缺少对应 tool_call 的工具结果消息。""" + + if not chat_history: + return chat_history, 0 + + available_tool_call_ids = { + tool_call.call_id + for message in chat_history + if isinstance(message, AssistantMessage) + for tool_call in message.tool_calls + if tool_call.call_id + } + + filtered_history: list[LLMContextMessage] = [] + removed_count = 0 + for message in chat_history: + if isinstance(message, ToolResultMessage) and message.tool_call_id not in available_tool_call_ids: + removed_count += 1 + continue + filtered_history.append(message) + + return filtered_history, removed_count diff --git a/src/maisaka/monitor_events.py b/src/maisaka/monitor_events.py index c94baae6..c25d0f30 100644 --- a/src/maisaka/monitor_events.py +++ b/src/maisaka/monitor_events.py @@ -1,74 +1,50 @@ """MaiSaka 实时监控事件广播模块。 -通过统一 WebSocket 将 MaiSaka 推理引擎各阶段的状态实时推送给前端监控界面, -无需落盘 HTML/TXT 中间文件即可在 WebUI 中渲染完整的聊天流推理过程。 +通过统一 WebSocket 将 MaiSaka 推理引擎各阶段状态实时推送给前端监控界面。 """ -from typing import Any, Dict, List, Optional - +from datetime import datetime import time +from typing import Any, Dict, List, Optional from src.common.logger import get_logger logger = get_logger("maisaka_monitor") -# WebSocket 广播使用的业务域与主题 MONITOR_DOMAIN = "maisaka_monitor" MONITOR_TOPIC = "main" -def _serialize_message(message: Any) -> Dict[str, Any]: - """将单条 LLM 消息序列化为可通过 WebSocket 传输的字典。 +def _normalize_payload_value(value: Any) -> Any: + """将事件载荷中的任意值规范化为可序列化结构。""" - 对二进制数据(如图片)仅保留元信息,不传输原始字节以减小带宽占用。 - - Args: - message: 原始消息对象,可以是 dict 或带 role/content 属性的消息实例。 - - Returns: - Dict[str, Any]: 序列化后的消息字典。 - """ - if isinstance(message, dict): - serialized: Dict[str, Any] = { - "role": str(message.get("role", "unknown")), - "content": message.get("content"), - } - if message.get("tool_call_id"): - serialized["tool_call_id"] = message["tool_call_id"] - if message.get("tool_calls"): - serialized["tool_calls"] = _serialize_tool_calls_from_dicts(message["tool_calls"]) - return serialized - - raw_role = getattr(message, "role", "unknown") - role_str = raw_role.value if hasattr(raw_role, "value") else str(raw_role) # type: ignore[union-attr] - - serialized = { - "role": role_str, - "content": _extract_text_content(getattr(message, "content", None)), - } - - tool_call_id = getattr(message, "tool_call_id", None) - if tool_call_id: - serialized["tool_call_id"] = str(tool_call_id) - - tool_calls = getattr(message, "tool_calls", None) - if tool_calls: - serialized["tool_calls"] = _serialize_tool_calls_from_objects(tool_calls) - - return serialized + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, dict): + normalized_dict: Dict[str, Any] = {} + for key, item in value.items(): + normalized_dict[str(key)] = _normalize_payload_value(item) + return normalized_dict + if isinstance(value, (list, tuple, set)): + return [_normalize_payload_value(item) for item in value] + if hasattr(value, "model_dump"): + try: + return _normalize_payload_value(value.model_dump()) + except Exception: + return str(value) + if hasattr(value, "__dict__"): + try: + return _normalize_payload_value(dict(value.__dict__)) + except Exception: + return str(value) + return str(value) def _extract_text_content(content: Any) -> Optional[str]: - """从消息内容中提取纯文本表示。 + """从消息内容中提取纯文本表示。""" - 支持字符串、列表(多模态内容块)等格式,对图片仅保留占位信息。 - - Args: - content: 消息的原始 content 字段。 - - Returns: - Optional[str]: 提取后的文本内容。 - """ if content is None: return None if isinstance(content, str): @@ -91,23 +67,17 @@ def _extract_text_content(content: Any) -> Optional[str]: def _serialize_tool_calls_from_objects(tool_calls: List[Any]) -> List[Dict[str, Any]]: - """将工具调用对象列表序列化为字典列表。 + """将工具调用对象列表序列化为字典列表。""" - Args: - tool_calls: 工具调用对象列表(ToolCall 或类似结构)。 - - Returns: - List[Dict[str, Any]]: 序列化后的工具调用列表。 - """ result: List[Dict[str, Any]] = [] - for tc in tool_calls: + for tool_call in tool_calls: serialized: Dict[str, Any] = { - "id": getattr(tc, "id", None) or getattr(tc, "tool_call_id", ""), - "name": getattr(tc, "func_name", None) or getattr(tc, "name", "unknown"), + "id": getattr(tool_call, "id", None) or getattr(tool_call, "call_id", ""), + "name": getattr(tool_call, "func_name", None) or getattr(tool_call, "name", "unknown"), } - args = getattr(tc, "args", None) or getattr(tc, "arguments", None) + args = getattr(tool_call, "args", None) or getattr(tool_call, "arguments", None) if isinstance(args, dict): - serialized["arguments"] = args + serialized["arguments"] = _normalize_payload_value(args) elif isinstance(args, str): serialized["arguments_raw"] = args result.append(serialized) @@ -115,73 +85,101 @@ def _serialize_tool_calls_from_objects(tool_calls: List[Any]) -> List[Dict[str, def _serialize_tool_calls_from_dicts(tool_calls: List[Any]) -> List[Dict[str, Any]]: - """将工具调用字典列表标准化为可传输格式。 + """将工具调用字典列表标准化为可传输格式。""" - Args: - tool_calls: 工具调用字典列表。 - - Returns: - List[Dict[str, Any]]: 标准化后的工具调用列表。 - """ result: List[Dict[str, Any]] = [] - for tc in tool_calls: - if isinstance(tc, dict): + for tool_call in tool_calls: + if isinstance(tool_call, dict): result.append({ - "id": tc.get("id", ""), - "name": tc.get("name", tc.get("func_name", "unknown")), - "arguments": tc.get("arguments", tc.get("args", {})), - }) - else: - result.append({ - "id": getattr(tc, "id", ""), - "name": getattr(tc, "func_name", "unknown"), - "arguments": getattr(tc, "args", {}), + "id": str(tool_call.get("id", "")), + "name": str(tool_call.get("name", tool_call.get("func_name", "unknown"))), + "arguments": _normalize_payload_value(tool_call.get("arguments", tool_call.get("args", {}))), }) + continue + + result.append({ + "id": str(getattr(tool_call, "id", getattr(tool_call, "call_id", ""))), + "name": str(getattr(tool_call, "func_name", getattr(tool_call, "name", "unknown"))), + "arguments": _normalize_payload_value(getattr(tool_call, "args", getattr(tool_call, "arguments", {}))), + }) return result +def _serialize_message(message: Any) -> Dict[str, Any]: + """将单条消息序列化为可通过 WebSocket 传输的字典。""" + + if isinstance(message, dict): + serialized: Dict[str, Any] = { + "role": str(message.get("role", "unknown")), + "content": _extract_text_content(message.get("content")), + } + if message.get("tool_call_id"): + serialized["tool_call_id"] = str(message["tool_call_id"]) + if message.get("tool_calls"): + serialized["tool_calls"] = _serialize_tool_calls_from_dicts(message["tool_calls"]) + return serialized + + raw_role = getattr(message, "role", "unknown") + role_str = raw_role.value if hasattr(raw_role, "value") else str(raw_role) + + serialized = { + "role": role_str, + "content": _extract_text_content(getattr(message, "content", None)), + } + tool_call_id = getattr(message, "tool_call_id", None) + if tool_call_id: + serialized["tool_call_id"] = str(tool_call_id) + + tool_calls = getattr(message, "tool_calls", None) + if tool_calls: + serialized["tool_calls"] = _serialize_tool_calls_from_objects(tool_calls) + + return serialized + + def _serialize_messages(messages: List[Any]) -> List[Dict[str, Any]]: - """批量序列化消息列表。 + """批量序列化消息列表。""" - Args: - messages: 原始消息列表。 + return [_serialize_message(message) for message in messages] - Returns: - List[Dict[str, Any]]: 序列化后的消息字典列表。 - """ - return [_serialize_message(msg) for msg in messages] + +def _serialize_tool_results(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """标准化最终 planner 卡中的工具结果列表。""" + + serialized_tools: List[Dict[str, Any]] = [] + for tool in tools: + serialized_tool = { + "tool_call_id": str(tool.get("tool_call_id", "")), + "tool_name": str(tool.get("tool_name", "")), + "tool_args": _normalize_payload_value(tool.get("tool_args", {})), + "success": bool(tool.get("success", False)), + "duration_ms": float(tool.get("duration_ms", 0.0) or 0.0), + "summary": str(tool.get("summary", "")), + } + detail = tool.get("detail") + if detail is not None: + serialized_tool["detail"] = _normalize_payload_value(detail) + serialized_tools.append(serialized_tool) + return serialized_tools async def _broadcast(event: str, data: Dict[str, Any]) -> None: - """通过统一 WebSocket 管理器向所有订阅了 maisaka_monitor 主题的连接广播事件。 + """通过统一 WebSocket 管理器向监控主题广播事件。""" - 延迟导入 websocket_manager 以避免循环依赖。 - - Args: - event: 事件名称。 - data: 事件数据。 - """ try: from src.webui.routers.websocket.manager import websocket_manager subscription_key = f"{MONITOR_DOMAIN}:{MONITOR_TOPIC}" total_connections = len(websocket_manager.connections) subscriber_count = sum( - 1 for conn in websocket_manager.connections.values() - if subscription_key in conn.subscriptions + 1 + for connection in websocket_manager.connections.values() + if subscription_key in connection.subscriptions ) - - # 诊断:打印 manager 对象 id 和连接状态 logger.info( f"[诊断] _broadcast: manager_id={id(websocket_manager)} " f"总连接={total_connections} 订阅者={subscriber_count} event={event}" ) - if subscriber_count == 0 and total_connections > 0: - for cid, conn in websocket_manager.connections.items(): - logger.info( - f"[诊断] 连接={cid[:8]}… 订阅={conn.subscriptions}" - ) - await websocket_manager.broadcast_to_topic( domain=MONITOR_DOMAIN, topic=MONITOR_TOPIC, @@ -193,12 +191,8 @@ async def _broadcast(event: str, data: Dict[str, Any]) -> None: async def emit_session_start(session_id: str, session_name: str) -> None: - """广播会话开始事件。 + """广播会话开始事件。""" - Args: - session_id: 聊天流 ID。 - session_name: 聊天流显示名称。 - """ await _broadcast("session.start", { "session_id": session_id, "session_name": session_name, @@ -213,17 +207,8 @@ async def emit_message_ingested( message_id: str, timestamp: float, ) -> None: - """广播新消息注入事件。 + """广播新消息注入事件。""" - 当新的用户消息被纳入 MaiSaka 推理上下文时触发。 - - Args: - session_id: 聊天流 ID。 - speaker_name: 发言者名称。 - content: 消息文本内容。 - message_id: 消息 ID。 - timestamp: 消息时间戳。 - """ await _broadcast("message.ingested", { "session_id": session_id, "speaker_name": speaker_name, @@ -240,15 +225,8 @@ async def emit_cycle_start( max_rounds: int, history_count: int, ) -> None: - """广播推理循环开始事件。 + """广播推理循环开始事件。""" - Args: - session_id: 聊天流 ID。 - cycle_id: 循环编号。 - round_index: 当前回合索引(从 0 开始)。 - max_rounds: 最大回合数。 - history_count: 当前上下文消息数。 - """ await _broadcast("cycle.start", { "session_id": session_id, "cycle_id": cycle_id, @@ -270,19 +248,8 @@ async def emit_timing_gate_result( selected_history_count: int, duration_ms: float, ) -> None: - """广播 Timing Gate 子代理结果事件。 + """广播 Timing Gate 结果事件。""" - Args: - session_id: 聊天流 ID。 - cycle_id: 循环编号。 - action: 控制决策(continue/wait/no_reply)。 - content: Timing Gate 返回的文本内容。 - tool_calls: 工具调用列表。 - messages: 发送给 Timing Gate 的消息列表。 - prompt_tokens: 输入 Token 数。 - selected_history_count: 已选上下文消息数。 - duration_ms: 执行耗时(毫秒)。 - """ await _broadcast("timing_gate.result", { "session_id": session_id, "cycle_id": cycle_id, @@ -297,177 +264,45 @@ async def emit_timing_gate_result( }) -async def emit_planner_request( +async def emit_planner_finalized( + *, session_id: str, cycle_id: int, - messages: List[Any], - tool_count: int, + request_messages: List[Any], selected_history_count: int, -) -> None: - """广播规划器请求开始事件。 - - 携带完整的消息列表,前端可以增量渲染新增消息。 - - Args: - session_id: 聊天流 ID。 - cycle_id: 循环编号。 - messages: 发送给规划器的完整消息列表。 - tool_count: 可用工具数量。 - selected_history_count: 已选上下文消息数。 - """ - await _broadcast("planner.request", { - "session_id": session_id, - "cycle_id": cycle_id, - "messages": _serialize_messages(messages), - "tool_count": tool_count, - "selected_history_count": selected_history_count, - "timestamp": time.time(), - }) - - -async def emit_planner_response( - session_id: str, - cycle_id: int, - content: Optional[str], - tool_calls: List[Any], + tool_count: int, + planner_content: Optional[str], + planner_tool_calls: List[Any], prompt_tokens: int, completion_tokens: int, total_tokens: int, duration_ms: float, -) -> None: - """广播规划器响应事件。 - - Args: - session_id: 聊天流 ID。 - cycle_id: 循环编号。 - content: 规划器返回的思考文本。 - tool_calls: 规划器返回的工具调用列表。 - prompt_tokens: 输入 Token 数。 - completion_tokens: 输出 Token 数。 - total_tokens: 总 Token 数。 - duration_ms: 执行耗时(毫秒)。 - """ - await _broadcast("planner.response", { - "session_id": session_id, - "cycle_id": cycle_id, - "content": content, - "tool_calls": _serialize_tool_calls_from_objects(tool_calls), - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "duration_ms": duration_ms, - "timestamp": time.time(), - }) - - -async def emit_tool_execution( - session_id: str, - cycle_id: int, - tool_name: str, - tool_args: Dict[str, Any], - result_summary: str, - success: bool, - duration_ms: float, -) -> None: - """广播工具执行结果事件。 - - Args: - session_id: 聊天流 ID。 - cycle_id: 循环编号。 - tool_name: 工具名称。 - tool_args: 工具参数。 - result_summary: 执行结果摘要。 - success: 是否成功。 - duration_ms: 执行耗时(毫秒)。 - """ - await _broadcast("tool.execution", { - "session_id": session_id, - "cycle_id": cycle_id, - "tool_name": tool_name, - "tool_args": tool_args, - "result_summary": result_summary, - "success": success, - "duration_ms": duration_ms, - "timestamp": time.time(), - }) - - -async def emit_cycle_end( - session_id: str, - cycle_id: int, + tools: List[Dict[str, Any]], time_records: Dict[str, float], agent_state: str, ) -> None: - """广播推理循环结束事件。 + """广播一轮 planner 结束后的最终聚合事件。""" - Args: - session_id: 聊天流 ID。 - cycle_id: 循环编号。 - time_records: 各阶段耗时记录。 - agent_state: 循环结束后的代理状态。 - """ - await _broadcast("cycle.end", { + await _broadcast("planner.finalized", { "session_id": session_id, "cycle_id": cycle_id, - "time_records": time_records, - "agent_state": agent_state, - "timestamp": time.time(), - }) - - -async def emit_replier_request( - session_id: str, - messages: List[Any], - model_name: str = "", -) -> None: - """广播回复器请求开始事件。 - - Args: - session_id: 聊天流 ID。 - messages: 发送给回复器的消息列表。 - model_name: 使用的模型名称。 - """ - await _broadcast("replier.request", { - "session_id": session_id, - "messages": _serialize_messages(messages), - "model_name": model_name, - "timestamp": time.time(), - }) - - -async def emit_replier_response( - session_id: str, - content: Optional[str], - reasoning: str, - model_name: str, - prompt_tokens: int, - completion_tokens: int, - total_tokens: int, - duration_ms: float, - success: bool, -) -> None: - """广播回复器响应事件。 - - Args: - session_id: 聊天流 ID。 - content: 回复器生成的文本。 - reasoning: 回复器的思考过程文本。 - model_name: 使用的模型名称。 - prompt_tokens: 输入 Token 数。 - completion_tokens: 输出 Token 数。 - total_tokens: 总 Token 数。 - duration_ms: 执行耗时(毫秒)。 - success: 是否生成成功。 - """ - await _broadcast("replier.response", { - "session_id": session_id, - "content": content, - "reasoning": reasoning, - "model_name": model_name, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "duration_ms": duration_ms, - "success": success, "timestamp": time.time(), + "request": { + "messages": _serialize_messages(request_messages), + "selected_history_count": selected_history_count, + "tool_count": tool_count, + }, + "planner": { + "content": planner_content, + "tool_calls": _serialize_tool_calls_from_objects(planner_tool_calls), + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "duration_ms": duration_ms, + }, + "tools": _serialize_tool_results(tools), + "final_state": { + "time_records": _normalize_payload_value(time_records), + "agent_state": agent_state, + }, }) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 6269b8ec..af4525df 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -36,12 +36,10 @@ from .context_messages import ( ) from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text, drop_leading_orphan_tool_results from .monitor_events import ( - emit_cycle_end, emit_cycle_start, emit_message_ingested, - emit_planner_response, + emit_planner_finalized, emit_timing_gate_result, - emit_tool_execution, ) from .planner_message_utils import build_planner_user_prefix_from_session_message @@ -279,6 +277,7 @@ class MaisakaReasoningEngine: ChatResponse( content=reason, tool_calls=[], + request_messages=[], raw_message=AssistantMessage( content="", timestamp=datetime.now(), @@ -288,6 +287,7 @@ class MaisakaReasoningEngine: sum(1 for message in self._runtime._chat_history if message.count_in_context), self._runtime._max_context_size, ), + tool_count=0, prompt_tokens=0, built_message_count=0, completion_tokens=0, @@ -346,6 +346,9 @@ class MaisakaReasoningEngine: history_count=len(self._runtime._chat_history), ) planner_started_at = 0.0 + planner_duration_ms = 0.0 + response: Optional[ChatResponse] = None + tool_monitor_results: list[dict[str, Any]] = [] try: visual_refresh_started_at = time.time() refreshed_message_count = await self._refresh_chat_history_visual_placeholders() @@ -403,17 +406,6 @@ class MaisakaReasoningEngine: f"回合={round_index + 1} " f"耗时={cycle_detail.time_records['planner']:.3f} 秒" ) - await emit_planner_response( - session_id=self._runtime.session_id, - cycle_id=cycle_detail.cycle_id, - content=response.content, - tool_calls=response.tool_calls, - prompt_tokens=response.prompt_tokens, - completion_tokens=response.completion_tokens, - total_tokens=response.total_tokens, - duration_ms=planner_duration_ms, - ) - reasoning_content = response.content or "" if self._should_replace_reasoning(reasoning_content): response.content = "我应该根据我上面思考的内容进行反思,重新思考我下一步的行动,我需要分析当前场景,对话,以及我可以使用的工具,然后先输出想法再使用工具" @@ -423,10 +415,11 @@ class MaisakaReasoningEngine: self._last_reasoning_content = reasoning_content self._runtime._chat_history.append(response.raw_message) tool_result_summaries: list[str] = [] + tool_monitor_results = [] if response.tool_calls: tool_started_at = time.time() - should_pause, tool_result_summaries = await self._handle_tool_calls( + should_pause, tool_result_summaries, tool_monitor_results = await self._handle_tool_calls( response.tool_calls, response.content or "", anchor_message, @@ -463,13 +456,24 @@ class MaisakaReasoningEngine: ) break finally: - self._end_cycle(cycle_detail) - await emit_cycle_end( - session_id=self._runtime.session_id, - cycle_id=cycle_detail.cycle_id, - time_records=dict(cycle_detail.time_records), - agent_state=self._runtime._agent_state, - ) + completed_cycle = self._end_cycle(cycle_detail) + if response is not None: + await emit_planner_finalized( + session_id=self._runtime.session_id, + cycle_id=cycle_detail.cycle_id, + request_messages=response.request_messages, + selected_history_count=response.selected_history_count, + tool_count=response.tool_count, + planner_content=response.content, + planner_tool_calls=response.tool_calls, + prompt_tokens=response.prompt_tokens, + completion_tokens=response.completion_tokens, + total_tokens=response.total_tokens, + duration_ms=planner_duration_ms, + tools=tool_monitor_results, + time_records=dict(completed_cycle.time_records), + agent_state=self._runtime._agent_state, + ) finally: if self._runtime._agent_state == self._runtime._STATE_RUNNING: self._runtime._agent_state = self._runtime._STATE_STOP @@ -683,7 +687,7 @@ class MaisakaReasoningEngine: def _drop_leading_orphan_tool_results( chat_history: list[LLMContextMessage], ) -> tuple[list[LLMContextMessage], int]: - """清理历史前缀中缺少对应 assistant tool_call 的工具结果消息。""" + """清理历史窗口中缺少对应 assistant tool_call 的工具结果消息。""" return drop_leading_orphan_tool_results(chat_history) @@ -1039,12 +1043,38 @@ class MaisakaReasoningEngine: normalized_content = self._truncate_tool_record_text(history_content, max_length=200) return f"- {tool_call.func_name} {summary_prefix}: {normalized_content}" + def _build_tool_monitor_result( + self, + tool_call: ToolCall, + invocation: ToolInvocation, + result: ToolExecutionResult, + duration_ms: float, + ) -> dict[str, Any]: + """构建 planner.finalized 中单个工具的监控结果。""" + + monitor_detail = result.metadata.get("monitor_detail") + normalized_detail = None + if monitor_detail is not None: + normalized_detail = self._normalize_tool_record_value(monitor_detail) + + return { + "tool_call_id": tool_call.call_id, + "tool_name": tool_call.func_name, + "tool_args": self._normalize_tool_record_value( + invocation.arguments if isinstance(invocation.arguments, dict) else {} + ), + "success": result.success, + "duration_ms": round(duration_ms, 2), + "summary": self._build_tool_result_summary(tool_call, result), + "detail": normalized_detail, + } + async def _handle_tool_calls( self, tool_calls: list[ToolCall], latest_thought: str, anchor_message: SessionMessage, - ) -> tuple[bool, list[str]]: + ) -> tuple[bool, list[str], list[dict[str, Any]]]: """执行一批统一工具调用。 Args: @@ -1057,6 +1087,7 @@ class MaisakaReasoningEngine: """ tool_result_summaries: list[str] = [] + tool_monitor_results: list[dict[str, Any]] = [] if self._runtime._tool_registry is None: for tool_call in tool_calls: @@ -1069,7 +1100,10 @@ class MaisakaReasoningEngine: await self._store_tool_execution_record(invocation, result, None) self._append_tool_execution_result(tool_call, result) tool_result_summaries.append(self._build_tool_result_summary(tool_call, result)) - return False, tool_result_summaries + tool_monitor_results.append( + self._build_tool_monitor_result(tool_call, invocation, result, duration_ms=0.0) + ) + return False, tool_result_summaries, tool_monitor_results execution_context = self._build_tool_execution_context(latest_thought, anchor_message) tool_spec_map = { @@ -1088,24 +1122,17 @@ class MaisakaReasoningEngine: ) self._append_tool_execution_result(tool_call, result) tool_result_summaries.append(self._build_tool_result_summary(tool_call, result)) + tool_monitor_results.append( + self._build_tool_monitor_result(tool_call, invocation, result, tool_duration_ms) + ) # 向监控前端广播工具执行结果 - cycle_id = self._runtime._current_cycle_detail.cycle_id if self._runtime._current_cycle_detail else 0 - await emit_tool_execution( - session_id=self._runtime.session_id, - cycle_id=cycle_id, - tool_name=tool_call.func_name, - tool_args=invocation.arguments if isinstance(invocation.arguments, dict) else {}, - result_summary=result.content[:500] if result.content else (result.error_message or "")[:500], - success=result.success, - duration_ms=tool_duration_ms, - ) if not result.success and tool_call.func_name == "reply": logger.warning(f"{self._runtime.log_prefix} 回复工具未生成可见消息,将继续下一轮循环") if bool(result.metadata.get("pause_execution", False)): - return True, tool_result_summaries + return True, tool_result_summaries, tool_monitor_results - return False, tool_result_summaries + return False, tool_result_summaries, tool_monitor_results diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py new file mode 100644 index 00000000..94f4390f --- /dev/null +++ b/src/memory_system/chat_history_summarizer.py @@ -0,0 +1,1123 @@ +""" +聊天内容概括器 +用于累积、打包和压缩聊天记录 +""" + +import asyncio +import json +import time +import re +import difflib +import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Set +from dataclasses import dataclass, field +from json_repair import repair_json + +from src.chat.message_receive.message import SessionMessage +from src.common.logger import get_logger +from src.config.config import global_config +from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.services.llm_service import LLMServiceClient +from src.services import message_service as message_api +from src.chat.utils.utils import is_bot_self +from src.person_info.person_info import Person +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.prompt.prompt_manager import prompt_manager + +logger = get_logger("chat_history_summarizer") + +HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorizer" + + +@dataclass +class MessageBatch: + """消息批次(用于触发话题检查的原始消息累积)""" + + messages: List[SessionMessage] + start_time: float + end_time: float + + +@dataclass +class TopicCacheItem: + """ + 话题缓存项 + + Attributes: + topic: 话题标题(一句话描述时间、人物、事件和主题) + messages: 与该话题相关的消息字符串列表(已经通过 build 函数转成可读文本) + participants: 涉及到的发言人昵称集合 + no_update_checks: 连续多少次“检查”没有新增内容 + """ + + topic: str + messages: List[str] = field(default_factory=list) + participants: Set[str] = field(default_factory=set) + no_update_checks: int = 0 + + +class ChatHistorySummarizer: + """聊天内容概括器""" + + def __init__(self, session_id: str, check_interval: int = 60): + """ + 初始化聊天内容概括器 + + Args: + session_id: 会话ID + check_interval: 定期检查间隔(秒),默认60秒 + """ + self.session_id = session_id + self._chat_display_name = self._get_chat_display_name() + self.log_prefix = f"[{self._chat_display_name}]" + + # 记录时间点,用于计算新消息 + self.last_check_time = time.time() + + # 记录上一次话题检查的时间,用于判断是否需要触发检查 + self.last_topic_check_time = time.time() + + # 当前累积的消息批次 + self.current_batch: Optional[MessageBatch] = None + + # 话题缓存:topic_str -> TopicCacheItem + # 在内存中维护,并通过本地文件实时持久化 + self.topic_cache: Dict[str, TopicCacheItem] = {} + self._safe_chat_id = self._sanitize_chat_id(self.session_id) + self._topic_cache_file = HIPPO_CACHE_DIR / f"{self._safe_chat_id}.json" + # 注意:批次加载需要异步查询消息,所以在 start() 中调用 + + # LLM请求器,用于压缩聊天内容 + self.summarizer_llm = LLMServiceClient( + task_name="utils", request_type="chat_history_summarizer" + ) + + # 后台循环相关 + self.check_interval = check_interval # 检查间隔(秒) + self._periodic_task: Optional[asyncio.Task] = None + self._running = False + + def _get_chat_display_name(self) -> str: + """获取聊天显示名称""" + try: + chat_name = _chat_manager.get_session_name(self.session_id) + if chat_name: + return chat_name + # 如果获取失败,使用简化的chat_id显示 + if len(self.session_id) > 20: + return f"{self.session_id[:8]}..." + return self.session_id + except Exception: + # 如果获取失败,使用简化的chat_id显示 + if len(self.session_id) > 20: + return f"{self.session_id[:8]}..." + return self.session_id + + def _sanitize_chat_id(self, chat_id: str) -> str: + """用于生成可作为文件名的 chat_id""" + return re.sub(r"[^a-zA-Z0-9_.-]", "_", chat_id) + + def _load_topic_cache_from_disk(self): + """在启动时加载本地话题缓存(同步部分),支持重启后继续""" + try: + if not self._topic_cache_file.exists(): + return + + with self._topic_cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + self.last_topic_check_time = data.get("last_topic_check_time", self.last_topic_check_time) + topics_data = data.get("topics", {}) + loaded_count = 0 + for topic, payload in topics_data.items(): + self.topic_cache[topic] = TopicCacheItem( + topic=topic, + messages=payload.get("messages", []), + participants=set(payload.get("participants", [])), + no_update_checks=payload.get("no_update_checks", 0), + ) + loaded_count += 1 + + if loaded_count: + logger.info(f"{self.log_prefix} 已加载 {loaded_count} 个话题缓存,继续追踪") + except Exception as e: + logger.error(f"{self.log_prefix} 加载话题缓存失败: {e}") + + async def _load_batch_from_disk(self): + """在启动时加载聊天批次,支持重启后继续""" + try: + if not self._topic_cache_file.exists(): + return + + with self._topic_cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + batch_data = data.get("current_batch") + if not batch_data: + return + + start_time = batch_data.get("start_time") + end_time = batch_data.get("end_time") + if not start_time or not end_time: + return + + # 根据时间范围重新查询消息 + messages = message_api.get_messages_by_time_in_chat( + chat_id=self.session_id, + start_time=start_time, + end_time=end_time, + limit=0, + limit_mode="latest", + filter_mai=False, + filter_command=False, + ) + + if messages: + self.current_batch = MessageBatch( + messages=messages, + start_time=start_time, + end_time=end_time, + ) + logger.info(f"{self.log_prefix} 已恢复聊天批次,包含 {len(messages)} 条消息") + except Exception as e: + logger.error(f"{self.log_prefix} 加载聊天批次失败: {e}") + + def _persist_topic_cache(self): + """实时持久化话题缓存和聊天批次,避免重启后丢失""" + try: + # 如果既没有话题缓存也没有批次,删除缓存文件 + if not self.topic_cache and not self.current_batch: + if self._topic_cache_file.exists(): + self._topic_cache_file.unlink() + return + + HIPPO_CACHE_DIR.mkdir(parents=True, exist_ok=True) + data = { + "chat_id": self.session_id, + "last_topic_check_time": self.last_topic_check_time, + "topics": { + topic: { + "messages": item.messages, + "participants": list(item.participants), + "no_update_checks": item.no_update_checks, + } + for topic, item in self.topic_cache.items() + }, + } + + # 保存当前批次的时间范围(如果有) + if self.current_batch: + data["current_batch"] = { + "start_time": self.current_batch.start_time, + "end_time": self.current_batch.end_time, + } + + with self._topic_cache_file.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + except Exception as e: + logger.error(f"{self.log_prefix} 持久化话题缓存失败: {e}") + + async def process(self, current_time: Optional[float] = None): + """ + 处理聊天内容概括 + + Args: + current_time: 当前时间戳,如果为None则使用time.time() + """ + if current_time is None: + current_time = time.time() + + try: + # 获取从上次检查时间到当前时间的新消息 + new_messages = message_api.get_messages_by_time_in_chat( + chat_id=self.session_id, + start_time=self.last_check_time, + end_time=current_time, + limit=0, + limit_mode="latest", + filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言 + filter_command=False, + ) + + if not new_messages: + # 没有新消息,检查是否需要进行“话题检查” + if self.current_batch and self.current_batch.messages: + await self._check_and_run_topic_check(current_time) + self.last_check_time = current_time + return + + logger.debug( + f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" + ) + + # 有新消息,更新最后检查时间 + self.last_check_time = current_time + + # 如果有当前批次,添加新消息 + if self.current_batch: + before_count = len(self.current_batch.messages) + self.current_batch.messages.extend(new_messages) + self.current_batch.end_time = current_time + logger.info( + f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息" + ) + # 更新批次后持久化 + self._persist_topic_cache() + else: + # 创建新批次 + self.current_batch = MessageBatch( + messages=new_messages, + start_time=new_messages[0].timestamp.timestamp() if new_messages else current_time, + end_time=current_time, + ) + logger.debug(f"{self.log_prefix} 新建聊天检查批次: {len(new_messages)} 条消息") + # 创建批次后持久化 + self._persist_topic_cache() + + # 检查是否需要触发“话题检查” + await self._check_and_run_topic_check(current_time) + + except Exception as e: + logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") + import traceback + + traceback.print_exc() + + async def _check_and_run_topic_check(self, current_time: float): + """ + 检查是否需要进行一次“话题检查” + + 触发条件: + - 当前批次消息数 >= 100,或者 + - 距离上一次检查的时间 > 3600 秒(1小时) + """ + if not self.current_batch or not self.current_batch.messages: + return + + messages = self.current_batch.messages + message_count = len(messages) + time_since_last_check = current_time - self.last_topic_check_time + + # 格式化时间差显示 + if time_since_last_check < 60: + time_str = f"{time_since_last_check:.1f}秒" + elif time_since_last_check < 3600: + time_str = f"{time_since_last_check / 60:.1f}分钟" + else: + time_str = f"{time_since_last_check / 3600:.1f}小时" + + logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}") + + # 检查"话题检查"触发条件 + should_check = False + + # 从配置中获取阈值 + message_threshold = global_config.memory.chat_history_topic_check_message_threshold + time_threshold_hours = global_config.memory.chat_history_topic_check_time_hours + min_messages = global_config.memory.chat_history_topic_check_min_messages + time_threshold_seconds = time_threshold_hours * 3600 + + # 条件1: 消息数量达到阈值,触发一次检查 + if message_count >= message_threshold: + should_check = True + logger.info( + f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: {message_threshold}条)" + ) + + # 条件2: 距离上一次检查超过时间阈值且消息数量达到最小阈值,触发一次检查 + elif time_since_last_check > time_threshold_seconds and message_count >= min_messages: + should_check = True + logger.info( + f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: {time_threshold_hours}小时)且消息数量达到 {message_count} 条(阈值: {min_messages}条)" + ) + + if should_check: + await self._run_topic_check_and_update_cache(messages) + # 本批次已经被处理为话题信息,可以清空 + self.current_batch = None + # 更新上一次检查时间,并持久化 + self.last_topic_check_time = current_time + self._persist_topic_cache() + + async def _run_topic_check_and_update_cache(self, messages: List[SessionMessage]): + """ + 执行一次“话题检查”: + 1. 首先确认这段消息里是否有 Bot 发言,没有则直接丢弃本次批次; + 2. 将消息编号并转成字符串,构造 LLM Prompt; + 3. 把历史话题标题列表放入 Prompt,要求 LLM: + - 识别当前聊天中的话题(1 个或多个); + - 为每个话题选出相关消息编号; + - 若话题属于历史话题,则沿用原话题标题; + 4. LLM 返回 JSON:多个 {topic, message_indices}; + 5. 更新本地话题缓存,并根据规则触发“话题打包存储”。 + """ + if not messages: + return + + start_time = messages[0].timestamp.timestamp() + end_time = messages[-1].timestamp.timestamp() + + logger.info( + f"{self.log_prefix} 开始话题检查 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" + ) + + # 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推) + # 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言, + # 说明 bot 没有参与这段对话,不应该记录 + has_bot_message = any( + is_bot_self(msg.platform, msg.message_info.user_info.user_id) for msg in messages + ) + + if not has_bot_message: + logger.info( + f"{self.log_prefix} 当前批次内无 Bot 发言,丢弃本次检查 | 时间范围: {start_time:.2f} - {end_time:.2f}" + ) + return + + # 2. 构造编号后的消息字符串和参与者信息 + numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = ( + self._build_numbered_messages_for_llm(messages) + ) + + # 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次) + existing_topics = list(self.topic_cache.keys()) + max_retries = 3 + attempt = 0 + success = False + topic_to_indices: Dict[str, List[int]] = {} + + while attempt < max_retries: + attempt += 1 + success, topic_to_indices = await self._analyze_topics_with_llm( + numbered_lines=numbered_lines, + existing_topics=existing_topics, + ) + + if success and topic_to_indices: + if attempt > 1: + logger.info( + f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}" + ) + break + + logger.warning( + f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败" + + ("" if attempt >= max_retries else ",准备重试") + ) + + if not success or not topic_to_indices: + logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃") + # 即使识别失败,也认为是一次"检查",但不更新 no_update_checks(保持原状) + return + + # 3.5. 检查新话题是否与历史话题相似(相似度>=90%则使用历史标题) + topic_mapping = self._build_topic_mapping(topic_to_indices, similarity_threshold=0.9) + + # 应用话题映射:将相似的新话题标题替换为历史话题标题 + if topic_mapping: + new_topic_to_indices: Dict[str, List[int]] = {} + for new_topic, indices in topic_to_indices.items(): + # 如果这个新话题需要映射到历史话题 + if new_topic in topic_mapping: + historical_topic = topic_mapping[new_topic] + # 如果历史话题已经存在,合并消息索引 + if historical_topic in new_topic_to_indices: + # 合并索引并去重 + combined_indices = list(set(new_topic_to_indices[historical_topic] + indices)) + new_topic_to_indices[historical_topic] = combined_indices + else: + new_topic_to_indices[historical_topic] = indices + else: + # 不需要映射,保持原样 + new_topic_to_indices[new_topic] = indices + topic_to_indices = new_topic_to_indices + + # 4. 统计哪些话题在本次检查中有新增内容 + updated_topics: Set[str] = set() + + for topic, indices in topic_to_indices.items(): + if not indices: + continue + + item = self.topic_cache.get(topic) + if not item: + # 新话题 + item = TopicCacheItem(topic=topic) + self.topic_cache[topic] = item + + # 收集属于该话题的消息文本(不带编号) + topic_msg_texts: List[str] = [] + new_participants: Set[str] = set() + for idx in indices: + msg_text = index_to_msg_text.get(idx) + if not msg_text: + continue + topic_msg_texts.append(msg_text) + new_participants.update(index_to_participants.get(idx, set())) + + if not topic_msg_texts: + continue + + # 将本次检查中属于该话题的所有消息合并为一个字符串(不带编号) + merged_text = "\n".join(topic_msg_texts) + item.messages.append(merged_text) + item.participants.update(new_participants) + # 本次检查中该话题有更新,重置计数 + item.no_update_checks = 0 + updated_topics.add(topic) + + # 5. 对于本次没有更新的历史话题,no_update_checks + 1 + for topic, item in list(self.topic_cache.items()): + if topic not in updated_topics: + item.no_update_checks += 1 + + # 6. 检查是否有话题需要打包存储 + # 从配置中获取阈值 + no_update_checks_threshold = global_config.memory.chat_history_finalize_no_update_checks + message_count_threshold = global_config.memory.chat_history_finalize_message_count + + topics_to_finalize: List[str] = [] + for topic, item in self.topic_cache.items(): + if item.no_update_checks >= no_update_checks_threshold: + logger.info( + f"{self.log_prefix} 话题[{topic}] 连续 {no_update_checks_threshold} 次检查无新增内容,触发打包存储" + ) + topics_to_finalize.append(topic) + continue + if len(item.messages) > message_count_threshold: + logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 {message_count_threshold},触发打包存储") + topics_to_finalize.append(topic) + + for topic in topics_to_finalize: + item = self.topic_cache.get(topic) + if not item: + continue + try: + await self._finalize_and_store_topic( + topic=topic, + item=item, + # 这里的时间范围尽量覆盖最近一次检查的区间 + start_time=start_time, + end_time=end_time, + ) + finally: + # 无论成功与否,都从缓存中删除,避免重复 + self.topic_cache.pop(topic, None) + + def _find_most_similar_topic( + self, new_topic: str, existing_topics: List[str], similarity_threshold: float = 0.9 + ) -> Optional[tuple[str, float]]: + """ + 查找与给定新话题最相似的历史话题 + + Args: + new_topic: 新话题标题 + existing_topics: 历史话题标题列表 + similarity_threshold: 相似度阈值,默认0.9(90%) + + Returns: + Optional[tuple[str, float]]: 如果找到相似度>=阈值的历史话题,返回(历史话题标题, 相似度), + 否则返回None + """ + if not existing_topics: + return None + + best_match = None + best_similarity = 0.0 + + for existing_topic in existing_topics: + similarity = difflib.SequenceMatcher(None, new_topic, existing_topic).ratio() + if similarity > best_similarity: + best_similarity = similarity + best_match = existing_topic + + # 如果相似度达到阈值,返回匹配结果 + if best_match and best_similarity >= similarity_threshold: + return (best_match, best_similarity) + + return None + + def _build_topic_mapping( + self, topic_to_indices: Dict[str, List[int]], similarity_threshold: float = 0.9 + ) -> Dict[str, str]: + """ + 构建新话题到历史话题的映射(如果相似度>=阈值) + + Args: + topic_to_indices: 新话题到消息索引的映射 + similarity_threshold: 相似度阈值,默认0.9(90%) + + Returns: + Dict[str, str]: 新话题 -> 历史话题的映射字典 + """ + existing_topics_list = list(self.topic_cache.keys()) + topic_mapping: Dict[str, str] = {} + + for new_topic in topic_to_indices.keys(): + # 如果新话题已经在历史话题中,不需要检查 + if new_topic in existing_topics_list: + continue + + # 查找最相似的历史话题 + result = self._find_most_similar_topic(new_topic, existing_topics_list, similarity_threshold) + if result: + historical_topic, similarity = result + topic_mapping[new_topic] = historical_topic + logger.info( + f"{self.log_prefix} 话题相似度检查: '{new_topic}' 与历史话题 '{historical_topic}' 相似度 {similarity:.2%},使用历史标题" + ) + + return topic_mapping + + def _build_numbered_messages_for_llm( + self, messages: List[SessionMessage] + ) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]: + """ + 将消息转为带编号的字符串,供 LLM 选择使用。 + + 返回: + numbered_lines: ["1. xxx", "2. yyy", ...] # 带编号,用于 LLM 选择 + index_to_msg_str: idx -> "idx. xxx" # 带编号,用于 LLM 选择 + index_to_msg_text: idx -> "xxx" # 不带编号,用于最终存储 + index_to_participants: idx -> {nickname1, nickname2, ...} + """ + numbered_lines: List[str] = [] + index_to_msg_str: Dict[int, str] = {} + index_to_msg_text: Dict[int, str] = {} # 不带编号的消息文本 + index_to_participants: Dict[int, Set[str]] = {} + + for idx, msg in enumerate(messages, start=1): + # 使用 build_readable_messages 生成可读文本 + try: + text = message_api.build_readable_messages( + messages=[msg], + replace_bot_name=True, + timestamp_mode="normal_no_YMD", + read_mark=0.0, + truncate=False, + show_actions=False, + ).strip() + except Exception: + # 回退到简单文本 + text = getattr(msg, "processed_plain_text", "") or "" + + # 获取发言人昵称 + participants: Set[str] = set() + try: + platform = msg.platform + user_id = msg.message_info.user_info.user_id + if platform and user_id: + person = Person(platform=platform, user_id=user_id) + if person.person_name: + participants.add(person.person_name) + except Exception: + pass + + # 带编号的字符串(用于 LLM 选择) + line = f"{idx}. {text}" + numbered_lines.append(line) + index_to_msg_str[idx] = line + # 不带编号的文本(用于最终存储) + index_to_msg_text[idx] = text + index_to_participants[idx] = participants + + return numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants + + async def _analyze_topics_with_llm( + self, + numbered_lines: List[str], + existing_topics: List[str], + ) -> tuple[bool, Dict[str, List[int]]]: + """ + 使用 LLM 识别本次检查中的话题,并为每个话题选择相关消息编号。 + + 要求: + - 话题用一句话清晰描述正在发生的事件,包括时间、人物、主要事件和主题; + - 可以有 1 个或多个话题; + - 若某个话题与历史话题列表中的某个话题是同一件事,请直接使用历史话题的字符串; + - 输出 JSON,格式: + [ + { + "topic": "话题标题字符串", + "message_indices": [1, 2, 5] + }, + ... + ] + """ + if not numbered_lines: + return False, {} + + history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" + messages_block = "\n".join(numbered_lines) + + prompt_template = prompt_manager.get_prompt("hippo_topic_analysis") + prompt_template.add_context("history_topics_block", history_topics_block) + prompt_template.add_context("messages_block", messages_block) + prompt = await prompt_manager.render_prompt(prompt_template) + + try: + generation_result = await self.summarizer_llm.generate_response( + prompt=prompt, + options=LLMGenerationOptions(temperature=0.3), + ) + response = generation_result.response + + logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") + logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") + + # 尝试从响应中提取JSON代码块 + json_str = None + json_pattern = r"```json\s*(.*?)\s*```" + matches = re.findall(json_pattern, response, re.DOTALL) + + if matches: + # 找到JSON代码块,使用第一个匹配 + json_str = matches[0].strip() + else: + # 如果没有找到代码块,尝试查找JSON数组的开始和结束位置 + # 查找第一个 [ 和最后一个 ] + start_idx = response.find("[") + end_idx = response.rfind("]") + if start_idx != -1 and end_idx != -1 and end_idx > start_idx: + json_str = response[start_idx : end_idx + 1].strip() + else: + # 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记) + json_str = response.strip() + json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) + json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) + json_str = json_str.strip() + + # 使用json_repair修复可能的JSON错误 + if json_str: + try: + repaired_json = repair_json(json_str) + result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json + except Exception as repair_error: + # 如果repair失败,尝试直接解析 + logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}") + result = json.loads(json_str) + else: + raise ValueError("无法从响应中提取JSON内容") + + if not isinstance(result, list): + logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}") + return False, {} + + topic_to_indices: Dict[str, List[int]] = {} + for item in result: + if not isinstance(item, dict): + continue + topic = item.get("topic") + indices = item.get("message_indices") or item.get("messages") or [] + if not topic or not isinstance(topic, str): + continue + if isinstance(indices, list): + valid_indices: List[int] = [] + for v in indices: + try: + iv = int(v) + if iv > 0: + valid_indices.append(iv) + except (TypeError, ValueError): + continue + if valid_indices: + topic_to_indices[topic] = valid_indices + + return True, topic_to_indices + + except Exception as e: + logger.error(f"{self.log_prefix} 话题识别 LLM 调用或解析失败: {e}") + logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") + return False, {} + + async def _finalize_and_store_topic( + self, + topic: str, + item: TopicCacheItem, + start_time: float, + end_time: float, + ): + """ + 对某个话题进行最终打包存储: + 1. 将 messages(list[str]) 拼接为 original_text; + 2. 使用 LLM 对 original_text 进行总结,得到 summary 和 keywords,theme 直接使用话题字符串; + 3. 写入数据库 ChatHistory; + 4. 完成后,调用方会从缓存中删除该话题。 + """ + if not item.messages: + logger.info(f"{self.log_prefix} 话题[{topic}] 无消息内容,跳过打包") + return + + original_text = "\n".join(item.messages) + + logger.info( + f"{self.log_prefix} 开始将聊天记录构建成记忆:[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" + ) + + # 使用 LLM 进行总结(基于话题名),带重试机制 + max_retries = 3 + attempt = 0 + success = False + keywords = [] + summary = "" + + while attempt < max_retries: + attempt += 1 + success, keywords, summary = await self._compress_with_llm(original_text, topic) + + if success and keywords and summary: + # 成功获取到有效的 keywords 和 summary + if attempt > 1: + logger.info(f"{self.log_prefix} 话题[{topic}] LLM 概括在第 {attempt} 次重试后成功") + break + + if attempt < max_retries: + logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败(第 {attempt} 次尝试),准备重试") + else: + logger.error(f"{self.log_prefix} 话题[{topic}] LLM 概括连续 {max_retries} 次失败,放弃存储") + + if not success or not keywords or not summary: + logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库") + return + + participants = list(item.participants) + + await self._store_to_database( + start_time=start_time, + end_time=end_time, + original_text=original_text, + participants=participants, + theme=topic, # 主题直接使用话题名 + keywords=keywords, + summary=summary, + ) + + logger.info( + f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}" + ) + + async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str]: + """ + 使用LLM压缩聊天内容(用于单个话题的最终总结) + + Args: + original_text: 聊天记录原文 + topic: 话题名称 + + Returns: + tuple[bool, List[str], str]: (是否成功, 关键词列表, 概括) + """ + prompt_template = prompt_manager.get_prompt("hippo_topic_summary") + prompt_template.add_context("topic", topic) + prompt_template.add_context("original_text", original_text) + prompt = await prompt_manager.render_prompt(prompt_template) + + try: + generation_result = await self.summarizer_llm.generate_response(prompt=prompt) + response = generation_result.response + + # 解析JSON响应 + json_str = response.strip() + json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) + json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) + json_str = json_str.strip() + + # 查找JSON对象的开始与结束 + start_idx = json_str.find("{") + if start_idx == -1: + raise ValueError("未找到JSON对象开始标记") + + end_idx = json_str.rfind("}") + if end_idx == -1 or end_idx <= start_idx: + logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复") + extracted_json = json_str[start_idx:] + else: + extracted_json = json_str[start_idx : end_idx + 1] + + def _parse_with_quote_fix(payload: str) -> Dict[str, Any]: + fixed_chars: List[str] = [] + in_string = False + escape_next = False + i = 0 + while i < len(payload): + char = payload[i] + if escape_next: + fixed_chars.append(char) + escape_next = False + elif char == "\\": + fixed_chars.append(char) + escape_next = True + elif char == '"' and not escape_next: + fixed_chars.append(char) + in_string = not in_string + elif in_string and char in {"“", "”"}: + # 在字符串值内部,将中文引号替换为转义的英文引号 + fixed_chars.append('\\"') + else: + fixed_chars.append(char) + i += 1 + + repaired = "".join(fixed_chars) + return json.loads(repaired) + + try: + result = json.loads(extracted_json) + except json.JSONDecodeError: + try: + repaired_json = repair_json(extracted_json) + if isinstance(repaired_json, str): + result = json.loads(repaired_json) + else: + result = repaired_json + except Exception as repair_error: + logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}") + result = _parse_with_quote_fix(extracted_json) + + keywords = result.get("keywords", []) + summary = result.get("summary", "") + + # 检查必需字段是否为空 + if not keywords or not summary: + logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少必需字段,原文\n{response}") + # 返回失败,和模型出错一样,让上层进行重试 + return False, [], "" + + # 确保keywords是列表 + if isinstance(keywords, str): + keywords = [keywords] + + return True, keywords, summary + + except Exception as e: + logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") + logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") + # 返回失败标志和默认值 + return False, [], "压缩失败,无法生成概括" + + async def _store_to_database( + self, + start_time: float, + end_time: float, + original_text: str, + participants: List[str], + theme: str, + keywords: List[str], + summary: str, + ): + """存储到数据库""" + try: + from src.common.database.database_model import ChatHistory + from src.services import database_service as database_api + + # 准备数据 + data = { + "session_id": self.session_id, + "start_timestamp": datetime.fromtimestamp(start_time), + "end_timestamp": datetime.fromtimestamp(end_time), + "original_messages": original_text, + "participants": json.dumps(participants, ensure_ascii=False), + "theme": theme, + "keywords": json.dumps(keywords, ensure_ascii=False), + "summary": summary, + "query_count": 0, + "query_forget_count": 0, + } + + saved_record = await database_api.db_save( + ChatHistory, + data=data, + ) + + if saved_record: + logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库") + else: + logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") + + if saved_record and saved_record.get("id") is not None: + await self._import_to_long_term_memory( + record_id=int(saved_record["id"]), + theme=theme, + summary=summary, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, + ) + + except Exception as e: + logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") + import traceback + + traceback.print_exc() + raise + + async def _import_to_long_term_memory( + self, + record_id: int, + theme: str, + summary: str, + participants: List[str], + start_time: float, + end_time: float, + original_text: str, + ): + """ + 将聊天历史总结导入到统一长期记忆 + + Args: + record_id: chat_history 主键 + theme: 话题主题 + summary: 概括内容 + participants: 参与者列表 + start_time: 开始时间 + end_time: 结束时间 + original_text: 原始文本(可能很长,需要截断) + """ + try: + from src.services.memory_service import memory_service + session = _chat_manager.get_session_by_session_id(self.session_id) + session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" + session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" + + content_parts = [] + if theme: + content_parts.append(f"主题:{theme}") + if summary: + content_parts.append(f"概括:{summary}") + if participants: + participants_text = "、".join(participants) + content_parts.append(f"参与者:{participants_text}") + content_to_import = "\n".join(content_parts) + + if not content_to_import.strip(): + logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底") + await self._fallback_import_to_long_term_memory( + record_id=record_id, + theme=theme, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, + ) + return + + result = await memory_service.ingest_summary( + external_id=f"chat_history:{record_id}", + chat_id=self.session_id, + text=content_to_import, + participants=participants, + time_start=start_time, + time_end=end_time, + tags=[theme] if theme else [], + metadata={"theme": theme, "original_text_length": len(original_text or "")}, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}") + else: + logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}") + else: + logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}") + await self._fallback_import_to_long_term_memory( + record_id=record_id, + theme=theme, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, + ) + + except Exception as e: + logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True) + + async def _fallback_import_to_long_term_memory( + self, + *, + record_id: int, + theme: str, + participants: List[str], + start_time: float, + end_time: float, + original_text: str, + ) -> None: + try: + from src.services.memory_service import memory_service + session = _chat_manager.get_session_by_session_id(self.session_id) + session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" + session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" + + result = await memory_service.ingest_summary( + external_id=f"chat_history:{record_id}", + chat_id=self.session_id, + text="", + participants=participants, + time_start=start_time, + time_end=end_time, + tags=[theme] if theme else [], + metadata={ + "theme": theme, + "original_text_length": len(original_text or ""), + "generate_from_chat": True, + "context_length": global_config.memory.chat_history_topic_check_message_threshold, + }, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}") + else: + logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}") + else: + logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}") + except Exception as exc: + logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True) + + async def start(self): + """启动后台定期检查循环""" + if self._running: + logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动") + return + + # 加载聊天批次(如果有) + await self._load_batch_from_disk() + + self._running = True + self._periodic_task = asyncio.create_task(self._periodic_check_loop()) + logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒") + + async def stop(self): + """停止后台定期检查循环""" + self._running = False + if self._periodic_task: + self._periodic_task.cancel() + try: + await self._periodic_task + except asyncio.CancelledError: + pass + self._periodic_task = None + logger.info(f"{self.log_prefix} 已停止后台定期检查循环") + + async def _periodic_check_loop(self): + """后台定期检查循环""" + try: + while self._running: + # 执行一次检查 + await self.process() + + # 等待指定间隔后再次检查 + await asyncio.sleep(self.check_interval) + except asyncio.CancelledError: + logger.info(f"{self.log_prefix} 后台检查循环被取消") + raise + except Exception as e: + logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") + import traceback + + traceback.print_exc() + self._running = False diff --git a/tests/test_maisaka_orphan_tool_results.py b/tests/test_maisaka_orphan_tool_results.py new file mode 100644 index 00000000..f7edcb50 --- /dev/null +++ b/tests/test_maisaka_orphan_tool_results.py @@ -0,0 +1,49 @@ +from datetime import datetime + +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent +from src.llm_models.payload_content.tool_option import ToolCall +from src.maisaka.chat_loop_service import MaisakaChatLoopService +from src.maisaka.context_messages import AssistantMessage, SessionBackedMessage, ToolResultMessage + + +def _build_user_message(text: str) -> SessionBackedMessage: + return SessionBackedMessage( + raw_message=MessageSequence([TextComponent(text)]), + visible_text=text, + timestamp=datetime.now(), + ) + + +def test_select_llm_context_messages_drops_orphan_tool_results_anywhere() -> None: + assistant_message = AssistantMessage( + content="", + timestamp=datetime.now(), + tool_calls=[ToolCall(call_id="call_1", func_name="wait", args={"seconds": 30})], + ) + orphan_tool_message = ToolResultMessage( + content="当前对话循环已暂停,等待新消息到来。", + timestamp=datetime.now(), + tool_call_id="orphan_call", + ) + matched_tool_message = ToolResultMessage( + content="等待 30 秒。", + timestamp=datetime.now(), + tool_call_id="call_1", + tool_name="wait", + ) + chat_history = [ + _build_user_message("第一条消息"), + orphan_tool_message, + assistant_message, + matched_tool_message, + _build_user_message("第二条消息"), + ] + + selected_history, _ = MaisakaChatLoopService.select_llm_context_messages( + chat_history, + max_context_size=8, + ) + + assert orphan_tool_message not in selected_history + assert assistant_message in selected_history + assert matched_tool_message in selected_history