From 01ba4f55c2ccbc93d2de18a0038addf0c0b77a6c Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Mon, 13 Apr 2026 18:57:50 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E4=B8=8D=E5=86=8D?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=9A=84=E8=81=8A=E5=A4=A9=E6=80=BB=E7=BB=93?= =?UTF-8?q?=EF=BC=8C=E7=A7=BB=E9=99=A4=E8=B7=AF=E5=BE=84=E6=98=BE=E7=A4=BA?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E9=A1=B9=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E5=99=A8=E9=94=99=E5=88=A4=E7=9A=84=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...t_chat_history_summarizer_memory_import.py | 148 -- ...test_group_chat_stream_memory_benchmark.py | 1187 ----------------- .../test_long_novel_memory_benchmark.py | 687 ---------- .../test_long_novel_memory_benchmark_live.py | 342 ----- .../test_memory_flow_service.py | 74 +- ...real_dialogue_business_flow_integration.py | 324 ----- .../test_real_dialogue_business_flow_live.py | 301 ----- src/chat/message_receive/bot.py | 7 - src/common/logger_color_and_mapping.py | 1 - src/config/config.py | 2 +- src/config/official_configs.py | 89 -- src/maisaka/builtin_tool/send_emoji.py | 67 +- src/maisaka/chat_loop_service.py | 2 - src/maisaka/display/prompt_cli_renderer.py | 14 +- src/maisaka/runtime.py | 1 - src/memory_system/chat_history_summarizer.py | 1123 ---------------- src/services/memory_flow_service.py | 48 +- 17 files changed, 22 insertions(+), 4395 deletions(-) delete mode 100644 pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py delete mode 100644 pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py delete mode 100644 pytests/A_memorix_test/test_long_novel_memory_benchmark.py delete mode 100644 pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py delete mode 100644 pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py delete mode 100644 pytests/A_memorix_test/test_real_dialogue_business_flow_live.py delete mode 100644 src/memory_system/chat_history_summarizer.py diff --git a/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py b/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py deleted file mode 100644 index 0f084ece..00000000 --- a/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py +++ /dev/null @@ -1,148 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.memory_system import chat_history_summarizer as summarizer_module - - -def _build_summarizer() -> summarizer_module.ChatHistorySummarizer: - summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer) - summarizer.session_id = "session-1" - summarizer.log_prefix = "[session-1]" - return summarizer - - -@pytest.mark.asyncio -async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch): - calls = [] - summarizer = _build_summarizer() - - async def fake_ingest_summary(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) - - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), - ) - monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8))) - monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) - - await summarizer._import_to_long_term_memory( - record_id=1, - theme="旅行计划", - summary="我们讨论了春游安排", - participants=["Alice", "Bob"], - start_time=1.0, - end_time=2.0, - original_text="long text", - ) - - assert len(calls) == 1 - payload = calls[0] - assert payload["external_id"] == "chat_history:1" - assert payload["chat_id"] == "session-1" - assert payload["participants"] == ["Alice", "Bob"] - assert payload["respect_filter"] is True - assert payload["user_id"] == "user-1" - assert payload["group_id"] == "" - assert "主题:旅行计划" in payload["text"] - assert "概括:我们讨论了春游安排" in payload["text"] - - -@pytest.mark.asyncio -async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch): - summarizer = _build_summarizer() - fallback_calls = [] - - async def fake_fallback(**kwargs): - fallback_calls.append(kwargs) - - summarizer._fallback_import_to_long_term_memory = fake_fallback - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), - ) - - await summarizer._import_to_long_term_memory( - record_id=2, - theme="", - summary="", - participants=[], - start_time=10.0, - end_time=20.0, - original_text="raw chat", - ) - - assert len(fallback_calls) == 1 - assert fallback_calls[0]["record_id"] == 2 - assert fallback_calls[0]["original_text"] == "raw chat" - - -@pytest.mark.asyncio -async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch): - summarizer = _build_summarizer() - fallback_calls = [] - - async def fake_ingest_summary(**kwargs): - return SimpleNamespace(success=False, detail="boom", stored_ids=[]) - - async def fake_fallback(**kwargs): - fallback_calls.append(kwargs) - - summarizer._fallback_import_to_long_term_memory = fake_fallback - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")), - ) - monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) - - await summarizer._import_to_long_term_memory( - record_id=3, - theme="电影", - summary="聊了电影推荐", - participants=["Alice"], - start_time=3.0, - end_time=4.0, - original_text="raw", - ) - - assert len(fallback_calls) == 1 - assert fallback_calls[0]["theme"] == "电影" - - -@pytest.mark.asyncio -async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch): - calls = [] - summarizer = _build_summarizer() - - async def fake_ingest_summary(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[]) - - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")), - ) - monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12))) - monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) - - await summarizer._fallback_import_to_long_term_memory( - record_id=4, - theme="工作", - participants=["Alice"], - start_time=5.0, - end_time=6.0, - original_text="a" * 128, - ) - - assert len(calls) == 1 - metadata = calls[0]["metadata"] - assert metadata["generate_from_chat"] is True - assert metadata["context_length"] == 12 - assert calls[0]["respect_filter"] is True - diff --git a/pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py b/pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py deleted file mode 100644 index b20f4e5d..00000000 --- a/pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py +++ /dev/null @@ -1,1187 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -import os -import re -import sys -import tempfile -import types -import typing -from datetime import datetime -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import numpy as np - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -PLUGINS_ROOT = PROJECT_ROOT / "plugins" -SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" - -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) -if str(PLUGINS_ROOT) not in sys.path: - sys.path.insert(0, str(PLUGINS_ROOT)) -if str(SDK_ROOT) not in sys.path: - sys.path.insert(0, str(SDK_ROOT)) - -if "maibot_sdk" not in sys.modules: - maibot_sdk = types.ModuleType("maibot_sdk") - maibot_sdk_types = types.ModuleType("maibot_sdk.types") - - class _FakeMaiBotPlugin: - def __init__(self, *args, **kwargs) -> None: - del args, kwargs - - def _fake_tool(*decorator_args, **decorator_kwargs): - del decorator_args, decorator_kwargs - - def wrapper(func): - return func - - return wrapper - - class _FakeToolParameterInfo: - def __init__(self, name: str, param_type: str, description: str, required: bool) -> None: - self.name = name - self.param_type = param_type - self.description = description - self.required = required - - class _FakeToolParamType: - STRING = "string" - INTEGER = "integer" - FLOAT = "float" - BOOLEAN = "boolean" - - maibot_sdk.MaiBotPlugin = _FakeMaiBotPlugin - maibot_sdk.Tool = _fake_tool - maibot_sdk_types.ToolParameterInfo = _FakeToolParameterInfo - maibot_sdk_types.ToolParamType = _FakeToolParamType - sys.modules["maibot_sdk"] = maibot_sdk - sys.modules["maibot_sdk.types"] = maibot_sdk_types - -try: - import aiohttp # type: ignore -except Exception: - aiohttp = types.ModuleType("aiohttp") - - class _FakeAioHttpClientError(Exception): - pass - - aiohttp.ClientError = _FakeAioHttpClientError - sys.modules["aiohttp"] = aiohttp - -try: - import openai # type: ignore -except Exception: - openai = types.ModuleType("openai") - - class _FakeOpenAIConnectionError(Exception): - pass - - class _FakeOpenAITimeoutError(Exception): - pass - - openai.APIConnectionError = _FakeOpenAIConnectionError - openai.APITimeoutError = _FakeOpenAITimeoutError - sys.modules["openai"] = openai - -if "eval_type_backport" not in sys.modules: - eval_type_backport_module = types.ModuleType("eval_type_backport") - - def _rewrite_union(expr: str) -> str: - clean = str(expr or "").strip() - if "|" not in clean: - return clean - parts = [part.strip() for part in clean.split("|")] - if len(parts) == 2 and "None" in parts: - target = parts[0] if parts[1] == "None" else parts[1] - return f"typing.Optional[{target}]" - return f"typing.Union[{', '.join(parts)}]" - - def _eval_type_backport(value, globalns=None, localns=None, try_default=False): - del try_default - expr = getattr(value, "__forward_arg__", value) - if not isinstance(expr, str): - expr = str(expr) - gns = dict(globalns or {}) - lns = dict(localns or {}) - builtin_aliases = { - "dict": typing.Dict, - "list": typing.List, - "set": typing.Set, - "tuple": typing.Tuple, - } - for namespace in (gns, lns): - namespace.setdefault("typing", typing) - namespace.setdefault("Any", typing.Any) - namespace.setdefault("Optional", typing.Optional) - namespace.setdefault("Union", typing.Union) - namespace.setdefault("Literal", getattr(typing, "Literal", None)) - namespace.update({key: value for key, value in builtin_aliases.items() if key not in namespace}) - - try: - return eval(expr, gns, lns) - except TypeError: - return eval(_rewrite_union(expr), gns, lns) - - eval_type_backport_module.eval_type_backport = _eval_type_backport - sys.modules["eval_type_backport"] = eval_type_backport_module - -from A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import MemorySearchResult, memory_service - - -def _resolve_benchmark_paths() -> tuple[Path, Path]: - configured = str(os.environ.get("A_MEMORIX_BENCHMARK_DATA_FILE", "") or "").strip() - if configured: - data_file = Path(configured).expanduser() - if not data_file.is_absolute(): - data_file = (Path.cwd() / data_file).resolve() - else: - data_file = Path(__file__).parent / "data" / "benchmarks" / "group_chat_stream_memory_benchmark.json" - - report_file = Path(__file__).parent / "data" / "benchmarks" / "results" / f"{data_file.stem}_report.json" - return data_file, report_file - - -DATA_FILE, REPORT_FILE = _resolve_benchmark_paths() - - -def _load_benchmark_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _PatchManager: - def __init__(self) -> None: - self._records: List[tuple[Any, str, Any, bool]] = [] - - def setattr(self, target: Any, name: str, value: Any) -> None: - existed = hasattr(target, name) - original = getattr(target, name) if existed else None - self._records.append((target, name, original, existed)) - setattr(target, name, value) - - def undo(self) -> None: - while self._records: - target, name, original, existed = self._records.pop() - if existed: - setattr(target, name, original) - else: - delattr(target, name) - - -class _FakeEmbeddingAdapter: - def __init__(self, dimension: int = 32) -> None: - self.dimension = dimension - - async def _detect_dimension(self) -> int: - return self.dimension - - async def encode(self, texts, dimensions=None): - dim = int(dimensions or self.dimension) - if isinstance(texts, str): - sequence = [texts] - single = True - else: - sequence = list(texts) - single = False - - rows = [] - for text in sequence: - vec = np.zeros(dim, dtype=np.float32) - for ch in str(text or ""): - code = ord(ch) - vec[code % dim] += 1.0 - vec[(code * 7) % dim] += 0.5 - if not vec.any(): - vec[0] = 1.0 - norm = np.linalg.norm(vec) - if norm > 0: - vec = vec / norm - rows.append(vec) - payload = np.vstack(rows) - return payload[0] if single else payload - - -class _KnownPerson: - def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None: - self.person_id = person_id - self.is_known = person_id in reverse_registry - self.person_name = reverse_registry.get(person_id, "") - self._registry = registry - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]: - for _ in range(max_rounds): - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(max(0.01, float(sleep_seconds))) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result: MemorySearchResult) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -def _keyword_hits(text: str, keywords: List[str]) -> int: - haystack = str(text or "") - return sum(1 for keyword in keywords if keyword in haystack) - - -def _keyword_recall(text: str, keywords: List[str]) -> float: - if not keywords: - return 1.0 - return _keyword_hits(text, keywords) / float(len(keywords)) - - -def _hit_blob(hit) -> str: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - return "\n".join( - [ - str(hit.content or ""), - str(hit.title or ""), - str(hit.source or ""), - json.dumps(meta, ensure_ascii=False), - ] - ) - - -def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int: - for index, hit in enumerate(search_result.hits[:5], start=1): - if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))): - return index - return 0 - - -def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str: - return "\n".join( - ( - f"{item.get('title', '')}\n" - f"{item.get('summary', '')}\n" - f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n" - f"{json.dumps(item.get('participants', []), ensure_ascii=False)}" - ) - for item in items - ) - - -def _episode_blob_from_hits(search_result: MemorySearchResult) -> str: - chunks = [] - for hit in search_result.hits: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - chunks.append( - "\n".join( - [ - str(hit.title or ""), - str(hit.content or ""), - json.dumps(meta.get("keywords", []) or [], ensure_ascii=False), - json.dumps(meta.get("participants", []) or [], ensure_ascii=False), - ] - ) - ) - return "\n".join(chunks) - - -async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - episode_source = f"chat_summary:{session_id}" - payload = await memory_service.episode_admin(action="query", source=episode_source, limit=20) - items = payload.get("items") or [] - blob = _episode_blob_from_items(items) - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "episode_count": len(items), - "reports": reports, - } - - -async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - episode_source = f"chat_summary:{session_id}" - - for case in episode_cases: - payload = await memory_service.episode_admin( - action="query", - source=episode_source, - query=case["query"], - limit=5, - ) - items = payload.get("items") or [] - blob = "\n".join( - f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}" - for item in items - ) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - result = await memory_service.search( - case["query"], - mode="episode", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - blob = _episode_blob_from_hits(result) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(result.hits), - "top_episode": result.hits[0].to_dict() if result.hits else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_time_cases(*, session_id: str, time_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in time_cases: - result = await memory_service.search( - case["query"], - mode="time", - chat_id=session_id, - time_start=case["time_expression"], - time_end=case["time_expression"], - respect_filter=False, - limit=5, - ) - blob = "\n".join(_hit_blob(hit) for hit in result.hits) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "time_expression": case["time_expression"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - - total = max(1, len(time_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]: - search_case = dataset["search_cases"][0] - time_case = dataset["time_cases"][0] - episode_case = dataset["episode_cases"][0] - aggregate_case = dataset["knowledge_fetcher_cases"][0] - tool_cases = [ - { - "name": "search", - "kwargs": { - "query": search_case["query"], - "mode": "search", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": search_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "time", - "kwargs": { - "query": time_case["query"], - "mode": "time", - "chat_id": session_id, - "limit": 5, - "time_expression": time_case["time_expression"], - }, - "expected_keywords": time_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "episode", - "kwargs": { - "query": episode_case["query"], - "mode": "episode", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": episode_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "aggregate", - "kwargs": { - "query": aggregate_case["query"], - "mode": "aggregate", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": aggregate_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - ] - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in tool_cases: - text = await query_long_term_memory(**case["kwargs"]) - recall = _keyword_recall(text, case["expected_keywords"]) - success = ( - "失败" not in text - and "无法解析" not in text - and "未找到" not in text - and recall >= float(case["minimum_keyword_recall"]) - ) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "name": case["name"], - "success": success, - "keyword_recall": recall, - "preview": text[:320], - } - ) - - total = max(1, len(tool_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -def _parse_fixture_message( - *, - line: str, - session_id: str, - platform: str, - group_id: str, - group_name: str, - speaker_to_user_id: Dict[str, str], - seq: int, -): - match = re.match(r"^\[(?P[^\]]+)\]\s*(?P[^:]+):(?P.*)$", str(line or "").strip()) - if not match: - raise ValueError(f"无法解析 fixture 消息: {line}") - - dt = datetime.strptime(match.group("ts"), "%Y-%m-%d %H:%M") - speaker = match.group("speaker").strip() - text = match.group("text").strip() - user_id = speaker_to_user_id.setdefault(speaker, f"user-{len(speaker_to_user_id) + 1}") - - return SimpleNamespace( - message_id=f"{session_id}-{seq}", - timestamp=dt, - platform=platform, - session_id=session_id, - reply_to=None, - processed_plain_text=text, - display_message=f"{speaker}:{text}", - message_info=SimpleNamespace( - user_info=SimpleNamespace( - user_id=user_id, - user_nickname=speaker, - user_cardname=speaker, - ), - group_info=SimpleNamespace(group_id=group_id, group_name=group_name), - additional_config={}, - ), - ) - - -async def _evaluate_runtime_trigger_flow(*, dataset: Dict[str, Any], session: Any) -> Dict[str, Any]: - streams = dataset["runtime_trigger_streams"] - reports: List[Dict[str, Any]] = [] - positive_successes = 0.0 - negative_successes = 0.0 - - speaker_to_user_id = {"Mai": "bot-mai"} - for payload in dataset["person_writebacks"]: - speaker_to_user_id.setdefault(payload["person_name"], payload["person_id"]) - - original_get_messages = summarizer_module.message_api.get_messages_by_time_in_chat - original_build_readable = summarizer_module.message_api.build_readable_messages - original_is_bot_self = summarizer_module.is_bot_self - original_person = summarizer_module.Person - original_analyze = summarizer_module.ChatHistorySummarizer._analyze_topics_with_llm - - try: - summarizer_module.message_api.build_readable_messages = ( - lambda messages, **kwargs: "\n".join(getattr(msg, "display_message", getattr(msg, "processed_plain_text", "")) for msg in messages) - ) - summarizer_module.is_bot_self = lambda platform, user_id: str(user_id or "") == "bot-mai" - summarizer_module.Person = lambda platform="", user_id="", person_id="": SimpleNamespace( - person_name=next( - ( - name - for name, mapped in speaker_to_user_id.items() - if mapped == (str(person_id or "").strip() or str(user_id or "").strip()) - ), - "", - ) - ) - - async def fake_analyze(self, numbered_lines, existing_topics): - del existing_topics - topic = str(getattr(self, "_fixture_topic", "") or "群聊话题") - return True, {topic: list(range(1, len(numbered_lines) + 1))} - - summarizer_module.ChatHistorySummarizer._analyze_topics_with_llm = fake_analyze - - for index, stream in enumerate(streams, start=1): - messages = [ - _parse_fixture_message( - line=line, - session_id=session.session_id, - platform=session.platform, - group_id=session.group_id, - group_name=dataset["session"]["display_name"], - speaker_to_user_id=speaker_to_user_id, - seq=1000 * index + seq, - ) - for seq, line in enumerate(stream["messages"], start=1) - ] - - def fake_get_messages_by_time_in_chat( - *, - chat_id: str, - start_time: float, - end_time: float, - limit: int = 0, - limit_mode: str = "latest", - filter_mai: bool = False, - filter_command: bool = False, - ): - del limit, limit_mode, filter_mai, filter_command - if chat_id != session.session_id: - return [] - return [ - msg - for msg in messages - if start_time <= msg.timestamp.timestamp() <= end_time - ] - - summarizer_module.message_api.get_messages_by_time_in_chat = fake_get_messages_by_time_in_chat - - summarizer = summarizer_module.ChatHistorySummarizer(session.session_id) - summarizer._fixture_topic = stream["topic"] - summarizer._persist_topic_cache = lambda: None - summarizer.last_check_time = float(stream["start_time"]) - 60.0 - summarizer.last_topic_check_time = float(stream["end_time"]) - float(stream["elapsed_since_last_check_hours"]) * 3600.0 - - await summarizer.process(current_time=float(stream["end_time"])) - - topic_cache_keys = list(summarizer.topic_cache.keys()) - expected_topic_present = stream["topic"] in summarizer.topic_cache - topic_cache_updated = bool(topic_cache_keys) - batch_cleared = summarizer.current_batch is None - - if stream["bot_participated"]: - success = expected_topic_present and topic_cache_updated and batch_cleared - positive_successes += 1.0 if success else 0.0 - else: - success = (not topic_cache_updated) and batch_cleared - negative_successes += 1.0 if success else 0.0 - - reports.append( - { - "stream_id": stream["stream_id"], - "bot_participated": stream["bot_participated"], - "success": success, - "topic_cache_keys": topic_cache_keys, - "current_batch_cleared": batch_cleared, - "expected_check_outcome": stream["expected_check_outcome"], - } - ) - finally: - summarizer_module.message_api.get_messages_by_time_in_chat = original_get_messages - summarizer_module.message_api.build_readable_messages = original_build_readable - summarizer_module.is_bot_self = original_is_bot_self - summarizer_module.Person = original_person - summarizer_module.ChatHistorySummarizer._analyze_topics_with_llm = original_analyze - - positive_total = max(1, sum(1 for item in streams if item["bot_participated"])) - negative_total = max(1, sum(1 for item in streams if not item["bot_participated"])) - return { - "positive_trigger_rate": round(positive_successes / positive_total, 4), - "negative_discard_rate": round(negative_successes / negative_total, 4), - "reports": reports, - } - - -def _average(values: List[float]) -> float: - if not values: - return 0.0 - return round(sum(float(v) for v in values) / float(len(values)), 4) - - -def _target_score(actual: float, target: float) -> float: - clean_target = float(target or 0.0) - if clean_target <= 0: - return 1.0 - return min(float(actual) / clean_target, 1.0) - - -def _build_final_score(*, metrics: Dict[str, Any], targets: Dict[str, Any]) -> Dict[str, Any]: - episode_summary = metrics["episode_summary_after_rebuild"] - category_scores = { - "search": _average( - [ - _target_score(metrics["search"]["accuracy_at_1"], targets["search"]["accuracy_at_1"]), - _target_score(metrics["search"]["recall_at_5"], targets["search"]["recall_at_5"]), - _target_score(metrics["search"]["keyword_recall_at_5"], targets["search"]["keyword_recall_at_5"]), - ] - ), - "writeback": _average( - [ - _target_score(metrics["writeback"]["success_rate"], targets["writeback"]["success_rate"]), - _target_score(metrics["writeback"]["keyword_recall"], targets["writeback"]["keyword_recall"]), - ] - ), - "knowledge_fetcher": _average( - [ - _target_score(metrics["knowledge_fetcher"]["success_rate"], targets["knowledge_fetcher"]["success_rate"]), - _target_score(metrics["knowledge_fetcher"]["keyword_recall"], targets["knowledge_fetcher"]["keyword_recall"]), - ] - ), - "profile": _average( - [ - _target_score(metrics["profile"]["success_rate"], targets["profile"]["success_rate"]), - _target_score(metrics["profile"]["evidence_rate"], targets["profile"]["evidence_rate"]), - ] - ), - "episode": _average( - [ - _target_score(episode_summary["success_rate"], targets["episode"]["success_rate"]), - _target_score(episode_summary["keyword_recall"], targets["episode"]["keyword_recall"]), - ] - ), - "negative_control": _target_score( - metrics["negative_control"]["zero_hit_rate"], targets["negative_control"]["zero_hit_rate"] - ), - "runtime_trigger": _average( - [ - _target_score( - metrics["runtime_trigger"]["positive_trigger_rate"], - targets["runtime_trigger"]["positive_trigger_rate"], - ), - _target_score( - metrics["runtime_trigger"]["negative_discard_rate"], - targets["runtime_trigger"]["negative_discard_rate"], - ), - ] - ), - } - overall_ratio = _average(list(category_scores.values())) - return { - "overall_ratio": overall_ratio, - "overall_score": round(overall_ratio * 100.0, 2), - "category_scores": {key: round(value * 100.0, 2) for key, value in category_scores.items()}, - } - - -async def _build_group_chat_benchmark_env(patch_manager: _PatchManager, tmp_path: Path): - dataset = _load_benchmark_fixture() - session_cfg = dataset["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} - reverse_registry = {value: key for key, value in registry.items()} - - patch_manager.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) - - async def fake_self_check(**kwargs): - return {"ok": True, "message": "ok", "encoded_dimension": 32} - - patch_manager.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) - patch_manager.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - patch_manager.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - patch_manager.setattr(person_info_module, "_chat_manager", fake_chat_manager) - patch_manager.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: registry.get(str(person_name or "").strip(), ""), - ) - patch_manager.setattr( - person_info_module, - "Person", - lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), - ) - - data_dir = (tmp_path / "a_memorix_group_chat_benchmark_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - patch_manager.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - return { - "dataset": dataset, - "kernel": kernel, - "session": session, - "person_registry": registry, - } - - -async def _run_group_chat_stream_memory_benchmark(tmp_path: Path): - patch_manager = _PatchManager() - group_chat_benchmark_env = await _build_group_chat_benchmark_env(patch_manager, tmp_path) - dataset = group_chat_benchmark_env["dataset"] - session_id = group_chat_benchmark_env["session"].session_id - - try: - created = await memory_service.import_admin( - action="create_paste", - name="group_chat_stream_memory_benchmark.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(dataset["import_payload"], ensure_ascii=False), - ) - assert created["success"] is True - - import_detail = await _wait_for_import_task(created["task"]["task_id"]) - assert import_detail["task"]["status"] == "completed" - - for record in dataset["chat_history_records"]: - summarizer = summarizer_module.ChatHistorySummarizer(session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - for payload in dataset["person_writebacks"]: - await person_info_module.store_person_memory_from_answer( - payload["person_name"], - payload["memory_content"], - session_id, - ) - - await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) - - search_case_reports: List[Dict[str, Any]] = [] - search_accuracy_at_1 = 0.0 - search_recall_at_5 = 0.0 - search_precision_at_5 = 0.0 - search_mrr = 0.0 - search_keyword_recall = 0.0 - - for case in dataset["search_cases"]: - result = await memory_service.search( - case["query"], - mode="search", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - joined = "\n".join(_hit_blob(hit) for hit in result.hits) - rank = _first_relevant_rank( - result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])) - ) - relevant_hits = sum( - 1 - for hit in result.hits[:5] - if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max( - 1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))) - ) - ) - keyword_recall = _keyword_recall(joined, case["expected_keywords"]) - search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 - search_recall_at_5 += 1.0 if rank > 0 else 0.0 - search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) - search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 - search_keyword_recall += keyword_recall - search_case_reports.append( - { - "query": case["query"], - "rank_of_first_relevant": rank, - "relevant_hits_top5": relevant_hits, - "keyword_recall_top5": keyword_recall, - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - - search_total = max(1, len(dataset["search_cases"])) - - writeback_reports: List[Dict[str, Any]] = [] - writeback_success_rate = 0.0 - writeback_keyword_recall = 0.0 - for payload in dataset["person_writebacks"]: - query = " ".join(payload["expected_keywords"]) - result = await memory_service.search( - query, - mode="search", - chat_id=session_id, - person_id=payload["person_id"], - respect_filter=False, - limit=5, - ) - joined = _join_hit_content(result) - recall = _keyword_recall(joined, payload["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - writeback_success_rate += 1.0 if success else 0.0 - writeback_keyword_recall += recall - writeback_reports.append( - { - "person_id": payload["person_id"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - } - ) - writeback_total = max(1, len(dataset["person_writebacks"])) - - knowledge_reports: List[Dict[str, Any]] = [] - knowledge_success_rate = 0.0 - knowledge_keyword_recall = 0.0 - fetcher = knowledge_module.KnowledgeFetcher( - private_name=dataset["session"]["display_name"], - stream_id=session_id, - ) - for case in dataset["knowledge_fetcher_cases"]: - knowledge_text, _ = await fetcher.fetch(case["query"], []) - recall = _keyword_recall(knowledge_text, case["expected_keywords"]) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) - knowledge_success_rate += 1.0 if success else 0.0 - knowledge_keyword_recall += recall - knowledge_reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "preview": knowledge_text[:300], - } - ) - knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) - - profile_reports: List[Dict[str, Any]] = [] - profile_success_rate = 0.0 - profile_keyword_recall = 0.0 - profile_evidence_rate = 0.0 - for case in dataset["profile_cases"]: - profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) - recall = _keyword_recall(profile.summary, case["expected_keywords"]) - has_evidence = bool(profile.evidence) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence - profile_success_rate += 1.0 if success else 0.0 - profile_keyword_recall += recall - profile_evidence_rate += 1.0 if has_evidence else 0.0 - profile_reports.append( - { - "person_id": case["person_id"], - "success": success, - "keyword_recall": recall, - "evidence_count": len(profile.evidence), - "summary_preview": profile.summary[:240], - } - ) - profile_total = max(1, len(dataset["profile_cases"])) - - time_mode = await _evaluate_time_cases(session_id=session_id, time_cases=dataset["time_cases"]) - episode_generation_auto = await _evaluate_episode_generation( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_admin_query_auto = await _evaluate_episode_admin_query( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_search_mode_auto = await _evaluate_episode_search_mode( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_rebuild = await memory_service.episode_admin(action="rebuild", source=f"chat_summary:{session_id}") - episode_generation_after_rebuild = await _evaluate_episode_generation( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_admin_query_after_rebuild = await _evaluate_episode_admin_query( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_search_mode_after_rebuild = await _evaluate_episode_search_mode( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) - - negative_reports: List[Dict[str, Any]] = [] - negative_zero_hits = 0.0 - for case in dataset["negative_control_cases"]: - result = await memory_service.search( - case["query"], - mode="search", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - success = len(result.hits) == 0 - negative_zero_hits += 1.0 if success else 0.0 - negative_reports.append( - { - "query": case["query"], - "success": success, - "hit_count": len(result.hits), - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - negative_total = max(1, len(dataset["negative_control_cases"])) - - runtime_trigger = await _evaluate_runtime_trigger_flow( - dataset=dataset, - session=group_chat_benchmark_env["session"], - ) - - episode_summary_after_rebuild = { - "success_rate": _average( - [ - episode_generation_after_rebuild["success_rate"], - episode_admin_query_after_rebuild["success_rate"], - episode_search_mode_after_rebuild["success_rate"], - ] - ), - "keyword_recall": _average( - [ - episode_generation_after_rebuild["keyword_recall"], - episode_admin_query_after_rebuild["keyword_recall"], - episode_search_mode_after_rebuild["keyword_recall"], - ] - ), - } - - metrics = { - "search": { - "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), - "recall_at_5": round(search_recall_at_5 / search_total, 4), - "precision_at_5": round(search_precision_at_5 / search_total, 4), - "mrr": round(search_mrr / search_total, 4), - "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), - }, - "time_mode": { - "success_rate": time_mode["success_rate"], - "keyword_recall": time_mode["keyword_recall"], - }, - "writeback": { - "success_rate": round(writeback_success_rate / writeback_total, 4), - "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), - }, - "knowledge_fetcher": { - "success_rate": round(knowledge_success_rate / knowledge_total, 4), - "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), - }, - "profile": { - "success_rate": round(profile_success_rate / profile_total, 4), - "keyword_recall": round(profile_keyword_recall / profile_total, 4), - "evidence_rate": round(profile_evidence_rate / profile_total, 4), - }, - "tool_modes": { - "success_rate": tool_modes["success_rate"], - "keyword_recall": tool_modes["keyword_recall"], - }, - "episode_generation_auto": { - "success_rate": episode_generation_auto["success_rate"], - "keyword_recall": episode_generation_auto["keyword_recall"], - "episode_count": episode_generation_auto["episode_count"], - }, - "episode_generation_after_rebuild": { - "success_rate": episode_generation_after_rebuild["success_rate"], - "keyword_recall": episode_generation_after_rebuild["keyword_recall"], - "episode_count": episode_generation_after_rebuild["episode_count"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_admin_query_auto": { - "success_rate": episode_admin_query_auto["success_rate"], - "keyword_recall": episode_admin_query_auto["keyword_recall"], - }, - "episode_admin_query_after_rebuild": { - "success_rate": episode_admin_query_after_rebuild["success_rate"], - "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_search_mode_auto": { - "success_rate": episode_search_mode_auto["success_rate"], - "keyword_recall": episode_search_mode_auto["keyword_recall"], - }, - "episode_search_mode_after_rebuild": { - "success_rate": episode_search_mode_after_rebuild["success_rate"], - "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_summary_after_rebuild": episode_summary_after_rebuild, - "negative_control": { - "zero_hit_rate": round(negative_zero_hits / negative_total, 4), - }, - "runtime_trigger": { - "positive_trigger_rate": runtime_trigger["positive_trigger_rate"], - "negative_discard_rate": runtime_trigger["negative_discard_rate"], - }, - } - final_score = _build_final_score(metrics=metrics, targets=dataset["meta"]["quantitative_targets"]) - - report = { - "dataset": dataset["meta"], - "import": { - "task_id": created["task"]["task_id"], - "status": import_detail["task"]["status"], - "paragraph_count": len(dataset["import_payload"]["paragraphs"]), - "relation_count": len(dataset["import_payload"].get("relations") or []), - }, - "metrics": metrics, - "score": final_score, - "cases": { - "search": search_case_reports, - "time_mode": time_mode["reports"], - "writeback": writeback_reports, - "knowledge_fetcher": knowledge_reports, - "profile": profile_reports, - "tool_modes": tool_modes["reports"], - "episode_generation_auto": episode_generation_auto["reports"], - "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], - "episode_admin_query_auto": episode_admin_query_auto["reports"], - "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], - "episode_search_mode_auto": episode_search_mode_auto["reports"], - "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], - "negative_control": negative_reports, - "runtime_trigger": runtime_trigger["reports"], - }, - } - - REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) - REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") - print(json.dumps({"metrics": report["metrics"], "score": report["score"]}, ensure_ascii=False, indent=2)) - - assert report["import"]["status"] == "completed" - assert report["score"]["overall_score"] >= 0.0 - return report - finally: - await group_chat_benchmark_env["kernel"].shutdown() - patch_manager.undo() - - -def run_group_chat_stream_memory_benchmark() -> Dict[str, Any]: - with tempfile.TemporaryDirectory(prefix="a_memorix_group_chat_benchmark_") as tmp_dir: - return asyncio.run(_run_group_chat_stream_memory_benchmark(Path(tmp_dir))) - - -if __name__ == "__main__": - report = run_group_chat_stream_memory_benchmark() - print(json.dumps({"final_score": report["score"], "metrics": report["metrics"]}, ensure_ascii=False, indent=2)) diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark.py deleted file mode 100644 index 819ecd97..00000000 --- a/pytests/A_memorix_test/test_long_novel_memory_benchmark.py +++ /dev/null @@ -1,687 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -from datetime import datetime -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import numpy as np -import pytest -import pytest_asyncio - -from A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import MemorySearchResult, memory_service - - -DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json" -REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json" - - -def _load_benchmark_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _FakeEmbeddingAdapter: - def __init__(self, dimension: int = 32) -> None: - self.dimension = dimension - - async def _detect_dimension(self) -> int: - return self.dimension - - async def encode(self, texts, dimensions=None): - dim = int(dimensions or self.dimension) - if isinstance(texts, str): - sequence = [texts] - single = True - else: - sequence = list(texts) - single = False - - rows = [] - for text in sequence: - vec = np.zeros(dim, dtype=np.float32) - for ch in str(text or ""): - code = ord(ch) - vec[code % dim] += 1.0 - vec[(code * 7) % dim] += 0.5 - if not vec.any(): - vec[0] = 1.0 - norm = np.linalg.norm(vec) - if norm > 0: - vec = vec / norm - rows.append(vec) - payload = np.vstack(rows) - return payload[0] if single else payload - - -class _KnownPerson: - def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None: - self.person_id = person_id - self.is_known = person_id in reverse_registry - self.person_name = reverse_registry.get(person_id, "") - self._registry = registry - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]: - for _ in range(max_rounds): - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(max(0.01, float(sleep_seconds))) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result: MemorySearchResult) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -def _keyword_hits(text: str, keywords: List[str]) -> int: - haystack = str(text or "") - return sum(1 for keyword in keywords if keyword in haystack) - - -def _keyword_recall(text: str, keywords: List[str]) -> float: - if not keywords: - return 1.0 - return _keyword_hits(text, keywords) / float(len(keywords)) - - -def _hit_blob(hit) -> str: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - return "\n".join( - [ - str(hit.content or ""), - str(hit.title or ""), - str(hit.source or ""), - json.dumps(meta, ensure_ascii=False), - ] - ) - - -def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int: - for index, hit in enumerate(search_result.hits[:5], start=1): - if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))): - return index - return 0 - - -def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str: - return "\n".join( - ( - f"{item.get('title', '')}\n" - f"{item.get('summary', '')}\n" - f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n" - f"{json.dumps(item.get('participants', []), ensure_ascii=False)}" - ) - for item in items - ) - - -def _episode_blob_from_hits(search_result: MemorySearchResult) -> str: - chunks = [] - for hit in search_result.hits: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - chunks.append( - "\n".join( - [ - str(hit.title or ""), - str(hit.content or ""), - json.dumps(meta.get("keywords", []) or [], ensure_ascii=False), - json.dumps(meta.get("participants", []) or [], ensure_ascii=False), - ] - ) - ) - return "\n".join(chunks) - - -async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - episode_source = f"chat_summary:{session_id}" - payload = await memory_service.episode_admin( - action="query", - source=episode_source, - limit=20, - ) - items = payload.get("items") or [] - blob = _episode_blob_from_items(items) - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "episode_count": len(items), - "reports": reports, - } - - -async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - episode_source = f"chat_summary:{session_id}" - - for case in episode_cases: - payload = await memory_service.episode_admin( - action="query", - source=episode_source, - query=case["query"], - limit=5, - ) - items = payload.get("items") or [] - blob = "\n".join( - f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}" - for item in items - ) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - result = await memory_service.search( - case["query"], - mode="episode", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - blob = _episode_blob_from_hits(result) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(result.hits), - "top_episode": result.hits[0].to_dict() if result.hits else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]: - search_case = dataset["search_cases"][0] - episode_case = dataset["episode_cases"][0] - aggregate_case = dataset["knowledge_fetcher_cases"][0] - first_record = (dataset.get("chat_history_records") or [{}])[0] - reference_ts = first_record.get("end_time") or first_record.get("start_time") or 0 - if reference_ts: - time_expression = datetime.fromtimestamp(float(reference_ts)).strftime("%Y/%m/%d") - else: - time_expression = "最近7天" - tool_cases = [ - { - "name": "search", - "kwargs": { - "query": "蓝漆铁盒 北塔木梯", - "mode": "search", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "time", - "kwargs": { - "query": "蓝漆铁盒 北塔", - "mode": "time", - "chat_id": session_id, - "limit": 5, - "time_expression": time_expression, - }, - "expected_keywords": ["蓝漆铁盒", "北塔木梯"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "episode", - "kwargs": { - "query": episode_case["query"], - "mode": "episode", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": episode_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "aggregate", - "kwargs": { - "query": aggregate_case["query"], - "mode": "aggregate", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": aggregate_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - ] - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in tool_cases: - text = await query_long_term_memory(**case["kwargs"]) - recall = _keyword_recall(text, case["expected_keywords"]) - success = ( - "失败" not in text - and "无法解析" not in text - and "未找到" not in text - and recall >= float(case["minimum_keyword_recall"]) - ) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "name": case["name"], - "success": success, - "keyword_recall": recall, - "preview": text[:320], - } - ) - - total = max(1, len(tool_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -@pytest_asyncio.fixture -async def benchmark_env(monkeypatch, tmp_path): - dataset = _load_benchmark_fixture() - session_cfg = dataset["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} - reverse_registry = {value: key for key, value in registry.items()} - - monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) - - async def fake_self_check(**kwargs): - return {"ok": True, "message": "ok", "encoded_dimension": 32} - - monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) - monkeypatch.setattr( - person_info_module, - "Person", - lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), - ) - - data_dir = (tmp_path / "a_memorix_benchmark_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "dataset": dataset, - "kernel": kernel, - "session": session, - "person_registry": registry, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_long_novel_memory_benchmark(benchmark_env): - dataset = benchmark_env["dataset"] - session_id = benchmark_env["session"].session_id - - created = await memory_service.import_admin( - action="create_paste", - name="long_novel_memory_benchmark.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(dataset["import_payload"], ensure_ascii=False), - ) - assert created["success"] is True - - import_detail = await _wait_for_import_task(created["task"]["task_id"]) - assert import_detail["task"]["status"] == "completed" - - for record in dataset["chat_history_records"]: - summarizer = summarizer_module.ChatHistorySummarizer(session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - for payload in dataset["person_writebacks"]: - await person_info_module.store_person_memory_from_answer( - payload["person_name"], - payload["memory_content"], - session_id, - ) - - await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) - - search_case_reports: List[Dict[str, Any]] = [] - search_accuracy_at_1 = 0.0 - search_recall_at_5 = 0.0 - search_precision_at_5 = 0.0 - search_mrr = 0.0 - search_keyword_recall = 0.0 - - for case in dataset["search_cases"]: - result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) - joined = _join_hit_content(result) - rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) - relevant_hits = sum( - 1 - for hit in result.hits[:5] - if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) - ) - keyword_recall = _keyword_recall(joined, case["expected_keywords"]) - search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 - search_recall_at_5 += 1.0 if rank > 0 else 0.0 - search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) - search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 - search_keyword_recall += keyword_recall - search_case_reports.append( - { - "query": case["query"], - "rank_of_first_relevant": rank, - "relevant_hits_top5": relevant_hits, - "keyword_recall_top5": keyword_recall, - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - - search_total = max(1, len(dataset["search_cases"])) - - writeback_reports: List[Dict[str, Any]] = [] - writeback_success_rate = 0.0 - writeback_keyword_recall = 0.0 - for payload in dataset["person_writebacks"]: - query = " ".join(payload["expected_keywords"]) - result = await memory_service.search( - query, - mode="search", - chat_id=session_id, - person_id=payload["person_id"], - respect_filter=False, - limit=5, - ) - joined = _join_hit_content(result) - recall = _keyword_recall(joined, payload["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - writeback_success_rate += 1.0 if success else 0.0 - writeback_keyword_recall += recall - writeback_reports.append( - { - "person_id": payload["person_id"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - } - ) - writeback_total = max(1, len(dataset["person_writebacks"])) - - knowledge_reports: List[Dict[str, Any]] = [] - knowledge_success_rate = 0.0 - knowledge_keyword_recall = 0.0 - fetcher = knowledge_module.KnowledgeFetcher( - private_name=dataset["session"]["display_name"], - stream_id=session_id, - ) - for case in dataset["knowledge_fetcher_cases"]: - knowledge_text, _ = await fetcher.fetch(case["query"], []) - recall = _keyword_recall(knowledge_text, case["expected_keywords"]) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) - knowledge_success_rate += 1.0 if success else 0.0 - knowledge_keyword_recall += recall - knowledge_reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "preview": knowledge_text[:300], - } - ) - knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) - - profile_reports: List[Dict[str, Any]] = [] - profile_success_rate = 0.0 - profile_keyword_recall = 0.0 - profile_evidence_rate = 0.0 - for case in dataset["profile_cases"]: - profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) - recall = _keyword_recall(profile.summary, case["expected_keywords"]) - has_evidence = bool(profile.evidence) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence - profile_success_rate += 1.0 if success else 0.0 - profile_keyword_recall += recall - profile_evidence_rate += 1.0 if has_evidence else 0.0 - profile_reports.append( - { - "person_id": case["person_id"], - "success": success, - "keyword_recall": recall, - "evidence_count": len(profile.evidence), - "summary_preview": profile.summary[:240], - } - ) - profile_total = max(1, len(dataset["profile_cases"])) - - episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_rebuild = await memory_service.episode_admin( - action="rebuild", - source=f"chat_summary:{session_id}", - ) - episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) - - report = { - "dataset": dataset["meta"], - "import": { - "task_id": created["task"]["task_id"], - "status": import_detail["task"]["status"], - "paragraph_count": len(dataset["import_payload"]["paragraphs"]), - }, - "metrics": { - "search": { - "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), - "recall_at_5": round(search_recall_at_5 / search_total, 4), - "precision_at_5": round(search_precision_at_5 / search_total, 4), - "mrr": round(search_mrr / search_total, 4), - "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), - }, - "writeback": { - "success_rate": round(writeback_success_rate / writeback_total, 4), - "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), - }, - "knowledge_fetcher": { - "success_rate": round(knowledge_success_rate / knowledge_total, 4), - "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), - }, - "profile": { - "success_rate": round(profile_success_rate / profile_total, 4), - "keyword_recall": round(profile_keyword_recall / profile_total, 4), - "evidence_rate": round(profile_evidence_rate / profile_total, 4), - }, - "tool_modes": { - "success_rate": tool_modes["success_rate"], - "keyword_recall": tool_modes["keyword_recall"], - }, - "episode_generation_auto": { - "success_rate": episode_generation_auto["success_rate"], - "keyword_recall": episode_generation_auto["keyword_recall"], - "episode_count": episode_generation_auto["episode_count"], - }, - "episode_generation_after_rebuild": { - "success_rate": episode_generation_after_rebuild["success_rate"], - "keyword_recall": episode_generation_after_rebuild["keyword_recall"], - "episode_count": episode_generation_after_rebuild["episode_count"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_admin_query_auto": { - "success_rate": episode_admin_query_auto["success_rate"], - "keyword_recall": episode_admin_query_auto["keyword_recall"], - }, - "episode_admin_query_after_rebuild": { - "success_rate": episode_admin_query_after_rebuild["success_rate"], - "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_search_mode_auto": { - "success_rate": episode_search_mode_auto["success_rate"], - "keyword_recall": episode_search_mode_auto["keyword_recall"], - }, - "episode_search_mode_after_rebuild": { - "success_rate": episode_search_mode_after_rebuild["success_rate"], - "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - }, - "cases": { - "search": search_case_reports, - "writeback": writeback_reports, - "knowledge_fetcher": knowledge_reports, - "profile": profile_reports, - "tool_modes": tool_modes["reports"], - "episode_generation_auto": episode_generation_auto["reports"], - "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], - "episode_admin_query_auto": episode_admin_query_auto["reports"], - "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], - "episode_search_mode_auto": episode_search_mode_auto["reports"], - "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], - }, - } - - REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) - REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") - print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) - - assert report["import"]["status"] == "completed" - assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35 - assert report["metrics"]["search"]["recall_at_5"] >= 0.6 - assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8 - assert report["metrics"]["writeback"]["success_rate"] >= 0.66 - assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66 - assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75 - assert report["metrics"]["profile"]["success_rate"] >= 0.66 - assert report["metrics"]["profile"]["evidence_rate"] >= 1.0 - assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75 - assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True - assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"] diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py deleted file mode 100644 index bd0560dc..00000000 --- a/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py +++ /dev/null @@ -1,342 +0,0 @@ -from __future__ import annotations - -import json -import os -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import pytest -import pytest_asyncio - -from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel -from pytests.A_memorix_test.test_long_novel_memory_benchmark import ( - _evaluate_episode_admin_query, - _evaluate_episode_generation, - _evaluate_episode_search_mode, - _evaluate_tool_modes, - _KernelBackedRuntimeManager, - _KnownPerson, - _first_relevant_rank, - _hit_blob, - _join_hit_content, - _keyword_hits, - _keyword_recall, - _load_benchmark_fixture, - _wait_for_import_task, -) -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import memory_service - - -pytestmark = pytest.mark.skipif( - os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", - reason="需要显式开启真实 external embedding benchmark", -) - -REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json" - - -@pytest_asyncio.fixture -async def benchmark_live_env(monkeypatch, tmp_path): - dataset = _load_benchmark_fixture() - session_cfg = dataset["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} - reverse_registry = {value: key for key, value in registry.items()} - - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) - monkeypatch.setattr( - person_info_module, - "Person", - lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), - ) - - data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "dataset": dataset, - "kernel": kernel, - "session": session, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_long_novel_memory_benchmark_live(benchmark_live_env): - dataset = benchmark_live_env["dataset"] - session_id = benchmark_live_env["session"].session_id - - self_check = await memory_service.runtime_admin(action="refresh_self_check") - assert self_check["success"] is True - assert self_check["report"]["ok"] is True - - created = await memory_service.import_admin( - action="create_paste", - name="long_novel_memory_benchmark.live.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(dataset["import_payload"], ensure_ascii=False), - ) - assert created["success"] is True - - import_detail = await _wait_for_import_task( - created["task"]["task_id"], - max_rounds=2400, - sleep_seconds=0.25, - ) - assert import_detail["task"]["status"] == "completed" - - for record in dataset["chat_history_records"]: - summarizer = summarizer_module.ChatHistorySummarizer(session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - for payload in dataset["person_writebacks"]: - await person_info_module.store_person_memory_from_answer( - payload["person_name"], - payload["memory_content"], - session_id, - ) - - await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) - - search_case_reports: List[Dict[str, Any]] = [] - search_accuracy_at_1 = 0.0 - search_recall_at_5 = 0.0 - search_precision_at_5 = 0.0 - search_mrr = 0.0 - search_keyword_recall = 0.0 - for case in dataset["search_cases"]: - result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) - joined = _join_hit_content(result) - rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) - relevant_hits = sum( - 1 - for hit in result.hits[:5] - if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) - ) - keyword_recall = _keyword_recall(joined, case["expected_keywords"]) - search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 - search_recall_at_5 += 1.0 if rank > 0 else 0.0 - search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) - search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 - search_keyword_recall += keyword_recall - search_case_reports.append( - { - "query": case["query"], - "rank_of_first_relevant": rank, - "relevant_hits_top5": relevant_hits, - "keyword_recall_top5": keyword_recall, - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - search_total = max(1, len(dataset["search_cases"])) - - writeback_reports: List[Dict[str, Any]] = [] - writeback_success_rate = 0.0 - writeback_keyword_recall = 0.0 - for payload in dataset["person_writebacks"]: - query = " ".join(payload["expected_keywords"]) - result = await memory_service.search( - query, - mode="search", - chat_id=session_id, - person_id=payload["person_id"], - respect_filter=False, - limit=5, - ) - joined = _join_hit_content(result) - recall = _keyword_recall(joined, payload["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - writeback_success_rate += 1.0 if success else 0.0 - writeback_keyword_recall += recall - writeback_reports.append( - { - "person_id": payload["person_id"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - } - ) - writeback_total = max(1, len(dataset["person_writebacks"])) - - knowledge_reports: List[Dict[str, Any]] = [] - knowledge_success_rate = 0.0 - knowledge_keyword_recall = 0.0 - fetcher = knowledge_module.KnowledgeFetcher( - private_name=dataset["session"]["display_name"], - stream_id=session_id, - ) - for case in dataset["knowledge_fetcher_cases"]: - knowledge_text, _ = await fetcher.fetch(case["query"], []) - recall = _keyword_recall(knowledge_text, case["expected_keywords"]) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) - knowledge_success_rate += 1.0 if success else 0.0 - knowledge_keyword_recall += recall - knowledge_reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "preview": knowledge_text[:300], - } - ) - knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) - - profile_reports: List[Dict[str, Any]] = [] - profile_success_rate = 0.0 - profile_keyword_recall = 0.0 - profile_evidence_rate = 0.0 - for case in dataset["profile_cases"]: - profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) - recall = _keyword_recall(profile.summary, case["expected_keywords"]) - has_evidence = bool(profile.evidence) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence - profile_success_rate += 1.0 if success else 0.0 - profile_keyword_recall += recall - profile_evidence_rate += 1.0 if has_evidence else 0.0 - profile_reports.append( - { - "person_id": case["person_id"], - "success": success, - "keyword_recall": recall, - "evidence_count": len(profile.evidence), - "summary_preview": profile.summary[:240], - } - ) - profile_total = max(1, len(dataset["profile_cases"])) - - episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_rebuild = await memory_service.episode_admin( - action="rebuild", - source=f"chat_summary:{session_id}", - ) - episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) - - report = { - "dataset": dataset["meta"], - "runtime_self_check": self_check["report"], - "import": { - "task_id": created["task"]["task_id"], - "status": import_detail["task"]["status"], - "paragraph_count": len(dataset["import_payload"]["paragraphs"]), - }, - "metrics": { - "search": { - "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), - "recall_at_5": round(search_recall_at_5 / search_total, 4), - "precision_at_5": round(search_precision_at_5 / search_total, 4), - "mrr": round(search_mrr / search_total, 4), - "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), - }, - "writeback": { - "success_rate": round(writeback_success_rate / writeback_total, 4), - "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), - }, - "knowledge_fetcher": { - "success_rate": round(knowledge_success_rate / knowledge_total, 4), - "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), - }, - "profile": { - "success_rate": round(profile_success_rate / profile_total, 4), - "keyword_recall": round(profile_keyword_recall / profile_total, 4), - "evidence_rate": round(profile_evidence_rate / profile_total, 4), - }, - "tool_modes": { - "success_rate": tool_modes["success_rate"], - "keyword_recall": tool_modes["keyword_recall"], - }, - "episode_generation_auto": { - "success_rate": episode_generation_auto["success_rate"], - "keyword_recall": episode_generation_auto["keyword_recall"], - "episode_count": episode_generation_auto["episode_count"], - }, - "episode_generation_after_rebuild": { - "success_rate": episode_generation_after_rebuild["success_rate"], - "keyword_recall": episode_generation_after_rebuild["keyword_recall"], - "episode_count": episode_generation_after_rebuild["episode_count"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_admin_query_auto": { - "success_rate": episode_admin_query_auto["success_rate"], - "keyword_recall": episode_admin_query_auto["keyword_recall"], - }, - "episode_admin_query_after_rebuild": { - "success_rate": episode_admin_query_after_rebuild["success_rate"], - "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_search_mode_auto": { - "success_rate": episode_search_mode_auto["success_rate"], - "keyword_recall": episode_search_mode_auto["keyword_recall"], - }, - "episode_search_mode_after_rebuild": { - "success_rate": episode_search_mode_after_rebuild["success_rate"], - "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - }, - "cases": { - "search": search_case_reports, - "writeback": writeback_reports, - "knowledge_fetcher": knowledge_reports, - "profile": profile_reports, - "tool_modes": tool_modes["reports"], - "episode_generation_auto": episode_generation_auto["reports"], - "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], - "episode_admin_query_auto": episode_admin_query_auto["reports"], - "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], - "episode_search_mode_auto": episode_search_mode_auto["reports"], - "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], - }, - } - - REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) - REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") - print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) - - assert report["import"]["status"] == "completed" - assert report["runtime_self_check"]["ok"] is True diff --git a/pytests/A_memorix_test/test_memory_flow_service.py b/pytests/A_memorix_test/test_memory_flow_service.py index 2d35e837..bdc89386 100644 --- a/pytests/A_memorix_test/test_memory_flow_service.py +++ b/pytests/A_memorix_test/test_memory_flow_service.py @@ -5,68 +5,6 @@ import pytest from src.services import memory_flow_service as memory_flow_module -@pytest.mark.asyncio -async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch): - starts: list[str] = [] - summarizers: list[object] = [] - - class FakeSummarizer: - def __init__(self, session_id: str): - self.session_id = session_id - summarizers.append(self) - - async def start(self): - starts.append(self.session_id) - - async def stop(self): - starts.append(f"stop:{self.session_id}") - - monkeypatch.setattr( - memory_flow_module, - "global_config", - SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), - ) - monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) - - manager = memory_flow_module.LongTermMemorySessionManager() - message = SimpleNamespace(session_id="session-1") - - await manager.on_message(message) - await manager.on_message(message) - - assert len(summarizers) == 1 - assert starts == ["session-1"] - - -@pytest.mark.asyncio -async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch): - stopped: list[str] = [] - - class FakeSummarizer: - def __init__(self, session_id: str): - self.session_id = session_id - - async def start(self): - return None - - async def stop(self): - stopped.append(self.session_id) - - monkeypatch.setattr( - memory_flow_module, - "global_config", - SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), - ) - monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) - - manager = memory_flow_module.LongTermMemorySessionManager() - await manager.on_message(SimpleNamespace(session_id="session-a")) - await manager.on_message(SimpleNamespace(session_id="session-b")) - await manager.shutdown() - - assert stopped == ["session-a", "session-b"] - - def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items(): raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]' @@ -101,16 +39,9 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch): @pytest.mark.asyncio -async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): +async def test_memory_automation_service_auto_starts_and_delegates(): events: list[tuple[str, str]] = [] - class FakeSessionManager: - async def on_message(self, message): - events.append(("incoming", message.session_id)) - - async def shutdown(self): - events.append(("shutdown", "session")) - class FakeFactWriteback: async def start(self): events.append(("start", "fact")) @@ -122,7 +53,6 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): events.append(("shutdown", "fact")) service = memory_flow_module.MemoryAutomationService() - service.session_manager = FakeSessionManager() service.fact_writeback = FakeFactWriteback() await service.on_incoming_message(SimpleNamespace(session_id="session-1")) @@ -131,8 +61,6 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): assert events == [ ("start", "fact"), - ("incoming", "session-1"), ("sent", "session-1"), - ("shutdown", "session"), ("shutdown", "fact"), ] diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py deleted file mode 100644 index 81132f96..00000000 --- a/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py +++ /dev/null @@ -1,324 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict - -import numpy as np -import pytest -import pytest_asyncio - -from A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import memory_service - - -DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" - - -def _load_dialogue_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _FakeEmbeddingAdapter: - def __init__(self, dimension: int = 16) -> None: - self.dimension = dimension - - async def _detect_dimension(self) -> int: - return self.dimension - - async def encode(self, texts, dimensions=None): - dim = int(dimensions or self.dimension) - if isinstance(texts, str): - sequence = [texts] - single = True - else: - sequence = list(texts) - single = False - - rows = [] - for text in sequence: - vec = np.zeros(dim, dtype=np.float32) - for ch in str(text or ""): - vec[ord(ch) % dim] += 1.0 - if not vec.any(): - vec[0] = 1.0 - norm = np.linalg.norm(vec) - if norm > 0: - vec = vec / norm - rows.append(vec) - payload = np.vstack(rows) - return payload[0] if single else payload - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]: - for _ in range(max_rounds): - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(0.05) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -@pytest_asyncio.fixture -async def real_dialogue_env(monkeypatch, tmp_path): - scenario = _load_dialogue_fixture() - session_cfg = scenario["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) - - async def fake_self_check(**kwargs): - return {"ok": True, "message": "ok"} - - monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - - data_dir = (tmp_path / "a_memorix_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "scenario": scenario, - "kernel": kernel, - "session": session, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - - created = await memory_service.import_admin( - action="create_paste", - name="private_alice.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(scenario["import_payload"], ensure_ascii=False), - ) - - assert created["success"] is True - detail = await _wait_for_import_task(created["task"]["task_id"]) - assert detail["task"]["status"] == "completed" - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - respect_filter=False, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=real_dialogue_env["session"].session_id, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch): - scenario = real_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - real_dialogue_env["session"].session_id, - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=real_dialogue_env["session"].session_id, - person_id=scenario["person"]["person_id"], - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - - await memory_service.ingest_text( - external_id="fixture:knowledge_fetcher", - source_type="dialogue_note", - text=scenario["person_fact"]["memory_content"], - chat_id=real_dialogue_env["session"].session_id, - person_ids=[scenario["person"]["person_id"]], - participants=[scenario["person"]["person_name"]], - respect_filter=False, - ) - - fetcher = knowledge_module.KnowledgeFetcher( - private_name=scenario["session"]["display_name"], - stream_id=real_dialogue_env["session"].session_id, - ) - knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) - - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in knowledge_text - - -@pytest.mark.asyncio -async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch): - scenario = real_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - real_dialogue_env["session"].session_id, - ) - - profile = await memory_service.get_person_profile( - scenario["person"]["person_id"], - chat_id=real_dialogue_env["session"].session_id, - ) - - assert profile.evidence - assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) - - -@pytest.mark.asyncio -async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - episodes = await memory_service.episode_admin( - action="query", - source=scenario["expectations"]["episode_source"], - limit=5, - ) - - assert episodes["success"] is True - assert int(episodes["count"]) >= 1 diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py deleted file mode 100644 index 5dadca3d..00000000 --- a/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py +++ /dev/null @@ -1,301 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -import os -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict - -import pytest -import pytest_asyncio - -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import memory_service - - -pytestmark = pytest.mark.skipif( - os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", - reason="需要显式开启真实 embedding / self-check 集成测试", -) - -DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" - - -def _load_dialogue_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]: - deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds)) - while asyncio.get_running_loop().time() < deadline: - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(0.2) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -@pytest_asyncio.fixture -async def live_dialogue_env(monkeypatch, tmp_path): - scenario = _load_dialogue_fixture() - session_cfg = scenario["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - - data_dir = (tmp_path / "a_memorix_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "scenario": scenario, - "kernel": kernel, - "session": session, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_live_runtime_self_check_passes(live_dialogue_env): - report = await memory_service.runtime_admin(action="refresh_self_check") - - assert report["success"] is True - assert report["report"]["ok"] is True - assert report["report"]["encoded_dimension"] > 0 - - -@pytest.mark.asyncio -async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - - created = await memory_service.import_admin( - action="create_paste", - name="private_alice.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(scenario["import_payload"], ensure_ascii=False), - ) - - assert created["success"] is True - detail = await _wait_for_import_task(created["task"]["task_id"]) - assert detail["task"]["status"] == "completed" - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - respect_filter=False, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=live_dialogue_env["session"].session_id, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch): - scenario = live_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - live_dialogue_env["session"].session_id, - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=live_dialogue_env["session"].session_id, - person_id=scenario["person"]["person_id"], - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - - await memory_service.ingest_text( - external_id="fixture:knowledge_fetcher", - source_type="dialogue_note", - text=scenario["person_fact"]["memory_content"], - chat_id=live_dialogue_env["session"].session_id, - person_ids=[scenario["person"]["person_id"]], - participants=[scenario["person"]["person_name"]], - respect_filter=False, - ) - - fetcher = knowledge_module.KnowledgeFetcher( - private_name=scenario["session"]["display_name"], - stream_id=live_dialogue_env["session"].session_id, - ) - knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) - - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in knowledge_text - - -@pytest.mark.asyncio -async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch): - scenario = live_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - live_dialogue_env["session"].session_id, - ) - - profile = await memory_service.get_person_profile( - scenario["person"]["person_id"], - chat_id=live_dialogue_env["session"].session_id, - ) - - assert profile.evidence - assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) - - -@pytest.mark.asyncio -async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - episodes = await memory_service.episode_admin( - action="query", - source=scenario["expectations"]["episode_source"], - limit=5, - ) - - assert episodes["success"] is True - assert int(episodes["count"]) >= 1 diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 7fed79c3..a7176994 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -612,13 +612,6 @@ class ChatBot: scope=scope, ) # 确保会话存在 - try: - from src.services.memory_flow_service import memory_automation_service - - await memory_automation_service.on_incoming_message(message) - except Exception as exc: - logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}") - # message.update_chat_stream(chat) # 命令处理 - 使用新插件系统检查并处理命令。 diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index dc4bdbb2..41b356df 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -159,7 +159,6 @@ MODULE_ALIASES = { "planner": "规划器", "config": "配置", "main": "主程序", - "chat_history_summarizer": "聊天概括器", "plugin_runtime.integration": "IPC插件系统", "plugin_runtime.host.supervisor": "插件监督器", "plugin_runtime.host.runner_manager": "插件监督器", diff --git a/src/config/config.py b/src/config/config.py index c1b6051b..8de2a29b 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -55,7 +55,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.7.1" +CONFIG_VERSION: str = "8.8.0" MODEL_CONFIG_VERSION: str = "1.14.0" logger = get_logger("config") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index c9dc617b..73645170 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -414,95 +414,6 @@ class MemoryConfig(ConfigBase): ) """Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数""" - long_term_auto_summary_enabled: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "book-open", - }, - ) - """是否自动启动聊天总结并导入长期记忆""" - - person_fact_writeback_enabled: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "user-round-pen", - }, - ) - """是否在发送回复后自动提取并写回人物事实到长期记忆""" - chat_history_topic_check_message_threshold: int = Field( - default=80, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "hash", - }, - ) - """聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查""" - - chat_history_topic_check_time_hours: float = Field( - default=8.0, - json_schema_extra={ - "x-widget": "input", - "x-icon": "clock", - }, - ) - """聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查""" - - chat_history_topic_check_min_messages: int = Field( - default=20, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "hash", - }, - ) - """聊天历史话题检查的时间触发模式下的最小消息数阈值""" - - chat_history_finalize_no_update_checks: int = Field( - default=3, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "check-circle", - }, - ) - """聊天历史话题打包存储的连续无更新检查次数阈值,当话题连续N次检查无新增内容时触发打包存储""" - - chat_history_finalize_message_count: int = Field( - default=5, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "package", - }, - ) - """聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储""" - - def model_post_init(self, context: Optional[dict] = None) -> None: - """验证配置值""" - if self.chat_history_topic_check_message_threshold < 1: - raise ValueError( - f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}" - ) - if self.chat_history_topic_check_time_hours <= 0: - raise ValueError( - f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}" - ) - if self.chat_history_topic_check_min_messages < 1: - raise ValueError( - f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}" - ) - if self.chat_history_finalize_no_update_checks < 1: - raise ValueError( - f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}" - ) - if self.chat_history_finalize_message_count < 1: - raise ValueError( - f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}" - ) - return super().model_post_init(context) class LearningItem(ConfigBase): diff --git a/src/maisaka/builtin_tool/send_emoji.py b/src/maisaka/builtin_tool/send_emoji.py index 27353ec3..b02b75f7 100644 --- a/src/maisaka/builtin_tool/send_emoji.py +++ b/src/maisaka/builtin_tool/send_emoji.py @@ -53,40 +53,16 @@ def get_tool_spec() -> ToolSpec: return ToolSpec( name="send_emoji", brief_description="发送一个合适的表情包来辅助表达情绪。", - detailed_description="参数说明:\n- emotion:string,可选。希望表达的情绪,例如 happy、sad、angry 等。", + detailed_description="无需参数,直接发送一个合适的表情包。", parameters_schema={ "type": "object", - "properties": { - "emotion": { - "type": "string", - "description": "希望表达的情绪,例如 happy、sad、angry 等。", - }, - }, + "properties": {}, }, provider_name="maisaka_builtin", provider_type="builtin", ) -def _normalize_candidate_emotions(emoji: MaiEmoji) -> list[str]: - """清洗候选表情上的情绪标签。""" - - raw_emotions = getattr(emoji, "emotion", None) - if isinstance(raw_emotions, list) and raw_emotions: - return [str(item).strip() for item in raw_emotions if str(item).strip()] - - description = str(getattr(emoji, "description", "") or "").strip() - if not description: - return [] - - normalized_description = ( - description.replace(",", ",") - .replace("、", ",") - .replace(";", ",") - ) - return [item.strip() for item in normalized_description.split(",") if item.strip()] - - async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes: """读取单个表情包图片字节。""" @@ -232,18 +208,6 @@ async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBacke ) -def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str: - """构建供监控展示使用的候选表情摘要。""" - - summary_lines: list[str] = [] - for index, emoji in enumerate(emojis, start=1): - description = emoji.description.strip() or "(无描述)" - emotions = "、".join(_normalize_candidate_emotions(emoji)) or "无" - summary_lines.append(f"{index}. 描述:{description}") - summary_lines.append(f" 情绪:{emotions}") - return "\n".join(summary_lines).strip() - - def _build_send_emoji_monitor_detail( *, request_messages: Optional[list[dict[str, Any]]] = None, @@ -252,7 +216,7 @@ def _build_send_emoji_monitor_detail( metrics: Optional[Dict[str, Any]] = None, extra_sections: Optional[list[dict[str, str]]] = None, ) -> Dict[str, Any]: - """构建 emotion tool 统一监控详情。""" + """构建 send_emoji 工具统一监控详情。""" detail: Dict[str, Any] = {} if isinstance(request_messages, list) and request_messages: @@ -281,7 +245,6 @@ def _build_send_emoji_monitor_detail( def _build_send_emoji_monitor_metadata( selection_metadata: Dict[str, Any], *, - requested_emotion: str, send_result: Optional[Any] = None, error_message: str = "", ) -> Dict[str, Any]: @@ -293,7 +256,6 @@ def _build_send_emoji_monitor_metadata( if send_result is not None: result_lines = [ - f"请求情绪:{requested_emotion or '未指定'}", f"命中情绪:{send_result.matched_emotion or '未命中'}", f"表情描述:{send_result.description or '无描述'}", f"情绪标签:{'、'.join(send_result.emotions) if send_result.emotions else '无'}", @@ -306,10 +268,7 @@ def _build_send_emoji_monitor_metadata( elif error_message.strip(): extra_sections.append({ "title": "表情发送结果", - "content": ( - f"请求情绪:{requested_emotion or '未指定'}\n" - f"发送结果:{error_message.strip()}" - ), + "content": f"发送结果:{error_message.strip()}", }) if extra_sections: @@ -322,7 +281,6 @@ def _build_send_emoji_monitor_metadata( async def _select_emoji_with_sub_agent( tool_ctx: BuiltinToolRuntimeContext, - requested_emotion: str, reasoning: str, context_texts: list[str], sample_size: int, @@ -347,14 +305,12 @@ async def _select_emoji_with_sub_agent( f"一共 {len(sampled_emojis)} 个位置。\n" f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}。\n" f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n" - "如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n" "你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n" '返回格式固定为:{"emoji_index":1,"reason":"简短理由"}' ) prompt_message = ReferenceMessage( content=( f"[选择任务]\n" - f"requested_emotion: {requested_emotion or '未指定'}\n" f"候选总数: {len(sampled_emojis)}\n" f"拼图布局: {grid_rows}x{grid_columns}\n" "请只输出 JSON。" @@ -439,7 +395,6 @@ async def handle_tool( """执行 send_emoji 内置工具。""" del context - emotion = str(invocation.arguments.get("emotion") or "").strip() context_texts = [ message.processed_plain_text.strip() for message in tool_ctx.runtime._chat_history[-5:] @@ -450,23 +405,20 @@ async def handle_tool( "message": "", "description": "", "emotion": [], - "requested_emotion": emotion, "matched_emotion": "", "reason": "", } selection_metadata: Dict[str, Any] = {"reason": "", "monitor_detail": {}} - logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}") + logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具") try: send_result = await send_emoji_for_maisaka( stream_id=tool_ctx.runtime.session_id, - requested_emotion=emotion, reasoning=tool_ctx.engine.last_reasoning_content, context_texts=context_texts, - emoji_selector=lambda requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent( + emoji_selector=lambda _requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent( tool_ctx, - requested_emotion, reasoning, list(context_texts or []), sample_size, @@ -482,7 +434,6 @@ async def handle_tool( structured_content=structured_result, metadata=_build_send_emoji_monitor_metadata( selection_metadata, - requested_emotion=emotion, error_message=structured_result["message"], ), ) @@ -493,7 +444,7 @@ async def handle_tool( logger.info( f"{tool_ctx.runtime.log_prefix} 表情包发送成功 " f"描述={send_result.description!r} 情绪标签={send_result.emotions} " - f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}" + f"命中情绪={send_result.matched_emotion!r}" ) if send_result.sent_message is not None: tool_ctx.append_sent_message_to_chat_history(send_result.sent_message) @@ -509,7 +460,6 @@ async def handle_tool( structured_content=structured_result, metadata=_build_send_emoji_monitor_metadata( selection_metadata, - requested_emotion=emotion, send_result=send_result, ), ) @@ -521,7 +471,7 @@ async def handle_tool( logger.warning( f"{tool_ctx.runtime.log_prefix} 表情包发送失败 " - f"请求情绪={emotion!r} 错误信息={send_result.message}" + f"错误信息={send_result.message}" ) return tool_ctx.build_failure_result( invocation.tool_name, @@ -529,7 +479,6 @@ async def handle_tool( structured_content=structured_result, metadata=_build_send_emoji_monitor_metadata( selection_metadata, - requested_emotion=emotion, send_result=send_result, ), ) diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index 5a6b1712..e45fa2d6 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -528,14 +528,12 @@ class MaisakaChatLoopService: prompt_section: RenderableType | None = None if global_config.debug.show_maisaka_thinking: - image_display_mode: str = "path_link" if global_config.maisaka.show_image_path else "legacy" prompt_section = PromptCLIVisualizer.build_prompt_section( built_messages, category="planner" if request_kind != "timing_gate" else "timing_gate", chat_id=self._session_id, request_kind=request_kind, selection_reason=selection_reason, - image_display_mode=image_display_mode, folded=global_config.debug.fold_maisaka_thinking, tool_definitions=list(all_tools), ) diff --git a/src/maisaka/display/prompt_cli_renderer.py b/src/maisaka/display/prompt_cli_renderer.py index 9de08cec..a770261e 100644 --- a/src/maisaka/display/prompt_cli_renderer.py +++ b/src/maisaka/display/prompt_cli_renderer.py @@ -799,7 +799,7 @@ class PromptCLIVisualizer: chat_id: str, request_kind: str, selection_reason: str, - image_display_mode: Literal["legacy", "path_link"], + image_display_mode: Literal["legacy", "path_link"] = "path_link", tool_definitions: list[dict[str, Any]] | None = None, ) -> RenderableType: """构建用于查看完整 prompt 的折叠入口内容。""" @@ -864,7 +864,7 @@ class PromptCLIVisualizer: chat_id: str, request_kind: str, selection_reason: str, - image_display_mode: Literal["legacy", "path_link"], + image_display_mode: Literal["legacy", "path_link"] = "path_link", folded: bool, tool_definitions: list[dict[str, Any]] | None = None, ) -> Panel: @@ -878,14 +878,10 @@ class PromptCLIVisualizer: chat_id=chat_id, request_kind=request_kind, selection_reason=selection_reason, - image_display_mode=image_display_mode, tool_definitions=tool_definitions, ) else: - ordered_panels = cls.build_prompt_panels( - messages, - image_display_mode=image_display_mode, - ) + ordered_panels = cls.build_prompt_panels(messages) prompt_renderable = Group(*ordered_panels) return Panel( @@ -1102,11 +1098,9 @@ class PromptCLIVisualizer: cls, messages: list[Any], *, - image_display_mode: Literal["legacy", "path_link"], + image_display_mode: Literal["legacy", "path_link"] = "path_link", ) -> List[Panel]: """构建完整 prompt 可视化面板。""" - if image_display_mode not in {mode.value for mode in PromptImageDisplayMode}: - image_display_mode = PromptImageDisplayMode.LEGACY settings = PromptImageDisplaySettings( display_mode=PromptImageDisplayMode(image_display_mode), ) diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 78d0a4b2..5de81cf2 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -1151,7 +1151,6 @@ class MaisakaHeartFlowChatting: chat_id=self.session_id, request_kind=labels["request_kind"], selection_reason=subtitle, - image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy", ), title=labels["prompt_title"], border_style=border_style, diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py deleted file mode 100644 index 94f4390f..00000000 --- a/src/memory_system/chat_history_summarizer.py +++ /dev/null @@ -1,1123 +0,0 @@ -""" -聊天内容概括器 -用于累积、打包和压缩聊天记录 -""" - -import asyncio -import json -import time -import re -import difflib -import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Set -from dataclasses import dataclass, field -from json_repair import repair_json - -from src.chat.message_receive.message import SessionMessage -from src.common.logger import get_logger -from src.config.config import global_config -from src.common.data_models.llm_service_data_models import LLMGenerationOptions -from src.services.llm_service import LLMServiceClient -from src.services import message_service as message_api -from src.chat.utils.utils import is_bot_self -from src.person_info.person_info import Person -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.prompt.prompt_manager import prompt_manager - -logger = get_logger("chat_history_summarizer") - -HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorizer" - - -@dataclass -class MessageBatch: - """消息批次(用于触发话题检查的原始消息累积)""" - - messages: List[SessionMessage] - start_time: float - end_time: float - - -@dataclass -class TopicCacheItem: - """ - 话题缓存项 - - Attributes: - topic: 话题标题(一句话描述时间、人物、事件和主题) - messages: 与该话题相关的消息字符串列表(已经通过 build 函数转成可读文本) - participants: 涉及到的发言人昵称集合 - no_update_checks: 连续多少次“检查”没有新增内容 - """ - - topic: str - messages: List[str] = field(default_factory=list) - participants: Set[str] = field(default_factory=set) - no_update_checks: int = 0 - - -class ChatHistorySummarizer: - """聊天内容概括器""" - - def __init__(self, session_id: str, check_interval: int = 60): - """ - 初始化聊天内容概括器 - - Args: - session_id: 会话ID - check_interval: 定期检查间隔(秒),默认60秒 - """ - self.session_id = session_id - self._chat_display_name = self._get_chat_display_name() - self.log_prefix = f"[{self._chat_display_name}]" - - # 记录时间点,用于计算新消息 - self.last_check_time = time.time() - - # 记录上一次话题检查的时间,用于判断是否需要触发检查 - self.last_topic_check_time = time.time() - - # 当前累积的消息批次 - self.current_batch: Optional[MessageBatch] = None - - # 话题缓存:topic_str -> TopicCacheItem - # 在内存中维护,并通过本地文件实时持久化 - self.topic_cache: Dict[str, TopicCacheItem] = {} - self._safe_chat_id = self._sanitize_chat_id(self.session_id) - self._topic_cache_file = HIPPO_CACHE_DIR / f"{self._safe_chat_id}.json" - # 注意:批次加载需要异步查询消息,所以在 start() 中调用 - - # LLM请求器,用于压缩聊天内容 - self.summarizer_llm = LLMServiceClient( - task_name="utils", request_type="chat_history_summarizer" - ) - - # 后台循环相关 - self.check_interval = check_interval # 检查间隔(秒) - self._periodic_task: Optional[asyncio.Task] = None - self._running = False - - def _get_chat_display_name(self) -> str: - """获取聊天显示名称""" - try: - chat_name = _chat_manager.get_session_name(self.session_id) - if chat_name: - return chat_name - # 如果获取失败,使用简化的chat_id显示 - if len(self.session_id) > 20: - return f"{self.session_id[:8]}..." - return self.session_id - except Exception: - # 如果获取失败,使用简化的chat_id显示 - if len(self.session_id) > 20: - return f"{self.session_id[:8]}..." - return self.session_id - - def _sanitize_chat_id(self, chat_id: str) -> str: - """用于生成可作为文件名的 chat_id""" - return re.sub(r"[^a-zA-Z0-9_.-]", "_", chat_id) - - def _load_topic_cache_from_disk(self): - """在启动时加载本地话题缓存(同步部分),支持重启后继续""" - try: - if not self._topic_cache_file.exists(): - return - - with self._topic_cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - self.last_topic_check_time = data.get("last_topic_check_time", self.last_topic_check_time) - topics_data = data.get("topics", {}) - loaded_count = 0 - for topic, payload in topics_data.items(): - self.topic_cache[topic] = TopicCacheItem( - topic=topic, - messages=payload.get("messages", []), - participants=set(payload.get("participants", [])), - no_update_checks=payload.get("no_update_checks", 0), - ) - loaded_count += 1 - - if loaded_count: - logger.info(f"{self.log_prefix} 已加载 {loaded_count} 个话题缓存,继续追踪") - except Exception as e: - logger.error(f"{self.log_prefix} 加载话题缓存失败: {e}") - - async def _load_batch_from_disk(self): - """在启动时加载聊天批次,支持重启后继续""" - try: - if not self._topic_cache_file.exists(): - return - - with self._topic_cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - batch_data = data.get("current_batch") - if not batch_data: - return - - start_time = batch_data.get("start_time") - end_time = batch_data.get("end_time") - if not start_time or not end_time: - return - - # 根据时间范围重新查询消息 - messages = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=start_time, - end_time=end_time, - limit=0, - limit_mode="latest", - filter_mai=False, - filter_command=False, - ) - - if messages: - self.current_batch = MessageBatch( - messages=messages, - start_time=start_time, - end_time=end_time, - ) - logger.info(f"{self.log_prefix} 已恢复聊天批次,包含 {len(messages)} 条消息") - except Exception as e: - logger.error(f"{self.log_prefix} 加载聊天批次失败: {e}") - - def _persist_topic_cache(self): - """实时持久化话题缓存和聊天批次,避免重启后丢失""" - try: - # 如果既没有话题缓存也没有批次,删除缓存文件 - if not self.topic_cache and not self.current_batch: - if self._topic_cache_file.exists(): - self._topic_cache_file.unlink() - return - - HIPPO_CACHE_DIR.mkdir(parents=True, exist_ok=True) - data = { - "chat_id": self.session_id, - "last_topic_check_time": self.last_topic_check_time, - "topics": { - topic: { - "messages": item.messages, - "participants": list(item.participants), - "no_update_checks": item.no_update_checks, - } - for topic, item in self.topic_cache.items() - }, - } - - # 保存当前批次的时间范围(如果有) - if self.current_batch: - data["current_batch"] = { - "start_time": self.current_batch.start_time, - "end_time": self.current_batch.end_time, - } - - with self._topic_cache_file.open("w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"{self.log_prefix} 持久化话题缓存失败: {e}") - - async def process(self, current_time: Optional[float] = None): - """ - 处理聊天内容概括 - - Args: - current_time: 当前时间戳,如果为None则使用time.time() - """ - if current_time is None: - current_time = time.time() - - try: - # 获取从上次检查时间到当前时间的新消息 - new_messages = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=self.last_check_time, - end_time=current_time, - limit=0, - limit_mode="latest", - filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言 - filter_command=False, - ) - - if not new_messages: - # 没有新消息,检查是否需要进行“话题检查” - if self.current_batch and self.current_batch.messages: - await self._check_and_run_topic_check(current_time) - self.last_check_time = current_time - return - - logger.debug( - f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" - ) - - # 有新消息,更新最后检查时间 - self.last_check_time = current_time - - # 如果有当前批次,添加新消息 - if self.current_batch: - before_count = len(self.current_batch.messages) - self.current_batch.messages.extend(new_messages) - self.current_batch.end_time = current_time - logger.info( - f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息" - ) - # 更新批次后持久化 - self._persist_topic_cache() - else: - # 创建新批次 - self.current_batch = MessageBatch( - messages=new_messages, - start_time=new_messages[0].timestamp.timestamp() if new_messages else current_time, - end_time=current_time, - ) - logger.debug(f"{self.log_prefix} 新建聊天检查批次: {len(new_messages)} 条消息") - # 创建批次后持久化 - self._persist_topic_cache() - - # 检查是否需要触发“话题检查” - await self._check_and_run_topic_check(current_time) - - except Exception as e: - logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") - import traceback - - traceback.print_exc() - - async def _check_and_run_topic_check(self, current_time: float): - """ - 检查是否需要进行一次“话题检查” - - 触发条件: - - 当前批次消息数 >= 100,或者 - - 距离上一次检查的时间 > 3600 秒(1小时) - """ - if not self.current_batch or not self.current_batch.messages: - return - - messages = self.current_batch.messages - message_count = len(messages) - time_since_last_check = current_time - self.last_topic_check_time - - # 格式化时间差显示 - if time_since_last_check < 60: - time_str = f"{time_since_last_check:.1f}秒" - elif time_since_last_check < 3600: - time_str = f"{time_since_last_check / 60:.1f}分钟" - else: - time_str = f"{time_since_last_check / 3600:.1f}小时" - - logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}") - - # 检查"话题检查"触发条件 - should_check = False - - # 从配置中获取阈值 - message_threshold = global_config.memory.chat_history_topic_check_message_threshold - time_threshold_hours = global_config.memory.chat_history_topic_check_time_hours - min_messages = global_config.memory.chat_history_topic_check_min_messages - time_threshold_seconds = time_threshold_hours * 3600 - - # 条件1: 消息数量达到阈值,触发一次检查 - if message_count >= message_threshold: - should_check = True - logger.info( - f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: {message_threshold}条)" - ) - - # 条件2: 距离上一次检查超过时间阈值且消息数量达到最小阈值,触发一次检查 - elif time_since_last_check > time_threshold_seconds and message_count >= min_messages: - should_check = True - logger.info( - f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: {time_threshold_hours}小时)且消息数量达到 {message_count} 条(阈值: {min_messages}条)" - ) - - if should_check: - await self._run_topic_check_and_update_cache(messages) - # 本批次已经被处理为话题信息,可以清空 - self.current_batch = None - # 更新上一次检查时间,并持久化 - self.last_topic_check_time = current_time - self._persist_topic_cache() - - async def _run_topic_check_and_update_cache(self, messages: List[SessionMessage]): - """ - 执行一次“话题检查”: - 1. 首先确认这段消息里是否有 Bot 发言,没有则直接丢弃本次批次; - 2. 将消息编号并转成字符串,构造 LLM Prompt; - 3. 把历史话题标题列表放入 Prompt,要求 LLM: - - 识别当前聊天中的话题(1 个或多个); - - 为每个话题选出相关消息编号; - - 若话题属于历史话题,则沿用原话题标题; - 4. LLM 返回 JSON:多个 {topic, message_indices}; - 5. 更新本地话题缓存,并根据规则触发“话题打包存储”。 - """ - if not messages: - return - - start_time = messages[0].timestamp.timestamp() - end_time = messages[-1].timestamp.timestamp() - - logger.info( - f"{self.log_prefix} 开始话题检查 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - - # 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推) - # 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言, - # 说明 bot 没有参与这段对话,不应该记录 - has_bot_message = any( - is_bot_self(msg.platform, msg.message_info.user_info.user_id) for msg in messages - ) - - if not has_bot_message: - logger.info( - f"{self.log_prefix} 当前批次内无 Bot 发言,丢弃本次检查 | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - return - - # 2. 构造编号后的消息字符串和参与者信息 - numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = ( - self._build_numbered_messages_for_llm(messages) - ) - - # 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次) - existing_topics = list(self.topic_cache.keys()) - max_retries = 3 - attempt = 0 - success = False - topic_to_indices: Dict[str, List[int]] = {} - - while attempt < max_retries: - attempt += 1 - success, topic_to_indices = await self._analyze_topics_with_llm( - numbered_lines=numbered_lines, - existing_topics=existing_topics, - ) - - if success and topic_to_indices: - if attempt > 1: - logger.info( - f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}" - ) - break - - logger.warning( - f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败" - + ("" if attempt >= max_retries else ",准备重试") - ) - - if not success or not topic_to_indices: - logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃") - # 即使识别失败,也认为是一次"检查",但不更新 no_update_checks(保持原状) - return - - # 3.5. 检查新话题是否与历史话题相似(相似度>=90%则使用历史标题) - topic_mapping = self._build_topic_mapping(topic_to_indices, similarity_threshold=0.9) - - # 应用话题映射:将相似的新话题标题替换为历史话题标题 - if topic_mapping: - new_topic_to_indices: Dict[str, List[int]] = {} - for new_topic, indices in topic_to_indices.items(): - # 如果这个新话题需要映射到历史话题 - if new_topic in topic_mapping: - historical_topic = topic_mapping[new_topic] - # 如果历史话题已经存在,合并消息索引 - if historical_topic in new_topic_to_indices: - # 合并索引并去重 - combined_indices = list(set(new_topic_to_indices[historical_topic] + indices)) - new_topic_to_indices[historical_topic] = combined_indices - else: - new_topic_to_indices[historical_topic] = indices - else: - # 不需要映射,保持原样 - new_topic_to_indices[new_topic] = indices - topic_to_indices = new_topic_to_indices - - # 4. 统计哪些话题在本次检查中有新增内容 - updated_topics: Set[str] = set() - - for topic, indices in topic_to_indices.items(): - if not indices: - continue - - item = self.topic_cache.get(topic) - if not item: - # 新话题 - item = TopicCacheItem(topic=topic) - self.topic_cache[topic] = item - - # 收集属于该话题的消息文本(不带编号) - topic_msg_texts: List[str] = [] - new_participants: Set[str] = set() - for idx in indices: - msg_text = index_to_msg_text.get(idx) - if not msg_text: - continue - topic_msg_texts.append(msg_text) - new_participants.update(index_to_participants.get(idx, set())) - - if not topic_msg_texts: - continue - - # 将本次检查中属于该话题的所有消息合并为一个字符串(不带编号) - merged_text = "\n".join(topic_msg_texts) - item.messages.append(merged_text) - item.participants.update(new_participants) - # 本次检查中该话题有更新,重置计数 - item.no_update_checks = 0 - updated_topics.add(topic) - - # 5. 对于本次没有更新的历史话题,no_update_checks + 1 - for topic, item in list(self.topic_cache.items()): - if topic not in updated_topics: - item.no_update_checks += 1 - - # 6. 检查是否有话题需要打包存储 - # 从配置中获取阈值 - no_update_checks_threshold = global_config.memory.chat_history_finalize_no_update_checks - message_count_threshold = global_config.memory.chat_history_finalize_message_count - - topics_to_finalize: List[str] = [] - for topic, item in self.topic_cache.items(): - if item.no_update_checks >= no_update_checks_threshold: - logger.info( - f"{self.log_prefix} 话题[{topic}] 连续 {no_update_checks_threshold} 次检查无新增内容,触发打包存储" - ) - topics_to_finalize.append(topic) - continue - if len(item.messages) > message_count_threshold: - logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 {message_count_threshold},触发打包存储") - topics_to_finalize.append(topic) - - for topic in topics_to_finalize: - item = self.topic_cache.get(topic) - if not item: - continue - try: - await self._finalize_and_store_topic( - topic=topic, - item=item, - # 这里的时间范围尽量覆盖最近一次检查的区间 - start_time=start_time, - end_time=end_time, - ) - finally: - # 无论成功与否,都从缓存中删除,避免重复 - self.topic_cache.pop(topic, None) - - def _find_most_similar_topic( - self, new_topic: str, existing_topics: List[str], similarity_threshold: float = 0.9 - ) -> Optional[tuple[str, float]]: - """ - 查找与给定新话题最相似的历史话题 - - Args: - new_topic: 新话题标题 - existing_topics: 历史话题标题列表 - similarity_threshold: 相似度阈值,默认0.9(90%) - - Returns: - Optional[tuple[str, float]]: 如果找到相似度>=阈值的历史话题,返回(历史话题标题, 相似度), - 否则返回None - """ - if not existing_topics: - return None - - best_match = None - best_similarity = 0.0 - - for existing_topic in existing_topics: - similarity = difflib.SequenceMatcher(None, new_topic, existing_topic).ratio() - if similarity > best_similarity: - best_similarity = similarity - best_match = existing_topic - - # 如果相似度达到阈值,返回匹配结果 - if best_match and best_similarity >= similarity_threshold: - return (best_match, best_similarity) - - return None - - def _build_topic_mapping( - self, topic_to_indices: Dict[str, List[int]], similarity_threshold: float = 0.9 - ) -> Dict[str, str]: - """ - 构建新话题到历史话题的映射(如果相似度>=阈值) - - Args: - topic_to_indices: 新话题到消息索引的映射 - similarity_threshold: 相似度阈值,默认0.9(90%) - - Returns: - Dict[str, str]: 新话题 -> 历史话题的映射字典 - """ - existing_topics_list = list(self.topic_cache.keys()) - topic_mapping: Dict[str, str] = {} - - for new_topic in topic_to_indices.keys(): - # 如果新话题已经在历史话题中,不需要检查 - if new_topic in existing_topics_list: - continue - - # 查找最相似的历史话题 - result = self._find_most_similar_topic(new_topic, existing_topics_list, similarity_threshold) - if result: - historical_topic, similarity = result - topic_mapping[new_topic] = historical_topic - logger.info( - f"{self.log_prefix} 话题相似度检查: '{new_topic}' 与历史话题 '{historical_topic}' 相似度 {similarity:.2%},使用历史标题" - ) - - return topic_mapping - - def _build_numbered_messages_for_llm( - self, messages: List[SessionMessage] - ) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]: - """ - 将消息转为带编号的字符串,供 LLM 选择使用。 - - 返回: - numbered_lines: ["1. xxx", "2. yyy", ...] # 带编号,用于 LLM 选择 - index_to_msg_str: idx -> "idx. xxx" # 带编号,用于 LLM 选择 - index_to_msg_text: idx -> "xxx" # 不带编号,用于最终存储 - index_to_participants: idx -> {nickname1, nickname2, ...} - """ - numbered_lines: List[str] = [] - index_to_msg_str: Dict[int, str] = {} - index_to_msg_text: Dict[int, str] = {} # 不带编号的消息文本 - index_to_participants: Dict[int, Set[str]] = {} - - for idx, msg in enumerate(messages, start=1): - # 使用 build_readable_messages 生成可读文本 - try: - text = message_api.build_readable_messages( - messages=[msg], - replace_bot_name=True, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=False, - show_actions=False, - ).strip() - except Exception: - # 回退到简单文本 - text = getattr(msg, "processed_plain_text", "") or "" - - # 获取发言人昵称 - participants: Set[str] = set() - try: - platform = msg.platform - user_id = msg.message_info.user_info.user_id - if platform and user_id: - person = Person(platform=platform, user_id=user_id) - if person.person_name: - participants.add(person.person_name) - except Exception: - pass - - # 带编号的字符串(用于 LLM 选择) - line = f"{idx}. {text}" - numbered_lines.append(line) - index_to_msg_str[idx] = line - # 不带编号的文本(用于最终存储) - index_to_msg_text[idx] = text - index_to_participants[idx] = participants - - return numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants - - async def _analyze_topics_with_llm( - self, - numbered_lines: List[str], - existing_topics: List[str], - ) -> tuple[bool, Dict[str, List[int]]]: - """ - 使用 LLM 识别本次检查中的话题,并为每个话题选择相关消息编号。 - - 要求: - - 话题用一句话清晰描述正在发生的事件,包括时间、人物、主要事件和主题; - - 可以有 1 个或多个话题; - - 若某个话题与历史话题列表中的某个话题是同一件事,请直接使用历史话题的字符串; - - 输出 JSON,格式: - [ - { - "topic": "话题标题字符串", - "message_indices": [1, 2, 5] - }, - ... - ] - """ - if not numbered_lines: - return False, {} - - history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" - messages_block = "\n".join(numbered_lines) - - prompt_template = prompt_manager.get_prompt("hippo_topic_analysis") - prompt_template.add_context("history_topics_block", history_topics_block) - prompt_template.add_context("messages_block", messages_block) - prompt = await prompt_manager.render_prompt(prompt_template) - - try: - generation_result = await self.summarizer_llm.generate_response( - prompt=prompt, - options=LLMGenerationOptions(temperature=0.3), - ) - response = generation_result.response - - logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") - logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") - - # 尝试从响应中提取JSON代码块 - json_str = None - json_pattern = r"```json\s*(.*?)\s*```" - matches = re.findall(json_pattern, response, re.DOTALL) - - if matches: - # 找到JSON代码块,使用第一个匹配 - json_str = matches[0].strip() - else: - # 如果没有找到代码块,尝试查找JSON数组的开始和结束位置 - # 查找第一个 [ 和最后一个 ] - start_idx = response.find("[") - end_idx = response.rfind("]") - if start_idx != -1 and end_idx != -1 and end_idx > start_idx: - json_str = response[start_idx : end_idx + 1].strip() - else: - # 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记) - json_str = response.strip() - json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) - json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) - json_str = json_str.strip() - - # 使用json_repair修复可能的JSON错误 - if json_str: - try: - repaired_json = repair_json(json_str) - result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json - except Exception as repair_error: - # 如果repair失败,尝试直接解析 - logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}") - result = json.loads(json_str) - else: - raise ValueError("无法从响应中提取JSON内容") - - if not isinstance(result, list): - logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}") - return False, {} - - topic_to_indices: Dict[str, List[int]] = {} - for item in result: - if not isinstance(item, dict): - continue - topic = item.get("topic") - indices = item.get("message_indices") or item.get("messages") or [] - if not topic or not isinstance(topic, str): - continue - if isinstance(indices, list): - valid_indices: List[int] = [] - for v in indices: - try: - iv = int(v) - if iv > 0: - valid_indices.append(iv) - except (TypeError, ValueError): - continue - if valid_indices: - topic_to_indices[topic] = valid_indices - - return True, topic_to_indices - - except Exception as e: - logger.error(f"{self.log_prefix} 话题识别 LLM 调用或解析失败: {e}") - logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - return False, {} - - async def _finalize_and_store_topic( - self, - topic: str, - item: TopicCacheItem, - start_time: float, - end_time: float, - ): - """ - 对某个话题进行最终打包存储: - 1. 将 messages(list[str]) 拼接为 original_text; - 2. 使用 LLM 对 original_text 进行总结,得到 summary 和 keywords,theme 直接使用话题字符串; - 3. 写入数据库 ChatHistory; - 4. 完成后,调用方会从缓存中删除该话题。 - """ - if not item.messages: - logger.info(f"{self.log_prefix} 话题[{topic}] 无消息内容,跳过打包") - return - - original_text = "\n".join(item.messages) - - logger.info( - f"{self.log_prefix} 开始将聊天记录构建成记忆:[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - - # 使用 LLM 进行总结(基于话题名),带重试机制 - max_retries = 3 - attempt = 0 - success = False - keywords = [] - summary = "" - - while attempt < max_retries: - attempt += 1 - success, keywords, summary = await self._compress_with_llm(original_text, topic) - - if success and keywords and summary: - # 成功获取到有效的 keywords 和 summary - if attempt > 1: - logger.info(f"{self.log_prefix} 话题[{topic}] LLM 概括在第 {attempt} 次重试后成功") - break - - if attempt < max_retries: - logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败(第 {attempt} 次尝试),准备重试") - else: - logger.error(f"{self.log_prefix} 话题[{topic}] LLM 概括连续 {max_retries} 次失败,放弃存储") - - if not success or not keywords or not summary: - logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库") - return - - participants = list(item.participants) - - await self._store_to_database( - start_time=start_time, - end_time=end_time, - original_text=original_text, - participants=participants, - theme=topic, # 主题直接使用话题名 - keywords=keywords, - summary=summary, - ) - - logger.info( - f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}" - ) - - async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str]: - """ - 使用LLM压缩聊天内容(用于单个话题的最终总结) - - Args: - original_text: 聊天记录原文 - topic: 话题名称 - - Returns: - tuple[bool, List[str], str]: (是否成功, 关键词列表, 概括) - """ - prompt_template = prompt_manager.get_prompt("hippo_topic_summary") - prompt_template.add_context("topic", topic) - prompt_template.add_context("original_text", original_text) - prompt = await prompt_manager.render_prompt(prompt_template) - - try: - generation_result = await self.summarizer_llm.generate_response(prompt=prompt) - response = generation_result.response - - # 解析JSON响应 - json_str = response.strip() - json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) - json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) - json_str = json_str.strip() - - # 查找JSON对象的开始与结束 - start_idx = json_str.find("{") - if start_idx == -1: - raise ValueError("未找到JSON对象开始标记") - - end_idx = json_str.rfind("}") - if end_idx == -1 or end_idx <= start_idx: - logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复") - extracted_json = json_str[start_idx:] - else: - extracted_json = json_str[start_idx : end_idx + 1] - - def _parse_with_quote_fix(payload: str) -> Dict[str, Any]: - fixed_chars: List[str] = [] - in_string = False - escape_next = False - i = 0 - while i < len(payload): - char = payload[i] - if escape_next: - fixed_chars.append(char) - escape_next = False - elif char == "\\": - fixed_chars.append(char) - escape_next = True - elif char == '"' and not escape_next: - fixed_chars.append(char) - in_string = not in_string - elif in_string and char in {"“", "”"}: - # 在字符串值内部,将中文引号替换为转义的英文引号 - fixed_chars.append('\\"') - else: - fixed_chars.append(char) - i += 1 - - repaired = "".join(fixed_chars) - return json.loads(repaired) - - try: - result = json.loads(extracted_json) - except json.JSONDecodeError: - try: - repaired_json = repair_json(extracted_json) - if isinstance(repaired_json, str): - result = json.loads(repaired_json) - else: - result = repaired_json - except Exception as repair_error: - logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}") - result = _parse_with_quote_fix(extracted_json) - - keywords = result.get("keywords", []) - summary = result.get("summary", "") - - # 检查必需字段是否为空 - if not keywords or not summary: - logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少必需字段,原文\n{response}") - # 返回失败,和模型出错一样,让上层进行重试 - return False, [], "" - - # 确保keywords是列表 - if isinstance(keywords, str): - keywords = [keywords] - - return True, keywords, summary - - except Exception as e: - logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") - logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - # 返回失败标志和默认值 - return False, [], "压缩失败,无法生成概括" - - async def _store_to_database( - self, - start_time: float, - end_time: float, - original_text: str, - participants: List[str], - theme: str, - keywords: List[str], - summary: str, - ): - """存储到数据库""" - try: - from src.common.database.database_model import ChatHistory - from src.services import database_service as database_api - - # 准备数据 - data = { - "session_id": self.session_id, - "start_timestamp": datetime.fromtimestamp(start_time), - "end_timestamp": datetime.fromtimestamp(end_time), - "original_messages": original_text, - "participants": json.dumps(participants, ensure_ascii=False), - "theme": theme, - "keywords": json.dumps(keywords, ensure_ascii=False), - "summary": summary, - "query_count": 0, - "query_forget_count": 0, - } - - saved_record = await database_api.db_save( - ChatHistory, - data=data, - ) - - if saved_record: - logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库") - else: - logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") - - if saved_record and saved_record.get("id") is not None: - await self._import_to_long_term_memory( - record_id=int(saved_record["id"]), - theme=theme, - summary=summary, - participants=participants, - start_time=start_time, - end_time=end_time, - original_text=original_text, - ) - - except Exception as e: - logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") - import traceback - - traceback.print_exc() - raise - - async def _import_to_long_term_memory( - self, - record_id: int, - theme: str, - summary: str, - participants: List[str], - start_time: float, - end_time: float, - original_text: str, - ): - """ - 将聊天历史总结导入到统一长期记忆 - - Args: - record_id: chat_history 主键 - theme: 话题主题 - summary: 概括内容 - participants: 参与者列表 - start_time: 开始时间 - end_time: 结束时间 - original_text: 原始文本(可能很长,需要截断) - """ - try: - from src.services.memory_service import memory_service - session = _chat_manager.get_session_by_session_id(self.session_id) - session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" - session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" - - content_parts = [] - if theme: - content_parts.append(f"主题:{theme}") - if summary: - content_parts.append(f"概括:{summary}") - if participants: - participants_text = "、".join(participants) - content_parts.append(f"参与者:{participants_text}") - content_to_import = "\n".join(content_parts) - - if not content_to_import.strip(): - logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底") - await self._fallback_import_to_long_term_memory( - record_id=record_id, - theme=theme, - participants=participants, - start_time=start_time, - end_time=end_time, - original_text=original_text, - ) - return - - result = await memory_service.ingest_summary( - external_id=f"chat_history:{record_id}", - chat_id=self.session_id, - text=content_to_import, - participants=participants, - time_start=start_time, - time_end=end_time, - tags=[theme] if theme else [], - metadata={"theme": theme, "original_text_length": len(original_text or "")}, - respect_filter=True, - user_id=session_user_id, - group_id=session_group_id, - ) - if result.success: - if result.detail == "chat_filtered": - logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}") - else: - logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}") - else: - logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}") - await self._fallback_import_to_long_term_memory( - record_id=record_id, - theme=theme, - participants=participants, - start_time=start_time, - end_time=end_time, - original_text=original_text, - ) - - except Exception as e: - logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True) - - async def _fallback_import_to_long_term_memory( - self, - *, - record_id: int, - theme: str, - participants: List[str], - start_time: float, - end_time: float, - original_text: str, - ) -> None: - try: - from src.services.memory_service import memory_service - session = _chat_manager.get_session_by_session_id(self.session_id) - session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" - session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" - - result = await memory_service.ingest_summary( - external_id=f"chat_history:{record_id}", - chat_id=self.session_id, - text="", - participants=participants, - time_start=start_time, - time_end=end_time, - tags=[theme] if theme else [], - metadata={ - "theme": theme, - "original_text_length": len(original_text or ""), - "generate_from_chat": True, - "context_length": global_config.memory.chat_history_topic_check_message_threshold, - }, - respect_filter=True, - user_id=session_user_id, - group_id=session_group_id, - ) - if result.success: - if result.detail == "chat_filtered": - logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}") - else: - logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}") - else: - logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}") - except Exception as exc: - logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True) - - async def start(self): - """启动后台定期检查循环""" - if self._running: - logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动") - return - - # 加载聊天批次(如果有) - await self._load_batch_from_disk() - - self._running = True - self._periodic_task = asyncio.create_task(self._periodic_check_loop()) - logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒") - - async def stop(self): - """停止后台定期检查循环""" - self._running = False - if self._periodic_task: - self._periodic_task.cancel() - try: - await self._periodic_task - except asyncio.CancelledError: - pass - self._periodic_task = None - logger.info(f"{self.log_prefix} 已停止后台定期检查循环") - - async def _periodic_check_loop(self): - """后台定期检查循环""" - try: - while self._running: - # 执行一次检查 - await self.process() - - # 等待指定间隔后再次检查 - await asyncio.sleep(self.check_interval) - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} 后台检查循环被取消") - raise - except Exception as e: - logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") - import traceback - - traceback.print_exc() - self._running = False diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py index 75ff0ca9..8b7d8aa4 100644 --- a/src/services/memory_flow_service.py +++ b/src/services/memory_flow_service.py @@ -2,54 +2,20 @@ from __future__ import annotations import asyncio import json -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from json_repair import repair_json from src.chat.utils.utils import is_bot_self -from src.common.message_repository import find_messages from src.common.logger import get_logger +from src.common.message_repository import find_messages from src.config.config import global_config -from src.memory_system.chat_history_summarizer import ChatHistorySummarizer from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer from src.services.llm_service import LLMServiceClient logger = get_logger("memory_flow_service") -class LongTermMemorySessionManager: - def __init__(self) -> None: - self._lock = asyncio.Lock() - self._summarizers: Dict[str, ChatHistorySummarizer] = {} - - async def on_message(self, message: Any) -> None: - if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)): - return - session_id = str(getattr(message, "session_id", "") or "").strip() - if not session_id: - return - - created = False - async with self._lock: - summarizer = self._summarizers.get(session_id) - if summarizer is None: - summarizer = ChatHistorySummarizer(session_id=session_id) - self._summarizers[session_id] = summarizer - created = True - if created: - await summarizer.start() - - async def shutdown(self) -> None: - async with self._lock: - items = list(self._summarizers.items()) - self._summarizers.clear() - for session_id, summarizer in items: - try: - await summarizer.stop() - except Exception as exc: - logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc) - - class PersonFactWritebackService: def __init__(self) -> None: self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256) @@ -123,7 +89,11 @@ class PersonFactWritebackService: if not session_id: return - person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip() + person_name = str( + getattr(target_person, "person_name", "") + or getattr(target_person, "nickname", "") + or "" + ).strip() if not person_name: return @@ -242,7 +212,6 @@ class PersonFactWritebackService: class MemoryAutomationService: def __init__(self) -> None: - self.session_manager = LongTermMemorySessionManager() self.fact_writeback = PersonFactWritebackService() self._started = False @@ -255,14 +224,13 @@ class MemoryAutomationService: async def shutdown(self) -> None: if not self._started: return - await self.session_manager.shutdown() await self.fact_writeback.shutdown() self._started = False async def on_incoming_message(self, message: Any) -> None: + del message if not self._started: await self.start() - await self.session_manager.on_message(message) async def on_message_sent(self, message: Any) -> None: if not self._started: From 8f2337fe9941519e5fb8a146cfcc9f9c81e6c3d6 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Mon, 13 Apr 2026 18:58:26 +0800 Subject: [PATCH 2/6] =?UTF-8?q?feat=EF=BC=9A=E6=96=B0=E5=A2=9E=E4=B8=80?= =?UTF-8?q?=E4=B8=AAreplyer=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/maisaka_generator_base.py | 30 ++++++++++++++++------ src/maisaka/builtin_tool/reply.py | 10 +++++++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/chat/replyer/maisaka_generator_base.py b/src/chat/replyer/maisaka_generator_base.py index ae3ab645..18e90fd4 100644 --- a/src/chat/replyer/maisaka_generator_base.py +++ b/src/chat/replyer/maisaka_generator_base.py @@ -1,9 +1,9 @@ -import time from dataclasses import dataclass, field from datetime import datetime from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple import random +import time from rich.console import Group, RenderableType from rich.panel import Panel @@ -110,11 +110,15 @@ class BaseMaisakaReplyGenerator: return "" def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str: - speaker_name, body = parse_speaker_content(message.processed_plain_text.strip()) - bot_nickname = global_config.bot.nickname.strip() or "Bot" - if speaker_name == bot_nickname: - return self._normalize_content(body.strip()) - return "" + # 只能根据结构化来源字段判断是否为 bot 自身写回的历史消息, + # 不能依赖昵称/群名片等可控文本,避免误判和提示注入。 + if message.source_kind != "guided_reply": + return "" + + plain_text = message.processed_plain_text.strip() + _, body = parse_speaker_content(plain_text) + normalized_body = body.strip() or plain_text + return self._normalize_content(normalized_body) if normalized_body else "" def _build_target_message_block(self, reply_message: Optional[SessionMessage]) -> str: if reply_message is None: @@ -210,6 +214,7 @@ class BaseMaisakaReplyGenerator: self, reply_message: Optional[SessionMessage], reply_reason: str, + reference_info: str = "", expression_habits: str = "", stream_id: Optional[str] = None, ) -> str: @@ -234,8 +239,13 @@ class BaseMaisakaReplyGenerator: sections.append(expression_habits.strip()) if target_message_block: sections.append(target_message_block) + reply_reference_lines: List[str] = [] if reply_reason.strip(): - sections.append(f"【回复信息参考】\n{reply_reason}") + reply_reference_lines.append(f"【最新推理】\n{reply_reason.strip()}") + if reference_info.strip(): + reply_reference_lines.append(f"【参考信息】\n{reference_info.strip()}") + if reply_reference_lines: + sections.append("【回复信息参考】\n" + "\n\n".join(reply_reference_lines)) if not sections: return system_prompt return f"{system_prompt}\n\n" + "\n\n".join(sections) @@ -308,6 +318,7 @@ class BaseMaisakaReplyGenerator: chat_history: List[LLMContextMessage], reply_message: Optional[SessionMessage], reply_reason: str, + reference_info: str = "", expression_habits: str = "", stream_id: Optional[str] = None, enable_visual_message: bool = False, @@ -316,6 +327,7 @@ class BaseMaisakaReplyGenerator: system_prompt = self._build_system_prompt( reply_message=reply_message, reply_reason=reply_reason, + reference_info=reference_info, expression_habits=expression_habits, stream_id=stream_id, ) @@ -377,6 +389,7 @@ class BaseMaisakaReplyGenerator: self, extra_info: str = "", reply_reason: str = "", + reference_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, chosen_actions: Optional[List[object]] = None, from_plugin: bool = True, @@ -461,6 +474,7 @@ class BaseMaisakaReplyGenerator: chat_history=filtered_history, reply_message=reply_message, reply_reason=reply_reason or "", + reference_info=reference_info or "", expression_habits=merged_expression_habits, stream_id=stream_id, ) @@ -486,6 +500,7 @@ class BaseMaisakaReplyGenerator: chat_history=filtered_history, reply_message=reply_message, reply_reason=reply_reason or "", + reference_info=reference_info or "", expression_habits=merged_expression_habits, stream_id=stream_id, enable_visual_message=self._resolve_enable_visual_message(model_info), @@ -504,7 +519,6 @@ class BaseMaisakaReplyGenerator: chat_id=preview_chat_id, request_kind="replyer", selection_reason=f"ID: {preview_chat_id}", - image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy", ), title="Reply Prompt", border_style="bright_yellow", diff --git a/src/maisaka/builtin_tool/reply.py b/src/maisaka/builtin_tool/reply.py index 00c392b9..debee914 100644 --- a/src/maisaka/builtin_tool/reply.py +++ b/src/maisaka/builtin_tool/reply.py @@ -36,7 +36,8 @@ def get_tool_spec() -> ToolSpec: detailed_description=( "参数说明:\n" "- msg_id:string,必填。要回复的目标用户消息编号。\n" - "- set_quote:boolean,可选。以引用回复的方式发送,默认 true。" + "- set_quote:boolean,可选。以引用回复的方式发送,默认 true。\n" + "- reference_info:string,可选。上文中有助于回复的所有参考信息,使用平文本格式。" ), parameters_schema={ "type": "object", @@ -50,6 +51,11 @@ def get_tool_spec() -> ToolSpec: "description": "以引用回复的方式发送这条回复,不用每句都引用。", "default": True, }, + "reference_info": { + "type": "string", + "description": "有助于回复的信息,之前搜集得到的事实性信息,记忆等,使用平文本格式。", + "default": True, + }, }, "required": ["msg_id"], }, @@ -75,6 +81,7 @@ async def handle_tool( """执行 reply 内置工具。""" latest_thought = context.reasoning if context is not None else invocation.reasoning + reference_info = str(invocation.arguments.get("reference_info") or "").strip() target_message_id = str(invocation.arguments.get("msg_id") or "").strip() set_quote = bool(invocation.arguments.get("set_quote", True)) @@ -117,6 +124,7 @@ async def handle_tool( try: success, reply_result = await replyer.generate_reply_with_context( reply_reason=latest_thought, + reference_info=reference_info, stream_id=tool_ctx.runtime.session_id, reply_message=target_message, chat_history=tool_ctx.runtime._chat_history, From 2471a2c4a4659a23df53e7a96655f5276511a453 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Mon, 13 Apr 2026 19:54:38 +0800 Subject: [PATCH 3/6] =?UTF-8?q?fix=EF=BC=9A=E5=B7=A5=E5=85=B7=E8=B0=83?= =?UTF-8?q?=E7=94=A8=E5=AD=98=E5=82=A8=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llm_models/model_client/openai_client.py | 12 ++++++++++-- src/services/database_service.py | 14 ++++++++------ src/services/memory_flow_service.py | 5 ----- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 2a3bf8a4..8a02e37c 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -6,6 +6,7 @@ import json import re from dataclasses import dataclass, field from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast +from uuid import uuid4 from json_repair import repair_json from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream @@ -119,6 +120,13 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple """OpenAI 非流式响应解析函数类型。""" +def _build_fallback_tool_call_id(prefix: str) -> str: + """为缺失原始调用 ID 的工具调用生成唯一兜底标识。""" + + normalized_prefix = str(prefix).strip() or "tool_call" + return f"{normalized_prefix}_{uuid4().hex}" + + def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: """将配置中的推理解析模式收敛为枚举值。 @@ -609,7 +617,7 @@ def _extract_xml_tool_calls( arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {} tool_calls.append( ToolCall( - call_id=f"xml_tool_call_{len(tool_calls) + 1}", + call_id=_build_fallback_tool_call_id("xml_tool_call"), func_name=function_name, args=arguments, ) @@ -855,7 +863,7 @@ class _OpenAIStreamAccumulator: if raw_arguments else None ) - call_id = state.call_id or f"tool_call_{index}" + call_id = state.call_id or _build_fallback_tool_call_id(f"tool_call_{index}") response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments)) response.raw_data = {"model": self.model_name} if self.model_name else None diff --git a/src/services/database_service.py b/src/services/database_service.py index 5e41f2c6..57215878 100644 --- a/src/services/database_service.py +++ b/src/services/database_service.py @@ -4,16 +4,18 @@ import json import time import traceback from datetime import datetime -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast -from sqlalchemy import delete, func, select -from sqlmodel import SQLModel +from sqlalchemy import delete, func +from sqlmodel import SQLModel, select -from src.chat.message_receive.chat_manager import BotChatSession from src.common.database.database import get_db_session from src.common.database.database_model import ToolRecord from src.common.logger import get_logger +if TYPE_CHECKING: + from src.chat.message_receive.chat_manager import BotChatSession + logger = get_logger("database_service") @@ -158,7 +160,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any] async def store_tool_info( - chat_stream: BotChatSession, + chat_stream: "BotChatSession", builtin_prompt: Optional[str] = None, display_prompt: str = "", tool_id: str = "", @@ -191,7 +193,7 @@ async def store_tool_info( async def store_action_info( - chat_stream: BotChatSession, + chat_stream: "BotChatSession", builtin_prompt: Optional[str] = None, display_prompt: str = "", thinking_id: str = "", diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py index 8b7d8aa4..969b8a23 100644 --- a/src/services/memory_flow_service.py +++ b/src/services/memory_flow_service.py @@ -227,11 +227,6 @@ class MemoryAutomationService: await self.fact_writeback.shutdown() self._started = False - async def on_incoming_message(self, message: Any) -> None: - del message - if not self._started: - await self.start() - async def on_message_sent(self, message: Any) -> None: if not self._started: await self.start() From 4729f5acdb135589375a5ee898d5d299c764e80c Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 14 Apr 2026 23:53:38 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E5=B0=8F=E6=94=B9=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_maisaka_tool_logging.py | 23 -- pytests/test_mute_plugin_sdk.py | 339 ------------------ pytests/utils_test/test_bot_identity_utils.py | 227 ------------ src/maisaka/chat_loop_service.py | 12 +- 4 files changed, 6 insertions(+), 595 deletions(-) delete mode 100644 pytests/test_maisaka_tool_logging.py delete mode 100644 pytests/test_mute_plugin_sdk.py delete mode 100644 pytests/utils_test/test_bot_identity_utils.py diff --git a/pytests/test_maisaka_tool_logging.py b/pytests/test_maisaka_tool_logging.py deleted file mode 100644 index 0216eb83..00000000 --- a/pytests/test_maisaka_tool_logging.py +++ /dev/null @@ -1,23 +0,0 @@ -from src.maisaka.chat_loop_service import MaisakaChatLoopService - - -def test_build_tool_names_log_text_supports_openai_function_schema() -> None: - tool_definitions = [ - { - "type": "function", - "function": { - "name": "mute_user", - "description": "禁言指定用户", - "parameters": { - "type": "object", - "properties": {}, - }, - }, - }, - { - "name": "reply", - "description": "发送回复", - }, - ] - - assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply" diff --git a/pytests/test_mute_plugin_sdk.py b/pytests/test_mute_plugin_sdk.py deleted file mode 100644 index c811cc51..00000000 --- a/pytests/test_mute_plugin_sdk.py +++ /dev/null @@ -1,339 +0,0 @@ -"""MutePlugin SDK 回归测试。""" - -from __future__ import annotations - -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import pytest - -from maibot_sdk.context import PluginContext -from maibot_sdk.plugin import MaiBotPlugin - -from plugins.MutePlugin.plugin import create_plugin -from src.core.tooling import ToolExecutionContext, ToolInvocation -from src.plugin_runtime.component_query import ComponentQueryService -from src.plugin_runtime.runner.manifest_validator import ManifestValidator - - -def _build_plugin() -> MaiBotPlugin: - """构造已注入默认配置的插件实例。""" - - plugin = create_plugin() - plugin.set_plugin_config(plugin.get_default_config()) - return plugin - - -def test_mute_plugin_manifest_is_valid_v2() -> None: - """MutePlugin 的 manifest 应符合当前运行时要求。""" - - validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0") - manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin")) - - assert manifest is not None - assert manifest.id == "sengokucola.mute-plugin" - assert manifest.manifest_version == 2 - - -def test_create_plugin_returns_sdk_plugin() -> None: - """插件入口应返回 SDK 插件实例。""" - - plugin = create_plugin() - - assert isinstance(plugin, MaiBotPlugin) - - -@pytest.mark.asyncio -async def test_mute_command_calls_napcat_group_ban_api() -> None: - """手动禁言命令应通过 NapCat Adapter 新 API 执行。""" - - plugin = _build_plugin() - plugin.set_plugin_config( - { - **plugin.get_default_config(), - "components": { - "enable_smart_mute": True, - "enable_mute_command": True, - }, - } - ) - - capability_calls: List[Dict[str, Any]] = [] - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]: - assert method == "cap.call" - assert payload is not None - capability_calls.append(payload) - - capability = payload["capability"] - if capability == "person.get_id_by_name": - return {"success": True, "person_id": "person-1"} - if capability == "person.get_value": - return {"success": True, "value": "123456"} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info": - return {"success": True, "result": {"role": "member"}} - if capability == "api.call": - return {"success": True, "result": {"status": "ok", "retcode": 0}} - if capability == "send.text": - return {"success": True} - raise AssertionError(f"unexpected capability: {capability}") - - plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call)) - - success, message, intercept = await plugin.handle_mute_command( - stream_id="group-10001", - group_id="10001", - user_id="42", - matched_groups={ - "target": "张三", - "duration": "120", - "reason": "刷屏", - }, - ) - - assert success is True - assert message == "成功禁言 张三" - assert intercept is True - - api_call = next( - call - for call in capability_calls - if call["capability"] == "api.call" - and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban" - ) - assert api_call["args"]["version"] == "1" - assert api_call["args"]["args"] == { - "group_id": "10001", - "user_id": "123456", - "duration": 120, - } - - -@pytest.mark.asyncio -async def test_mute_tool_requires_target_person_name() -> None: - """禁言工具在缺少目标时应直接失败并提示。""" - - plugin = _build_plugin() - capability_calls: List[Dict[str, Any]] = [] - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]: - assert method == "cap.call" - assert payload is not None - capability_calls.append(payload) - return {"success": True} - - plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_mute_tool( - stream_id="group-10001", - group_id="10001", - target="", - duration="60", - reason="测试", - ) - - assert success is False - assert message == "禁言目标不能为空" - assert capability_calls[-1]["capability"] == "send.text" - assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦" - - -@pytest.mark.asyncio -async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None: - """禁言工具应能兼容解包多层 capability 返回结果。""" - - plugin = _build_plugin() - capability_calls: List[Dict[str, Any]] = [] - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]: - assert method == "cap.call" - assert payload is not None - capability_calls.append(payload) - - capability = payload["capability"] - if capability == "person.get_id_by_name": - return {"success": True, "result": {"success": True, "person_id": "person-1"}} - if capability == "person.get_value": - return {"success": True, "result": {"success": True, "value": "123456"}} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info": - return {"success": True, "result": {"role": "member"}} - if capability == "api.call": - return {"success": True, "result": {"status": "ok"}} - if capability == "send.text": - return {"success": True} - raise AssertionError(f"unexpected capability: {capability}") - - plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_mute_tool( - stream_id="group-10001", - group_id="10001", - target="张三", - duration=60, - reason="测试", - ) - - assert success is True - assert message == "成功禁言 张三" - - api_call = next( - call - for call in capability_calls - if call["capability"] == "api.call" - and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban" - ) - assert api_call["args"]["args"]["user_id"] == "123456" - - -@pytest.mark.asyncio -async def test_mute_tool_rejects_owner_before_group_ban_call() -> None: - """禁言工具应在检测到群主时提前返回明确提示。""" - - plugin = _build_plugin() - capability_calls: List[Dict[str, Any]] = [] - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]: - assert method == "cap.call" - assert payload is not None - capability_calls.append(payload) - - capability = payload["capability"] - if capability == "person.get_id_by_name": - return {"success": True, "person_id": "person-1"} - if capability == "person.get_value": - return {"success": True, "value": "123456"} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info": - return {"success": True, "result": {"role": "owner"}} - if capability == "send.text": - return {"success": True} - raise AssertionError(f"unexpected capability: {capability}") - - plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_mute_tool( - stream_id="group-10001", - group_id="10001", - target="张三", - duration=60, - reason="测试", - ) - - assert success is False - assert message == "张三 是群主,不能被禁言" - assert not any( - call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban" - for call in capability_calls - ) - - -@pytest.mark.asyncio -async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None: - """NapCat 返回 cannot ban owner 时应转成明确中文提示。""" - - plugin = _build_plugin() - capability_calls: List[Dict[str, Any]] = [] - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]: - assert method == "cap.call" - assert payload is not None - capability_calls.append(payload) - - capability = payload["capability"] - if capability == "person.get_id_by_name": - return {"success": True, "person_id": "person-1"} - if capability == "person.get_value": - return {"success": True, "value": "123456"} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info": - return {"success": True, "result": {"role": "member"}} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban": - return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"} - if capability == "send.text": - return {"success": True} - raise AssertionError(f"unexpected capability: {capability}") - - plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_mute_tool( - stream_id="group-10001", - group_id="10001", - target="张三", - duration=60, - reason="测试", - ) - - assert success is False - assert message == "张三 是群主,不能被禁言" - - -@pytest.mark.asyncio -async def test_mute_tool_accepts_nested_ok_api_result() -> None: - """嵌套的 success/result/status=ok 返回值也应判定为成功。""" - - plugin = _build_plugin() - - async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]: - assert method == "cap.call" - assert payload is not None - - capability = payload["capability"] - if capability == "person.get_id_by_name": - return {"success": True, "person_id": "person-1"} - if capability == "person.get_value": - return {"success": True, "value": "123456"} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info": - return {"success": True, "result": {"role": "member"}} - if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban": - return { - "success": True, - "result": { - "status": "ok", - "retcode": 0, - "data": None, - "message": "", - "wording": "", - }, - } - if capability == "send.text": - return {"success": True} - raise AssertionError(f"unexpected capability: {capability}") - - plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call)) - - success, message = await plugin.handle_mute_tool( - stream_id="group-10001", - group_id="10001", - target="张三", - duration=60, - reason="测试", - ) - - assert success is True - assert message == "成功禁言 张三" - - -def test_tool_invocation_payload_injects_group_and_user_context() -> None: - """插件工具执行时应自动补齐群聊上下文字段。""" - - entry = SimpleNamespace(invoke_method="plugin.invoke_tool") - anchor_message = SimpleNamespace( - message_info=SimpleNamespace( - group_info=SimpleNamespace(group_id="10001"), - user_info=SimpleNamespace(user_id="20002"), - ) - ) - invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1") - context = ToolExecutionContext( - session_id="session-1", - stream_id="session-1", - reasoning="test", - metadata={"anchor_message": anchor_message}, - ) - - payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context) - - assert payload["target"] == "张三" - assert payload["stream_id"] == "session-1" - assert payload["chat_id"] == "session-1" - assert payload["group_id"] == "10001" - assert payload["user_id"] == "20002" diff --git a/pytests/utils_test/test_bot_identity_utils.py b/pytests/utils_test/test_bot_identity_utils.py deleted file mode 100644 index c345174b..00000000 --- a/pytests/utils_test/test_bot_identity_utils.py +++ /dev/null @@ -1,227 +0,0 @@ -from pathlib import Path -from types import ModuleType, SimpleNamespace - -import importlib.util -import sys - - -class DummyLogger: - def __init__(self) -> None: - self.warning_messages: list[str] = [] - - def debug(self, _msg: str) -> None: - return - - def info(self, _msg: str) -> None: - return - - def warning(self, msg: str) -> None: - self.warning_messages.append(msg) - - def error(self, _msg: str) -> None: - return - - -def load_utils_module(monkeypatch, qq_account=123456, platforms=None): - logger = DummyLogger() - configured_platforms = platforms or [] - - def _stub_module(name: str) -> ModuleType: - module = ModuleType(name) - monkeypatch.setitem(sys.modules, name, module) - return module - - for package_name in [ - "src", - "src.chat", - "src.chat.message_receive", - "src.chat.utils", - "src.common", - "src.config", - "src.llm_models", - "src.person_info", - ]: - if package_name not in sys.modules: - package_module = ModuleType(package_name) - package_module.__path__ = [] - monkeypatch.setitem(sys.modules, package_name, package_module) - - jieba_module = ModuleType("jieba") - jieba_module.cut = lambda text: list(text) - monkeypatch.setitem(sys.modules, "jieba", jieba_module) - - logger_module = _stub_module("src.common.logger") - logger_module.get_logger = lambda _name: logger - - config_module = _stub_module("src.config.config") - config_module.global_config = SimpleNamespace( - bot=SimpleNamespace( - qq_account=qq_account, - platforms=configured_platforms, - nickname="MaiBot", - alias_names=[], - ), - chat=SimpleNamespace( - at_bot_inevitable_reply=1, - mentioned_bot_reply=1, - ), - ) - config_module.model_config = SimpleNamespace() - - message_module = _stub_module("src.chat.message_receive.message") - - class SessionMessage: - pass - - message_module.SessionMessage = SessionMessage - - chat_manager_module = _stub_module("src.chat.message_receive.chat_manager") - chat_manager_module.chat_manager = SimpleNamespace(get_session_by_session_id=lambda _chat_id: None) - - llm_module = _stub_module("src.llm_models.utils_model") - - class LLMRequest: - def __init__(self, *args, **kwargs) -> None: - del args, kwargs - - llm_module.LLMRequest = LLMRequest - - person_module = _stub_module("src.person_info.person_info") - - class Person: - pass - - person_module.Person = Person - - typo_generator_module = _stub_module("src.chat.utils.typo_generator") - - class ChineseTypoGenerator: - def __init__(self, *args, **kwargs) -> None: - del args, kwargs - - def create_typo_sentence(self, sentence: str): - return sentence, "" - - typo_generator_module.ChineseTypoGenerator = ChineseTypoGenerator - - file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "utils" / "utils.py" - spec = importlib.util.spec_from_file_location("src.chat.utils.utils", file_path) - utils_module = importlib.util.module_from_spec(spec) - utils_module.__package__ = "src.chat.utils" - monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module) - assert spec.loader is not None - spec.loader.exec_module(utils_module) - return utils_module, logger - - -def test_platform_specific_bot_accounts(monkeypatch): - utils_module, _logger = load_utils_module( - monkeypatch, - qq_account=123456, - platforms=[" TG : tg_bot ", "discord: disc_bot"], - ) - - assert utils_module.get_bot_account("qq") == "123456" - assert utils_module.get_bot_account("webui") == "123456" - assert utils_module.get_bot_account("telegram") == "tg_bot" - assert utils_module.get_bot_account("tg") == "tg_bot" - assert utils_module.get_bot_account("discord") == "disc_bot" - - assert utils_module.is_bot_self("qq", "123456") - assert utils_module.is_bot_self("webui", "123456") - assert utils_module.is_bot_self("telegram", "tg_bot") - assert utils_module.is_bot_self(" TG ", "tg_bot") - - -def test_get_all_bot_accounts_includes_runtime_aliases(monkeypatch): - utils_module, _logger = load_utils_module( - monkeypatch, - qq_account=123456, - platforms=["TG:tg_bot", "discord:disc_bot"], - ) - - assert utils_module.get_all_bot_accounts() == { - "qq": "123456", - "webui": "123456", - "telegram": "tg_bot", - "tg": "tg_bot", - "discord": "disc_bot", - } - - -def test_get_all_bot_accounts_keeps_canonical_qq_identity(monkeypatch): - utils_module, _logger = load_utils_module( - monkeypatch, - qq_account=123456, - platforms=["qq:999999", "webui:888888", "TG:tg_bot"], - ) - - assert utils_module.get_all_bot_accounts()["qq"] == "123456" - assert utils_module.get_all_bot_accounts()["webui"] == "123456" - - -def test_unknown_platform_no_longer_falls_back_to_qq(monkeypatch): - utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[]) - - assert utils_module.is_bot_self("unknown_platform", "123456") is False - assert logger.warning_messages - assert "unknown_platform" in logger.warning_messages[-1] - - -def test_unknown_platform_warns_only_once(monkeypatch): - utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[]) - - assert utils_module.is_bot_self("unknown_platform", "first") is False - assert utils_module.is_bot_self(" unknown_platform ", "second") is False - assert len(logger.warning_messages) == 1 - - -def test_unconfigured_qq_account_disables_qq_and_webui_identity(monkeypatch): - utils_module, _logger = load_utils_module(monkeypatch, qq_account=0, platforms=["telegram:tg_bot"]) - - assert utils_module.get_bot_account("qq") == "" - assert utils_module.get_bot_account("webui") == "" - assert utils_module.is_bot_self("qq", "0") is False - assert utils_module.is_bot_self("webui", "0") is False - - -def test_is_mentioned_bot_in_message_uses_platform_account(monkeypatch): - utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=["TG:tg_bot"]) - - message = SimpleNamespace( - processed_plain_text="@tg_bot 你好", - platform="telegram", - is_mentioned=False, - message_segment=None, - message_info=SimpleNamespace( - additional_config={}, - user_info=SimpleNamespace(user_id="user_1"), - ), - ) - - is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message) - - assert is_mentioned is True - assert is_at is True - assert reply_probability == 1.0 - - -def test_is_mentioned_bot_in_message_normalizes_qq_platform(monkeypatch): - utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[]) - - message = SimpleNamespace( - processed_plain_text="@ 你好", - platform=" QQ ", - is_mentioned=False, - message_segment=None, - message_info=SimpleNamespace( - additional_config={}, - user_info=SimpleNamespace(user_id="user_1"), - ), - ) - - is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message) - - assert is_mentioned is True - assert is_at is True - assert reply_probability == 1.0 diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index e45fa2d6..f12fb049 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -627,18 +627,18 @@ class MaisakaChatLoopService: break if not selected_indices: - return [], f"没有选择到上下文消息,实际发送 {effective_context_size} 条 user/assistant 消息" + return [], "实际发送 0 条消息(tool 0 条,普通消息 0 条)" selected_indices.reverse() selected_history = [filtered_history[index] for index in selected_indices] - selected_history, hidden_assistant_count = MaisakaChatLoopService._hide_early_assistant_messages(selected_history) + selected_history, _ = MaisakaChatLoopService._hide_early_assistant_messages(selected_history) selected_history, _ = drop_orphan_tool_results(selected_history) + tool_message_count = sum(1 for message in selected_history if isinstance(message, ToolResultMessage)) + normal_message_count = len(selected_history) - tool_message_count selection_reason = ( - f"上下文裁剪:最近 {effective_context_size} 条 user/assistant 消息," - f"实际发送 {len(selected_history)} 条" + f"实际发送 {len(selected_history)} 条消息" + f"|消息 {normal_message_count} 条|tool {tool_message_count} 条" ) - if hidden_assistant_count > 0: - selection_reason += f",已隐藏最早 {hidden_assistant_count} 条 assistant 消息" return ( selected_history, selection_reason, From 6297c500119bd14719a6651e77d7ac0edc6777a0 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 15 Apr 2026 11:46:22 +0800 Subject: [PATCH 5/6] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8D=E9=9D=9E?= =?UTF-8?q?=E5=A4=9A=E6=A8=A1=E6=80=81=E6=A8=A1=E5=9E=8B=E6=84=8F=E5=A4=96?= =?UTF-8?q?=E4=BC=A0=E5=85=A5=E5=9B=BE=E7=89=87=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/maisaka_generator_base.py | 36 ++---------- src/maisaka/chat_loop_service.py | 38 ++++++++++++- src/maisaka/context_messages.py | 64 ++++++++++++++++++---- src/maisaka/reasoning_engine.py | 42 +------------- src/maisaka/visual_mode_utils.py | 43 +++++++++++++++ 5 files changed, 139 insertions(+), 84 deletions(-) create mode 100644 src/maisaka/visual_mode_utils.py diff --git a/src/chat/replyer/maisaka_generator_base.py b/src/chat/replyer/maisaka_generator_base.py index 18e90fd4..812b82d5 100644 --- a/src/chat/replyer/maisaka_generator_base.py +++ b/src/chat/replyer/maisaka_generator_base.py @@ -13,7 +13,6 @@ from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.message import SessionMessage from src.chat.utils.utils import get_chat_type_and_target_info from src.cli.console import console -from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.common.data_models.reply_generation_data_models import ( GenerationMetrics, LLMCompletionResult, @@ -32,9 +31,10 @@ from src.maisaka.context_messages import ( ReferenceMessage, SessionBackedMessage, ToolResultMessage, + build_llm_message_from_context, ) from src.maisaka.display.prompt_cli_renderer import PromptCLIVisualizer -from src.maisaka.message_adapter import clone_message_sequence, parse_speaker_content +from src.maisaka.message_adapter import parse_speaker_content from src.plugin_runtime.hook_payloads import serialize_prompt_messages from .maisaka_expression_selector import maisaka_expression_selector @@ -253,28 +253,6 @@ class BaseMaisakaReplyGenerator: def _build_reply_instruction(self) -> str: return "请自然地回复。不要输出多余说明、括号、@ 或额外标记,只输出实际要发送的内容。" - def _build_visual_user_message( - self, - message: SessionBackedMessage, - enable_visual_message: bool, - ) -> Optional[Message]: - if not enable_visual_message: - return None - - raw_message = clone_message_sequence(message.raw_message) - if not raw_message.components: - raw_message = MessageSequence([TextComponent(message.processed_plain_text)]) - - visual_message = SessionBackedMessage( - raw_message=raw_message, - visible_text=message.processed_plain_text, - timestamp=message.timestamp, - message_id=message.message_id, - original_message=message.original_message, - source_kind=message.source_kind, - ) - return visual_message.to_llm_message() - def _build_history_messages( self, chat_history: List[LLMContextMessage], @@ -294,12 +272,10 @@ class BaseMaisakaReplyGenerator: ) continue - visual_message = self._build_visual_user_message(message, enable_visual_message) - if visual_message is not None: - messages.append(visual_message) - continue - - llm_message = message.to_llm_message() + llm_message = build_llm_message_from_context( + message, + enable_visual_message=enable_visual_message, + ) if llm_message is not None: messages.append(llm_message) continue diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index f12fb049..f4a8a7db 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -30,9 +30,15 @@ from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistr from src.services.llm_service import LLMServiceClient from .builtin_tool import get_builtin_tools -from .context_messages import AssistantMessage, LLMContextMessage, ToolResultMessage +from .context_messages import ( + AssistantMessage, + LLMContextMessage, + ToolResultMessage, + build_llm_message_from_context, +) from .history_utils import drop_orphan_tool_results from .display.prompt_cli_renderer import PromptCLIVisualizer +from .visual_mode_utils import resolve_enable_visual_planner TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"} @@ -395,6 +401,7 @@ class MaisakaChatLoopService: self, selected_history: List[LLMContextMessage], *, + enable_visual_message: bool, injected_user_messages: Sequence[str] | None = None, system_prompt: Optional[str] = None, ) -> List[Message]: @@ -413,7 +420,10 @@ class MaisakaChatLoopService: messages.append(system_msg.build()) for msg in selected_history: - llm_message = msg.to_llm_message() + llm_message = build_llm_message_from_context( + msg, + enable_visual_message=enable_visual_message, + ) if llm_message is not None: messages.append(llm_message) @@ -475,12 +485,15 @@ class MaisakaChatLoopService: if not self._prompts_loaded: await self.ensure_chat_prompt_loaded() + enable_visual_message = self._resolve_enable_visual_message(request_kind) selected_history, selection_reason = self.select_llm_context_messages( chat_history, request_kind=request_kind, + enable_visual_message=enable_visual_message, ) built_messages = self._build_request_messages( selected_history, + enable_visual_message=enable_visual_message, injected_user_messages=injected_user_messages, ) @@ -602,6 +615,7 @@ class MaisakaChatLoopService: def select_llm_context_messages( chat_history: List[LLMContextMessage], *, + enable_visual_message: Optional[bool] = None, request_kind: str = "planner", max_context_size: Optional[int] = None, ) -> tuple[List[LLMContextMessage], str]: @@ -615,9 +629,21 @@ class MaisakaChatLoopService: selected_indices: List[int] = [] counted_message_count = 0 + active_enable_visual_message = ( + enable_visual_message + if enable_visual_message is not None + else MaisakaChatLoopService._resolve_enable_visual_message(request_kind) + ) + for index in range(len(filtered_history) - 1, -1, -1): message = filtered_history[index] - if message.to_llm_message() is None: + if ( + build_llm_message_from_context( + message, + enable_visual_message=active_enable_visual_message, + ) + is None + ): continue selected_indices.append(index) @@ -683,6 +709,12 @@ class MaisakaChatLoopService: return filtered_history + @staticmethod + def _resolve_enable_visual_message(request_kind: str) -> bool: + if request_kind in {"planner", "timing_gate"}: + return resolve_enable_visual_planner() + return True + @staticmethod def _hide_early_assistant_messages( selected_history: List[LLMContextMessage], diff --git a/src/maisaka/context_messages.py b/src/maisaka/context_messages.py index c96e9993..cefa7dc4 100644 --- a/src/maisaka/context_messages.py +++ b/src/maisaka/context_messages.py @@ -40,10 +40,15 @@ def _guess_image_format(image_bytes: bytes) -> Optional[str]: return None -def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent) -> bool: +def _append_emoji_component( + builder: MessageBuilder, + component: EmojiComponent, + *, + enable_visual_message: bool, +) -> bool: """将表情组件追加到 LLM 消息构建器。""" image_format = _guess_image_format(component.binary_data) - if image_format and component.binary_data: + if enable_visual_message and image_format and component.binary_data: builder.add_text_content("[消息类型]表情包") builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8")) return True @@ -56,10 +61,15 @@ def _append_emoji_component(builder: MessageBuilder, component: EmojiComponent) return True -def _append_image_component(builder: MessageBuilder, component: ImageComponent) -> bool: +def _append_image_component( + builder: MessageBuilder, + component: ImageComponent, + *, + enable_visual_message: bool, +) -> bool: """将图片组件追加到 LLM 消息构建器。""" image_format = _guess_image_format(component.binary_data) - if image_format and component.binary_data: + if enable_visual_message and image_format and component.binary_data: builder.add_text_content("[消息类型]图片") builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8")) return True @@ -216,6 +226,7 @@ def _build_message_from_sequence( message_sequence: MessageSequence, fallback_text: str, *, + enable_visual_message: bool = True, tool_call_id: Optional[str] = None, tool_name: Optional[str] = None, tool_calls: Optional[list[ToolCall]] = None, @@ -238,11 +249,25 @@ def _build_message_from_sequence( continue if isinstance(component, EmojiComponent): - has_content = _append_emoji_component(builder, component) or has_content + has_content = ( + _append_emoji_component( + builder, + component, + enable_visual_message=enable_visual_message, + ) + or has_content + ) continue if isinstance(component, ImageComponent): - has_content = _append_image_component(builder, component) or has_content + has_content = ( + _append_image_component( + builder, + component, + enable_visual_message=enable_visual_message, + ) + or has_content + ) continue if isinstance(component, AtComponent): @@ -297,7 +322,7 @@ class LLMContextMessage(ABC): return self.__class__.__name__ @abstractmethod - def to_llm_message(self) -> Optional[Message]: + def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]: """转换为统一 LLM 消息。""" def consume_once(self) -> bool: @@ -328,11 +353,12 @@ class SessionBackedMessage(LLMContextMessage): def source(self) -> str: return self.source_kind - def to_llm_message(self) -> Optional[Message]: + def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]: return _build_message_from_sequence( RoleType.User, self.raw_message, self.processed_plain_text, + enable_visual_message=enable_visual_message, ) @classmethod @@ -366,7 +392,8 @@ class ComplexSessionMessage(SessionBackedMessage): def source(self) -> str: return f"{self.source_kind}:{self.complex_message_type}" - def to_llm_message(self) -> Optional[Message]: + def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]: + del enable_visual_message message_sequence = MessageSequence([TextComponent(self.prompt_text)]) return _build_message_from_sequence( RoleType.User, @@ -426,7 +453,8 @@ class ReferenceMessage(LLMContextMessage): def source(self) -> str: return self.reference_type.value - def to_llm_message(self) -> Optional[Message]: + def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]: + del enable_visual_message message_sequence = MessageSequence([TextComponent(self.processed_plain_text)]) return _build_message_from_sequence(RoleType.User, message_sequence, self.processed_plain_text) @@ -463,7 +491,8 @@ class AssistantMessage(LLMContextMessage): def source(self) -> str: return self.source_kind - def to_llm_message(self) -> Optional[Message]: + def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]: + del enable_visual_message message_sequence = MessageSequence([]) if self.content: message_sequence.text(self.content) @@ -501,7 +530,8 @@ class ToolResultMessage(LLMContextMessage): def source(self) -> str: return self.tool_name or "tool" - def to_llm_message(self) -> Optional[Message]: + def to_llm_message(self, enable_visual_message: bool = True) -> Optional[Message]: + del enable_visual_message message_sequence = MessageSequence([TextComponent(self.content)]) return _build_message_from_sequence( RoleType.Tool, @@ -510,3 +540,13 @@ class ToolResultMessage(LLMContextMessage): tool_call_id=self.tool_call_id, tool_name=self.tool_name, ) + + +def build_llm_message_from_context( + context_message: LLMContextMessage, + *, + enable_visual_message: bool = True, +) -> Optional[Message]: + """将 Maisaka 内部上下文消息转换为发给 LLM 的统一消息。""" + + return context_message.to_llm_message(enable_visual_message=enable_visual_message) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index b30008ab..0632a77f 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt -from src.config.config import config_manager, global_config +from src.config.config import global_config from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.llm_models.exceptions import ReqAbortException from src.llm_models.payload_content.tool_option import ToolCall @@ -43,6 +43,7 @@ from .monitor_events import ( emit_timing_gate_result, ) from .planner_message_utils import build_planner_user_prefix_from_session_message +from .visual_mode_utils import resolve_enable_visual_planner if TYPE_CHECKING: from .runtime import MaisakaHeartFlowChatting @@ -738,47 +739,10 @@ class MaisakaReasoningEngine: planner_prefix: str, ) -> MessageSequence: message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix) - if self._resolve_enable_visual_planner(): + if resolve_enable_visual_planner(): await self._hydrate_visual_components(message_sequence.components) return message_sequence - @staticmethod - def _resolve_enable_visual_planner() -> bool: - planner_mode = global_config.visual.planner_mode - planner_task_config = config_manager.get_model_config().model_task_config.planner - models_by_name = {model.name: model for model in config_manager.get_model_config().models} - - if planner_mode == "text": - return False - - planner_models: list[str] = list(planner_task_config.model_list) - missing_models = [model_name for model_name in planner_models if model_name not in models_by_name] - non_visual_models = [ - model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual - ] - - if planner_mode == "multimodal": - if missing_models: - raise ValueError( - "planner_mode=multimodal,但 planner 任务存在未定义的模型:" - f"{', '.join(missing_models)}" - ) - if non_visual_models: - raise ValueError( - "planner_mode=multimodal,但 planner 任务存在未开启 visual 的模型:" - f"{', '.join(non_visual_models)}" - ) - return True - - if missing_models: - logger.warning( - "planner_mode=auto 时发现 planner 任务存在未定义模型:" - f"{', '.join(missing_models)},将退化为纯文本 planner" - ) - return False - - return bool(planner_models) and not non_visual_models - async def _hydrate_visual_components(self, planner_components: list[object]) -> None: """在 Maisaka 真正需要图片或表情时,按需回填二进制数据。""" load_tasks: list[asyncio.Task[None]] = [] diff --git a/src/maisaka/visual_mode_utils.py b/src/maisaka/visual_mode_utils.py new file mode 100644 index 00000000..d9c15a6e --- /dev/null +++ b/src/maisaka/visual_mode_utils.py @@ -0,0 +1,43 @@ +from src.common.logger import get_logger +from src.config.config import config_manager, global_config + +logger = get_logger("maisaka_visual_mode") + + +def resolve_enable_visual_planner() -> bool: + """根据 planner 配置解析当前是否应启用视觉消息。""" + + planner_mode = global_config.visual.planner_mode + planner_task_config = config_manager.get_model_config().model_task_config.planner + models_by_name = {model.name: model for model in config_manager.get_model_config().models} + + if planner_mode == "text": + return False + + planner_models: list[str] = list(planner_task_config.model_list) + missing_models = [model_name for model_name in planner_models if model_name not in models_by_name] + non_visual_models = [ + model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual + ] + + if planner_mode == "multimodal": + if missing_models: + raise ValueError( + "planner_mode=multimodal,但 planner 任务存在未定义的模型:" + f"{', '.join(missing_models)}" + ) + if non_visual_models: + raise ValueError( + "planner_mode=multimodal,但 planner 任务存在未开启 visual 的模型:" + f"{', '.join(non_visual_models)}" + ) + return True + + if missing_models: + logger.warning( + "planner_mode=auto 时发现 planner 任务存在未定义模型:" + f"{', '.join(missing_models)},将退化为纯文本 planner" + ) + return False + + return bool(planner_models) and not non_visual_models From 3391723cf4c0c3f338b6ea06e2730bb54553788b Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 15 Apr 2026 19:27:37 +0800 Subject: [PATCH 6/6] =?UTF-8?q?feat=EF=BC=9A=E5=8F=91=E9=80=81=E7=9A=84?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E6=B3=A8=E5=85=A5=E5=9B=9Emaisaka=E5=8E=86?= =?UTF-8?q?=E5=8F=B2=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/emoji_system/maisaka_tool.py | 2 + src/maisaka/builtin_tool/context.py | 36 +------------ src/maisaka/builtin_tool/reply.py | 7 +-- src/maisaka/builtin_tool/send_emoji.py | 4 +- src/maisaka/runtime.py | 37 +++++++++++++ src/plugin_runtime/capabilities/core.py | 20 +++++++ src/services/send_service.py | 72 ++++++++++++++++++++++++- 7 files changed, 134 insertions(+), 44 deletions(-) diff --git a/src/emoji_system/maisaka_tool.py b/src/emoji_system/maisaka_tool.py index bbdc072d..d9a9cca7 100644 --- a/src/emoji_system/maisaka_tool.py +++ b/src/emoji_system/maisaka_tool.py @@ -335,6 +335,8 @@ async def send_emoji_for_maisaka( storage_message=True, set_reply=False, reply_message=None, + sync_to_maisaka_history=True, + maisaka_source_kind="guided_reply", ) sent = sent_message is not None except Exception as exc: diff --git a/src/maisaka/builtin_tool/context.py b/src/maisaka/builtin_tool/context.py index e221ba11..4cf37986 100644 --- a/src/maisaka/builtin_tool/context.py +++ b/src/maisaka/builtin_tool/context.py @@ -4,7 +4,7 @@ from __future__ import annotations from base64 import b64decode from datetime import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from src.chat.utils.utils import process_llm_response from src.common.data_models.message_component_data_model import EmojiComponent, MessageSequence, TextComponent @@ -12,13 +12,10 @@ from src.config.config import global_config from src.core.tooling import ToolExecutionResult from ..context_messages import SessionBackedMessage -from ..history_utils import build_prefixed_message_sequence, build_session_message_visible_text from ..message_adapter import format_speaker_content from ..planner_message_utils import build_planner_prefix, build_session_backed_text_message if TYPE_CHECKING: - from src.chat.message_receive.message import SessionMessage - from ..reasoning_engine import MaisakaReasoningEngine from ..runtime import MaisakaHeartFlowChatting @@ -139,37 +136,6 @@ class BuiltinToolRuntimeContext: return self.engine._get_runtime_manager() - @staticmethod - def _build_visible_text_from_sent_message(message: "SessionMessage") -> str: - """将已发送消息转换为 Maisaka 可见文本。""" - - return build_session_message_visible_text(message) - - def append_sent_message_to_chat_history( - self, - message: "SessionMessage", - *, - source_kind: str = "guided_reply", - ) -> None: - """将真实已发送消息同步到 Maisaka 历史。""" - - user_info = message.message_info.user_info - speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id - planner_prefix = build_planner_prefix( - timestamp=message.timestamp, - user_name=speaker_name, - group_card=user_info.user_cardname or "", - message_id=message.message_id, - include_message_id=not message.is_notify and bool(message.message_id), - ) - history_message = SessionBackedMessage.from_session_message( - message, - raw_message=build_prefixed_message_sequence(message.raw_message, planner_prefix), - visible_text=self._build_visible_text_from_sent_message(message), - source_kind=source_kind, - ) - self.runtime._chat_history.append(history_message) - def append_guided_reply_to_chat_history(self, reply_text: str) -> None: """将引导回复写回 Maisaka 历史。""" diff --git a/src/maisaka/builtin_tool/reply.py b/src/maisaka/builtin_tool/reply.py index debee914..fa182401 100644 --- a/src/maisaka/builtin_tool/reply.py +++ b/src/maisaka/builtin_tool/reply.py @@ -160,7 +160,6 @@ async def handle_tool( combined_reply_text = "".join(reply_segments) try: sent = False - sent_messages = [] if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME: for segment in reply_segments: render_cli_message(segment) @@ -174,11 +173,12 @@ async def handle_tool( reply_message=target_message if set_quote and index == 0 else None, selected_expressions=reply_result.selected_expression_ids or None, typing=index > 0, + sync_to_maisaka_history=True, + maisaka_source_kind="guided_reply", ) sent = sent_message is not None if not sent: break - sent_messages.append(sent_message) except Exception: logger.exception( f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}" @@ -206,9 +206,6 @@ async def handle_tool( if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME: tool_ctx.append_guided_reply_to_chat_history(combined_reply_text) - else: - for sent_message in sent_messages: - tool_ctx.append_sent_message_to_chat_history(sent_message) tool_ctx.runtime._record_reply_sent() return tool_ctx.build_success_result( invocation.tool_name, diff --git a/src/maisaka/builtin_tool/send_emoji.py b/src/maisaka/builtin_tool/send_emoji.py index b02b75f7..b1853452 100644 --- a/src/maisaka/builtin_tool/send_emoji.py +++ b/src/maisaka/builtin_tool/send_emoji.py @@ -446,9 +446,7 @@ async def handle_tool( f"描述={send_result.description!r} 情绪标签={send_result.emotions} " f"命中情绪={send_result.matched_emotion!r}" ) - if send_result.sent_message is not None: - tool_ctx.append_sent_message_to_chat_history(send_result.sent_message) - else: + if send_result.sent_message is None: tool_ctx.append_sent_emoji_to_chat_history( emoji_base64=send_result.emoji_base64, success_message=_EMOJI_SUCCESS_MESSAGE, diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 5de81cf2..37c79180 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -183,6 +183,43 @@ class MaisakaHeartFlowChatting: self._talk_frequency_adjust = max(0.01, float(frequency)) self._schedule_message_turn() + def append_sent_message_to_chat_history( + self, + message: SessionMessage, + *, + source_kind: str = "guided_reply", + ) -> bool: + """将一条已发送成功的消息同步到 Maisaka 内部历史。""" + + try: + from .context_messages import SessionBackedMessage + from .history_utils import build_prefixed_message_sequence, build_session_message_visible_text + from .planner_message_utils import build_planner_prefix + + user_info = message.message_info.user_info + speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id + planner_prefix = build_planner_prefix( + timestamp=message.timestamp, + user_name=speaker_name, + group_card=user_info.user_cardname or "", + message_id=message.message_id, + include_message_id=not message.is_notify and bool(message.message_id), + ) + history_message = SessionBackedMessage.from_session_message( + message, + raw_message=build_prefixed_message_sequence(message.raw_message, planner_prefix), + visible_text=build_session_message_visible_text(message), + source_kind=source_kind, + ) + self._chat_history.append(history_message) + return True + except Exception as exc: + logger.warning( + f"{self.log_prefix} 同步已发送消息到 Maisaka 历史失败: " + f"message_id={message.message_id} error={exc}" + ) + return False + async def register_message(self, message: SessionMessage) -> None: """缓存一条新消息并唤醒主循环。""" if self._running: diff --git a/src/plugin_runtime/capabilities/core.py b/src/plugin_runtime/capabilities/core.py index 843b8ce0..25f2a6b9 100644 --- a/src/plugin_runtime/capabilities/core.py +++ b/src/plugin_runtime/capabilities/core.py @@ -75,6 +75,8 @@ class RuntimeCoreCapabilityMixin: text = str(args.get("text", "")) stream_id = str(args.get("stream_id", "")) + sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False)) + maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send") if not text or not stream_id: return {"success": False, "error": "缺少必要参数 text 或 stream_id"} @@ -85,6 +87,8 @@ class RuntimeCoreCapabilityMixin: typing=bool(args.get("typing", False)), set_reply=bool(args.get("set_reply", False)), storage_message=bool(args.get("storage_message", True)), + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) return {"success": result} except Exception as exc: @@ -107,6 +111,8 @@ class RuntimeCoreCapabilityMixin: emoji_base64 = str(args.get("emoji_base64", "")) stream_id = str(args.get("stream_id", "")) + sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False)) + maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send") if not emoji_base64 or not stream_id: return {"success": False, "error": "缺少必要参数 emoji_base64 或 stream_id"} @@ -115,6 +121,8 @@ class RuntimeCoreCapabilityMixin: emoji_base64=emoji_base64, stream_id=stream_id, storage_message=bool(args.get("storage_message", True)), + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) return {"success": result} except Exception as exc: @@ -137,6 +145,8 @@ class RuntimeCoreCapabilityMixin: image_base64 = str(args.get("image_base64", "")) stream_id = str(args.get("stream_id", "")) + sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False)) + maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send") if not image_base64 or not stream_id: return {"success": False, "error": "缺少必要参数 image_base64 或 stream_id"} @@ -145,6 +155,8 @@ class RuntimeCoreCapabilityMixin: image_base64=image_base64, stream_id=stream_id, storage_message=bool(args.get("storage_message", True)), + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) return {"success": result} except Exception as exc: @@ -167,6 +179,8 @@ class RuntimeCoreCapabilityMixin: command = str(args.get("command", "")) stream_id = str(args.get("stream_id", "")) + sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False)) + maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send") if not command or not stream_id: return {"success": False, "error": "缺少必要参数 command 或 stream_id"} @@ -177,6 +191,8 @@ class RuntimeCoreCapabilityMixin: stream_id=stream_id, storage_message=bool(args.get("storage_message", True)), display_message=str(args.get("display_message", "")), + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) return {"success": result} except Exception as exc: @@ -202,6 +218,8 @@ class RuntimeCoreCapabilityMixin: if content is None: content = args.get("data", "") stream_id = str(args.get("stream_id", "")) + sync_to_maisaka_history = bool(args.get("sync_to_maisaka_history", False)) + maisaka_source_kind = str(args.get("maisaka_source_kind", "plugin_send") or "plugin_send") if not message_type or not stream_id: return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} @@ -213,6 +231,8 @@ class RuntimeCoreCapabilityMixin: display_message=str(args.get("display_message", "")), typing=bool(args.get("typing", False)), storage_message=bool(args.get("storage_message", True)), + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) return {"success": result} except Exception as exc: diff --git a/src/services/send_service.py b/src/services/send_service.py index 22d30e8f..34b3f6d9 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -707,6 +707,28 @@ async def _notify_memory_automation_on_message_sent(message: SessionMessage) -> logger.warning(f"[{session_id}] 长期记忆人物事实写回注册失败: {exc}") +def _sync_sent_message_to_maisaka_history( + message: SessionMessage, + *, + source_kind: str, +) -> None: + """将已发送成功的消息同步到当前会话对应的 Maisaka 历史。""" + + session_id = str(message.session_id or "").strip() + if not session_id: + return + + try: + from src.chat.heart_flow.heartflow_manager import heartflow_manager + + runtime = heartflow_manager.heartflow_chat_list.get(session_id) + if runtime is None: + return + runtime.append_sent_message_to_chat_history(message, source_kind=source_kind) + except Exception as exc: + logger.warning(f"[SendService] 同步消息到 Maisaka 历史失败: session_id={session_id} error={exc}") + + def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None: """输出 Platform IO 批量发送失败详情。 @@ -837,13 +859,15 @@ async def send_session_message_with_message( reply_message_id: Optional[str] = None, storage_message: bool = True, show_log: bool = True, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> Optional[SessionMessage]: """统一发送一条内部消息,并返回最终发送成功的消息对象。""" if not message.message_id: logger.error("[SendService] 消息缺少 message_id,无法发送") raise ValueError("消息缺少 message_id,无法发送") - return await _send_via_platform_io( + sent_message = await _send_via_platform_io( message, typing=typing, set_reply=set_reply, @@ -851,6 +875,12 @@ async def send_session_message_with_message( storage_message=storage_message, show_log=show_log, ) + if sent_message is not None and sync_to_maisaka_history: + _sync_sent_message_to_maisaka_history( + sent_message, + source_kind=str(maisaka_source_kind or "outbound_send"), + ) + return sent_message async def send_session_message( @@ -861,6 +891,8 @@ async def send_session_message( reply_message_id: Optional[str] = None, storage_message: bool = True, show_log: bool = True, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """统一发送一条内部消息。 @@ -893,6 +925,8 @@ async def send_session_message( reply_message_id=reply_message_id, storage_message=storage_message, show_log=show_log, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) is not None ) @@ -908,6 +942,8 @@ async def _send_to_target( storage_message: bool = True, show_log: bool = True, selected_expressions: Optional[List[int]] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """向指定目标构建并发送消息,并返回是否发送成功。""" return ( @@ -921,6 +957,8 @@ async def _send_to_target( storage_message=storage_message, show_log=show_log, selected_expressions=selected_expressions, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) is not None ) @@ -936,6 +974,8 @@ async def _send_to_target_with_message( storage_message: bool = True, show_log: bool = True, selected_expressions: Optional[List[int]] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> Optional[SessionMessage]: """向指定目标构建并发送消息。 @@ -998,6 +1038,8 @@ async def _send_to_target_with_message( reply_message_id=reply_message.message_id if reply_message is not None else None, storage_message=storage_message, show_log=show_log, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) if sent_message is not None: logger.debug(f"[SendService] 成功发送消息到 {stream_id}") @@ -1019,6 +1061,8 @@ async def text_to_stream_with_message( reply_message: Optional[MaiMessage] = None, storage_message: bool = True, selected_expressions: Optional[List[int]] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> Optional[SessionMessage]: """向指定流发送文本消息,并返回发送成功后的消息对象。""" return await _send_to_target_with_message( @@ -1030,6 +1074,8 @@ async def text_to_stream_with_message( reply_message=reply_message, storage_message=storage_message, selected_expressions=selected_expressions, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) @@ -1041,6 +1087,8 @@ async def text_to_stream( reply_message: Optional[MaiMessage] = None, storage_message: bool = True, selected_expressions: Optional[List[int]] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """向指定流发送文本消息。 @@ -1065,6 +1113,8 @@ async def text_to_stream( reply_message=reply_message, storage_message=storage_message, selected_expressions=selected_expressions, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) is not None ) @@ -1076,6 +1126,8 @@ async def emoji_to_stream_with_message( storage_message: bool = True, set_reply: bool = False, reply_message: Optional[MaiMessage] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> Optional[SessionMessage]: """向指定流发送表情消息,并返回发送成功后的消息对象。""" return await _send_to_target_with_message( @@ -1086,6 +1138,8 @@ async def emoji_to_stream_with_message( storage_message=storage_message, set_reply=set_reply, reply_message=reply_message, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) @@ -1095,6 +1149,8 @@ async def emoji_to_stream( storage_message: bool = True, set_reply: bool = False, reply_message: Optional[MaiMessage] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """向指定流发送表情消息。 @@ -1115,6 +1171,8 @@ async def emoji_to_stream( storage_message=storage_message, set_reply=set_reply, reply_message=reply_message, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) is not None ) @@ -1126,6 +1184,8 @@ async def image_to_stream( storage_message: bool = True, set_reply: bool = False, reply_message: Optional[MaiMessage] = None, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """向指定流发送图片消息。 @@ -1147,6 +1207,8 @@ async def image_to_stream( storage_message=storage_message, set_reply=set_reply, reply_message=reply_message, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) @@ -1160,6 +1222,8 @@ async def custom_to_stream( set_reply: bool = False, storage_message: bool = True, show_log: bool = True, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """向指定流发送自定义类型消息。 @@ -1186,6 +1250,8 @@ async def custom_to_stream( set_reply=set_reply, storage_message=storage_message, show_log=show_log, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, ) @@ -1198,6 +1264,8 @@ async def custom_reply_set_to_stream( set_reply: bool = False, storage_message: bool = True, show_log: bool = True, + sync_to_maisaka_history: bool = False, + maisaka_source_kind: str = "outbound_send", ) -> bool: """向指定流发送消息组件序列。 @@ -1223,4 +1291,6 @@ async def custom_reply_set_to_stream( set_reply=set_reply, storage_message=storage_message, show_log=show_log, + sync_to_maisaka_history=sync_to_maisaka_history, + maisaka_source_kind=maisaka_source_kind, )