diff --git a/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py b/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py deleted file mode 100644 index 0f084ece..00000000 --- a/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py +++ /dev/null @@ -1,148 +0,0 @@ -from types import SimpleNamespace - -import pytest - -from src.memory_system import chat_history_summarizer as summarizer_module - - -def _build_summarizer() -> summarizer_module.ChatHistorySummarizer: - summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer) - summarizer.session_id = "session-1" - summarizer.log_prefix = "[session-1]" - return summarizer - - -@pytest.mark.asyncio -async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch): - calls = [] - summarizer = _build_summarizer() - - async def fake_ingest_summary(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) - - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), - ) - monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8))) - monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) - - await summarizer._import_to_long_term_memory( - record_id=1, - theme="旅行计划", - summary="我们讨论了春游安排", - participants=["Alice", "Bob"], - start_time=1.0, - end_time=2.0, - original_text="long text", - ) - - assert len(calls) == 1 - payload = calls[0] - assert payload["external_id"] == "chat_history:1" - assert payload["chat_id"] == "session-1" - assert payload["participants"] == ["Alice", "Bob"] - assert payload["respect_filter"] is True - assert payload["user_id"] == "user-1" - assert payload["group_id"] == "" - assert "主题:旅行计划" in payload["text"] - assert "概括:我们讨论了春游安排" in payload["text"] - - -@pytest.mark.asyncio -async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch): - summarizer = _build_summarizer() - fallback_calls = [] - - async def fake_fallback(**kwargs): - fallback_calls.append(kwargs) - - summarizer._fallback_import_to_long_term_memory = fake_fallback - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), - ) - - await summarizer._import_to_long_term_memory( - record_id=2, - theme="", - summary="", - participants=[], - start_time=10.0, - end_time=20.0, - original_text="raw chat", - ) - - assert len(fallback_calls) == 1 - assert fallback_calls[0]["record_id"] == 2 - assert fallback_calls[0]["original_text"] == "raw chat" - - -@pytest.mark.asyncio -async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch): - summarizer = _build_summarizer() - fallback_calls = [] - - async def fake_ingest_summary(**kwargs): - return SimpleNamespace(success=False, detail="boom", stored_ids=[]) - - async def fake_fallback(**kwargs): - fallback_calls.append(kwargs) - - summarizer._fallback_import_to_long_term_memory = fake_fallback - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")), - ) - monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) - - await summarizer._import_to_long_term_memory( - record_id=3, - theme="电影", - summary="聊了电影推荐", - participants=["Alice"], - start_time=3.0, - end_time=4.0, - original_text="raw", - ) - - assert len(fallback_calls) == 1 - assert fallback_calls[0]["theme"] == "电影" - - -@pytest.mark.asyncio -async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch): - calls = [] - summarizer = _build_summarizer() - - async def fake_ingest_summary(**kwargs): - calls.append(kwargs) - return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[]) - - monkeypatch.setattr( - summarizer_module, - "_chat_manager", - SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")), - ) - monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12))) - monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) - - await summarizer._fallback_import_to_long_term_memory( - record_id=4, - theme="工作", - participants=["Alice"], - start_time=5.0, - end_time=6.0, - original_text="a" * 128, - ) - - assert len(calls) == 1 - metadata = calls[0]["metadata"] - assert metadata["generate_from_chat"] is True - assert metadata["context_length"] == 12 - assert calls[0]["respect_filter"] is True - diff --git a/pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py b/pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py deleted file mode 100644 index b20f4e5d..00000000 --- a/pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py +++ /dev/null @@ -1,1187 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -import os -import re -import sys -import tempfile -import types -import typing -from datetime import datetime -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import numpy as np - -PROJECT_ROOT = Path(__file__).resolve().parents[2] -PLUGINS_ROOT = PROJECT_ROOT / "plugins" -SDK_ROOT = PROJECT_ROOT / "packages" / "maibot-plugin-sdk" - -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) -if str(PLUGINS_ROOT) not in sys.path: - sys.path.insert(0, str(PLUGINS_ROOT)) -if str(SDK_ROOT) not in sys.path: - sys.path.insert(0, str(SDK_ROOT)) - -if "maibot_sdk" not in sys.modules: - maibot_sdk = types.ModuleType("maibot_sdk") - maibot_sdk_types = types.ModuleType("maibot_sdk.types") - - class _FakeMaiBotPlugin: - def __init__(self, *args, **kwargs) -> None: - del args, kwargs - - def _fake_tool(*decorator_args, **decorator_kwargs): - del decorator_args, decorator_kwargs - - def wrapper(func): - return func - - return wrapper - - class _FakeToolParameterInfo: - def __init__(self, name: str, param_type: str, description: str, required: bool) -> None: - self.name = name - self.param_type = param_type - self.description = description - self.required = required - - class _FakeToolParamType: - STRING = "string" - INTEGER = "integer" - FLOAT = "float" - BOOLEAN = "boolean" - - maibot_sdk.MaiBotPlugin = _FakeMaiBotPlugin - maibot_sdk.Tool = _fake_tool - maibot_sdk_types.ToolParameterInfo = _FakeToolParameterInfo - maibot_sdk_types.ToolParamType = _FakeToolParamType - sys.modules["maibot_sdk"] = maibot_sdk - sys.modules["maibot_sdk.types"] = maibot_sdk_types - -try: - import aiohttp # type: ignore -except Exception: - aiohttp = types.ModuleType("aiohttp") - - class _FakeAioHttpClientError(Exception): - pass - - aiohttp.ClientError = _FakeAioHttpClientError - sys.modules["aiohttp"] = aiohttp - -try: - import openai # type: ignore -except Exception: - openai = types.ModuleType("openai") - - class _FakeOpenAIConnectionError(Exception): - pass - - class _FakeOpenAITimeoutError(Exception): - pass - - openai.APIConnectionError = _FakeOpenAIConnectionError - openai.APITimeoutError = _FakeOpenAITimeoutError - sys.modules["openai"] = openai - -if "eval_type_backport" not in sys.modules: - eval_type_backport_module = types.ModuleType("eval_type_backport") - - def _rewrite_union(expr: str) -> str: - clean = str(expr or "").strip() - if "|" not in clean: - return clean - parts = [part.strip() for part in clean.split("|")] - if len(parts) == 2 and "None" in parts: - target = parts[0] if parts[1] == "None" else parts[1] - return f"typing.Optional[{target}]" - return f"typing.Union[{', '.join(parts)}]" - - def _eval_type_backport(value, globalns=None, localns=None, try_default=False): - del try_default - expr = getattr(value, "__forward_arg__", value) - if not isinstance(expr, str): - expr = str(expr) - gns = dict(globalns or {}) - lns = dict(localns or {}) - builtin_aliases = { - "dict": typing.Dict, - "list": typing.List, - "set": typing.Set, - "tuple": typing.Tuple, - } - for namespace in (gns, lns): - namespace.setdefault("typing", typing) - namespace.setdefault("Any", typing.Any) - namespace.setdefault("Optional", typing.Optional) - namespace.setdefault("Union", typing.Union) - namespace.setdefault("Literal", getattr(typing, "Literal", None)) - namespace.update({key: value for key, value in builtin_aliases.items() if key not in namespace}) - - try: - return eval(expr, gns, lns) - except TypeError: - return eval(_rewrite_union(expr), gns, lns) - - eval_type_backport_module.eval_type_backport = _eval_type_backport - sys.modules["eval_type_backport"] = eval_type_backport_module - -from A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import MemorySearchResult, memory_service - - -def _resolve_benchmark_paths() -> tuple[Path, Path]: - configured = str(os.environ.get("A_MEMORIX_BENCHMARK_DATA_FILE", "") or "").strip() - if configured: - data_file = Path(configured).expanduser() - if not data_file.is_absolute(): - data_file = (Path.cwd() / data_file).resolve() - else: - data_file = Path(__file__).parent / "data" / "benchmarks" / "group_chat_stream_memory_benchmark.json" - - report_file = Path(__file__).parent / "data" / "benchmarks" / "results" / f"{data_file.stem}_report.json" - return data_file, report_file - - -DATA_FILE, REPORT_FILE = _resolve_benchmark_paths() - - -def _load_benchmark_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _PatchManager: - def __init__(self) -> None: - self._records: List[tuple[Any, str, Any, bool]] = [] - - def setattr(self, target: Any, name: str, value: Any) -> None: - existed = hasattr(target, name) - original = getattr(target, name) if existed else None - self._records.append((target, name, original, existed)) - setattr(target, name, value) - - def undo(self) -> None: - while self._records: - target, name, original, existed = self._records.pop() - if existed: - setattr(target, name, original) - else: - delattr(target, name) - - -class _FakeEmbeddingAdapter: - def __init__(self, dimension: int = 32) -> None: - self.dimension = dimension - - async def _detect_dimension(self) -> int: - return self.dimension - - async def encode(self, texts, dimensions=None): - dim = int(dimensions or self.dimension) - if isinstance(texts, str): - sequence = [texts] - single = True - else: - sequence = list(texts) - single = False - - rows = [] - for text in sequence: - vec = np.zeros(dim, dtype=np.float32) - for ch in str(text or ""): - code = ord(ch) - vec[code % dim] += 1.0 - vec[(code * 7) % dim] += 0.5 - if not vec.any(): - vec[0] = 1.0 - norm = np.linalg.norm(vec) - if norm > 0: - vec = vec / norm - rows.append(vec) - payload = np.vstack(rows) - return payload[0] if single else payload - - -class _KnownPerson: - def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None: - self.person_id = person_id - self.is_known = person_id in reverse_registry - self.person_name = reverse_registry.get(person_id, "") - self._registry = registry - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]: - for _ in range(max_rounds): - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(max(0.01, float(sleep_seconds))) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result: MemorySearchResult) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -def _keyword_hits(text: str, keywords: List[str]) -> int: - haystack = str(text or "") - return sum(1 for keyword in keywords if keyword in haystack) - - -def _keyword_recall(text: str, keywords: List[str]) -> float: - if not keywords: - return 1.0 - return _keyword_hits(text, keywords) / float(len(keywords)) - - -def _hit_blob(hit) -> str: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - return "\n".join( - [ - str(hit.content or ""), - str(hit.title or ""), - str(hit.source or ""), - json.dumps(meta, ensure_ascii=False), - ] - ) - - -def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int: - for index, hit in enumerate(search_result.hits[:5], start=1): - if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))): - return index - return 0 - - -def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str: - return "\n".join( - ( - f"{item.get('title', '')}\n" - f"{item.get('summary', '')}\n" - f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n" - f"{json.dumps(item.get('participants', []), ensure_ascii=False)}" - ) - for item in items - ) - - -def _episode_blob_from_hits(search_result: MemorySearchResult) -> str: - chunks = [] - for hit in search_result.hits: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - chunks.append( - "\n".join( - [ - str(hit.title or ""), - str(hit.content or ""), - json.dumps(meta.get("keywords", []) or [], ensure_ascii=False), - json.dumps(meta.get("participants", []) or [], ensure_ascii=False), - ] - ) - ) - return "\n".join(chunks) - - -async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - episode_source = f"chat_summary:{session_id}" - payload = await memory_service.episode_admin(action="query", source=episode_source, limit=20) - items = payload.get("items") or [] - blob = _episode_blob_from_items(items) - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "episode_count": len(items), - "reports": reports, - } - - -async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - episode_source = f"chat_summary:{session_id}" - - for case in episode_cases: - payload = await memory_service.episode_admin( - action="query", - source=episode_source, - query=case["query"], - limit=5, - ) - items = payload.get("items") or [] - blob = "\n".join( - f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}" - for item in items - ) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - result = await memory_service.search( - case["query"], - mode="episode", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - blob = _episode_blob_from_hits(result) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(result.hits), - "top_episode": result.hits[0].to_dict() if result.hits else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_time_cases(*, session_id: str, time_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in time_cases: - result = await memory_service.search( - case["query"], - mode="time", - chat_id=session_id, - time_start=case["time_expression"], - time_end=case["time_expression"], - respect_filter=False, - limit=5, - ) - blob = "\n".join(_hit_blob(hit) for hit in result.hits) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "time_expression": case["time_expression"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - - total = max(1, len(time_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]: - search_case = dataset["search_cases"][0] - time_case = dataset["time_cases"][0] - episode_case = dataset["episode_cases"][0] - aggregate_case = dataset["knowledge_fetcher_cases"][0] - tool_cases = [ - { - "name": "search", - "kwargs": { - "query": search_case["query"], - "mode": "search", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": search_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "time", - "kwargs": { - "query": time_case["query"], - "mode": "time", - "chat_id": session_id, - "limit": 5, - "time_expression": time_case["time_expression"], - }, - "expected_keywords": time_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "episode", - "kwargs": { - "query": episode_case["query"], - "mode": "episode", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": episode_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "aggregate", - "kwargs": { - "query": aggregate_case["query"], - "mode": "aggregate", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": aggregate_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - ] - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in tool_cases: - text = await query_long_term_memory(**case["kwargs"]) - recall = _keyword_recall(text, case["expected_keywords"]) - success = ( - "失败" not in text - and "无法解析" not in text - and "未找到" not in text - and recall >= float(case["minimum_keyword_recall"]) - ) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "name": case["name"], - "success": success, - "keyword_recall": recall, - "preview": text[:320], - } - ) - - total = max(1, len(tool_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -def _parse_fixture_message( - *, - line: str, - session_id: str, - platform: str, - group_id: str, - group_name: str, - speaker_to_user_id: Dict[str, str], - seq: int, -): - match = re.match(r"^\[(?P[^\]]+)\]\s*(?P[^:]+):(?P.*)$", str(line or "").strip()) - if not match: - raise ValueError(f"无法解析 fixture 消息: {line}") - - dt = datetime.strptime(match.group("ts"), "%Y-%m-%d %H:%M") - speaker = match.group("speaker").strip() - text = match.group("text").strip() - user_id = speaker_to_user_id.setdefault(speaker, f"user-{len(speaker_to_user_id) + 1}") - - return SimpleNamespace( - message_id=f"{session_id}-{seq}", - timestamp=dt, - platform=platform, - session_id=session_id, - reply_to=None, - processed_plain_text=text, - display_message=f"{speaker}:{text}", - message_info=SimpleNamespace( - user_info=SimpleNamespace( - user_id=user_id, - user_nickname=speaker, - user_cardname=speaker, - ), - group_info=SimpleNamespace(group_id=group_id, group_name=group_name), - additional_config={}, - ), - ) - - -async def _evaluate_runtime_trigger_flow(*, dataset: Dict[str, Any], session: Any) -> Dict[str, Any]: - streams = dataset["runtime_trigger_streams"] - reports: List[Dict[str, Any]] = [] - positive_successes = 0.0 - negative_successes = 0.0 - - speaker_to_user_id = {"Mai": "bot-mai"} - for payload in dataset["person_writebacks"]: - speaker_to_user_id.setdefault(payload["person_name"], payload["person_id"]) - - original_get_messages = summarizer_module.message_api.get_messages_by_time_in_chat - original_build_readable = summarizer_module.message_api.build_readable_messages - original_is_bot_self = summarizer_module.is_bot_self - original_person = summarizer_module.Person - original_analyze = summarizer_module.ChatHistorySummarizer._analyze_topics_with_llm - - try: - summarizer_module.message_api.build_readable_messages = ( - lambda messages, **kwargs: "\n".join(getattr(msg, "display_message", getattr(msg, "processed_plain_text", "")) for msg in messages) - ) - summarizer_module.is_bot_self = lambda platform, user_id: str(user_id or "") == "bot-mai" - summarizer_module.Person = lambda platform="", user_id="", person_id="": SimpleNamespace( - person_name=next( - ( - name - for name, mapped in speaker_to_user_id.items() - if mapped == (str(person_id or "").strip() or str(user_id or "").strip()) - ), - "", - ) - ) - - async def fake_analyze(self, numbered_lines, existing_topics): - del existing_topics - topic = str(getattr(self, "_fixture_topic", "") or "群聊话题") - return True, {topic: list(range(1, len(numbered_lines) + 1))} - - summarizer_module.ChatHistorySummarizer._analyze_topics_with_llm = fake_analyze - - for index, stream in enumerate(streams, start=1): - messages = [ - _parse_fixture_message( - line=line, - session_id=session.session_id, - platform=session.platform, - group_id=session.group_id, - group_name=dataset["session"]["display_name"], - speaker_to_user_id=speaker_to_user_id, - seq=1000 * index + seq, - ) - for seq, line in enumerate(stream["messages"], start=1) - ] - - def fake_get_messages_by_time_in_chat( - *, - chat_id: str, - start_time: float, - end_time: float, - limit: int = 0, - limit_mode: str = "latest", - filter_mai: bool = False, - filter_command: bool = False, - ): - del limit, limit_mode, filter_mai, filter_command - if chat_id != session.session_id: - return [] - return [ - msg - for msg in messages - if start_time <= msg.timestamp.timestamp() <= end_time - ] - - summarizer_module.message_api.get_messages_by_time_in_chat = fake_get_messages_by_time_in_chat - - summarizer = summarizer_module.ChatHistorySummarizer(session.session_id) - summarizer._fixture_topic = stream["topic"] - summarizer._persist_topic_cache = lambda: None - summarizer.last_check_time = float(stream["start_time"]) - 60.0 - summarizer.last_topic_check_time = float(stream["end_time"]) - float(stream["elapsed_since_last_check_hours"]) * 3600.0 - - await summarizer.process(current_time=float(stream["end_time"])) - - topic_cache_keys = list(summarizer.topic_cache.keys()) - expected_topic_present = stream["topic"] in summarizer.topic_cache - topic_cache_updated = bool(topic_cache_keys) - batch_cleared = summarizer.current_batch is None - - if stream["bot_participated"]: - success = expected_topic_present and topic_cache_updated and batch_cleared - positive_successes += 1.0 if success else 0.0 - else: - success = (not topic_cache_updated) and batch_cleared - negative_successes += 1.0 if success else 0.0 - - reports.append( - { - "stream_id": stream["stream_id"], - "bot_participated": stream["bot_participated"], - "success": success, - "topic_cache_keys": topic_cache_keys, - "current_batch_cleared": batch_cleared, - "expected_check_outcome": stream["expected_check_outcome"], - } - ) - finally: - summarizer_module.message_api.get_messages_by_time_in_chat = original_get_messages - summarizer_module.message_api.build_readable_messages = original_build_readable - summarizer_module.is_bot_self = original_is_bot_self - summarizer_module.Person = original_person - summarizer_module.ChatHistorySummarizer._analyze_topics_with_llm = original_analyze - - positive_total = max(1, sum(1 for item in streams if item["bot_participated"])) - negative_total = max(1, sum(1 for item in streams if not item["bot_participated"])) - return { - "positive_trigger_rate": round(positive_successes / positive_total, 4), - "negative_discard_rate": round(negative_successes / negative_total, 4), - "reports": reports, - } - - -def _average(values: List[float]) -> float: - if not values: - return 0.0 - return round(sum(float(v) for v in values) / float(len(values)), 4) - - -def _target_score(actual: float, target: float) -> float: - clean_target = float(target or 0.0) - if clean_target <= 0: - return 1.0 - return min(float(actual) / clean_target, 1.0) - - -def _build_final_score(*, metrics: Dict[str, Any], targets: Dict[str, Any]) -> Dict[str, Any]: - episode_summary = metrics["episode_summary_after_rebuild"] - category_scores = { - "search": _average( - [ - _target_score(metrics["search"]["accuracy_at_1"], targets["search"]["accuracy_at_1"]), - _target_score(metrics["search"]["recall_at_5"], targets["search"]["recall_at_5"]), - _target_score(metrics["search"]["keyword_recall_at_5"], targets["search"]["keyword_recall_at_5"]), - ] - ), - "writeback": _average( - [ - _target_score(metrics["writeback"]["success_rate"], targets["writeback"]["success_rate"]), - _target_score(metrics["writeback"]["keyword_recall"], targets["writeback"]["keyword_recall"]), - ] - ), - "knowledge_fetcher": _average( - [ - _target_score(metrics["knowledge_fetcher"]["success_rate"], targets["knowledge_fetcher"]["success_rate"]), - _target_score(metrics["knowledge_fetcher"]["keyword_recall"], targets["knowledge_fetcher"]["keyword_recall"]), - ] - ), - "profile": _average( - [ - _target_score(metrics["profile"]["success_rate"], targets["profile"]["success_rate"]), - _target_score(metrics["profile"]["evidence_rate"], targets["profile"]["evidence_rate"]), - ] - ), - "episode": _average( - [ - _target_score(episode_summary["success_rate"], targets["episode"]["success_rate"]), - _target_score(episode_summary["keyword_recall"], targets["episode"]["keyword_recall"]), - ] - ), - "negative_control": _target_score( - metrics["negative_control"]["zero_hit_rate"], targets["negative_control"]["zero_hit_rate"] - ), - "runtime_trigger": _average( - [ - _target_score( - metrics["runtime_trigger"]["positive_trigger_rate"], - targets["runtime_trigger"]["positive_trigger_rate"], - ), - _target_score( - metrics["runtime_trigger"]["negative_discard_rate"], - targets["runtime_trigger"]["negative_discard_rate"], - ), - ] - ), - } - overall_ratio = _average(list(category_scores.values())) - return { - "overall_ratio": overall_ratio, - "overall_score": round(overall_ratio * 100.0, 2), - "category_scores": {key: round(value * 100.0, 2) for key, value in category_scores.items()}, - } - - -async def _build_group_chat_benchmark_env(patch_manager: _PatchManager, tmp_path: Path): - dataset = _load_benchmark_fixture() - session_cfg = dataset["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} - reverse_registry = {value: key for key, value in registry.items()} - - patch_manager.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) - - async def fake_self_check(**kwargs): - return {"ok": True, "message": "ok", "encoded_dimension": 32} - - patch_manager.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) - patch_manager.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - patch_manager.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - patch_manager.setattr(person_info_module, "_chat_manager", fake_chat_manager) - patch_manager.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: registry.get(str(person_name or "").strip(), ""), - ) - patch_manager.setattr( - person_info_module, - "Person", - lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), - ) - - data_dir = (tmp_path / "a_memorix_group_chat_benchmark_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - patch_manager.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - return { - "dataset": dataset, - "kernel": kernel, - "session": session, - "person_registry": registry, - } - - -async def _run_group_chat_stream_memory_benchmark(tmp_path: Path): - patch_manager = _PatchManager() - group_chat_benchmark_env = await _build_group_chat_benchmark_env(patch_manager, tmp_path) - dataset = group_chat_benchmark_env["dataset"] - session_id = group_chat_benchmark_env["session"].session_id - - try: - created = await memory_service.import_admin( - action="create_paste", - name="group_chat_stream_memory_benchmark.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(dataset["import_payload"], ensure_ascii=False), - ) - assert created["success"] is True - - import_detail = await _wait_for_import_task(created["task"]["task_id"]) - assert import_detail["task"]["status"] == "completed" - - for record in dataset["chat_history_records"]: - summarizer = summarizer_module.ChatHistorySummarizer(session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - for payload in dataset["person_writebacks"]: - await person_info_module.store_person_memory_from_answer( - payload["person_name"], - payload["memory_content"], - session_id, - ) - - await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) - - search_case_reports: List[Dict[str, Any]] = [] - search_accuracy_at_1 = 0.0 - search_recall_at_5 = 0.0 - search_precision_at_5 = 0.0 - search_mrr = 0.0 - search_keyword_recall = 0.0 - - for case in dataset["search_cases"]: - result = await memory_service.search( - case["query"], - mode="search", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - joined = "\n".join(_hit_blob(hit) for hit in result.hits) - rank = _first_relevant_rank( - result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])) - ) - relevant_hits = sum( - 1 - for hit in result.hits[:5] - if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max( - 1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))) - ) - ) - keyword_recall = _keyword_recall(joined, case["expected_keywords"]) - search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 - search_recall_at_5 += 1.0 if rank > 0 else 0.0 - search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) - search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 - search_keyword_recall += keyword_recall - search_case_reports.append( - { - "query": case["query"], - "rank_of_first_relevant": rank, - "relevant_hits_top5": relevant_hits, - "keyword_recall_top5": keyword_recall, - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - - search_total = max(1, len(dataset["search_cases"])) - - writeback_reports: List[Dict[str, Any]] = [] - writeback_success_rate = 0.0 - writeback_keyword_recall = 0.0 - for payload in dataset["person_writebacks"]: - query = " ".join(payload["expected_keywords"]) - result = await memory_service.search( - query, - mode="search", - chat_id=session_id, - person_id=payload["person_id"], - respect_filter=False, - limit=5, - ) - joined = _join_hit_content(result) - recall = _keyword_recall(joined, payload["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - writeback_success_rate += 1.0 if success else 0.0 - writeback_keyword_recall += recall - writeback_reports.append( - { - "person_id": payload["person_id"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - } - ) - writeback_total = max(1, len(dataset["person_writebacks"])) - - knowledge_reports: List[Dict[str, Any]] = [] - knowledge_success_rate = 0.0 - knowledge_keyword_recall = 0.0 - fetcher = knowledge_module.KnowledgeFetcher( - private_name=dataset["session"]["display_name"], - stream_id=session_id, - ) - for case in dataset["knowledge_fetcher_cases"]: - knowledge_text, _ = await fetcher.fetch(case["query"], []) - recall = _keyword_recall(knowledge_text, case["expected_keywords"]) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) - knowledge_success_rate += 1.0 if success else 0.0 - knowledge_keyword_recall += recall - knowledge_reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "preview": knowledge_text[:300], - } - ) - knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) - - profile_reports: List[Dict[str, Any]] = [] - profile_success_rate = 0.0 - profile_keyword_recall = 0.0 - profile_evidence_rate = 0.0 - for case in dataset["profile_cases"]: - profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) - recall = _keyword_recall(profile.summary, case["expected_keywords"]) - has_evidence = bool(profile.evidence) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence - profile_success_rate += 1.0 if success else 0.0 - profile_keyword_recall += recall - profile_evidence_rate += 1.0 if has_evidence else 0.0 - profile_reports.append( - { - "person_id": case["person_id"], - "success": success, - "keyword_recall": recall, - "evidence_count": len(profile.evidence), - "summary_preview": profile.summary[:240], - } - ) - profile_total = max(1, len(dataset["profile_cases"])) - - time_mode = await _evaluate_time_cases(session_id=session_id, time_cases=dataset["time_cases"]) - episode_generation_auto = await _evaluate_episode_generation( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_admin_query_auto = await _evaluate_episode_admin_query( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_search_mode_auto = await _evaluate_episode_search_mode( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_rebuild = await memory_service.episode_admin(action="rebuild", source=f"chat_summary:{session_id}") - episode_generation_after_rebuild = await _evaluate_episode_generation( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_admin_query_after_rebuild = await _evaluate_episode_admin_query( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - episode_search_mode_after_rebuild = await _evaluate_episode_search_mode( - session_id=session_id, episode_cases=dataset["episode_cases"] - ) - tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) - - negative_reports: List[Dict[str, Any]] = [] - negative_zero_hits = 0.0 - for case in dataset["negative_control_cases"]: - result = await memory_service.search( - case["query"], - mode="search", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - success = len(result.hits) == 0 - negative_zero_hits += 1.0 if success else 0.0 - negative_reports.append( - { - "query": case["query"], - "success": success, - "hit_count": len(result.hits), - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - negative_total = max(1, len(dataset["negative_control_cases"])) - - runtime_trigger = await _evaluate_runtime_trigger_flow( - dataset=dataset, - session=group_chat_benchmark_env["session"], - ) - - episode_summary_after_rebuild = { - "success_rate": _average( - [ - episode_generation_after_rebuild["success_rate"], - episode_admin_query_after_rebuild["success_rate"], - episode_search_mode_after_rebuild["success_rate"], - ] - ), - "keyword_recall": _average( - [ - episode_generation_after_rebuild["keyword_recall"], - episode_admin_query_after_rebuild["keyword_recall"], - episode_search_mode_after_rebuild["keyword_recall"], - ] - ), - } - - metrics = { - "search": { - "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), - "recall_at_5": round(search_recall_at_5 / search_total, 4), - "precision_at_5": round(search_precision_at_5 / search_total, 4), - "mrr": round(search_mrr / search_total, 4), - "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), - }, - "time_mode": { - "success_rate": time_mode["success_rate"], - "keyword_recall": time_mode["keyword_recall"], - }, - "writeback": { - "success_rate": round(writeback_success_rate / writeback_total, 4), - "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), - }, - "knowledge_fetcher": { - "success_rate": round(knowledge_success_rate / knowledge_total, 4), - "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), - }, - "profile": { - "success_rate": round(profile_success_rate / profile_total, 4), - "keyword_recall": round(profile_keyword_recall / profile_total, 4), - "evidence_rate": round(profile_evidence_rate / profile_total, 4), - }, - "tool_modes": { - "success_rate": tool_modes["success_rate"], - "keyword_recall": tool_modes["keyword_recall"], - }, - "episode_generation_auto": { - "success_rate": episode_generation_auto["success_rate"], - "keyword_recall": episode_generation_auto["keyword_recall"], - "episode_count": episode_generation_auto["episode_count"], - }, - "episode_generation_after_rebuild": { - "success_rate": episode_generation_after_rebuild["success_rate"], - "keyword_recall": episode_generation_after_rebuild["keyword_recall"], - "episode_count": episode_generation_after_rebuild["episode_count"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_admin_query_auto": { - "success_rate": episode_admin_query_auto["success_rate"], - "keyword_recall": episode_admin_query_auto["keyword_recall"], - }, - "episode_admin_query_after_rebuild": { - "success_rate": episode_admin_query_after_rebuild["success_rate"], - "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_search_mode_auto": { - "success_rate": episode_search_mode_auto["success_rate"], - "keyword_recall": episode_search_mode_auto["keyword_recall"], - }, - "episode_search_mode_after_rebuild": { - "success_rate": episode_search_mode_after_rebuild["success_rate"], - "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_summary_after_rebuild": episode_summary_after_rebuild, - "negative_control": { - "zero_hit_rate": round(negative_zero_hits / negative_total, 4), - }, - "runtime_trigger": { - "positive_trigger_rate": runtime_trigger["positive_trigger_rate"], - "negative_discard_rate": runtime_trigger["negative_discard_rate"], - }, - } - final_score = _build_final_score(metrics=metrics, targets=dataset["meta"]["quantitative_targets"]) - - report = { - "dataset": dataset["meta"], - "import": { - "task_id": created["task"]["task_id"], - "status": import_detail["task"]["status"], - "paragraph_count": len(dataset["import_payload"]["paragraphs"]), - "relation_count": len(dataset["import_payload"].get("relations") or []), - }, - "metrics": metrics, - "score": final_score, - "cases": { - "search": search_case_reports, - "time_mode": time_mode["reports"], - "writeback": writeback_reports, - "knowledge_fetcher": knowledge_reports, - "profile": profile_reports, - "tool_modes": tool_modes["reports"], - "episode_generation_auto": episode_generation_auto["reports"], - "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], - "episode_admin_query_auto": episode_admin_query_auto["reports"], - "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], - "episode_search_mode_auto": episode_search_mode_auto["reports"], - "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], - "negative_control": negative_reports, - "runtime_trigger": runtime_trigger["reports"], - }, - } - - REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) - REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") - print(json.dumps({"metrics": report["metrics"], "score": report["score"]}, ensure_ascii=False, indent=2)) - - assert report["import"]["status"] == "completed" - assert report["score"]["overall_score"] >= 0.0 - return report - finally: - await group_chat_benchmark_env["kernel"].shutdown() - patch_manager.undo() - - -def run_group_chat_stream_memory_benchmark() -> Dict[str, Any]: - with tempfile.TemporaryDirectory(prefix="a_memorix_group_chat_benchmark_") as tmp_dir: - return asyncio.run(_run_group_chat_stream_memory_benchmark(Path(tmp_dir))) - - -if __name__ == "__main__": - report = run_group_chat_stream_memory_benchmark() - print(json.dumps({"final_score": report["score"], "metrics": report["metrics"]}, ensure_ascii=False, indent=2)) diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark.py deleted file mode 100644 index 819ecd97..00000000 --- a/pytests/A_memorix_test/test_long_novel_memory_benchmark.py +++ /dev/null @@ -1,687 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -from datetime import datetime -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import numpy as np -import pytest -import pytest_asyncio - -from A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import MemorySearchResult, memory_service - - -DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json" -REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json" - - -def _load_benchmark_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _FakeEmbeddingAdapter: - def __init__(self, dimension: int = 32) -> None: - self.dimension = dimension - - async def _detect_dimension(self) -> int: - return self.dimension - - async def encode(self, texts, dimensions=None): - dim = int(dimensions or self.dimension) - if isinstance(texts, str): - sequence = [texts] - single = True - else: - sequence = list(texts) - single = False - - rows = [] - for text in sequence: - vec = np.zeros(dim, dtype=np.float32) - for ch in str(text or ""): - code = ord(ch) - vec[code % dim] += 1.0 - vec[(code * 7) % dim] += 0.5 - if not vec.any(): - vec[0] = 1.0 - norm = np.linalg.norm(vec) - if norm > 0: - vec = vec / norm - rows.append(vec) - payload = np.vstack(rows) - return payload[0] if single else payload - - -class _KnownPerson: - def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None: - self.person_id = person_id - self.is_known = person_id in reverse_registry - self.person_name = reverse_registry.get(person_id, "") - self._registry = registry - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]: - for _ in range(max_rounds): - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(max(0.01, float(sleep_seconds))) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result: MemorySearchResult) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -def _keyword_hits(text: str, keywords: List[str]) -> int: - haystack = str(text or "") - return sum(1 for keyword in keywords if keyword in haystack) - - -def _keyword_recall(text: str, keywords: List[str]) -> float: - if not keywords: - return 1.0 - return _keyword_hits(text, keywords) / float(len(keywords)) - - -def _hit_blob(hit) -> str: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - return "\n".join( - [ - str(hit.content or ""), - str(hit.title or ""), - str(hit.source or ""), - json.dumps(meta, ensure_ascii=False), - ] - ) - - -def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int: - for index, hit in enumerate(search_result.hits[:5], start=1): - if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))): - return index - return 0 - - -def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str: - return "\n".join( - ( - f"{item.get('title', '')}\n" - f"{item.get('summary', '')}\n" - f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n" - f"{json.dumps(item.get('participants', []), ensure_ascii=False)}" - ) - for item in items - ) - - -def _episode_blob_from_hits(search_result: MemorySearchResult) -> str: - chunks = [] - for hit in search_result.hits: - meta = hit.metadata if isinstance(hit.metadata, dict) else {} - chunks.append( - "\n".join( - [ - str(hit.title or ""), - str(hit.content or ""), - json.dumps(meta.get("keywords", []) or [], ensure_ascii=False), - json.dumps(meta.get("participants", []) or [], ensure_ascii=False), - ] - ) - ) - return "\n".join(chunks) - - -async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - episode_source = f"chat_summary:{session_id}" - payload = await memory_service.episode_admin( - action="query", - source=episode_source, - limit=20, - ) - items = payload.get("items") or [] - blob = _episode_blob_from_items(items) - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "episode_count": len(items), - "reports": reports, - } - - -async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - episode_source = f"chat_summary:{session_id}" - - for case in episode_cases: - payload = await memory_service.episode_admin( - action="query", - source=episode_source, - query=case["query"], - limit=5, - ) - items = payload.get("items") or [] - blob = "\n".join( - f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}" - for item in items - ) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(items), - "top_episode": items[0] if items else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in episode_cases: - result = await memory_service.search( - case["query"], - mode="episode", - chat_id=session_id, - respect_filter=False, - limit=5, - ) - blob = _episode_blob_from_hits(result) - recall = _keyword_recall(blob, case["expected_keywords"]) - success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0)) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "episode_count": len(result.hits), - "top_episode": result.hits[0].to_dict() if result.hits else None, - } - ) - - total = max(1, len(episode_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]: - search_case = dataset["search_cases"][0] - episode_case = dataset["episode_cases"][0] - aggregate_case = dataset["knowledge_fetcher_cases"][0] - first_record = (dataset.get("chat_history_records") or [{}])[0] - reference_ts = first_record.get("end_time") or first_record.get("start_time") or 0 - if reference_ts: - time_expression = datetime.fromtimestamp(float(reference_ts)).strftime("%Y/%m/%d") - else: - time_expression = "最近7天" - tool_cases = [ - { - "name": "search", - "kwargs": { - "query": "蓝漆铁盒 北塔木梯", - "mode": "search", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "time", - "kwargs": { - "query": "蓝漆铁盒 北塔", - "mode": "time", - "chat_id": session_id, - "limit": 5, - "time_expression": time_expression, - }, - "expected_keywords": ["蓝漆铁盒", "北塔木梯"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "episode", - "kwargs": { - "query": episode_case["query"], - "mode": "episode", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": episode_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - { - "name": "aggregate", - "kwargs": { - "query": aggregate_case["query"], - "mode": "aggregate", - "chat_id": session_id, - "limit": 5, - }, - "expected_keywords": aggregate_case["expected_keywords"], - "minimum_keyword_recall": 0.67, - }, - ] - reports: List[Dict[str, Any]] = [] - success_rate = 0.0 - keyword_recall = 0.0 - - for case in tool_cases: - text = await query_long_term_memory(**case["kwargs"]) - recall = _keyword_recall(text, case["expected_keywords"]) - success = ( - "失败" not in text - and "无法解析" not in text - and "未找到" not in text - and recall >= float(case["minimum_keyword_recall"]) - ) - success_rate += 1.0 if success else 0.0 - keyword_recall += recall - reports.append( - { - "name": case["name"], - "success": success, - "keyword_recall": recall, - "preview": text[:320], - } - ) - - total = max(1, len(tool_cases)) - return { - "success_rate": round(success_rate / total, 4), - "keyword_recall": round(keyword_recall / total, 4), - "reports": reports, - } - - -@pytest_asyncio.fixture -async def benchmark_env(monkeypatch, tmp_path): - dataset = _load_benchmark_fixture() - session_cfg = dataset["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} - reverse_registry = {value: key for key, value in registry.items()} - - monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) - - async def fake_self_check(**kwargs): - return {"ok": True, "message": "ok", "encoded_dimension": 32} - - monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) - monkeypatch.setattr( - person_info_module, - "Person", - lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), - ) - - data_dir = (tmp_path / "a_memorix_benchmark_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "dataset": dataset, - "kernel": kernel, - "session": session, - "person_registry": registry, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_long_novel_memory_benchmark(benchmark_env): - dataset = benchmark_env["dataset"] - session_id = benchmark_env["session"].session_id - - created = await memory_service.import_admin( - action="create_paste", - name="long_novel_memory_benchmark.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(dataset["import_payload"], ensure_ascii=False), - ) - assert created["success"] is True - - import_detail = await _wait_for_import_task(created["task"]["task_id"]) - assert import_detail["task"]["status"] == "completed" - - for record in dataset["chat_history_records"]: - summarizer = summarizer_module.ChatHistorySummarizer(session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - for payload in dataset["person_writebacks"]: - await person_info_module.store_person_memory_from_answer( - payload["person_name"], - payload["memory_content"], - session_id, - ) - - await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) - - search_case_reports: List[Dict[str, Any]] = [] - search_accuracy_at_1 = 0.0 - search_recall_at_5 = 0.0 - search_precision_at_5 = 0.0 - search_mrr = 0.0 - search_keyword_recall = 0.0 - - for case in dataset["search_cases"]: - result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) - joined = _join_hit_content(result) - rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) - relevant_hits = sum( - 1 - for hit in result.hits[:5] - if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) - ) - keyword_recall = _keyword_recall(joined, case["expected_keywords"]) - search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 - search_recall_at_5 += 1.0 if rank > 0 else 0.0 - search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) - search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 - search_keyword_recall += keyword_recall - search_case_reports.append( - { - "query": case["query"], - "rank_of_first_relevant": rank, - "relevant_hits_top5": relevant_hits, - "keyword_recall_top5": keyword_recall, - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - - search_total = max(1, len(dataset["search_cases"])) - - writeback_reports: List[Dict[str, Any]] = [] - writeback_success_rate = 0.0 - writeback_keyword_recall = 0.0 - for payload in dataset["person_writebacks"]: - query = " ".join(payload["expected_keywords"]) - result = await memory_service.search( - query, - mode="search", - chat_id=session_id, - person_id=payload["person_id"], - respect_filter=False, - limit=5, - ) - joined = _join_hit_content(result) - recall = _keyword_recall(joined, payload["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - writeback_success_rate += 1.0 if success else 0.0 - writeback_keyword_recall += recall - writeback_reports.append( - { - "person_id": payload["person_id"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - } - ) - writeback_total = max(1, len(dataset["person_writebacks"])) - - knowledge_reports: List[Dict[str, Any]] = [] - knowledge_success_rate = 0.0 - knowledge_keyword_recall = 0.0 - fetcher = knowledge_module.KnowledgeFetcher( - private_name=dataset["session"]["display_name"], - stream_id=session_id, - ) - for case in dataset["knowledge_fetcher_cases"]: - knowledge_text, _ = await fetcher.fetch(case["query"], []) - recall = _keyword_recall(knowledge_text, case["expected_keywords"]) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) - knowledge_success_rate += 1.0 if success else 0.0 - knowledge_keyword_recall += recall - knowledge_reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "preview": knowledge_text[:300], - } - ) - knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) - - profile_reports: List[Dict[str, Any]] = [] - profile_success_rate = 0.0 - profile_keyword_recall = 0.0 - profile_evidence_rate = 0.0 - for case in dataset["profile_cases"]: - profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) - recall = _keyword_recall(profile.summary, case["expected_keywords"]) - has_evidence = bool(profile.evidence) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence - profile_success_rate += 1.0 if success else 0.0 - profile_keyword_recall += recall - profile_evidence_rate += 1.0 if has_evidence else 0.0 - profile_reports.append( - { - "person_id": case["person_id"], - "success": success, - "keyword_recall": recall, - "evidence_count": len(profile.evidence), - "summary_preview": profile.summary[:240], - } - ) - profile_total = max(1, len(dataset["profile_cases"])) - - episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_rebuild = await memory_service.episode_admin( - action="rebuild", - source=f"chat_summary:{session_id}", - ) - episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) - - report = { - "dataset": dataset["meta"], - "import": { - "task_id": created["task"]["task_id"], - "status": import_detail["task"]["status"], - "paragraph_count": len(dataset["import_payload"]["paragraphs"]), - }, - "metrics": { - "search": { - "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), - "recall_at_5": round(search_recall_at_5 / search_total, 4), - "precision_at_5": round(search_precision_at_5 / search_total, 4), - "mrr": round(search_mrr / search_total, 4), - "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), - }, - "writeback": { - "success_rate": round(writeback_success_rate / writeback_total, 4), - "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), - }, - "knowledge_fetcher": { - "success_rate": round(knowledge_success_rate / knowledge_total, 4), - "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), - }, - "profile": { - "success_rate": round(profile_success_rate / profile_total, 4), - "keyword_recall": round(profile_keyword_recall / profile_total, 4), - "evidence_rate": round(profile_evidence_rate / profile_total, 4), - }, - "tool_modes": { - "success_rate": tool_modes["success_rate"], - "keyword_recall": tool_modes["keyword_recall"], - }, - "episode_generation_auto": { - "success_rate": episode_generation_auto["success_rate"], - "keyword_recall": episode_generation_auto["keyword_recall"], - "episode_count": episode_generation_auto["episode_count"], - }, - "episode_generation_after_rebuild": { - "success_rate": episode_generation_after_rebuild["success_rate"], - "keyword_recall": episode_generation_after_rebuild["keyword_recall"], - "episode_count": episode_generation_after_rebuild["episode_count"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_admin_query_auto": { - "success_rate": episode_admin_query_auto["success_rate"], - "keyword_recall": episode_admin_query_auto["keyword_recall"], - }, - "episode_admin_query_after_rebuild": { - "success_rate": episode_admin_query_after_rebuild["success_rate"], - "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_search_mode_auto": { - "success_rate": episode_search_mode_auto["success_rate"], - "keyword_recall": episode_search_mode_auto["keyword_recall"], - }, - "episode_search_mode_after_rebuild": { - "success_rate": episode_search_mode_after_rebuild["success_rate"], - "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - }, - "cases": { - "search": search_case_reports, - "writeback": writeback_reports, - "knowledge_fetcher": knowledge_reports, - "profile": profile_reports, - "tool_modes": tool_modes["reports"], - "episode_generation_auto": episode_generation_auto["reports"], - "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], - "episode_admin_query_auto": episode_admin_query_auto["reports"], - "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], - "episode_search_mode_auto": episode_search_mode_auto["reports"], - "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], - }, - } - - REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) - REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") - print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) - - assert report["import"]["status"] == "completed" - assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35 - assert report["metrics"]["search"]["recall_at_5"] >= 0.6 - assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8 - assert report["metrics"]["writeback"]["success_rate"] >= 0.66 - assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66 - assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75 - assert report["metrics"]["profile"]["success_rate"] >= 0.66 - assert report["metrics"]["profile"]["evidence_rate"] >= 1.0 - assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75 - assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True - assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"] diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py deleted file mode 100644 index bd0560dc..00000000 --- a/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py +++ /dev/null @@ -1,342 +0,0 @@ -from __future__ import annotations - -import json -import os -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict, List - -import pytest -import pytest_asyncio - -from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel -from pytests.A_memorix_test.test_long_novel_memory_benchmark import ( - _evaluate_episode_admin_query, - _evaluate_episode_generation, - _evaluate_episode_search_mode, - _evaluate_tool_modes, - _KernelBackedRuntimeManager, - _KnownPerson, - _first_relevant_rank, - _hit_blob, - _join_hit_content, - _keyword_hits, - _keyword_recall, - _load_benchmark_fixture, - _wait_for_import_task, -) -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import memory_service - - -pytestmark = pytest.mark.skipif( - os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", - reason="需要显式开启真实 external embedding benchmark", -) - -REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json" - - -@pytest_asyncio.fixture -async def benchmark_live_env(monkeypatch, tmp_path): - dataset = _load_benchmark_fixture() - session_cfg = dataset["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} - reverse_registry = {value: key for key, value in registry.items()} - - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) - monkeypatch.setattr( - person_info_module, - "Person", - lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), - ) - - data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "dataset": dataset, - "kernel": kernel, - "session": session, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_long_novel_memory_benchmark_live(benchmark_live_env): - dataset = benchmark_live_env["dataset"] - session_id = benchmark_live_env["session"].session_id - - self_check = await memory_service.runtime_admin(action="refresh_self_check") - assert self_check["success"] is True - assert self_check["report"]["ok"] is True - - created = await memory_service.import_admin( - action="create_paste", - name="long_novel_memory_benchmark.live.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(dataset["import_payload"], ensure_ascii=False), - ) - assert created["success"] is True - - import_detail = await _wait_for_import_task( - created["task"]["task_id"], - max_rounds=2400, - sleep_seconds=0.25, - ) - assert import_detail["task"]["status"] == "completed" - - for record in dataset["chat_history_records"]: - summarizer = summarizer_module.ChatHistorySummarizer(session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - for payload in dataset["person_writebacks"]: - await person_info_module.store_person_memory_from_answer( - payload["person_name"], - payload["memory_content"], - session_id, - ) - - await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) - - search_case_reports: List[Dict[str, Any]] = [] - search_accuracy_at_1 = 0.0 - search_recall_at_5 = 0.0 - search_precision_at_5 = 0.0 - search_mrr = 0.0 - search_keyword_recall = 0.0 - for case in dataset["search_cases"]: - result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) - joined = _join_hit_content(result) - rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) - relevant_hits = sum( - 1 - for hit in result.hits[:5] - if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) - ) - keyword_recall = _keyword_recall(joined, case["expected_keywords"]) - search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 - search_recall_at_5 += 1.0 if rank > 0 else 0.0 - search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) - search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 - search_keyword_recall += keyword_recall - search_case_reports.append( - { - "query": case["query"], - "rank_of_first_relevant": rank, - "relevant_hits_top5": relevant_hits, - "keyword_recall_top5": keyword_recall, - "top_hit": result.hits[0].to_dict() if result.hits else None, - } - ) - search_total = max(1, len(dataset["search_cases"])) - - writeback_reports: List[Dict[str, Any]] = [] - writeback_success_rate = 0.0 - writeback_keyword_recall = 0.0 - for payload in dataset["person_writebacks"]: - query = " ".join(payload["expected_keywords"]) - result = await memory_service.search( - query, - mode="search", - chat_id=session_id, - person_id=payload["person_id"], - respect_filter=False, - limit=5, - ) - joined = _join_hit_content(result) - recall = _keyword_recall(joined, payload["expected_keywords"]) - success = bool(result.hits) and recall >= 0.67 - writeback_success_rate += 1.0 if success else 0.0 - writeback_keyword_recall += recall - writeback_reports.append( - { - "person_id": payload["person_id"], - "success": success, - "keyword_recall": recall, - "hit_count": len(result.hits), - } - ) - writeback_total = max(1, len(dataset["person_writebacks"])) - - knowledge_reports: List[Dict[str, Any]] = [] - knowledge_success_rate = 0.0 - knowledge_keyword_recall = 0.0 - fetcher = knowledge_module.KnowledgeFetcher( - private_name=dataset["session"]["display_name"], - stream_id=session_id, - ) - for case in dataset["knowledge_fetcher_cases"]: - knowledge_text, _ = await fetcher.fetch(case["query"], []) - recall = _keyword_recall(knowledge_text, case["expected_keywords"]) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) - knowledge_success_rate += 1.0 if success else 0.0 - knowledge_keyword_recall += recall - knowledge_reports.append( - { - "query": case["query"], - "success": success, - "keyword_recall": recall, - "preview": knowledge_text[:300], - } - ) - knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) - - profile_reports: List[Dict[str, Any]] = [] - profile_success_rate = 0.0 - profile_keyword_recall = 0.0 - profile_evidence_rate = 0.0 - for case in dataset["profile_cases"]: - profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) - recall = _keyword_recall(profile.summary, case["expected_keywords"]) - has_evidence = bool(profile.evidence) - success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence - profile_success_rate += 1.0 if success else 0.0 - profile_keyword_recall += recall - profile_evidence_rate += 1.0 if has_evidence else 0.0 - profile_reports.append( - { - "person_id": case["person_id"], - "success": success, - "keyword_recall": recall, - "evidence_count": len(profile.evidence), - "summary_preview": profile.summary[:240], - } - ) - profile_total = max(1, len(dataset["profile_cases"])) - - episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_rebuild = await memory_service.episode_admin( - action="rebuild", - source=f"chat_summary:{session_id}", - ) - episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) - episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) - tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) - - report = { - "dataset": dataset["meta"], - "runtime_self_check": self_check["report"], - "import": { - "task_id": created["task"]["task_id"], - "status": import_detail["task"]["status"], - "paragraph_count": len(dataset["import_payload"]["paragraphs"]), - }, - "metrics": { - "search": { - "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), - "recall_at_5": round(search_recall_at_5 / search_total, 4), - "precision_at_5": round(search_precision_at_5 / search_total, 4), - "mrr": round(search_mrr / search_total, 4), - "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), - }, - "writeback": { - "success_rate": round(writeback_success_rate / writeback_total, 4), - "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), - }, - "knowledge_fetcher": { - "success_rate": round(knowledge_success_rate / knowledge_total, 4), - "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), - }, - "profile": { - "success_rate": round(profile_success_rate / profile_total, 4), - "keyword_recall": round(profile_keyword_recall / profile_total, 4), - "evidence_rate": round(profile_evidence_rate / profile_total, 4), - }, - "tool_modes": { - "success_rate": tool_modes["success_rate"], - "keyword_recall": tool_modes["keyword_recall"], - }, - "episode_generation_auto": { - "success_rate": episode_generation_auto["success_rate"], - "keyword_recall": episode_generation_auto["keyword_recall"], - "episode_count": episode_generation_auto["episode_count"], - }, - "episode_generation_after_rebuild": { - "success_rate": episode_generation_after_rebuild["success_rate"], - "keyword_recall": episode_generation_after_rebuild["keyword_recall"], - "episode_count": episode_generation_after_rebuild["episode_count"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_admin_query_auto": { - "success_rate": episode_admin_query_auto["success_rate"], - "keyword_recall": episode_admin_query_auto["keyword_recall"], - }, - "episode_admin_query_after_rebuild": { - "success_rate": episode_admin_query_after_rebuild["success_rate"], - "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - "episode_search_mode_auto": { - "success_rate": episode_search_mode_auto["success_rate"], - "keyword_recall": episode_search_mode_auto["keyword_recall"], - }, - "episode_search_mode_after_rebuild": { - "success_rate": episode_search_mode_after_rebuild["success_rate"], - "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], - "rebuild_success": bool(episode_rebuild.get("success", False)), - }, - }, - "cases": { - "search": search_case_reports, - "writeback": writeback_reports, - "knowledge_fetcher": knowledge_reports, - "profile": profile_reports, - "tool_modes": tool_modes["reports"], - "episode_generation_auto": episode_generation_auto["reports"], - "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], - "episode_admin_query_auto": episode_admin_query_auto["reports"], - "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], - "episode_search_mode_auto": episode_search_mode_auto["reports"], - "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], - }, - } - - REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) - REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") - print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) - - assert report["import"]["status"] == "completed" - assert report["runtime_self_check"]["ok"] is True diff --git a/pytests/A_memorix_test/test_memory_flow_service.py b/pytests/A_memorix_test/test_memory_flow_service.py index 2d35e837..bdc89386 100644 --- a/pytests/A_memorix_test/test_memory_flow_service.py +++ b/pytests/A_memorix_test/test_memory_flow_service.py @@ -5,68 +5,6 @@ import pytest from src.services import memory_flow_service as memory_flow_module -@pytest.mark.asyncio -async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch): - starts: list[str] = [] - summarizers: list[object] = [] - - class FakeSummarizer: - def __init__(self, session_id: str): - self.session_id = session_id - summarizers.append(self) - - async def start(self): - starts.append(self.session_id) - - async def stop(self): - starts.append(f"stop:{self.session_id}") - - monkeypatch.setattr( - memory_flow_module, - "global_config", - SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), - ) - monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) - - manager = memory_flow_module.LongTermMemorySessionManager() - message = SimpleNamespace(session_id="session-1") - - await manager.on_message(message) - await manager.on_message(message) - - assert len(summarizers) == 1 - assert starts == ["session-1"] - - -@pytest.mark.asyncio -async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch): - stopped: list[str] = [] - - class FakeSummarizer: - def __init__(self, session_id: str): - self.session_id = session_id - - async def start(self): - return None - - async def stop(self): - stopped.append(self.session_id) - - monkeypatch.setattr( - memory_flow_module, - "global_config", - SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), - ) - monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) - - manager = memory_flow_module.LongTermMemorySessionManager() - await manager.on_message(SimpleNamespace(session_id="session-a")) - await manager.on_message(SimpleNamespace(session_id="session-b")) - await manager.shutdown() - - assert stopped == ["session-a", "session-b"] - - def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items(): raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]' @@ -101,16 +39,9 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch): @pytest.mark.asyncio -async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): +async def test_memory_automation_service_auto_starts_and_delegates(): events: list[tuple[str, str]] = [] - class FakeSessionManager: - async def on_message(self, message): - events.append(("incoming", message.session_id)) - - async def shutdown(self): - events.append(("shutdown", "session")) - class FakeFactWriteback: async def start(self): events.append(("start", "fact")) @@ -122,7 +53,6 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): events.append(("shutdown", "fact")) service = memory_flow_module.MemoryAutomationService() - service.session_manager = FakeSessionManager() service.fact_writeback = FakeFactWriteback() await service.on_incoming_message(SimpleNamespace(session_id="session-1")) @@ -131,8 +61,6 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): assert events == [ ("start", "fact"), - ("incoming", "session-1"), ("sent", "session-1"), - ("shutdown", "session"), ("shutdown", "fact"), ] diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py deleted file mode 100644 index 81132f96..00000000 --- a/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py +++ /dev/null @@ -1,324 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict - -import numpy as np -import pytest -import pytest_asyncio - -from A_memorix.core.runtime import sdk_memory_kernel as kernel_module -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import memory_service - - -DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" - - -def _load_dialogue_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _FakeEmbeddingAdapter: - def __init__(self, dimension: int = 16) -> None: - self.dimension = dimension - - async def _detect_dimension(self) -> int: - return self.dimension - - async def encode(self, texts, dimensions=None): - dim = int(dimensions or self.dimension) - if isinstance(texts, str): - sequence = [texts] - single = True - else: - sequence = list(texts) - single = False - - rows = [] - for text in sequence: - vec = np.zeros(dim, dtype=np.float32) - for ch in str(text or ""): - vec[ord(ch) % dim] += 1.0 - if not vec.any(): - vec[0] = 1.0 - norm = np.linalg.norm(vec) - if norm > 0: - vec = vec / norm - rows.append(vec) - payload = np.vstack(rows) - return payload[0] if single else payload - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]: - for _ in range(max_rounds): - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(0.05) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -@pytest_asyncio.fixture -async def real_dialogue_env(monkeypatch, tmp_path): - scenario = _load_dialogue_fixture() - session_cfg = scenario["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) - - async def fake_self_check(**kwargs): - return {"ok": True, "message": "ok"} - - monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - - data_dir = (tmp_path / "a_memorix_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "scenario": scenario, - "kernel": kernel, - "session": session, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - - created = await memory_service.import_admin( - action="create_paste", - name="private_alice.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(scenario["import_payload"], ensure_ascii=False), - ) - - assert created["success"] is True - detail = await _wait_for_import_task(created["task"]["task_id"]) - assert detail["task"]["status"] == "completed" - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - respect_filter=False, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=real_dialogue_env["session"].session_id, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch): - scenario = real_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - real_dialogue_env["session"].session_id, - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=real_dialogue_env["session"].session_id, - person_id=scenario["person"]["person_id"], - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - - await memory_service.ingest_text( - external_id="fixture:knowledge_fetcher", - source_type="dialogue_note", - text=scenario["person_fact"]["memory_content"], - chat_id=real_dialogue_env["session"].session_id, - person_ids=[scenario["person"]["person_id"]], - participants=[scenario["person"]["person_name"]], - respect_filter=False, - ) - - fetcher = knowledge_module.KnowledgeFetcher( - private_name=scenario["session"]["display_name"], - stream_id=real_dialogue_env["session"].session_id, - ) - knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) - - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in knowledge_text - - -@pytest.mark.asyncio -async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch): - scenario = real_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - real_dialogue_env["session"].session_id, - ) - - profile = await memory_service.get_person_profile( - scenario["person"]["person_id"], - chat_id=real_dialogue_env["session"].session_id, - ) - - assert profile.evidence - assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) - - -@pytest.mark.asyncio -async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env): - scenario = real_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - episodes = await memory_service.episode_admin( - action="query", - source=scenario["expectations"]["episode_source"], - limit=5, - ) - - assert episodes["success"] is True - assert int(episodes["count"]) >= 1 diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py deleted file mode 100644 index 5dadca3d..00000000 --- a/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py +++ /dev/null @@ -1,301 +0,0 @@ -from __future__ import annotations - -import asyncio -import inspect -import json -import os -from pathlib import Path -from types import SimpleNamespace -from typing import Any, Dict - -import pytest -import pytest_asyncio - -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel -from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module -from src.memory_system import chat_history_summarizer as summarizer_module -from src.person_info import person_info as person_info_module -from src.services import memory_service as memory_service_module -from src.services.memory_service import memory_service - - -pytestmark = pytest.mark.skipif( - os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", - reason="需要显式开启真实 embedding / self-check 集成测试", -) - -DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" - - -def _load_dialogue_fixture() -> Dict[str, Any]: - return json.loads(DATA_FILE.read_text(encoding="utf-8")) - - -class _KernelBackedRuntimeManager: - def __init__(self, kernel: SDKMemoryKernel) -> None: - self.kernel = kernel - - async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000): - del timeout_ms - payload = args or {} - if component_name == "search_memory": - return await self.kernel.search_memory( - KernelSearchRequest( - query=str(payload.get("query", "") or ""), - limit=int(payload.get("limit", 5) or 5), - mode=str(payload.get("mode", "hybrid") or "hybrid"), - chat_id=str(payload.get("chat_id", "") or ""), - person_id=str(payload.get("person_id", "") or ""), - time_start=payload.get("time_start"), - time_end=payload.get("time_end"), - respect_filter=bool(payload.get("respect_filter", True)), - user_id=str(payload.get("user_id", "") or ""), - group_id=str(payload.get("group_id", "") or ""), - ) - ) - - handler = getattr(self.kernel, component_name) - result = handler(**payload) - return await result if inspect.isawaitable(result) else result - - -async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]: - deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds)) - while asyncio.get_running_loop().time() < deadline: - detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) - task = detail.get("task") or {} - status = str(task.get("status", "") or "") - if status in {"completed", "completed_with_errors", "failed", "cancelled"}: - return detail - await asyncio.sleep(0.2) - raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") - - -def _join_hit_content(search_result) -> str: - return "\n".join(hit.content for hit in search_result.hits) - - -@pytest_asyncio.fixture -async def live_dialogue_env(monkeypatch, tmp_path): - scenario = _load_dialogue_fixture() - session_cfg = scenario["session"] - session = SimpleNamespace( - session_id=session_cfg["session_id"], - platform=session_cfg["platform"], - user_id=session_cfg["user_id"], - group_id=session_cfg["group_id"], - ) - fake_chat_manager = SimpleNamespace( - get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, - get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, - ) - - monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) - monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) - - data_dir = (tmp_path / "a_memorix_data").resolve() - kernel = SDKMemoryKernel( - plugin_root=tmp_path / "plugin_root", - config={ - "storage": {"data_dir": str(data_dir)}, - "advanced": {"enable_auto_save": False}, - "memory": {"base_decay_interval_hours": 24}, - "person_profile": {"refresh_interval_minutes": 5}, - }, - ) - manager = _KernelBackedRuntimeManager(kernel) - monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager) - - await kernel.initialize() - try: - yield { - "scenario": scenario, - "kernel": kernel, - "session": session, - } - finally: - await kernel.shutdown() - - -@pytest.mark.asyncio -async def test_live_runtime_self_check_passes(live_dialogue_env): - report = await memory_service.runtime_admin(action="refresh_self_check") - - assert report["success"] is True - assert report["report"]["ok"] is True - assert report["report"]["encoded_dimension"] > 0 - - -@pytest.mark.asyncio -async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - - created = await memory_service.import_admin( - action="create_paste", - name="private_alice.json", - input_mode="json", - llm_enabled=False, - content=json.dumps(scenario["import_payload"], ensure_ascii=False), - ) - - assert created["success"] is True - detail = await _wait_for_import_task(created["task"]["task_id"]) - assert detail["task"]["status"] == "completed" - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - respect_filter=False, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=live_dialogue_env["session"].session_id, - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch): - scenario = live_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - live_dialogue_env["session"].session_id, - ) - - search = await memory_service.search( - scenario["search_queries"]["direct"], - mode="search", - chat_id=live_dialogue_env["session"].session_id, - person_id=scenario["person"]["person_id"], - ) - - assert search.hits - joined = _join_hit_content(search) - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in joined - - -@pytest.mark.asyncio -async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - - await memory_service.ingest_text( - external_id="fixture:knowledge_fetcher", - source_type="dialogue_note", - text=scenario["person_fact"]["memory_content"], - chat_id=live_dialogue_env["session"].session_id, - person_ids=[scenario["person"]["person_id"]], - participants=[scenario["person"]["person_name"]], - respect_filter=False, - ) - - fetcher = knowledge_module.KnowledgeFetcher( - private_name=scenario["session"]["display_name"], - stream_id=live_dialogue_env["session"].session_id, - ) - knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) - - for keyword in scenario["expectations"]["search_keywords"]: - assert keyword in knowledge_text - - -@pytest.mark.asyncio -async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch): - scenario = live_dialogue_env["scenario"] - - class _KnownPerson: - def __init__(self, person_id: str) -> None: - self.person_id = person_id - self.is_known = True - self.person_name = scenario["person"]["person_name"] - - monkeypatch.setattr( - person_info_module, - "get_person_id_by_person_name", - lambda person_name: scenario["person"]["person_id"], - ) - monkeypatch.setattr(person_info_module, "Person", _KnownPerson) - - await person_info_module.store_person_memory_from_answer( - scenario["person"]["person_name"], - scenario["person_fact"]["memory_content"], - live_dialogue_env["session"].session_id, - ) - - profile = await memory_service.get_person_profile( - scenario["person"]["person_id"], - chat_id=live_dialogue_env["session"].session_id, - ) - - assert profile.evidence - assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) - - -@pytest.mark.asyncio -async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env): - scenario = live_dialogue_env["scenario"] - record = scenario["chat_history_record"] - - summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) - await summarizer._import_to_long_term_memory( - record_id=record["record_id"], - theme=record["theme"], - summary=record["summary"], - participants=record["participants"], - start_time=record["start_time"], - end_time=record["end_time"], - original_text=record["original_text"], - ) - - episodes = await memory_service.episode_admin( - action="query", - source=scenario["expectations"]["episode_source"], - limit=5, - ) - - assert episodes["success"] is True - assert int(episodes["count"]) >= 1 diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 7fed79c3..a7176994 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -612,13 +612,6 @@ class ChatBot: scope=scope, ) # 确保会话存在 - try: - from src.services.memory_flow_service import memory_automation_service - - await memory_automation_service.on_incoming_message(message) - except Exception as exc: - logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}") - # message.update_chat_stream(chat) # 命令处理 - 使用新插件系统检查并处理命令。 diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index dc4bdbb2..41b356df 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -159,7 +159,6 @@ MODULE_ALIASES = { "planner": "规划器", "config": "配置", "main": "主程序", - "chat_history_summarizer": "聊天概括器", "plugin_runtime.integration": "IPC插件系统", "plugin_runtime.host.supervisor": "插件监督器", "plugin_runtime.host.runner_manager": "插件监督器", diff --git a/src/config/config.py b/src/config/config.py index c1b6051b..8de2a29b 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -55,7 +55,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.7.1" +CONFIG_VERSION: str = "8.8.0" MODEL_CONFIG_VERSION: str = "1.14.0" logger = get_logger("config") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index c9dc617b..73645170 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -414,95 +414,6 @@ class MemoryConfig(ConfigBase): ) """Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数""" - long_term_auto_summary_enabled: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "book-open", - }, - ) - """是否自动启动聊天总结并导入长期记忆""" - - person_fact_writeback_enabled: bool = Field( - default=True, - json_schema_extra={ - "x-widget": "switch", - "x-icon": "user-round-pen", - }, - ) - """是否在发送回复后自动提取并写回人物事实到长期记忆""" - chat_history_topic_check_message_threshold: int = Field( - default=80, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "hash", - }, - ) - """聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查""" - - chat_history_topic_check_time_hours: float = Field( - default=8.0, - json_schema_extra={ - "x-widget": "input", - "x-icon": "clock", - }, - ) - """聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查""" - - chat_history_topic_check_min_messages: int = Field( - default=20, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "hash", - }, - ) - """聊天历史话题检查的时间触发模式下的最小消息数阈值""" - - chat_history_finalize_no_update_checks: int = Field( - default=3, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "check-circle", - }, - ) - """聊天历史话题打包存储的连续无更新检查次数阈值,当话题连续N次检查无新增内容时触发打包存储""" - - chat_history_finalize_message_count: int = Field( - default=5, - ge=1, - json_schema_extra={ - "x-widget": "input", - "x-icon": "package", - }, - ) - """聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储""" - - def model_post_init(self, context: Optional[dict] = None) -> None: - """验证配置值""" - if self.chat_history_topic_check_message_threshold < 1: - raise ValueError( - f"chat_history_topic_check_message_threshold 必须至少为1,当前值: {self.chat_history_topic_check_message_threshold}" - ) - if self.chat_history_topic_check_time_hours <= 0: - raise ValueError( - f"chat_history_topic_check_time_hours 必须大于0,当前值: {self.chat_history_topic_check_time_hours}" - ) - if self.chat_history_topic_check_min_messages < 1: - raise ValueError( - f"chat_history_topic_check_min_messages 必须至少为1,当前值: {self.chat_history_topic_check_min_messages}" - ) - if self.chat_history_finalize_no_update_checks < 1: - raise ValueError( - f"chat_history_finalize_no_update_checks 必须至少为1,当前值: {self.chat_history_finalize_no_update_checks}" - ) - if self.chat_history_finalize_message_count < 1: - raise ValueError( - f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}" - ) - return super().model_post_init(context) class LearningItem(ConfigBase): diff --git a/src/maisaka/builtin_tool/send_emoji.py b/src/maisaka/builtin_tool/send_emoji.py index 27353ec3..b02b75f7 100644 --- a/src/maisaka/builtin_tool/send_emoji.py +++ b/src/maisaka/builtin_tool/send_emoji.py @@ -53,40 +53,16 @@ def get_tool_spec() -> ToolSpec: return ToolSpec( name="send_emoji", brief_description="发送一个合适的表情包来辅助表达情绪。", - detailed_description="参数说明:\n- emotion:string,可选。希望表达的情绪,例如 happy、sad、angry 等。", + detailed_description="无需参数,直接发送一个合适的表情包。", parameters_schema={ "type": "object", - "properties": { - "emotion": { - "type": "string", - "description": "希望表达的情绪,例如 happy、sad、angry 等。", - }, - }, + "properties": {}, }, provider_name="maisaka_builtin", provider_type="builtin", ) -def _normalize_candidate_emotions(emoji: MaiEmoji) -> list[str]: - """清洗候选表情上的情绪标签。""" - - raw_emotions = getattr(emoji, "emotion", None) - if isinstance(raw_emotions, list) and raw_emotions: - return [str(item).strip() for item in raw_emotions if str(item).strip()] - - description = str(getattr(emoji, "description", "") or "").strip() - if not description: - return [] - - normalized_description = ( - description.replace(",", ",") - .replace("、", ",") - .replace(";", ",") - ) - return [item.strip() for item in normalized_description.split(",") if item.strip()] - - async def _load_emoji_bytes(emoji: MaiEmoji) -> bytes: """读取单个表情包图片字节。""" @@ -232,18 +208,6 @@ async def _build_emoji_candidate_message(emojis: list[MaiEmoji]) -> SessionBacke ) -def _build_emoji_candidate_summary(emojis: list[MaiEmoji]) -> str: - """构建供监控展示使用的候选表情摘要。""" - - summary_lines: list[str] = [] - for index, emoji in enumerate(emojis, start=1): - description = emoji.description.strip() or "(无描述)" - emotions = "、".join(_normalize_candidate_emotions(emoji)) or "无" - summary_lines.append(f"{index}. 描述:{description}") - summary_lines.append(f" 情绪:{emotions}") - return "\n".join(summary_lines).strip() - - def _build_send_emoji_monitor_detail( *, request_messages: Optional[list[dict[str, Any]]] = None, @@ -252,7 +216,7 @@ def _build_send_emoji_monitor_detail( metrics: Optional[Dict[str, Any]] = None, extra_sections: Optional[list[dict[str, str]]] = None, ) -> Dict[str, Any]: - """构建 emotion tool 统一监控详情。""" + """构建 send_emoji 工具统一监控详情。""" detail: Dict[str, Any] = {} if isinstance(request_messages, list) and request_messages: @@ -281,7 +245,6 @@ def _build_send_emoji_monitor_detail( def _build_send_emoji_monitor_metadata( selection_metadata: Dict[str, Any], *, - requested_emotion: str, send_result: Optional[Any] = None, error_message: str = "", ) -> Dict[str, Any]: @@ -293,7 +256,6 @@ def _build_send_emoji_monitor_metadata( if send_result is not None: result_lines = [ - f"请求情绪:{requested_emotion or '未指定'}", f"命中情绪:{send_result.matched_emotion or '未命中'}", f"表情描述:{send_result.description or '无描述'}", f"情绪标签:{'、'.join(send_result.emotions) if send_result.emotions else '无'}", @@ -306,10 +268,7 @@ def _build_send_emoji_monitor_metadata( elif error_message.strip(): extra_sections.append({ "title": "表情发送结果", - "content": ( - f"请求情绪:{requested_emotion or '未指定'}\n" - f"发送结果:{error_message.strip()}" - ), + "content": f"发送结果:{error_message.strip()}", }) if extra_sections: @@ -322,7 +281,6 @@ def _build_send_emoji_monitor_metadata( async def _select_emoji_with_sub_agent( tool_ctx: BuiltinToolRuntimeContext, - requested_emotion: str, reasoning: str, context_texts: list[str], sample_size: int, @@ -347,14 +305,12 @@ async def _select_emoji_with_sub_agent( f"一共 {len(sampled_emojis)} 个位置。\n" f"每张小图左上角都有一个较大的序号,范围是 1 到 {len(sampled_emojis)}。\n" f"你的任务是根据上下文和当前语气,从这 {len(sampled_emojis)} 张图里选出最合适的一张表情包。\n" - "如果提供了 requested_emotion,请优先考虑与其接近的候选;如果没有完全匹配,则选择最符合上下文语气的候选。\n" "你必须返回一个 JSON 对象(json object),不要输出任何 JSON 之外的内容。\n" '返回格式固定为:{"emoji_index":1,"reason":"简短理由"}' ) prompt_message = ReferenceMessage( content=( f"[选择任务]\n" - f"requested_emotion: {requested_emotion or '未指定'}\n" f"候选总数: {len(sampled_emojis)}\n" f"拼图布局: {grid_rows}x{grid_columns}\n" "请只输出 JSON。" @@ -439,7 +395,6 @@ async def handle_tool( """执行 send_emoji 内置工具。""" del context - emotion = str(invocation.arguments.get("emotion") or "").strip() context_texts = [ message.processed_plain_text.strip() for message in tool_ctx.runtime._chat_history[-5:] @@ -450,23 +405,20 @@ async def handle_tool( "message": "", "description": "", "emotion": [], - "requested_emotion": emotion, "matched_emotion": "", "reason": "", } selection_metadata: Dict[str, Any] = {"reason": "", "monitor_detail": {}} - logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具,请求情绪={emotion!r}") + logger.info(f"{tool_ctx.runtime.log_prefix} 触发表情包发送工具") try: send_result = await send_emoji_for_maisaka( stream_id=tool_ctx.runtime.session_id, - requested_emotion=emotion, reasoning=tool_ctx.engine.last_reasoning_content, context_texts=context_texts, - emoji_selector=lambda requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent( + emoji_selector=lambda _requested_emotion, reasoning, context_texts, sample_size: _select_emoji_with_sub_agent( tool_ctx, - requested_emotion, reasoning, list(context_texts or []), sample_size, @@ -482,7 +434,6 @@ async def handle_tool( structured_content=structured_result, metadata=_build_send_emoji_monitor_metadata( selection_metadata, - requested_emotion=emotion, error_message=structured_result["message"], ), ) @@ -493,7 +444,7 @@ async def handle_tool( logger.info( f"{tool_ctx.runtime.log_prefix} 表情包发送成功 " f"描述={send_result.description!r} 情绪标签={send_result.emotions} " - f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}" + f"命中情绪={send_result.matched_emotion!r}" ) if send_result.sent_message is not None: tool_ctx.append_sent_message_to_chat_history(send_result.sent_message) @@ -509,7 +460,6 @@ async def handle_tool( structured_content=structured_result, metadata=_build_send_emoji_monitor_metadata( selection_metadata, - requested_emotion=emotion, send_result=send_result, ), ) @@ -521,7 +471,7 @@ async def handle_tool( logger.warning( f"{tool_ctx.runtime.log_prefix} 表情包发送失败 " - f"请求情绪={emotion!r} 错误信息={send_result.message}" + f"错误信息={send_result.message}" ) return tool_ctx.build_failure_result( invocation.tool_name, @@ -529,7 +479,6 @@ async def handle_tool( structured_content=structured_result, metadata=_build_send_emoji_monitor_metadata( selection_metadata, - requested_emotion=emotion, send_result=send_result, ), ) diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index 5a6b1712..e45fa2d6 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -528,14 +528,12 @@ class MaisakaChatLoopService: prompt_section: RenderableType | None = None if global_config.debug.show_maisaka_thinking: - image_display_mode: str = "path_link" if global_config.maisaka.show_image_path else "legacy" prompt_section = PromptCLIVisualizer.build_prompt_section( built_messages, category="planner" if request_kind != "timing_gate" else "timing_gate", chat_id=self._session_id, request_kind=request_kind, selection_reason=selection_reason, - image_display_mode=image_display_mode, folded=global_config.debug.fold_maisaka_thinking, tool_definitions=list(all_tools), ) diff --git a/src/maisaka/display/prompt_cli_renderer.py b/src/maisaka/display/prompt_cli_renderer.py index 9de08cec..a770261e 100644 --- a/src/maisaka/display/prompt_cli_renderer.py +++ b/src/maisaka/display/prompt_cli_renderer.py @@ -799,7 +799,7 @@ class PromptCLIVisualizer: chat_id: str, request_kind: str, selection_reason: str, - image_display_mode: Literal["legacy", "path_link"], + image_display_mode: Literal["legacy", "path_link"] = "path_link", tool_definitions: list[dict[str, Any]] | None = None, ) -> RenderableType: """构建用于查看完整 prompt 的折叠入口内容。""" @@ -864,7 +864,7 @@ class PromptCLIVisualizer: chat_id: str, request_kind: str, selection_reason: str, - image_display_mode: Literal["legacy", "path_link"], + image_display_mode: Literal["legacy", "path_link"] = "path_link", folded: bool, tool_definitions: list[dict[str, Any]] | None = None, ) -> Panel: @@ -878,14 +878,10 @@ class PromptCLIVisualizer: chat_id=chat_id, request_kind=request_kind, selection_reason=selection_reason, - image_display_mode=image_display_mode, tool_definitions=tool_definitions, ) else: - ordered_panels = cls.build_prompt_panels( - messages, - image_display_mode=image_display_mode, - ) + ordered_panels = cls.build_prompt_panels(messages) prompt_renderable = Group(*ordered_panels) return Panel( @@ -1102,11 +1098,9 @@ class PromptCLIVisualizer: cls, messages: list[Any], *, - image_display_mode: Literal["legacy", "path_link"], + image_display_mode: Literal["legacy", "path_link"] = "path_link", ) -> List[Panel]: """构建完整 prompt 可视化面板。""" - if image_display_mode not in {mode.value for mode in PromptImageDisplayMode}: - image_display_mode = PromptImageDisplayMode.LEGACY settings = PromptImageDisplaySettings( display_mode=PromptImageDisplayMode(image_display_mode), ) diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 78d0a4b2..5de81cf2 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -1151,7 +1151,6 @@ class MaisakaHeartFlowChatting: chat_id=self.session_id, request_kind=labels["request_kind"], selection_reason=subtitle, - image_display_mode="path_link" if global_config.maisaka.show_image_path else "legacy", ), title=labels["prompt_title"], border_style=border_style, diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py deleted file mode 100644 index 94f4390f..00000000 --- a/src/memory_system/chat_history_summarizer.py +++ /dev/null @@ -1,1123 +0,0 @@ -""" -聊天内容概括器 -用于累积、打包和压缩聊天记录 -""" - -import asyncio -import json -import time -import re -import difflib -import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Set -from dataclasses import dataclass, field -from json_repair import repair_json - -from src.chat.message_receive.message import SessionMessage -from src.common.logger import get_logger -from src.config.config import global_config -from src.common.data_models.llm_service_data_models import LLMGenerationOptions -from src.services.llm_service import LLMServiceClient -from src.services import message_service as message_api -from src.chat.utils.utils import is_bot_self -from src.person_info.person_info import Person -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.prompt.prompt_manager import prompt_manager - -logger = get_logger("chat_history_summarizer") - -HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorizer" - - -@dataclass -class MessageBatch: - """消息批次(用于触发话题检查的原始消息累积)""" - - messages: List[SessionMessage] - start_time: float - end_time: float - - -@dataclass -class TopicCacheItem: - """ - 话题缓存项 - - Attributes: - topic: 话题标题(一句话描述时间、人物、事件和主题) - messages: 与该话题相关的消息字符串列表(已经通过 build 函数转成可读文本) - participants: 涉及到的发言人昵称集合 - no_update_checks: 连续多少次“检查”没有新增内容 - """ - - topic: str - messages: List[str] = field(default_factory=list) - participants: Set[str] = field(default_factory=set) - no_update_checks: int = 0 - - -class ChatHistorySummarizer: - """聊天内容概括器""" - - def __init__(self, session_id: str, check_interval: int = 60): - """ - 初始化聊天内容概括器 - - Args: - session_id: 会话ID - check_interval: 定期检查间隔(秒),默认60秒 - """ - self.session_id = session_id - self._chat_display_name = self._get_chat_display_name() - self.log_prefix = f"[{self._chat_display_name}]" - - # 记录时间点,用于计算新消息 - self.last_check_time = time.time() - - # 记录上一次话题检查的时间,用于判断是否需要触发检查 - self.last_topic_check_time = time.time() - - # 当前累积的消息批次 - self.current_batch: Optional[MessageBatch] = None - - # 话题缓存:topic_str -> TopicCacheItem - # 在内存中维护,并通过本地文件实时持久化 - self.topic_cache: Dict[str, TopicCacheItem] = {} - self._safe_chat_id = self._sanitize_chat_id(self.session_id) - self._topic_cache_file = HIPPO_CACHE_DIR / f"{self._safe_chat_id}.json" - # 注意:批次加载需要异步查询消息,所以在 start() 中调用 - - # LLM请求器,用于压缩聊天内容 - self.summarizer_llm = LLMServiceClient( - task_name="utils", request_type="chat_history_summarizer" - ) - - # 后台循环相关 - self.check_interval = check_interval # 检查间隔(秒) - self._periodic_task: Optional[asyncio.Task] = None - self._running = False - - def _get_chat_display_name(self) -> str: - """获取聊天显示名称""" - try: - chat_name = _chat_manager.get_session_name(self.session_id) - if chat_name: - return chat_name - # 如果获取失败,使用简化的chat_id显示 - if len(self.session_id) > 20: - return f"{self.session_id[:8]}..." - return self.session_id - except Exception: - # 如果获取失败,使用简化的chat_id显示 - if len(self.session_id) > 20: - return f"{self.session_id[:8]}..." - return self.session_id - - def _sanitize_chat_id(self, chat_id: str) -> str: - """用于生成可作为文件名的 chat_id""" - return re.sub(r"[^a-zA-Z0-9_.-]", "_", chat_id) - - def _load_topic_cache_from_disk(self): - """在启动时加载本地话题缓存(同步部分),支持重启后继续""" - try: - if not self._topic_cache_file.exists(): - return - - with self._topic_cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - self.last_topic_check_time = data.get("last_topic_check_time", self.last_topic_check_time) - topics_data = data.get("topics", {}) - loaded_count = 0 - for topic, payload in topics_data.items(): - self.topic_cache[topic] = TopicCacheItem( - topic=topic, - messages=payload.get("messages", []), - participants=set(payload.get("participants", [])), - no_update_checks=payload.get("no_update_checks", 0), - ) - loaded_count += 1 - - if loaded_count: - logger.info(f"{self.log_prefix} 已加载 {loaded_count} 个话题缓存,继续追踪") - except Exception as e: - logger.error(f"{self.log_prefix} 加载话题缓存失败: {e}") - - async def _load_batch_from_disk(self): - """在启动时加载聊天批次,支持重启后继续""" - try: - if not self._topic_cache_file.exists(): - return - - with self._topic_cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - batch_data = data.get("current_batch") - if not batch_data: - return - - start_time = batch_data.get("start_time") - end_time = batch_data.get("end_time") - if not start_time or not end_time: - return - - # 根据时间范围重新查询消息 - messages = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=start_time, - end_time=end_time, - limit=0, - limit_mode="latest", - filter_mai=False, - filter_command=False, - ) - - if messages: - self.current_batch = MessageBatch( - messages=messages, - start_time=start_time, - end_time=end_time, - ) - logger.info(f"{self.log_prefix} 已恢复聊天批次,包含 {len(messages)} 条消息") - except Exception as e: - logger.error(f"{self.log_prefix} 加载聊天批次失败: {e}") - - def _persist_topic_cache(self): - """实时持久化话题缓存和聊天批次,避免重启后丢失""" - try: - # 如果既没有话题缓存也没有批次,删除缓存文件 - if not self.topic_cache and not self.current_batch: - if self._topic_cache_file.exists(): - self._topic_cache_file.unlink() - return - - HIPPO_CACHE_DIR.mkdir(parents=True, exist_ok=True) - data = { - "chat_id": self.session_id, - "last_topic_check_time": self.last_topic_check_time, - "topics": { - topic: { - "messages": item.messages, - "participants": list(item.participants), - "no_update_checks": item.no_update_checks, - } - for topic, item in self.topic_cache.items() - }, - } - - # 保存当前批次的时间范围(如果有) - if self.current_batch: - data["current_batch"] = { - "start_time": self.current_batch.start_time, - "end_time": self.current_batch.end_time, - } - - with self._topic_cache_file.open("w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"{self.log_prefix} 持久化话题缓存失败: {e}") - - async def process(self, current_time: Optional[float] = None): - """ - 处理聊天内容概括 - - Args: - current_time: 当前时间戳,如果为None则使用time.time() - """ - if current_time is None: - current_time = time.time() - - try: - # 获取从上次检查时间到当前时间的新消息 - new_messages = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=self.last_check_time, - end_time=current_time, - limit=0, - limit_mode="latest", - filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言 - filter_command=False, - ) - - if not new_messages: - # 没有新消息,检查是否需要进行“话题检查” - if self.current_batch and self.current_batch.messages: - await self._check_and_run_topic_check(current_time) - self.last_check_time = current_time - return - - logger.debug( - f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" - ) - - # 有新消息,更新最后检查时间 - self.last_check_time = current_time - - # 如果有当前批次,添加新消息 - if self.current_batch: - before_count = len(self.current_batch.messages) - self.current_batch.messages.extend(new_messages) - self.current_batch.end_time = current_time - logger.info( - f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息" - ) - # 更新批次后持久化 - self._persist_topic_cache() - else: - # 创建新批次 - self.current_batch = MessageBatch( - messages=new_messages, - start_time=new_messages[0].timestamp.timestamp() if new_messages else current_time, - end_time=current_time, - ) - logger.debug(f"{self.log_prefix} 新建聊天检查批次: {len(new_messages)} 条消息") - # 创建批次后持久化 - self._persist_topic_cache() - - # 检查是否需要触发“话题检查” - await self._check_and_run_topic_check(current_time) - - except Exception as e: - logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") - import traceback - - traceback.print_exc() - - async def _check_and_run_topic_check(self, current_time: float): - """ - 检查是否需要进行一次“话题检查” - - 触发条件: - - 当前批次消息数 >= 100,或者 - - 距离上一次检查的时间 > 3600 秒(1小时) - """ - if not self.current_batch or not self.current_batch.messages: - return - - messages = self.current_batch.messages - message_count = len(messages) - time_since_last_check = current_time - self.last_topic_check_time - - # 格式化时间差显示 - if time_since_last_check < 60: - time_str = f"{time_since_last_check:.1f}秒" - elif time_since_last_check < 3600: - time_str = f"{time_since_last_check / 60:.1f}分钟" - else: - time_str = f"{time_since_last_check / 3600:.1f}小时" - - logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}") - - # 检查"话题检查"触发条件 - should_check = False - - # 从配置中获取阈值 - message_threshold = global_config.memory.chat_history_topic_check_message_threshold - time_threshold_hours = global_config.memory.chat_history_topic_check_time_hours - min_messages = global_config.memory.chat_history_topic_check_min_messages - time_threshold_seconds = time_threshold_hours * 3600 - - # 条件1: 消息数量达到阈值,触发一次检查 - if message_count >= message_threshold: - should_check = True - logger.info( - f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: {message_threshold}条)" - ) - - # 条件2: 距离上一次检查超过时间阈值且消息数量达到最小阈值,触发一次检查 - elif time_since_last_check > time_threshold_seconds and message_count >= min_messages: - should_check = True - logger.info( - f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: {time_threshold_hours}小时)且消息数量达到 {message_count} 条(阈值: {min_messages}条)" - ) - - if should_check: - await self._run_topic_check_and_update_cache(messages) - # 本批次已经被处理为话题信息,可以清空 - self.current_batch = None - # 更新上一次检查时间,并持久化 - self.last_topic_check_time = current_time - self._persist_topic_cache() - - async def _run_topic_check_and_update_cache(self, messages: List[SessionMessage]): - """ - 执行一次“话题检查”: - 1. 首先确认这段消息里是否有 Bot 发言,没有则直接丢弃本次批次; - 2. 将消息编号并转成字符串,构造 LLM Prompt; - 3. 把历史话题标题列表放入 Prompt,要求 LLM: - - 识别当前聊天中的话题(1 个或多个); - - 为每个话题选出相关消息编号; - - 若话题属于历史话题,则沿用原话题标题; - 4. LLM 返回 JSON:多个 {topic, message_indices}; - 5. 更新本地话题缓存,并根据规则触发“话题打包存储”。 - """ - if not messages: - return - - start_time = messages[0].timestamp.timestamp() - end_time = messages[-1].timestamp.timestamp() - - logger.info( - f"{self.log_prefix} 开始话题检查 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - - # 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推) - # 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言, - # 说明 bot 没有参与这段对话,不应该记录 - has_bot_message = any( - is_bot_self(msg.platform, msg.message_info.user_info.user_id) for msg in messages - ) - - if not has_bot_message: - logger.info( - f"{self.log_prefix} 当前批次内无 Bot 发言,丢弃本次检查 | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - return - - # 2. 构造编号后的消息字符串和参与者信息 - numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = ( - self._build_numbered_messages_for_llm(messages) - ) - - # 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次) - existing_topics = list(self.topic_cache.keys()) - max_retries = 3 - attempt = 0 - success = False - topic_to_indices: Dict[str, List[int]] = {} - - while attempt < max_retries: - attempt += 1 - success, topic_to_indices = await self._analyze_topics_with_llm( - numbered_lines=numbered_lines, - existing_topics=existing_topics, - ) - - if success and topic_to_indices: - if attempt > 1: - logger.info( - f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}" - ) - break - - logger.warning( - f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败" - + ("" if attempt >= max_retries else ",准备重试") - ) - - if not success or not topic_to_indices: - logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃") - # 即使识别失败,也认为是一次"检查",但不更新 no_update_checks(保持原状) - return - - # 3.5. 检查新话题是否与历史话题相似(相似度>=90%则使用历史标题) - topic_mapping = self._build_topic_mapping(topic_to_indices, similarity_threshold=0.9) - - # 应用话题映射:将相似的新话题标题替换为历史话题标题 - if topic_mapping: - new_topic_to_indices: Dict[str, List[int]] = {} - for new_topic, indices in topic_to_indices.items(): - # 如果这个新话题需要映射到历史话题 - if new_topic in topic_mapping: - historical_topic = topic_mapping[new_topic] - # 如果历史话题已经存在,合并消息索引 - if historical_topic in new_topic_to_indices: - # 合并索引并去重 - combined_indices = list(set(new_topic_to_indices[historical_topic] + indices)) - new_topic_to_indices[historical_topic] = combined_indices - else: - new_topic_to_indices[historical_topic] = indices - else: - # 不需要映射,保持原样 - new_topic_to_indices[new_topic] = indices - topic_to_indices = new_topic_to_indices - - # 4. 统计哪些话题在本次检查中有新增内容 - updated_topics: Set[str] = set() - - for topic, indices in topic_to_indices.items(): - if not indices: - continue - - item = self.topic_cache.get(topic) - if not item: - # 新话题 - item = TopicCacheItem(topic=topic) - self.topic_cache[topic] = item - - # 收集属于该话题的消息文本(不带编号) - topic_msg_texts: List[str] = [] - new_participants: Set[str] = set() - for idx in indices: - msg_text = index_to_msg_text.get(idx) - if not msg_text: - continue - topic_msg_texts.append(msg_text) - new_participants.update(index_to_participants.get(idx, set())) - - if not topic_msg_texts: - continue - - # 将本次检查中属于该话题的所有消息合并为一个字符串(不带编号) - merged_text = "\n".join(topic_msg_texts) - item.messages.append(merged_text) - item.participants.update(new_participants) - # 本次检查中该话题有更新,重置计数 - item.no_update_checks = 0 - updated_topics.add(topic) - - # 5. 对于本次没有更新的历史话题,no_update_checks + 1 - for topic, item in list(self.topic_cache.items()): - if topic not in updated_topics: - item.no_update_checks += 1 - - # 6. 检查是否有话题需要打包存储 - # 从配置中获取阈值 - no_update_checks_threshold = global_config.memory.chat_history_finalize_no_update_checks - message_count_threshold = global_config.memory.chat_history_finalize_message_count - - topics_to_finalize: List[str] = [] - for topic, item in self.topic_cache.items(): - if item.no_update_checks >= no_update_checks_threshold: - logger.info( - f"{self.log_prefix} 话题[{topic}] 连续 {no_update_checks_threshold} 次检查无新增内容,触发打包存储" - ) - topics_to_finalize.append(topic) - continue - if len(item.messages) > message_count_threshold: - logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 {message_count_threshold},触发打包存储") - topics_to_finalize.append(topic) - - for topic in topics_to_finalize: - item = self.topic_cache.get(topic) - if not item: - continue - try: - await self._finalize_and_store_topic( - topic=topic, - item=item, - # 这里的时间范围尽量覆盖最近一次检查的区间 - start_time=start_time, - end_time=end_time, - ) - finally: - # 无论成功与否,都从缓存中删除,避免重复 - self.topic_cache.pop(topic, None) - - def _find_most_similar_topic( - self, new_topic: str, existing_topics: List[str], similarity_threshold: float = 0.9 - ) -> Optional[tuple[str, float]]: - """ - 查找与给定新话题最相似的历史话题 - - Args: - new_topic: 新话题标题 - existing_topics: 历史话题标题列表 - similarity_threshold: 相似度阈值,默认0.9(90%) - - Returns: - Optional[tuple[str, float]]: 如果找到相似度>=阈值的历史话题,返回(历史话题标题, 相似度), - 否则返回None - """ - if not existing_topics: - return None - - best_match = None - best_similarity = 0.0 - - for existing_topic in existing_topics: - similarity = difflib.SequenceMatcher(None, new_topic, existing_topic).ratio() - if similarity > best_similarity: - best_similarity = similarity - best_match = existing_topic - - # 如果相似度达到阈值,返回匹配结果 - if best_match and best_similarity >= similarity_threshold: - return (best_match, best_similarity) - - return None - - def _build_topic_mapping( - self, topic_to_indices: Dict[str, List[int]], similarity_threshold: float = 0.9 - ) -> Dict[str, str]: - """ - 构建新话题到历史话题的映射(如果相似度>=阈值) - - Args: - topic_to_indices: 新话题到消息索引的映射 - similarity_threshold: 相似度阈值,默认0.9(90%) - - Returns: - Dict[str, str]: 新话题 -> 历史话题的映射字典 - """ - existing_topics_list = list(self.topic_cache.keys()) - topic_mapping: Dict[str, str] = {} - - for new_topic in topic_to_indices.keys(): - # 如果新话题已经在历史话题中,不需要检查 - if new_topic in existing_topics_list: - continue - - # 查找最相似的历史话题 - result = self._find_most_similar_topic(new_topic, existing_topics_list, similarity_threshold) - if result: - historical_topic, similarity = result - topic_mapping[new_topic] = historical_topic - logger.info( - f"{self.log_prefix} 话题相似度检查: '{new_topic}' 与历史话题 '{historical_topic}' 相似度 {similarity:.2%},使用历史标题" - ) - - return topic_mapping - - def _build_numbered_messages_for_llm( - self, messages: List[SessionMessage] - ) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]: - """ - 将消息转为带编号的字符串,供 LLM 选择使用。 - - 返回: - numbered_lines: ["1. xxx", "2. yyy", ...] # 带编号,用于 LLM 选择 - index_to_msg_str: idx -> "idx. xxx" # 带编号,用于 LLM 选择 - index_to_msg_text: idx -> "xxx" # 不带编号,用于最终存储 - index_to_participants: idx -> {nickname1, nickname2, ...} - """ - numbered_lines: List[str] = [] - index_to_msg_str: Dict[int, str] = {} - index_to_msg_text: Dict[int, str] = {} # 不带编号的消息文本 - index_to_participants: Dict[int, Set[str]] = {} - - for idx, msg in enumerate(messages, start=1): - # 使用 build_readable_messages 生成可读文本 - try: - text = message_api.build_readable_messages( - messages=[msg], - replace_bot_name=True, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=False, - show_actions=False, - ).strip() - except Exception: - # 回退到简单文本 - text = getattr(msg, "processed_plain_text", "") or "" - - # 获取发言人昵称 - participants: Set[str] = set() - try: - platform = msg.platform - user_id = msg.message_info.user_info.user_id - if platform and user_id: - person = Person(platform=platform, user_id=user_id) - if person.person_name: - participants.add(person.person_name) - except Exception: - pass - - # 带编号的字符串(用于 LLM 选择) - line = f"{idx}. {text}" - numbered_lines.append(line) - index_to_msg_str[idx] = line - # 不带编号的文本(用于最终存储) - index_to_msg_text[idx] = text - index_to_participants[idx] = participants - - return numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants - - async def _analyze_topics_with_llm( - self, - numbered_lines: List[str], - existing_topics: List[str], - ) -> tuple[bool, Dict[str, List[int]]]: - """ - 使用 LLM 识别本次检查中的话题,并为每个话题选择相关消息编号。 - - 要求: - - 话题用一句话清晰描述正在发生的事件,包括时间、人物、主要事件和主题; - - 可以有 1 个或多个话题; - - 若某个话题与历史话题列表中的某个话题是同一件事,请直接使用历史话题的字符串; - - 输出 JSON,格式: - [ - { - "topic": "话题标题字符串", - "message_indices": [1, 2, 5] - }, - ... - ] - """ - if not numbered_lines: - return False, {} - - history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" - messages_block = "\n".join(numbered_lines) - - prompt_template = prompt_manager.get_prompt("hippo_topic_analysis") - prompt_template.add_context("history_topics_block", history_topics_block) - prompt_template.add_context("messages_block", messages_block) - prompt = await prompt_manager.render_prompt(prompt_template) - - try: - generation_result = await self.summarizer_llm.generate_response( - prompt=prompt, - options=LLMGenerationOptions(temperature=0.3), - ) - response = generation_result.response - - logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") - logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") - - # 尝试从响应中提取JSON代码块 - json_str = None - json_pattern = r"```json\s*(.*?)\s*```" - matches = re.findall(json_pattern, response, re.DOTALL) - - if matches: - # 找到JSON代码块,使用第一个匹配 - json_str = matches[0].strip() - else: - # 如果没有找到代码块,尝试查找JSON数组的开始和结束位置 - # 查找第一个 [ 和最后一个 ] - start_idx = response.find("[") - end_idx = response.rfind("]") - if start_idx != -1 and end_idx != -1 and end_idx > start_idx: - json_str = response[start_idx : end_idx + 1].strip() - else: - # 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记) - json_str = response.strip() - json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) - json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) - json_str = json_str.strip() - - # 使用json_repair修复可能的JSON错误 - if json_str: - try: - repaired_json = repair_json(json_str) - result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json - except Exception as repair_error: - # 如果repair失败,尝试直接解析 - logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}") - result = json.loads(json_str) - else: - raise ValueError("无法从响应中提取JSON内容") - - if not isinstance(result, list): - logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}") - return False, {} - - topic_to_indices: Dict[str, List[int]] = {} - for item in result: - if not isinstance(item, dict): - continue - topic = item.get("topic") - indices = item.get("message_indices") or item.get("messages") or [] - if not topic or not isinstance(topic, str): - continue - if isinstance(indices, list): - valid_indices: List[int] = [] - for v in indices: - try: - iv = int(v) - if iv > 0: - valid_indices.append(iv) - except (TypeError, ValueError): - continue - if valid_indices: - topic_to_indices[topic] = valid_indices - - return True, topic_to_indices - - except Exception as e: - logger.error(f"{self.log_prefix} 话题识别 LLM 调用或解析失败: {e}") - logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - return False, {} - - async def _finalize_and_store_topic( - self, - topic: str, - item: TopicCacheItem, - start_time: float, - end_time: float, - ): - """ - 对某个话题进行最终打包存储: - 1. 将 messages(list[str]) 拼接为 original_text; - 2. 使用 LLM 对 original_text 进行总结,得到 summary 和 keywords,theme 直接使用话题字符串; - 3. 写入数据库 ChatHistory; - 4. 完成后,调用方会从缓存中删除该话题。 - """ - if not item.messages: - logger.info(f"{self.log_prefix} 话题[{topic}] 无消息内容,跳过打包") - return - - original_text = "\n".join(item.messages) - - logger.info( - f"{self.log_prefix} 开始将聊天记录构建成记忆:[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - - # 使用 LLM 进行总结(基于话题名),带重试机制 - max_retries = 3 - attempt = 0 - success = False - keywords = [] - summary = "" - - while attempt < max_retries: - attempt += 1 - success, keywords, summary = await self._compress_with_llm(original_text, topic) - - if success and keywords and summary: - # 成功获取到有效的 keywords 和 summary - if attempt > 1: - logger.info(f"{self.log_prefix} 话题[{topic}] LLM 概括在第 {attempt} 次重试后成功") - break - - if attempt < max_retries: - logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败(第 {attempt} 次尝试),准备重试") - else: - logger.error(f"{self.log_prefix} 话题[{topic}] LLM 概括连续 {max_retries} 次失败,放弃存储") - - if not success or not keywords or not summary: - logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库") - return - - participants = list(item.participants) - - await self._store_to_database( - start_time=start_time, - end_time=end_time, - original_text=original_text, - participants=participants, - theme=topic, # 主题直接使用话题名 - keywords=keywords, - summary=summary, - ) - - logger.info( - f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}" - ) - - async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str]: - """ - 使用LLM压缩聊天内容(用于单个话题的最终总结) - - Args: - original_text: 聊天记录原文 - topic: 话题名称 - - Returns: - tuple[bool, List[str], str]: (是否成功, 关键词列表, 概括) - """ - prompt_template = prompt_manager.get_prompt("hippo_topic_summary") - prompt_template.add_context("topic", topic) - prompt_template.add_context("original_text", original_text) - prompt = await prompt_manager.render_prompt(prompt_template) - - try: - generation_result = await self.summarizer_llm.generate_response(prompt=prompt) - response = generation_result.response - - # 解析JSON响应 - json_str = response.strip() - json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) - json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) - json_str = json_str.strip() - - # 查找JSON对象的开始与结束 - start_idx = json_str.find("{") - if start_idx == -1: - raise ValueError("未找到JSON对象开始标记") - - end_idx = json_str.rfind("}") - if end_idx == -1 or end_idx <= start_idx: - logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复") - extracted_json = json_str[start_idx:] - else: - extracted_json = json_str[start_idx : end_idx + 1] - - def _parse_with_quote_fix(payload: str) -> Dict[str, Any]: - fixed_chars: List[str] = [] - in_string = False - escape_next = False - i = 0 - while i < len(payload): - char = payload[i] - if escape_next: - fixed_chars.append(char) - escape_next = False - elif char == "\\": - fixed_chars.append(char) - escape_next = True - elif char == '"' and not escape_next: - fixed_chars.append(char) - in_string = not in_string - elif in_string and char in {"“", "”"}: - # 在字符串值内部,将中文引号替换为转义的英文引号 - fixed_chars.append('\\"') - else: - fixed_chars.append(char) - i += 1 - - repaired = "".join(fixed_chars) - return json.loads(repaired) - - try: - result = json.loads(extracted_json) - except json.JSONDecodeError: - try: - repaired_json = repair_json(extracted_json) - if isinstance(repaired_json, str): - result = json.loads(repaired_json) - else: - result = repaired_json - except Exception as repair_error: - logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}") - result = _parse_with_quote_fix(extracted_json) - - keywords = result.get("keywords", []) - summary = result.get("summary", "") - - # 检查必需字段是否为空 - if not keywords or not summary: - logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少必需字段,原文\n{response}") - # 返回失败,和模型出错一样,让上层进行重试 - return False, [], "" - - # 确保keywords是列表 - if isinstance(keywords, str): - keywords = [keywords] - - return True, keywords, summary - - except Exception as e: - logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") - logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - # 返回失败标志和默认值 - return False, [], "压缩失败,无法生成概括" - - async def _store_to_database( - self, - start_time: float, - end_time: float, - original_text: str, - participants: List[str], - theme: str, - keywords: List[str], - summary: str, - ): - """存储到数据库""" - try: - from src.common.database.database_model import ChatHistory - from src.services import database_service as database_api - - # 准备数据 - data = { - "session_id": self.session_id, - "start_timestamp": datetime.fromtimestamp(start_time), - "end_timestamp": datetime.fromtimestamp(end_time), - "original_messages": original_text, - "participants": json.dumps(participants, ensure_ascii=False), - "theme": theme, - "keywords": json.dumps(keywords, ensure_ascii=False), - "summary": summary, - "query_count": 0, - "query_forget_count": 0, - } - - saved_record = await database_api.db_save( - ChatHistory, - data=data, - ) - - if saved_record: - logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库") - else: - logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") - - if saved_record and saved_record.get("id") is not None: - await self._import_to_long_term_memory( - record_id=int(saved_record["id"]), - theme=theme, - summary=summary, - participants=participants, - start_time=start_time, - end_time=end_time, - original_text=original_text, - ) - - except Exception as e: - logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") - import traceback - - traceback.print_exc() - raise - - async def _import_to_long_term_memory( - self, - record_id: int, - theme: str, - summary: str, - participants: List[str], - start_time: float, - end_time: float, - original_text: str, - ): - """ - 将聊天历史总结导入到统一长期记忆 - - Args: - record_id: chat_history 主键 - theme: 话题主题 - summary: 概括内容 - participants: 参与者列表 - start_time: 开始时间 - end_time: 结束时间 - original_text: 原始文本(可能很长,需要截断) - """ - try: - from src.services.memory_service import memory_service - session = _chat_manager.get_session_by_session_id(self.session_id) - session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" - session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" - - content_parts = [] - if theme: - content_parts.append(f"主题:{theme}") - if summary: - content_parts.append(f"概括:{summary}") - if participants: - participants_text = "、".join(participants) - content_parts.append(f"参与者:{participants_text}") - content_to_import = "\n".join(content_parts) - - if not content_to_import.strip(): - logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底") - await self._fallback_import_to_long_term_memory( - record_id=record_id, - theme=theme, - participants=participants, - start_time=start_time, - end_time=end_time, - original_text=original_text, - ) - return - - result = await memory_service.ingest_summary( - external_id=f"chat_history:{record_id}", - chat_id=self.session_id, - text=content_to_import, - participants=participants, - time_start=start_time, - time_end=end_time, - tags=[theme] if theme else [], - metadata={"theme": theme, "original_text_length": len(original_text or "")}, - respect_filter=True, - user_id=session_user_id, - group_id=session_group_id, - ) - if result.success: - if result.detail == "chat_filtered": - logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}") - else: - logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}") - else: - logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}") - await self._fallback_import_to_long_term_memory( - record_id=record_id, - theme=theme, - participants=participants, - start_time=start_time, - end_time=end_time, - original_text=original_text, - ) - - except Exception as e: - logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True) - - async def _fallback_import_to_long_term_memory( - self, - *, - record_id: int, - theme: str, - participants: List[str], - start_time: float, - end_time: float, - original_text: str, - ) -> None: - try: - from src.services.memory_service import memory_service - session = _chat_manager.get_session_by_session_id(self.session_id) - session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" - session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" - - result = await memory_service.ingest_summary( - external_id=f"chat_history:{record_id}", - chat_id=self.session_id, - text="", - participants=participants, - time_start=start_time, - time_end=end_time, - tags=[theme] if theme else [], - metadata={ - "theme": theme, - "original_text_length": len(original_text or ""), - "generate_from_chat": True, - "context_length": global_config.memory.chat_history_topic_check_message_threshold, - }, - respect_filter=True, - user_id=session_user_id, - group_id=session_group_id, - ) - if result.success: - if result.detail == "chat_filtered": - logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}") - else: - logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}") - else: - logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}") - except Exception as exc: - logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True) - - async def start(self): - """启动后台定期检查循环""" - if self._running: - logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动") - return - - # 加载聊天批次(如果有) - await self._load_batch_from_disk() - - self._running = True - self._periodic_task = asyncio.create_task(self._periodic_check_loop()) - logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒") - - async def stop(self): - """停止后台定期检查循环""" - self._running = False - if self._periodic_task: - self._periodic_task.cancel() - try: - await self._periodic_task - except asyncio.CancelledError: - pass - self._periodic_task = None - logger.info(f"{self.log_prefix} 已停止后台定期检查循环") - - async def _periodic_check_loop(self): - """后台定期检查循环""" - try: - while self._running: - # 执行一次检查 - await self.process() - - # 等待指定间隔后再次检查 - await asyncio.sleep(self.check_interval) - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} 后台检查循环被取消") - raise - except Exception as e: - logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") - import traceback - - traceback.print_exc() - self._running = False diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py index 75ff0ca9..8b7d8aa4 100644 --- a/src/services/memory_flow_service.py +++ b/src/services/memory_flow_service.py @@ -2,54 +2,20 @@ from __future__ import annotations import asyncio import json -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from json_repair import repair_json from src.chat.utils.utils import is_bot_self -from src.common.message_repository import find_messages from src.common.logger import get_logger +from src.common.message_repository import find_messages from src.config.config import global_config -from src.memory_system.chat_history_summarizer import ChatHistorySummarizer from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer from src.services.llm_service import LLMServiceClient logger = get_logger("memory_flow_service") -class LongTermMemorySessionManager: - def __init__(self) -> None: - self._lock = asyncio.Lock() - self._summarizers: Dict[str, ChatHistorySummarizer] = {} - - async def on_message(self, message: Any) -> None: - if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)): - return - session_id = str(getattr(message, "session_id", "") or "").strip() - if not session_id: - return - - created = False - async with self._lock: - summarizer = self._summarizers.get(session_id) - if summarizer is None: - summarizer = ChatHistorySummarizer(session_id=session_id) - self._summarizers[session_id] = summarizer - created = True - if created: - await summarizer.start() - - async def shutdown(self) -> None: - async with self._lock: - items = list(self._summarizers.items()) - self._summarizers.clear() - for session_id, summarizer in items: - try: - await summarizer.stop() - except Exception as exc: - logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc) - - class PersonFactWritebackService: def __init__(self) -> None: self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256) @@ -123,7 +89,11 @@ class PersonFactWritebackService: if not session_id: return - person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip() + person_name = str( + getattr(target_person, "person_name", "") + or getattr(target_person, "nickname", "") + or "" + ).strip() if not person_name: return @@ -242,7 +212,6 @@ class PersonFactWritebackService: class MemoryAutomationService: def __init__(self) -> None: - self.session_manager = LongTermMemorySessionManager() self.fact_writeback = PersonFactWritebackService() self._started = False @@ -255,14 +224,13 @@ class MemoryAutomationService: async def shutdown(self) -> None: if not self._started: return - await self.session_manager.shutdown() await self.fact_writeback.shutdown() self._started = False async def on_incoming_message(self, message: Any) -> None: + del message if not self._started: await self.start() - await self.session_manager.on_message(message) async def on_message_sent(self, message: Any) -> None: if not self._started: