diff --git a/prompts/zh-CN/memory_get_knowledge.prompt b/prompts/zh-CN/memory_get_knowledge.prompt new file mode 100644 index 00000000..aa9e8967 --- /dev/null +++ b/prompts/zh-CN/memory_get_knowledge.prompt @@ -0,0 +1,26 @@ +你是一个专门获取长期记忆的助手。你的名字是{bot_name}。现在是{time_now}。 +群里正在进行的聊天内容: +{chat_history} + +现在,{sender}发送了内容:{target_message},你想要回复ta。 +请仔细分析聊天内容,考虑以下几点: +1. 内容中是否包含需要查询历史知识或长期记忆的问题 +2. 是否有明确的知识获取指令 + +如果需要使用长期记忆工具,请直接调用函数 `search_long_term_memory`;如果不需要任何工具,直接输出 `No tool needed`。 + +工具模式说明: +- `mode="search"`:普通长期记忆检索,适合查具体事实、偏好、历史对话内容 +- `mode="time"`:按时间范围检索,必须同时提供 `time_expression` +- `mode="episode"`:按事件/情节检索,适合查“那次经历”“那件事的经过” +- `mode="aggregate"`:综合检索,适合“整体回忆一下”“把相关线索综合找出来” + +优先规则: +- 问“某段时间发生了什么”:优先 `time` +- 问“某次事件/某段经历”:优先 `episode` +- 问“整体情况/最近发生过什么”:优先 `aggregate` +- 问单点事实:优先 `search` + +`time_expression` 可用表达: +- `今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天` +- 或绝对时间:`2026/03/18`、`2026/03/18 09:30` diff --git a/prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt b/prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt new file mode 100644 index 00000000..91ea6eab --- /dev/null +++ b/prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt @@ -0,0 +1,34 @@ +你的名字是{bot_name}。现在是{time_now}。 +你正在参与聊天,你需要搜集信息来帮助你进行回复。 +重要,这是当前聊天记录: +{chat_history} +聊天记录结束 + +已收集的信息: +{collected_info} + +- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点 +- 思考完毕后,使用工具 + +**工具说明:** +- 如果涉及过往事件、历史对话、用户长期偏好或某段时间发生的事件,可以使用长期记忆查询工具 +- 如果遇到不熟悉的词语、缩写、黑话或网络用语,可以使用query_words工具查询其含义 +- 你必须使用tool,如果需要查询你必须给出使用什么工具进行查询 +- 当你决定结束查询时,必须调用return_information工具返回总结信息并结束查询 + +长期记忆工具 `search_long_term_memory` 支持以下模式: +- `mode="search"`:普通事实/偏好/历史内容检索。适合问“她喜欢什么”“我们之前讨论过什么”。 +- `mode="time"`:按时间范围检索。适合问“昨天发生了什么”“最近7天有哪些相关记忆”。 +- `mode="episode"`:按事件/情节检索。适合问“那次灯塔停电的经过是什么”“关于某次经历还有什么”。 +- `mode="aggregate"`:综合检索。适合问“帮我整体回忆一下这个人最近的情况”“把相关线索综合找出来”。 + +模式选择建议: +- 问单点事实、偏好、人设、具体信息:优先 `search` +- 问某段时间发生了什么:优先 `time` +- 问某次事件、某段经历、某个剧情片段:优先 `episode` +- 问整体回忆、综合找线索、总结最近发生的事:优先 `aggregate` + +时间模式要求: +- 使用 `mode="time"` 时,必须填写 `time_expression` +- 可用时间表达包括:`今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天` +- 也可以使用绝对时间:`2026/03/18`、`2026/03/18 09:30` diff --git a/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py b/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py new file mode 100644 index 00000000..0f084ece --- /dev/null +++ b/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py @@ -0,0 +1,148 @@ +from types import SimpleNamespace + +import pytest + +from src.memory_system import chat_history_summarizer as summarizer_module + + +def _build_summarizer() -> summarizer_module.ChatHistorySummarizer: + summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer) + summarizer.session_id = "session-1" + summarizer.log_prefix = "[session-1]" + return summarizer + + +@pytest.mark.asyncio +async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch): + calls = [] + summarizer = _build_summarizer() + + async def fake_ingest_summary(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), + ) + monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8))) + monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) + + await summarizer._import_to_long_term_memory( + record_id=1, + theme="旅行计划", + summary="我们讨论了春游安排", + participants=["Alice", "Bob"], + start_time=1.0, + end_time=2.0, + original_text="long text", + ) + + assert len(calls) == 1 + payload = calls[0] + assert payload["external_id"] == "chat_history:1" + assert payload["chat_id"] == "session-1" + assert payload["participants"] == ["Alice", "Bob"] + assert payload["respect_filter"] is True + assert payload["user_id"] == "user-1" + assert payload["group_id"] == "" + assert "主题:旅行计划" in payload["text"] + assert "概括:我们讨论了春游安排" in payload["text"] + + +@pytest.mark.asyncio +async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch): + summarizer = _build_summarizer() + fallback_calls = [] + + async def fake_fallback(**kwargs): + fallback_calls.append(kwargs) + + summarizer._fallback_import_to_long_term_memory = fake_fallback + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), + ) + + await summarizer._import_to_long_term_memory( + record_id=2, + theme="", + summary="", + participants=[], + start_time=10.0, + end_time=20.0, + original_text="raw chat", + ) + + assert len(fallback_calls) == 1 + assert fallback_calls[0]["record_id"] == 2 + assert fallback_calls[0]["original_text"] == "raw chat" + + +@pytest.mark.asyncio +async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch): + summarizer = _build_summarizer() + fallback_calls = [] + + async def fake_ingest_summary(**kwargs): + return SimpleNamespace(success=False, detail="boom", stored_ids=[]) + + async def fake_fallback(**kwargs): + fallback_calls.append(kwargs) + + summarizer._fallback_import_to_long_term_memory = fake_fallback + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")), + ) + monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) + + await summarizer._import_to_long_term_memory( + record_id=3, + theme="电影", + summary="聊了电影推荐", + participants=["Alice"], + start_time=3.0, + end_time=4.0, + original_text="raw", + ) + + assert len(fallback_calls) == 1 + assert fallback_calls[0]["theme"] == "电影" + + +@pytest.mark.asyncio +async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch): + calls = [] + summarizer = _build_summarizer() + + async def fake_ingest_summary(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[]) + + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")), + ) + monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12))) + monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) + + await summarizer._fallback_import_to_long_term_memory( + record_id=4, + theme="工作", + participants=["Alice"], + start_time=5.0, + end_time=6.0, + original_text="a" * 128, + ) + + assert len(calls) == 1 + metadata = calls[0]["metadata"] + assert metadata["generate_from_chat"] is True + assert metadata["context_length"] == 12 + assert calls[0]["respect_filter"] is True + diff --git a/pytests/A_memorix_test/test_knowledge_fetcher.py b/pytests/A_memorix_test/test_knowledge_fetcher.py new file mode 100644 index 00000000..4fb4e564 --- /dev/null +++ b/pytests/A_memorix_test/test_knowledge_fetcher.py @@ -0,0 +1,127 @@ +from types import SimpleNamespace + +import pytest + +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.services.memory_service import MemoryHit, MemorySearchResult + + +def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch): + monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object()) + monkeypatch.setattr( + knowledge_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), + ) + monkeypatch.setattr( + knowledge_module, + "resolve_person_id_for_memory", + lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}", + ) + + fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") + + assert fetcher._resolve_private_memory_context() == { + "chat_id": "stream-1", + "person_id": "Alice:qq:42", + "user_id": "42", + "group_id": "", + } + + +@pytest.mark.asyncio +async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch): + monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object()) + monkeypatch.setattr( + knowledge_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), + ) + monkeypatch.setattr( + knowledge_module, + "resolve_person_id_for_memory", + lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}", + ) + + calls = [] + + async def fake_search(query: str, **kwargs): + calls.append((query, kwargs)) + return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False) + + monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search) + + fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") + result = await fetcher._memory_get_knowledge("她喜欢什么") + + assert "1. 她喜欢猫" in result + assert calls == [ + ( + "她喜欢什么", + { + "limit": 5, + "mode": "search", + "chat_id": "stream-1", + "person_id": "Alice:qq:42", + "user_id": "42", + "group_id": "", + "respect_filter": True, + }, + ) + ] + + +@pytest.mark.asyncio +async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch): + monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object()) + monkeypatch.setattr( + knowledge_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), + ) + monkeypatch.setattr( + knowledge_module, + "resolve_person_id_for_memory", + lambda *, person_name, platform, user_id: "person-1", + ) + + calls = [] + + async def fake_search(query: str, **kwargs): + calls.append((query, kwargs)) + if kwargs.get("person_id"): + return MemorySearchResult(summary="", hits=[], filtered=False) + return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False) + + monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search) + + fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") + result = await fetcher._memory_get_knowledge("Alice 最近在忙什么") + + assert "杭州音乐节" in result + assert calls == [ + ( + "Alice 最近在忙什么", + { + "limit": 5, + "mode": "search", + "chat_id": "stream-1", + "person_id": "person-1", + "user_id": "42", + "group_id": "", + "respect_filter": True, + }, + ), + ( + "Alice 最近在忙什么", + { + "limit": 5, + "mode": "search", + "chat_id": "stream-1", + "person_id": "", + "user_id": "42", + "group_id": "", + "respect_filter": True, + }, + ), + ] diff --git a/pytests/A_memorix_test/test_legacy_config_migration.py b/pytests/A_memorix_test/test_legacy_config_migration.py new file mode 100644 index 00000000..c382e4f3 --- /dev/null +++ b/pytests/A_memorix_test/test_legacy_config_migration.py @@ -0,0 +1,35 @@ +from src.config.legacy_migration import try_migrate_legacy_bot_config_dict + + +def test_legacy_learning_list_with_numeric_fourth_column_is_migrated(): + payload = { + "expression": { + "learning_list": [ + ["qq:123456:group", "enable", "disable", "0.5"], + ["", "disable", "enable", "0.1"], + ] + } + } + + result = try_migrate_legacy_bot_config_dict(payload) + + assert result.migrated is True + assert "expression.learning_list" in result.reason + assert result.data["expression"]["learning_list"] == [ + { + "platform": "qq", + "item_id": "123456", + "rule_type": "group", + "use_expression": True, + "enable_learning": False, + "enable_jargon_learning": False, + }, + { + "platform": "", + "item_id": "", + "rule_type": "group", + "use_expression": False, + "enable_learning": True, + "enable_jargon_learning": False, + }, + ] diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark.py new file mode 100644 index 00000000..3c3e4090 --- /dev/null +++ b/pytests/A_memorix_test/test_long_novel_memory_benchmark.py @@ -0,0 +1,691 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List + +import numpy as np +import pytest +import pytest_asyncio + +from A_memorix.core.runtime import sdk_memory_kernel as kernel_module +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import MemorySearchResult, memory_service + + +DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json" +REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json" + + +def _load_benchmark_fixture() -> Dict[str, Any]: + return json.loads(DATA_FILE.read_text(encoding="utf-8")) + + +class _FakeEmbeddingAdapter: + def __init__(self, dimension: int = 32) -> None: + self.dimension = dimension + + async def _detect_dimension(self) -> int: + return self.dimension + + async def encode(self, texts, dimensions=None): + dim = int(dimensions or self.dimension) + if isinstance(texts, str): + sequence = [texts] + single = True + else: + sequence = list(texts) + single = False + + rows = [] + for text in sequence: + vec = np.zeros(dim, dtype=np.float32) + for ch in str(text or ""): + code = ord(ch) + vec[code % dim] += 1.0 + vec[(code * 7) % dim] += 0.5 + if not vec.any(): + vec[0] = 1.0 + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + rows.append(vec) + payload = np.vstack(rows) + return payload[0] if single else payload + + +class _KnownPerson: + def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None: + self.person_id = person_id + self.is_known = person_id in reverse_registry + self.person_name = reverse_registry.get(person_id, "") + self._registry = registry + + +class _KernelBackedRuntimeManager: + is_running = True + + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke_plugin( + self, + *, + method: str, + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None, + timeout_ms: int, + ): + del method, plugin_id, timeout_ms + payload = args or {} + if component_name == "search_memory": + return await self.kernel.search_memory( + KernelSearchRequest( + query=str(payload.get("query", "") or ""), + limit=int(payload.get("limit", 5) or 5), + mode=str(payload.get("mode", "hybrid") or "hybrid"), + chat_id=str(payload.get("chat_id", "") or ""), + person_id=str(payload.get("person_id", "") or ""), + time_start=payload.get("time_start"), + time_end=payload.get("time_end"), + respect_filter=bool(payload.get("respect_filter", True)), + user_id=str(payload.get("user_id", "") or ""), + group_id=str(payload.get("group_id", "") or ""), + ) + ) + + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]: + for _ in range(max_rounds): + detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) + task = detail.get("task") or {} + status = str(task.get("status", "") or "") + if status in {"completed", "completed_with_errors", "failed", "cancelled"}: + return detail + await asyncio.sleep(max(0.01, float(sleep_seconds))) + raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") + + +def _join_hit_content(search_result: MemorySearchResult) -> str: + return "\n".join(hit.content for hit in search_result.hits) + + +def _keyword_hits(text: str, keywords: List[str]) -> int: + haystack = str(text or "") + return sum(1 for keyword in keywords if keyword in haystack) + + +def _keyword_recall(text: str, keywords: List[str]) -> float: + if not keywords: + return 1.0 + return _keyword_hits(text, keywords) / float(len(keywords)) + + +def _hit_blob(hit) -> str: + meta = hit.metadata if isinstance(hit.metadata, dict) else {} + return "\n".join( + [ + str(hit.content or ""), + str(hit.title or ""), + str(hit.source or ""), + json.dumps(meta, ensure_ascii=False), + ] + ) + + +def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int: + for index, hit in enumerate(search_result.hits[:5], start=1): + if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))): + return index + return 0 + + +def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str: + return "\n".join( + ( + f"{item.get('title', '')}\n" + f"{item.get('summary', '')}\n" + f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n" + f"{json.dumps(item.get('participants', []), ensure_ascii=False)}" + ) + for item in items + ) + + +def _episode_blob_from_hits(search_result: MemorySearchResult) -> str: + chunks = [] + for hit in search_result.hits: + meta = hit.metadata if isinstance(hit.metadata, dict) else {} + chunks.append( + "\n".join( + [ + str(hit.title or ""), + str(hit.content or ""), + json.dumps(meta.get("keywords", []) or [], ensure_ascii=False), + json.dumps(meta.get("participants", []) or [], ensure_ascii=False), + ] + ) + ) + return "\n".join(chunks) + + +async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: + episode_source = f"chat_summary:{session_id}" + payload = await memory_service.episode_admin( + action="query", + source=episode_source, + limit=20, + ) + items = payload.get("items") or [] + blob = _episode_blob_from_items(items) + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + + for case in episode_cases: + recall = _keyword_recall(blob, case["expected_keywords"]) + success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "episode_count": len(items), + "top_episode": items[0] if items else None, + } + ) + + total = max(1, len(episode_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "episode_count": len(items), + "reports": reports, + } + + +async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + episode_source = f"chat_summary:{session_id}" + + for case in episode_cases: + payload = await memory_service.episode_admin( + action="query", + source=episode_source, + query=case["query"], + limit=5, + ) + items = payload.get("items") or [] + blob = "\n".join( + f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}" + for item in items + ) + recall = _keyword_recall(blob, case["expected_keywords"]) + success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "episode_count": len(items), + "top_episode": items[0] if items else None, + } + ) + + total = max(1, len(episode_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "reports": reports, + } + + +async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + + for case in episode_cases: + result = await memory_service.search( + case["query"], + mode="episode", + chat_id=session_id, + respect_filter=False, + limit=5, + ) + blob = _episode_blob_from_hits(result) + recall = _keyword_recall(blob, case["expected_keywords"]) + success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0)) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "episode_count": len(result.hits), + "top_episode": result.hits[0].to_dict() if result.hits else None, + } + ) + + total = max(1, len(episode_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "reports": reports, + } + + +async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]: + search_case = dataset["search_cases"][0] + episode_case = dataset["episode_cases"][0] + aggregate_case = dataset["knowledge_fetcher_cases"][0] + tool_cases = [ + { + "name": "search", + "kwargs": { + "query": "蓝漆铁盒 北塔木梯", + "mode": "search", + "chat_id": session_id, + "limit": 5, + }, + "expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"], + "minimum_keyword_recall": 0.67, + }, + { + "name": "time", + "kwargs": { + "query": "蓝漆铁盒 北塔", + "mode": "time", + "chat_id": session_id, + "limit": 5, + "time_expression": "最近7天", + }, + "expected_keywords": ["蓝漆铁盒", "北塔木梯"], + "minimum_keyword_recall": 0.67, + }, + { + "name": "episode", + "kwargs": { + "query": episode_case["query"], + "mode": "episode", + "chat_id": session_id, + "limit": 5, + }, + "expected_keywords": episode_case["expected_keywords"], + "minimum_keyword_recall": 0.67, + }, + { + "name": "aggregate", + "kwargs": { + "query": aggregate_case["query"], + "mode": "aggregate", + "chat_id": session_id, + "limit": 5, + }, + "expected_keywords": aggregate_case["expected_keywords"], + "minimum_keyword_recall": 0.67, + }, + ] + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + + for case in tool_cases: + text = await query_long_term_memory(**case["kwargs"]) + recall = _keyword_recall(text, case["expected_keywords"]) + success = ( + "失败" not in text + and "无法解析" not in text + and "未找到" not in text + and recall >= float(case["minimum_keyword_recall"]) + ) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "name": case["name"], + "success": success, + "keyword_recall": recall, + "preview": text[:320], + } + ) + + total = max(1, len(tool_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "reports": reports, + } + + +@pytest_asyncio.fixture +async def benchmark_env(monkeypatch, tmp_path): + dataset = _load_benchmark_fixture() + session_cfg = dataset["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} + reverse_registry = {value: key for key, value in registry.items()} + + monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) + + async def fake_self_check(**kwargs): + return {"ok": True, "message": "ok", "encoded_dimension": 32} + + monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) + monkeypatch.setattr( + person_info_module, + "Person", + lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), + ) + + data_dir = (tmp_path / "a_memorix_benchmark_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "dataset": dataset, + "kernel": kernel, + "session": session, + "person_registry": registry, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_long_novel_memory_benchmark(benchmark_env): + dataset = benchmark_env["dataset"] + session_id = benchmark_env["session"].session_id + + created = await memory_service.import_admin( + action="create_paste", + name="long_novel_memory_benchmark.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(dataset["import_payload"], ensure_ascii=False), + ) + assert created["success"] is True + + import_detail = await _wait_for_import_task(created["task"]["task_id"]) + assert import_detail["task"]["status"] == "completed" + + for record in dataset["chat_history_records"]: + summarizer = summarizer_module.ChatHistorySummarizer(session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + for payload in dataset["person_writebacks"]: + await person_info_module.store_person_memory_from_answer( + payload["person_name"], + payload["memory_content"], + session_id, + ) + + await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) + + search_case_reports: List[Dict[str, Any]] = [] + search_accuracy_at_1 = 0.0 + search_recall_at_5 = 0.0 + search_precision_at_5 = 0.0 + search_mrr = 0.0 + search_keyword_recall = 0.0 + + for case in dataset["search_cases"]: + result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) + joined = _join_hit_content(result) + rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) + relevant_hits = sum( + 1 + for hit in result.hits[:5] + if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) + ) + keyword_recall = _keyword_recall(joined, case["expected_keywords"]) + search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 + search_recall_at_5 += 1.0 if rank > 0 else 0.0 + search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) + search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 + search_keyword_recall += keyword_recall + search_case_reports.append( + { + "query": case["query"], + "rank_of_first_relevant": rank, + "relevant_hits_top5": relevant_hits, + "keyword_recall_top5": keyword_recall, + "top_hit": result.hits[0].to_dict() if result.hits else None, + } + ) + + search_total = max(1, len(dataset["search_cases"])) + + writeback_reports: List[Dict[str, Any]] = [] + writeback_success_rate = 0.0 + writeback_keyword_recall = 0.0 + for payload in dataset["person_writebacks"]: + query = " ".join(payload["expected_keywords"]) + result = await memory_service.search( + query, + mode="search", + chat_id=session_id, + person_id=payload["person_id"], + respect_filter=False, + limit=5, + ) + joined = _join_hit_content(result) + recall = _keyword_recall(joined, payload["expected_keywords"]) + success = bool(result.hits) and recall >= 0.67 + writeback_success_rate += 1.0 if success else 0.0 + writeback_keyword_recall += recall + writeback_reports.append( + { + "person_id": payload["person_id"], + "success": success, + "keyword_recall": recall, + "hit_count": len(result.hits), + } + ) + writeback_total = max(1, len(dataset["person_writebacks"])) + + knowledge_reports: List[Dict[str, Any]] = [] + knowledge_success_rate = 0.0 + knowledge_keyword_recall = 0.0 + fetcher = knowledge_module.KnowledgeFetcher( + private_name=dataset["session"]["display_name"], + stream_id=session_id, + ) + for case in dataset["knowledge_fetcher_cases"]: + knowledge_text, _ = await fetcher.fetch(case["query"], []) + recall = _keyword_recall(knowledge_text, case["expected_keywords"]) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) + knowledge_success_rate += 1.0 if success else 0.0 + knowledge_keyword_recall += recall + knowledge_reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "preview": knowledge_text[:300], + } + ) + knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) + + profile_reports: List[Dict[str, Any]] = [] + profile_success_rate = 0.0 + profile_keyword_recall = 0.0 + profile_evidence_rate = 0.0 + for case in dataset["profile_cases"]: + profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) + recall = _keyword_recall(profile.summary, case["expected_keywords"]) + has_evidence = bool(profile.evidence) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence + profile_success_rate += 1.0 if success else 0.0 + profile_keyword_recall += recall + profile_evidence_rate += 1.0 if has_evidence else 0.0 + profile_reports.append( + { + "person_id": case["person_id"], + "success": success, + "keyword_recall": recall, + "evidence_count": len(profile.evidence), + "summary_preview": profile.summary[:240], + } + ) + profile_total = max(1, len(dataset["profile_cases"])) + + episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_rebuild = await memory_service.episode_admin( + action="rebuild", + source=f"chat_summary:{session_id}", + ) + episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) + + report = { + "dataset": dataset["meta"], + "import": { + "task_id": created["task"]["task_id"], + "status": import_detail["task"]["status"], + "paragraph_count": len(dataset["import_payload"]["paragraphs"]), + }, + "metrics": { + "search": { + "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), + "recall_at_5": round(search_recall_at_5 / search_total, 4), + "precision_at_5": round(search_precision_at_5 / search_total, 4), + "mrr": round(search_mrr / search_total, 4), + "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), + }, + "writeback": { + "success_rate": round(writeback_success_rate / writeback_total, 4), + "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), + }, + "knowledge_fetcher": { + "success_rate": round(knowledge_success_rate / knowledge_total, 4), + "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), + }, + "profile": { + "success_rate": round(profile_success_rate / profile_total, 4), + "keyword_recall": round(profile_keyword_recall / profile_total, 4), + "evidence_rate": round(profile_evidence_rate / profile_total, 4), + }, + "tool_modes": { + "success_rate": tool_modes["success_rate"], + "keyword_recall": tool_modes["keyword_recall"], + }, + "episode_generation_auto": { + "success_rate": episode_generation_auto["success_rate"], + "keyword_recall": episode_generation_auto["keyword_recall"], + "episode_count": episode_generation_auto["episode_count"], + }, + "episode_generation_after_rebuild": { + "success_rate": episode_generation_after_rebuild["success_rate"], + "keyword_recall": episode_generation_after_rebuild["keyword_recall"], + "episode_count": episode_generation_after_rebuild["episode_count"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_admin_query_auto": { + "success_rate": episode_admin_query_auto["success_rate"], + "keyword_recall": episode_admin_query_auto["keyword_recall"], + }, + "episode_admin_query_after_rebuild": { + "success_rate": episode_admin_query_after_rebuild["success_rate"], + "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_search_mode_auto": { + "success_rate": episode_search_mode_auto["success_rate"], + "keyword_recall": episode_search_mode_auto["keyword_recall"], + }, + "episode_search_mode_after_rebuild": { + "success_rate": episode_search_mode_after_rebuild["success_rate"], + "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + }, + "cases": { + "search": search_case_reports, + "writeback": writeback_reports, + "knowledge_fetcher": knowledge_reports, + "profile": profile_reports, + "tool_modes": tool_modes["reports"], + "episode_generation_auto": episode_generation_auto["reports"], + "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], + "episode_admin_query_auto": episode_admin_query_auto["reports"], + "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], + "episode_search_mode_auto": episode_search_mode_auto["reports"], + "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], + }, + } + + REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) + REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) + + assert report["import"]["status"] == "completed" + assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35 + assert report["metrics"]["search"]["recall_at_5"] >= 0.6 + assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8 + assert report["metrics"]["writeback"]["success_rate"] >= 0.66 + assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66 + assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75 + assert report["metrics"]["profile"]["success_rate"] >= 0.66 + assert report["metrics"]["profile"]["evidence_rate"] >= 1.0 + assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75 + assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True + assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"] diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py new file mode 100644 index 00000000..1dad0795 --- /dev/null +++ b/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest +import pytest_asyncio + +from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel +from pytests.test_long_novel_memory_benchmark import ( + _evaluate_episode_admin_query, + _evaluate_episode_generation, + _evaluate_episode_search_mode, + _evaluate_tool_modes, + _KernelBackedRuntimeManager, + _KnownPerson, + _first_relevant_rank, + _hit_blob, + _join_hit_content, + _keyword_hits, + _keyword_recall, + _load_benchmark_fixture, + _wait_for_import_task, +) +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import memory_service + + +pytestmark = pytest.mark.skipif( + os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", + reason="需要显式开启真实 external embedding benchmark", +) + +REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json" + + +@pytest_asyncio.fixture +async def benchmark_live_env(monkeypatch, tmp_path): + dataset = _load_benchmark_fixture() + session_cfg = dataset["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} + reverse_registry = {value: key for key, value in registry.items()} + + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) + monkeypatch.setattr( + person_info_module, + "Person", + lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), + ) + + data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "dataset": dataset, + "kernel": kernel, + "session": session, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_long_novel_memory_benchmark_live(benchmark_live_env): + dataset = benchmark_live_env["dataset"] + session_id = benchmark_live_env["session"].session_id + + self_check = await memory_service.runtime_admin(action="refresh_self_check") + assert self_check["success"] is True + assert self_check["report"]["ok"] is True + + created = await memory_service.import_admin( + action="create_paste", + name="long_novel_memory_benchmark.live.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(dataset["import_payload"], ensure_ascii=False), + ) + assert created["success"] is True + + import_detail = await _wait_for_import_task( + created["task"]["task_id"], + max_rounds=2400, + sleep_seconds=0.25, + ) + assert import_detail["task"]["status"] == "completed" + + for record in dataset["chat_history_records"]: + summarizer = summarizer_module.ChatHistorySummarizer(session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + for payload in dataset["person_writebacks"]: + await person_info_module.store_person_memory_from_answer( + payload["person_name"], + payload["memory_content"], + session_id, + ) + + await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) + + search_case_reports: List[Dict[str, Any]] = [] + search_accuracy_at_1 = 0.0 + search_recall_at_5 = 0.0 + search_precision_at_5 = 0.0 + search_mrr = 0.0 + search_keyword_recall = 0.0 + for case in dataset["search_cases"]: + result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) + joined = _join_hit_content(result) + rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) + relevant_hits = sum( + 1 + for hit in result.hits[:5] + if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) + ) + keyword_recall = _keyword_recall(joined, case["expected_keywords"]) + search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 + search_recall_at_5 += 1.0 if rank > 0 else 0.0 + search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) + search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 + search_keyword_recall += keyword_recall + search_case_reports.append( + { + "query": case["query"], + "rank_of_first_relevant": rank, + "relevant_hits_top5": relevant_hits, + "keyword_recall_top5": keyword_recall, + "top_hit": result.hits[0].to_dict() if result.hits else None, + } + ) + search_total = max(1, len(dataset["search_cases"])) + + writeback_reports: List[Dict[str, Any]] = [] + writeback_success_rate = 0.0 + writeback_keyword_recall = 0.0 + for payload in dataset["person_writebacks"]: + query = " ".join(payload["expected_keywords"]) + result = await memory_service.search( + query, + mode="search", + chat_id=session_id, + person_id=payload["person_id"], + respect_filter=False, + limit=5, + ) + joined = _join_hit_content(result) + recall = _keyword_recall(joined, payload["expected_keywords"]) + success = bool(result.hits) and recall >= 0.67 + writeback_success_rate += 1.0 if success else 0.0 + writeback_keyword_recall += recall + writeback_reports.append( + { + "person_id": payload["person_id"], + "success": success, + "keyword_recall": recall, + "hit_count": len(result.hits), + } + ) + writeback_total = max(1, len(dataset["person_writebacks"])) + + knowledge_reports: List[Dict[str, Any]] = [] + knowledge_success_rate = 0.0 + knowledge_keyword_recall = 0.0 + fetcher = knowledge_module.KnowledgeFetcher( + private_name=dataset["session"]["display_name"], + stream_id=session_id, + ) + for case in dataset["knowledge_fetcher_cases"]: + knowledge_text, _ = await fetcher.fetch(case["query"], []) + recall = _keyword_recall(knowledge_text, case["expected_keywords"]) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) + knowledge_success_rate += 1.0 if success else 0.0 + knowledge_keyword_recall += recall + knowledge_reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "preview": knowledge_text[:300], + } + ) + knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) + + profile_reports: List[Dict[str, Any]] = [] + profile_success_rate = 0.0 + profile_keyword_recall = 0.0 + profile_evidence_rate = 0.0 + for case in dataset["profile_cases"]: + profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) + recall = _keyword_recall(profile.summary, case["expected_keywords"]) + has_evidence = bool(profile.evidence) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence + profile_success_rate += 1.0 if success else 0.0 + profile_keyword_recall += recall + profile_evidence_rate += 1.0 if has_evidence else 0.0 + profile_reports.append( + { + "person_id": case["person_id"], + "success": success, + "keyword_recall": recall, + "evidence_count": len(profile.evidence), + "summary_preview": profile.summary[:240], + } + ) + profile_total = max(1, len(dataset["profile_cases"])) + + episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_rebuild = await memory_service.episode_admin( + action="rebuild", + source=f"chat_summary:{session_id}", + ) + episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) + + report = { + "dataset": dataset["meta"], + "runtime_self_check": self_check["report"], + "import": { + "task_id": created["task"]["task_id"], + "status": import_detail["task"]["status"], + "paragraph_count": len(dataset["import_payload"]["paragraphs"]), + }, + "metrics": { + "search": { + "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), + "recall_at_5": round(search_recall_at_5 / search_total, 4), + "precision_at_5": round(search_precision_at_5 / search_total, 4), + "mrr": round(search_mrr / search_total, 4), + "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), + }, + "writeback": { + "success_rate": round(writeback_success_rate / writeback_total, 4), + "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), + }, + "knowledge_fetcher": { + "success_rate": round(knowledge_success_rate / knowledge_total, 4), + "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), + }, + "profile": { + "success_rate": round(profile_success_rate / profile_total, 4), + "keyword_recall": round(profile_keyword_recall / profile_total, 4), + "evidence_rate": round(profile_evidence_rate / profile_total, 4), + }, + "tool_modes": { + "success_rate": tool_modes["success_rate"], + "keyword_recall": tool_modes["keyword_recall"], + }, + "episode_generation_auto": { + "success_rate": episode_generation_auto["success_rate"], + "keyword_recall": episode_generation_auto["keyword_recall"], + "episode_count": episode_generation_auto["episode_count"], + }, + "episode_generation_after_rebuild": { + "success_rate": episode_generation_after_rebuild["success_rate"], + "keyword_recall": episode_generation_after_rebuild["keyword_recall"], + "episode_count": episode_generation_after_rebuild["episode_count"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_admin_query_auto": { + "success_rate": episode_admin_query_auto["success_rate"], + "keyword_recall": episode_admin_query_auto["keyword_recall"], + }, + "episode_admin_query_after_rebuild": { + "success_rate": episode_admin_query_after_rebuild["success_rate"], + "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_search_mode_auto": { + "success_rate": episode_search_mode_auto["success_rate"], + "keyword_recall": episode_search_mode_auto["keyword_recall"], + }, + "episode_search_mode_after_rebuild": { + "success_rate": episode_search_mode_after_rebuild["success_rate"], + "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + }, + "cases": { + "search": search_case_reports, + "writeback": writeback_reports, + "knowledge_fetcher": knowledge_reports, + "profile": profile_reports, + "tool_modes": tool_modes["reports"], + "episode_generation_auto": episode_generation_auto["reports"], + "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], + "episode_admin_query_auto": episode_admin_query_auto["reports"], + "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], + "episode_search_mode_auto": episode_search_mode_auto["reports"], + "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], + }, + } + + REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) + REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) + + assert report["import"]["status"] == "completed" + assert report["runtime_self_check"]["ok"] is True diff --git a/pytests/A_memorix_test/test_memory_flow_service.py b/pytests/A_memorix_test/test_memory_flow_service.py new file mode 100644 index 00000000..2d35e837 --- /dev/null +++ b/pytests/A_memorix_test/test_memory_flow_service.py @@ -0,0 +1,138 @@ +from types import SimpleNamespace + +import pytest + +from src.services import memory_flow_service as memory_flow_module + + +@pytest.mark.asyncio +async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch): + starts: list[str] = [] + summarizers: list[object] = [] + + class FakeSummarizer: + def __init__(self, session_id: str): + self.session_id = session_id + summarizers.append(self) + + async def start(self): + starts.append(self.session_id) + + async def stop(self): + starts.append(f"stop:{self.session_id}") + + monkeypatch.setattr( + memory_flow_module, + "global_config", + SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), + ) + monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) + + manager = memory_flow_module.LongTermMemorySessionManager() + message = SimpleNamespace(session_id="session-1") + + await manager.on_message(message) + await manager.on_message(message) + + assert len(summarizers) == 1 + assert starts == ["session-1"] + + +@pytest.mark.asyncio +async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch): + stopped: list[str] = [] + + class FakeSummarizer: + def __init__(self, session_id: str): + self.session_id = session_id + + async def start(self): + return None + + async def stop(self): + stopped.append(self.session_id) + + monkeypatch.setattr( + memory_flow_module, + "global_config", + SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), + ) + monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) + + manager = memory_flow_module.LongTermMemorySessionManager() + await manager.on_message(SimpleNamespace(session_id="session-a")) + await manager.on_message(SimpleNamespace(session_id="session-b")) + await manager.shutdown() + + assert stopped == ["session-a", "session-b"] + + +def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items(): + raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]' + + result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw) + + assert result == ["他喜欢猫", "他会弹吉他"] + + +def test_person_fact_looks_ephemeral_detects_short_chitchat(): + assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈") + assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?") + assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴") + + +def test_person_fact_resolve_target_person_for_private_chat(monkeypatch): + class FakePerson: + def __init__(self, person_id: str): + self.person_id = person_id + self.is_known = True + + service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService) + monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False) + monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}") + monkeypatch.setattr(memory_flow_module, "Person", FakePerson) + + message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id="")) + + person = service._resolve_target_person(message) + + assert person is not None + assert person.person_id == "qq:123" + + +@pytest.mark.asyncio +async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): + events: list[tuple[str, str]] = [] + + class FakeSessionManager: + async def on_message(self, message): + events.append(("incoming", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "session")) + + class FakeFactWriteback: + async def start(self): + events.append(("start", "fact")) + + async def enqueue(self, message): + events.append(("sent", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "fact")) + + service = memory_flow_module.MemoryAutomationService() + service.session_manager = FakeSessionManager() + service.fact_writeback = FakeFactWriteback() + + await service.on_incoming_message(SimpleNamespace(session_id="session-1")) + await service.on_message_sent(SimpleNamespace(session_id="session-1")) + await service.shutdown() + + assert events == [ + ("start", "fact"), + ("incoming", "session-1"), + ("sent", "session-1"), + ("shutdown", "session"), + ("shutdown", "fact"), + ] diff --git a/pytests/A_memorix_test/test_memory_service.py b/pytests/A_memorix_test/test_memory_service.py new file mode 100644 index 00000000..bac85afc --- /dev/null +++ b/pytests/A_memorix_test/test_memory_service.py @@ -0,0 +1,281 @@ +import pytest + +from src.services.memory_service import MemorySearchResult, MemoryService + + +def test_coerce_write_result_treats_skipped_payload_as_success(): + result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"}) + + assert result.success is True + assert result.stored_ids == [] + assert result.skipped_ids == ["p1"] + assert result.detail == "chat_filtered" + + +@pytest.mark.asyncio +async def test_graph_admin_invokes_plugin(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "nodes": [], "edges": []} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.graph_admin(action="get_graph", limit=12) + + assert result["success"] is True + assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})] + + +@pytest.mark.asyncio +async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"success": True, "items": [{"hash": "abc"}], "count": 1} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.get_recycle_bin(limit=5) + + assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1} + assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})] + + +@pytest.mark.asyncio +async def test_search_respects_filter_by_default(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"summary": "ok", "hits": [], "filtered": True} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.search( + "mai", + chat_id="stream-1", + person_id="person-1", + user_id="user-1", + group_id="", + ) + + assert isinstance(result, MemorySearchResult) + assert result.filtered is True + assert calls == [ + ( + "search_memory", + { + "query": "mai", + "limit": 5, + "mode": "hybrid", + "chat_id": "stream-1", + "person_id": "person-1", + "time_start": None, + "time_end": None, + "respect_filter": True, + "user_id": "user-1", + "group_id": "", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_ingest_summary_can_bypass_filter(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"success": True, "stored_ids": ["p1"], "detail": ""} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.ingest_summary( + external_id="chat_history:1", + chat_id="stream-1", + text="summary", + respect_filter=False, + user_id="user-1", + ) + + assert result.success is True + assert calls == [ + ( + "ingest_summary", + { + "external_id": "chat_history:1", + "chat_id": "stream-1", + "text": "summary", + "participants": [], + "time_start": None, + "time_end": None, + "tags": [], + "metadata": {}, + "respect_filter": False, + "user_id": "user-1", + "group_id": "", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_v5_admin_invokes_plugin(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "count": 1} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.v5_admin(action="status", target="mai", limit=5) + + assert result["success"] is True + assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})] + + +@pytest.mark.asyncio +async def test_delete_admin_uses_long_timeout(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "operation_id": "del-1"} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"}) + + assert result["success"] is True + assert calls == [ + ( + "memory_delete_admin", + {"action": "execute", "mode": "relation", "selector": {"query": "mai"}}, + {"timeout_ms": 120000}, + ) + ] + + +@pytest.mark.asyncio +async def test_search_returns_empty_when_query_and_time_missing_async(): + service = MemoryService() + + result = await service.search("", time_start=None, time_end=None) + + assert isinstance(result, MemorySearchResult) + assert result.summary == "" + assert result.hits == [] + assert result.filtered is False + + +@pytest.mark.asyncio +async def test_search_accepts_string_time_bounds(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"summary": "ok", "hits": [], "filtered": False} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.search( + "广播站", + mode="time", + time_start="2026/03/18", + time_end="2026/03/18 09:30", + ) + + assert isinstance(result, MemorySearchResult) + assert calls == [ + ( + "search_memory", + { + "query": "广播站", + "limit": 5, + "mode": "time", + "chat_id": "", + "person_id": "", + "time_start": "2026/03/18", + "time_end": "2026/03/18 09:30", + "respect_filter": True, + "user_id": "", + "group_id": "", + }, + ) + ] + + +def test_coerce_search_result_preserves_aggregate_source_branches(): + result = MemoryService._coerce_search_result( + { + "hits": [ + { + "content": "广播站值夜班", + "type": "paragraph", + "metadata": {"event_time_start": 1.0}, + "source_branches": ["search", "time"], + "rank": 1, + } + ] + } + ) + + assert result.hits[0].metadata["source_branches"] == ["search", "time"] + assert result.hits[0].metadata["rank"] == 1 + + +@pytest.mark.asyncio +async def test_import_admin_uses_long_timeout(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "task_id": "import-1"} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.import_admin(action="create_lpmm_openie", alias="lpmm") + + assert result["success"] is True + assert calls == [ + ( + "memory_import_admin", + {"action": "create_lpmm_openie", "alias": "lpmm"}, + {"timeout_ms": 120000}, + ) + ] + + +@pytest.mark.asyncio +async def test_tuning_admin_uses_long_timeout(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "task_id": "tuning-1"} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.tuning_admin(action="create_task", payload={"query": "mai"}) + + assert result["success"] is True + assert calls == [ + ( + "memory_tuning_admin", + {"action": "create_task", "payload": {"query": "mai"}}, + {"timeout_ms": 120000}, + ) + ] diff --git a/pytests/A_memorix_test/test_person_memory_writeback.py b/pytests/A_memorix_test/test_person_memory_writeback.py new file mode 100644 index 00000000..f177405a --- /dev/null +++ b/pytests/A_memorix_test/test_person_memory_writeback.py @@ -0,0 +1,81 @@ +from types import SimpleNamespace + +import pytest + +from src.person_info import person_info as person_info_module + + +@pytest.mark.asyncio +async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch): + calls = [] + + class FakePerson: + def __init__(self, person_id: str): + self.person_id = person_id + self.person_name = "Alice" + self.is_known = True + + async def fake_ingest_text(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1") + monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session)) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1") + monkeypatch.setattr(person_info_module, "Person", FakePerson) + monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) + + await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1") + + assert len(calls) == 1 + payload = calls[0] + assert payload["external_id"].startswith("person_fact:person-1:") + assert payload["source_type"] == "person_fact" + assert payload["chat_id"] == "session-1" + assert payload["person_ids"] == ["person-1"] + assert payload["participants"] == ["Alice"] + assert payload["respect_filter"] is True + assert payload["user_id"] == "10001" + assert payload["group_id"] == "" + assert payload["metadata"]["person_id"] == "person-1" + + +@pytest.mark.asyncio +async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch): + calls = [] + + class FakePerson: + def __init__(self, person_id: str): + self.person_id = person_id + self.person_name = "Unknown" + self.is_known = False + + async def fake_ingest_text(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1") + monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session)) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1") + monkeypatch.setattr(person_info_module, "Person", FakePerson) + monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) + + await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1") + + assert calls == [] + + +@pytest.mark.asyncio +async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch): + calls = [] + + async def fake_ingest_text(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) + + await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1") + + assert calls == [] + diff --git a/pytests/A_memorix_test/test_query_long_term_memory_tool.py b/pytests/A_memorix_test/test_query_long_term_memory_tool.py new file mode 100644 index 00000000..23310e1f --- /dev/null +++ b/pytests/A_memorix_test/test_query_long_term_memory_tool.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from src.memory_system.retrieval_tools import query_long_term_memory as tool_module +from src.memory_system.retrieval_tools import init_all_tools +from src.memory_system.retrieval_tools.query_long_term_memory import ( + _resolve_time_expression, + query_long_term_memory, + register_tool, +) +from src.memory_system.retrieval_tools.tool_registry import get_tool_registry +from src.services.memory_service import MemoryHit, MemorySearchResult + + +def test_resolve_time_expression_supports_relative_and_absolute_inputs(): + now = datetime(2026, 3, 18, 15, 30) + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) + assert start_text == "2026/03/18 00:00" + assert end_text == "2026/03/18 23:59" + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) + assert start_text == "2026/03/12 00:00" + assert end_text == "2026/03/18 23:59" + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) + assert start_text == "2026/03/18 00:00" + assert end_text == "2026/03/18 23:59" + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30) + assert start_text == "2026/03/18 09:30" + assert end_text == "2026/03/18 09:30" + + +def test_register_tool_exposes_mode_and_time_expression(): + register_tool() + tool = get_tool_registry().get_tool("search_long_term_memory") + + assert tool is not None + params = {item["name"]: item for item in tool.parameters} + assert "mode" in params + assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"] + assert "time_expression" in params + assert params["query"]["required"] is False + + +def test_init_all_tools_registers_long_term_memory_tool(): + init_all_tools() + + tool = get_tool_registry().get_tool("search_long_term_memory") + assert tool is not None + + +@pytest.mark.asyncio +async def test_query_long_term_memory_search_mode_maps_to_hybrid(monkeypatch): + captured = {} + + async def fake_search(query, **kwargs): + captured["query"] = query + captured["kwargs"] = kwargs + return MemorySearchResult( + hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")], + ) + + monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) + + text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1") + + assert "Alice 喜欢猫" in text + assert captured == { + "query": "Alice 喜欢什么", + "kwargs": { + "limit": 5, + "mode": "hybrid", + "chat_id": "stream-1", + "person_id": "person-1", + "time_start": None, + "time_end": None, + }, + } + + +@pytest.mark.asyncio +async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch): + captured = {} + + async def fake_search(query, **kwargs): + captured["query"] = query + captured["kwargs"] = kwargs + return MemorySearchResult( + hits=[ + MemoryHit( + content="昨天晚上广播站停播了十分钟。", + score=0.8, + hit_type="paragraph", + metadata={"event_time_start": 1773797400.0}, + ) + ] + ) + + monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) + monkeypatch.setattr( + tool_module, + "_resolve_time_expression", + lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"), + ) + + text = await query_long_term_memory( + query="广播站", + mode="time", + time_expression="昨天", + chat_id="stream-1", + ) + + assert "指定时间范围" in text + assert "广播站停播" in text + assert captured == { + "query": "广播站", + "kwargs": { + "limit": 5, + "mode": "time", + "chat_id": "stream-1", + "person_id": "", + "time_start": 1773795600.0, + "time_end": 1773881940.0, + }, + } + + +@pytest.mark.asyncio +async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch): + responses = { + "episode": MemorySearchResult( + hits=[ + MemoryHit( + content="苏弦在灯塔拆开了那封冬信。", + title="冬信重见天日", + hit_type="episode", + metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]}, + ) + ] + ), + "aggregate": MemorySearchResult( + hits=[ + MemoryHit( + content="唐未在广播站值夜班时带着黑狗墨点。", + hit_type="paragraph", + metadata={"source_branches": ["search", "time"]}, + ) + ] + ), + } + + async def fake_search(query, **kwargs): + return responses[kwargs["mode"]] + + monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) + + episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode") + aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate") + + assert "事件《冬信重见天日》" in episode_text + assert "参与者:苏弦" in episode_text + assert "[search,time][paragraph]" in aggregate_text + + +@pytest.mark.asyncio +async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message(): + text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周") + + assert "无法解析" in text + assert "最近7天" in text diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py new file mode 100644 index 00000000..71d94a7b --- /dev/null +++ b/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict + +import numpy as np +import pytest +import pytest_asyncio + +from A_memorix.core.runtime import sdk_memory_kernel as kernel_module +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import memory_service + + +DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" + + +def _load_dialogue_fixture() -> Dict[str, Any]: + return json.loads(DATA_FILE.read_text(encoding="utf-8")) + + +class _FakeEmbeddingAdapter: + def __init__(self, dimension: int = 16) -> None: + self.dimension = dimension + + async def _detect_dimension(self) -> int: + return self.dimension + + async def encode(self, texts, dimensions=None): + dim = int(dimensions or self.dimension) + if isinstance(texts, str): + sequence = [texts] + single = True + else: + sequence = list(texts) + single = False + + rows = [] + for text in sequence: + vec = np.zeros(dim, dtype=np.float32) + for ch in str(text or ""): + vec[ord(ch) % dim] += 1.0 + if not vec.any(): + vec[0] = 1.0 + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + rows.append(vec) + payload = np.vstack(rows) + return payload[0] if single else payload + + +class _KernelBackedRuntimeManager: + is_running = True + + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke_plugin( + self, + *, + method: str, + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None, + timeout_ms: int, + ): + del method, plugin_id, timeout_ms + payload = args or {} + if component_name == "search_memory": + return await self.kernel.search_memory( + KernelSearchRequest( + query=str(payload.get("query", "") or ""), + limit=int(payload.get("limit", 5) or 5), + mode=str(payload.get("mode", "hybrid") or "hybrid"), + chat_id=str(payload.get("chat_id", "") or ""), + person_id=str(payload.get("person_id", "") or ""), + time_start=payload.get("time_start"), + time_end=payload.get("time_end"), + respect_filter=bool(payload.get("respect_filter", True)), + user_id=str(payload.get("user_id", "") or ""), + group_id=str(payload.get("group_id", "") or ""), + ) + ) + + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]: + for _ in range(max_rounds): + detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) + task = detail.get("task") or {} + status = str(task.get("status", "") or "") + if status in {"completed", "completed_with_errors", "failed", "cancelled"}: + return detail + await asyncio.sleep(0.05) + raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") + + +def _join_hit_content(search_result) -> str: + return "\n".join(hit.content for hit in search_result.hits) + + +@pytest_asyncio.fixture +async def real_dialogue_env(monkeypatch, tmp_path): + scenario = _load_dialogue_fixture() + session_cfg = scenario["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) + + async def fake_self_check(**kwargs): + return {"ok": True, "message": "ok"} + + monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + + data_dir = (tmp_path / "a_memorix_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "scenario": scenario, + "kernel": kernel, + "session": session, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + + created = await memory_service.import_admin( + action="create_paste", + name="private_alice.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(scenario["import_payload"], ensure_ascii=False), + ) + + assert created["success"] is True + detail = await _wait_for_import_task(created["task"]["task_id"]) + assert detail["task"]["status"] == "completed" + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + respect_filter=False, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=real_dialogue_env["session"].session_id, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch): + scenario = real_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + real_dialogue_env["session"].session_id, + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=real_dialogue_env["session"].session_id, + person_id=scenario["person"]["person_id"], + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + + await memory_service.ingest_text( + external_id="fixture:knowledge_fetcher", + source_type="dialogue_note", + text=scenario["person_fact"]["memory_content"], + chat_id=real_dialogue_env["session"].session_id, + person_ids=[scenario["person"]["person_id"]], + participants=[scenario["person"]["person_name"]], + respect_filter=False, + ) + + fetcher = knowledge_module.KnowledgeFetcher( + private_name=scenario["session"]["display_name"], + stream_id=real_dialogue_env["session"].session_id, + ) + knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) + + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in knowledge_text + + +@pytest.mark.asyncio +async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch): + scenario = real_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + real_dialogue_env["session"].session_id, + ) + + profile = await memory_service.get_person_profile( + scenario["person"]["person_id"], + chat_id=real_dialogue_env["session"].session_id, + ) + + assert profile.evidence + assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) + + +@pytest.mark.asyncio +async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + episodes = await memory_service.episode_admin( + action="query", + source=scenario["expectations"]["episode_source"], + limit=5, + ) + + assert episodes["success"] is True + assert int(episodes["count"]) >= 1 diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py new file mode 100644 index 00000000..808d4c23 --- /dev/null +++ b/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict + +import pytest +import pytest_asyncio + +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import memory_service + + +pytestmark = pytest.mark.skipif( + os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", + reason="需要显式开启真实 embedding / self-check 集成测试", +) + +DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" + + +def _load_dialogue_fixture() -> Dict[str, Any]: + return json.loads(DATA_FILE.read_text(encoding="utf-8")) + + +class _KernelBackedRuntimeManager: + is_running = True + + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke_plugin( + self, + *, + method: str, + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None, + timeout_ms: int, + ): + del method, plugin_id, timeout_ms + payload = args or {} + if component_name == "search_memory": + return await self.kernel.search_memory( + KernelSearchRequest( + query=str(payload.get("query", "") or ""), + limit=int(payload.get("limit", 5) or 5), + mode=str(payload.get("mode", "hybrid") or "hybrid"), + chat_id=str(payload.get("chat_id", "") or ""), + person_id=str(payload.get("person_id", "") or ""), + time_start=payload.get("time_start"), + time_end=payload.get("time_end"), + respect_filter=bool(payload.get("respect_filter", True)), + user_id=str(payload.get("user_id", "") or ""), + group_id=str(payload.get("group_id", "") or ""), + ) + ) + + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]: + deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds)) + while asyncio.get_running_loop().time() < deadline: + detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) + task = detail.get("task") or {} + status = str(task.get("status", "") or "") + if status in {"completed", "completed_with_errors", "failed", "cancelled"}: + return detail + await asyncio.sleep(0.2) + raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") + + +def _join_hit_content(search_result) -> str: + return "\n".join(hit.content for hit in search_result.hits) + + +@pytest_asyncio.fixture +async def live_dialogue_env(monkeypatch, tmp_path): + scenario = _load_dialogue_fixture() + session_cfg = scenario["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + + data_dir = (tmp_path / "a_memorix_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "scenario": scenario, + "kernel": kernel, + "session": session, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_live_runtime_self_check_passes(live_dialogue_env): + report = await memory_service.runtime_admin(action="refresh_self_check") + + assert report["success"] is True + assert report["report"]["ok"] is True + assert report["report"]["encoded_dimension"] > 0 + + +@pytest.mark.asyncio +async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + + created = await memory_service.import_admin( + action="create_paste", + name="private_alice.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(scenario["import_payload"], ensure_ascii=False), + ) + + assert created["success"] is True + detail = await _wait_for_import_task(created["task"]["task_id"]) + assert detail["task"]["status"] == "completed" + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + respect_filter=False, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=live_dialogue_env["session"].session_id, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch): + scenario = live_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + live_dialogue_env["session"].session_id, + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=live_dialogue_env["session"].session_id, + person_id=scenario["person"]["person_id"], + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + + await memory_service.ingest_text( + external_id="fixture:knowledge_fetcher", + source_type="dialogue_note", + text=scenario["person_fact"]["memory_content"], + chat_id=live_dialogue_env["session"].session_id, + person_ids=[scenario["person"]["person_id"]], + participants=[scenario["person"]["person_name"]], + respect_filter=False, + ) + + fetcher = knowledge_module.KnowledgeFetcher( + private_name=scenario["session"]["display_name"], + stream_id=live_dialogue_env["session"].session_id, + ) + knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) + + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in knowledge_text + + +@pytest.mark.asyncio +async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch): + scenario = live_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + live_dialogue_env["session"].session_id, + ) + + profile = await memory_service.get_person_profile( + scenario["person"]["person_id"], + chat_id=live_dialogue_env["session"].session_id, + ) + + assert profile.evidence + assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) + + +@pytest.mark.asyncio +async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + episodes = await memory_service.episode_admin( + action="query", + source=scenario["expectations"]["episode_source"], + limit=5, + ) + + assert episodes["success"] is True + assert int(episodes["count"]) >= 1 diff --git a/pytests/webui/test_memory_routes.py b/pytests/webui/test_memory_routes.py new file mode 100644 index 00000000..d66a8333 --- /dev/null +++ b/pytests/webui/test_memory_routes.py @@ -0,0 +1,279 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient +import pytest + +from src.services.memory_service import MemorySearchResult +from src.webui.dependencies import require_auth +from src.webui.routers import memory as memory_router_module +from src.webui.routers.memory import compat_router, router + + +@pytest.fixture +def client() -> TestClient: + app = FastAPI() + app.dependency_overrides[require_auth] = lambda: "ok" + app.include_router(router) + app.include_router(compat_router) + return TestClient(app) + + +def test_webui_memory_graph_route(client: TestClient, monkeypatch): + async def fake_graph_admin(*, action: str, **kwargs): + assert action == "get_graph" + return {"success": True, "nodes": [], "edges": [], "total_nodes": 0, "limit": kwargs.get("limit")} + + monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) + + response = client.get("/api/webui/memory/graph", params={"limit": 77}) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["limit"] == 77 + + +def test_compat_aggregate_route(client: TestClient, monkeypatch): + async def fake_search(query: str, **kwargs): + assert kwargs["mode"] == "aggregate" + assert kwargs["respect_filter"] is False + return MemorySearchResult(summary=f"summary:{query}", hits=[]) + + monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search) + + response = client.get("/api/query/aggregate", params={"query": "mai"}) + + assert response.status_code == 200 + assert response.json() == {"success": True, "summary": "summary:mai", "hits": [], "filtered": False} + + +def test_auto_save_routes(client: TestClient, monkeypatch): + async def fake_runtime_admin(*, action: str, **kwargs): + if action == "get_config": + return {"success": True, "auto_save": True} + if action == "set_auto_save": + return {"success": True, "auto_save": kwargs["enabled"]} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin) + + get_response = client.get("/api/config/auto_save") + post_response = client.post("/api/config/auto_save", json={"enabled": False}) + + assert get_response.status_code == 200 + assert get_response.json() == {"success": True, "auto_save": True} + assert post_response.status_code == 200 + assert post_response.json() == {"success": True, "auto_save": False} + + +def test_recycle_bin_route(client: TestClient, monkeypatch): + async def fake_get_recycle_bin(*, limit: int): + return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit} + + monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin) + + response = client.get("/api/memory/recycle_bin", params={"limit": 10}) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["count"] == 1 + assert response.json()["limit"] == 10 + + +def test_import_guide_route(client: TestClient, monkeypatch): + async def fake_import_admin(*, action: str, **kwargs): + assert kwargs == {} + if action == "get_guide": + return {"success": True} + if action == "get_settings": + return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) + + response = client.get("/api/webui/memory/import/guide") + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["source"] == "local" + assert "长期记忆导入说明" in response.json()["content"] + + +def test_import_upload_route(client: TestClient, monkeypatch, tmp_path): + monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path) + + async def fake_import_admin(*, action: str, **kwargs): + assert action == "create_upload" + staged_files = kwargs["staged_files"] + assert len(staged_files) == 1 + assert staged_files[0]["filename"] == "demo.txt" + assert memory_router_module.Path(staged_files[0]["staged_path"]).exists() + return {"success": True, "task_id": "task-1"} + + monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) + + response = client.post( + "/api/import/upload", + data={"payload_json": "{\"source\": \"upload\"}"}, + files=[("files", ("demo.txt", b"hello world", "text/plain"))], + ) + + assert response.status_code == 200 + assert response.json() == {"success": True, "task_id": "task-1"} + assert list(tmp_path.iterdir()) == [] + + +def test_v5_status_route(client: TestClient, monkeypatch): + async def fake_v5_admin(*, action: str, **kwargs): + assert action == "status" + assert kwargs["target"] == "mai" + return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3} + + monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin) + + response = client.get("/api/webui/memory/v5/status", params={"target": "mai"}) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["deleted_count"] == 3 + + +def test_delete_preview_route(client: TestClient, monkeypatch): + async def fake_delete_admin(*, action: str, **kwargs): + assert action == "preview" + assert kwargs["mode"] == "paragraph" + assert kwargs["selector"] == {"query": "demo"} + return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True} + + monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) + + response = client.post( + "/api/webui/memory/delete/preview", + json={"mode": "paragraph", "selector": {"query": "demo"}}, + ) + + assert response.status_code == 200 + assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True} + + +def test_episode_process_pending_route(client: TestClient, monkeypatch): + async def fake_episode_admin(*, action: str, **kwargs): + assert action == "process_pending" + assert kwargs == {"limit": 7, "max_retry": 4} + return {"success": True, "processed": 3} + + monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin) + + response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4}) + + assert response.status_code == 200 + assert response.json() == {"success": True, "processed": 3} + + +def test_import_list_route_includes_settings(client: TestClient, monkeypatch): + calls = [] + + async def fake_import_admin(*, action: str, **kwargs): + calls.append((action, kwargs)) + if action == "list": + return {"success": True, "items": [{"task_id": "task-1"}]} + if action == "get_settings": + return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) + + response = client.get("/api/webui/memory/import/tasks", params={"limit": 9}) + + assert response.status_code == 200 + assert response.json()["items"] == [{"task_id": "task-1"}] + assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}} + assert calls == [("list", {"limit": 9}), ("get_settings", {})] + + +def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch): + calls = [] + + async def fake_tuning_admin(*, action: str, **kwargs): + calls.append((action, kwargs)) + if action == "get_profile": + return {"success": True, "profile": {"retrieval": {"top_k": 8}}} + if action == "get_settings": + return {"success": True, "settings": {"profiles": ["default"]}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin) + + response = client.get("/api/webui/memory/retrieval_tuning/profile") + + assert response.status_code == 200 + assert response.json()["profile"] == {"retrieval": {"top_k": 8}} + assert response.json()["settings"] == {"profiles": ["default"]} + assert calls == [("get_profile", {}), ("get_settings", {})] + + +def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch): + async def fake_tuning_admin(*, action: str, **kwargs): + assert action == "get_report" + assert kwargs == {"task_id": "task-1", "format": "json"} + return { + "success": True, + "report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"}, + } + + monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin) + + response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"}) + + assert response.status_code == 200 + assert response.json() == { + "success": True, + "format": "json", + "content": "{\"ok\": true}", + "path": "/tmp/report.json", + "error": "", + } + + +def test_delete_execute_route(client: TestClient, monkeypatch): + async def fake_delete_admin(*, action: str, **kwargs): + assert action == "execute" + assert kwargs["mode"] == "source" + assert kwargs["selector"] == {"source": "chat_summary:stream-1"} + assert kwargs["reason"] == "cleanup" + assert kwargs["requested_by"] == "tester" + return {"success": True, "operation_id": "del-1"} + + monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) + + response = client.post( + "/api/webui/memory/delete/execute", + json={ + "mode": "source", + "selector": {"source": "chat_summary:stream-1"}, + "reason": "cleanup", + "requested_by": "tester", + }, + ) + + assert response.status_code == 200 + assert response.json() == {"success": True, "operation_id": "del-1"} + + +def test_delete_operation_routes(client: TestClient, monkeypatch): + async def fake_delete_admin(*, action: str, **kwargs): + if action == "list_operations": + assert kwargs == {"limit": 5, "mode": "paragraph"} + return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1} + if action == "get_operation": + assert kwargs == {"operation_id": "del-1"} + return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) + + list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"}) + get_response = client.get("/api/webui/memory/delete/operations/del-1") + + assert list_response.status_code == 200 + assert list_response.json()["count"] == 1 + assert get_response.status_code == 200 + assert get_response.json()["operation"]["operation_id"] == "del-1" diff --git a/src/bw_learner/jargon_explainer_old.py b/src/bw_learner/jargon_explainer_old.py index 4d144b2c..94031b4a 100644 --- a/src/bw_learner/jargon_explainer_old.py +++ b/src/bw_learner/jargon_explainer_old.py @@ -7,7 +7,7 @@ from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.jargon_miner_old import search_jargon +from src.bw_learner.jargon_explainer import search_jargon from src.bw_learner.learner_utils_old import ( is_bot_message, contains_bot_self_name, diff --git a/src/bw_learner/learner_utils_old.py b/src/bw_learner/learner_utils_old.py index 3f21c55d..6095ef48 100644 --- a/src/bw_learner/learner_utils_old.py +++ b/src/bw_learner/learner_utils_old.py @@ -196,6 +196,32 @@ def contains_bot_self_name(content: str) -> bool: return any(name in target for name in candidates) +def is_bot_message(msg: Any) -> bool: + """判断消息是否来自机器人自身。""" + if msg is None: + return False + + bot_config = getattr(global_config, "bot", None) + if not bot_config: + return False + + user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() + if not user_id: + return False + + known_accounts = { + str(getattr(bot_config, "qq_account", "") or "").strip(), + str(getattr(bot_config, "telegram_account", "") or "").strip(), + } + + for platform in getattr(bot_config, "platforms", []) or []: + account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip() + if account: + known_accounts.add(account) + + return user_id in {account for account in known_accounts if account} + + # def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: # """ # 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 diff --git a/src/chat/brain_chat/PFC/conversation.py b/src/chat/brain_chat/PFC/conversation.py index 1e1e89b1..ab5a7b3d 100644 --- a/src/chat/brain_chat/PFC/conversation.py +++ b/src/chat/brain_chat/PFC/conversation.py @@ -55,7 +55,7 @@ class Conversation: self.action_planner = ActionPlanner(self.stream_id, self.private_name) self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name) self.reply_generator = ReplyGenerator(self.stream_id, self.private_name) - self.knowledge_fetcher = KnowledgeFetcher(self.private_name) + self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id) self.waiter = Waiter(self.stream_id, self.private_name) self.direct_sender = DirectMessageSender(self.private_name) diff --git a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py index 67509bd5..4d47f609 100644 --- a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py +++ b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py @@ -1,11 +1,14 @@ -from typing import List, Tuple, Dict, Any +from typing import Any, Dict, List, Tuple + +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.common.logger import get_logger # NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned # from src.plugins.memory_system.Hippocampus import HippocampusManager -from src.llm_models.utils_model import LLMRequest from src.config.config import model_config -from src.chat.knowledge import qa_manager +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import resolve_person_id_for_memory +from src.services.memory_service import memory_service logger = get_logger("knowledge_fetcher") @@ -13,11 +16,39 @@ logger = get_logger("knowledge_fetcher") class KnowledgeFetcher: """知识调取器""" - def __init__(self, private_name: str): + def __init__(self, private_name: str, stream_id: str): self.llm = LLMRequest(model_set=model_config.model_task_config.utils) self.private_name = private_name + self.stream_id = stream_id - def _lpmm_get_knowledge(self, query: str) -> str: + def _resolve_private_memory_context(self) -> Dict[str, str]: + session = _chat_manager.get_session_by_session_id(self.stream_id) + if session is None: + return {"chat_id": self.stream_id} + + group_id = str(getattr(session, "group_id", "") or "").strip() + user_id = str(getattr(session, "user_id", "") or "").strip() + platform = str(getattr(session, "platform", "") or "").strip() + + person_id = "" + if not group_id: + try: + person_id = resolve_person_id_for_memory( + person_name=self.private_name, + platform=platform, + user_id=user_id, + ) + except Exception as exc: + logger.debug(f"[私聊][{self.private_name}]解析人物ID失败: {exc}") + + return { + "chat_id": self.stream_id, + "person_id": person_id, + "user_id": user_id, + "group_id": group_id, + } + + async def _memory_get_knowledge(self, query: str) -> str: """获取相关知识 Args: @@ -27,13 +58,32 @@ class KnowledgeFetcher: str: 构造好的,带相关度的知识 """ - logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识") + logger.debug(f"[私聊][{self.private_name}]正在从长期记忆中获取知识") try: - knowledge_info = qa_manager.get_knowledge(query) - logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}") - return knowledge_info + context = self._resolve_private_memory_context() + search_kwargs = { + "limit": 5, + "mode": "search", + "chat_id": context.get("chat_id", ""), + "person_id": context.get("person_id", ""), + "user_id": context.get("user_id", ""), + "group_id": context.get("group_id", ""), + "respect_filter": True, + } + result = await memory_service.search(query, **search_kwargs) + if not result.filtered and not result.hits and search_kwargs["person_id"]: + fallback_kwargs = dict(search_kwargs) + fallback_kwargs["person_id"] = "" + logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆") + result = await memory_service.search(query, **fallback_kwargs) + knowledge_info = result.to_text(limit=5) + if result.filtered: + logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过") + else: + logger.debug(f"[私聊][{self.private_name}]长期记忆查询结果: {knowledge_info[:150]}") + return knowledge_info or "未找到匹配的知识" except Exception as e: - logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}") + logger.error(f"[私聊][{self.private_name}]长期记忆搜索工具执行失败: {str(e)}") return "未找到匹配的知识" async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]: @@ -72,7 +122,7 @@ class KnowledgeFetcher: # sources_text = ",".join(sources) knowledge_text += "\n现在有以下**知识**可供参考:\n " - knowledge_text += self._lpmm_get_knowledge(query) + knowledge_text += await self._memory_get_knowledge(query) knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n" return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配" diff --git a/src/chat/knowledge/__init__.py b/src/chat/knowledge/__init__.py deleted file mode 100644 index 57e94472..00000000 --- a/src/chat/knowledge/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.qa_manager import QAManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.global_logger import logger -from src.config.config import global_config -import os - -INVALID_ENTITY = [ - "", - "你", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "她们", - "它们", -] - -RAG_GRAPH_NAMESPACE = "rag-graph" -RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" -RAG_PG_HASH_NAMESPACE = "rag-pg-hash" - - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -DATA_PATH = os.path.join(ROOT_PATH, "data") - - -qa_manager = None -inspire_manager = None - - -def get_qa_manager(): - return qa_manager - - -def lpmm_start_up(): # sourcery skip: extract-duplicate-method - # 检查LPMM知识库是否启用 - if global_config.lpmm_knowledge.enable: - logger.info("正在初始化Mai-LPMM") - logger.info("创建LLM客户端") - - # 初始化Embedding库 - embed_manager = EmbeddingManager( - max_workers=global_config.lpmm_knowledge.max_embedding_workers, - chunk_size=global_config.lpmm_knowledge.embedding_chunk_size, - ) - logger.info("正在从文件加载Embedding库") - try: - embed_manager.load_from_file() - except Exception as e: - logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") - logger.info("Embedding库加载完成") - # 初始化KG - kg_manager = KGManager() - logger.info("正在从文件加载KG") - try: - kg_manager.load_from_file() - except Exception as e: - logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") - logger.info("KG加载完成") - - logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") - logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") - - # 数据比对:Embedding库与KG的段落hash集合 - for pg_hash in kg_manager.stored_paragraph_hashes: - # 使用与EmbeddingStore中一致的命名空间格式 - key = f"paragraph-{pg_hash}" - if key not in embed_manager.stored_pg_hashes: - logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") - global qa_manager - # 问答系统(用于知识库) - qa_manager = QAManager( - embed_manager, - kg_manager, - ) - - # # 记忆激活(用于记忆库) - # global inspire_manager - # inspire_manager = MemoryActiveManager( - # embed_manager, - # llm_client_list[global_config["embedding"]["provider"]], - # ) - else: - logger.info("LPMM知识库已禁用,跳过初始化") - # 创建空的占位符对象,避免导入错误 diff --git a/src/chat/knowledge/lpmm_ops.py b/src/chat/knowledge/lpmm_ops.py deleted file mode 100644 index acaac4ca..00000000 --- a/src/chat/knowledge/lpmm_ops.py +++ /dev/null @@ -1,380 +0,0 @@ -import asyncio -import os -from functools import partial -from typing import List, Callable, Any -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.qa_manager import QAManager -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.knowledge import get_qa_manager, lpmm_start_up - -logger = get_logger("LPMM-Plugin-API") - - -class LPMMOperations: - """ - LPMM 内部操作接口。 - 封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。 - """ - - def __init__(self): - self._initialized = False - - async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any: - """ - 在线程池中执行可取消的同步操作。 - 当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。 - 注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。 - - Args: - func: 要执行的同步函数 - *args: 函数的位置参数 - **kwargs: 函数的关键字参数 - - Returns: - 函数的返回值 - - Raises: - asyncio.CancelledError: 当任务被取消时 - """ - loop = asyncio.get_event_loop() - # 在线程池中执行,当协程被取消时会立即响应 - # 虽然线程池中的操作可能仍在运行,但协程不会阻塞 - return await loop.run_in_executor(None, func, *args, **kwargs) - - async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]: - """获取并确保 LPMM 管理器已初始化""" - qa_mgr = get_qa_manager() - if qa_mgr is None: - # 如果全局没初始化,尝试初始化 - if not global_config.lpmm_knowledge.enable: - logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。") - - lpmm_start_up() - qa_mgr = get_qa_manager() - - if qa_mgr is None: - raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。") - - return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr - - async def add_content(self, text: str, auto_split: bool = True) -> dict: - """ - 向知识库添加新内容。 - - Args: - text: 原始文本。 - auto_split: 是否自动按双换行符分割段落。 - - True: 自动分割(默认),支持多段文本(用双换行分隔) - - False: 不分割,将整个文本作为完整一段处理 - - Returns: - dict: {"status": "success/error", "count": 导入段落数, "message": "描述"} - """ - try: - embed_mgr, kg_mgr, _ = await self._get_managers() - - # 1. 分段处理 - if auto_split: - # 自动按双换行符分割 - paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] - else: - # 不分割,作为完整一段 - text_stripped = text.strip() - if not text_stripped: - return {"status": "error", "message": "文本内容为空"} - paragraphs = [text_stripped] - - if not paragraphs: - return {"status": "error", "message": "文本内容为空"} - - # 2. 实体与三元组抽取 (内部调用大模型) - from src.chat.knowledge.ie_process import IEProcess - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config - - llm_ner = LLMRequest( - model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" - ) - llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") - ie_process = IEProcess(llm_ner, llm_rdf) - - logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...") - extracted_docs = await ie_process.process_paragraphs(paragraphs) - - # 3. 构造并导入数据 - # 这里我们手动实现导入逻辑,不依赖外部脚本 - # a. 准备段落 - raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs} - # b. 准备三元组 - triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs} - - # 向量化并入库 - # 注意:此处模仿 import_openie.py 的核心逻辑 - # 1. 先进行去重检查,只处理新段落 - # store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本 - new_raw_paragraphs = {} - new_triple_list_data = {} - - for pg_hash, passage in raw_paragraphs.items(): - key = f"paragraph-{pg_hash}" - if key not in embed_mgr.stored_pg_hashes: - new_raw_paragraphs[pg_hash] = passage - new_triple_list_data[pg_hash] = triple_list_data[pg_hash] - - if not new_raw_paragraphs: - return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"} - - # 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入 - # store_new_data_set 会自动处理嵌入生成和存储 - # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data) - - # 3. 构建知识图谱(只需要三元组数据和embedding_manager) - await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr) - - # 4. 持久化 - await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) - await self._run_cancellable_executor(embed_mgr.save_to_file) - await self._run_cancellable_executor(kg_mgr.save_to_file) - - return { - "status": "success", - "count": len(new_raw_paragraphs), - "message": f"成功导入 {len(new_raw_paragraphs)} 条知识", - } - - except asyncio.CancelledError: - logger.warning("[Plugin API] 导入操作被用户中断") - return {"status": "cancelled", "message": "导入操作已被用户中断"} - except Exception as e: - logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True) - return {"status": "error", "message": str(e)} - - async def search(self, query: str, top_k: int = 3) -> List[str]: - """ - 检索知识库。 - - Args: - query: 查询问题。 - top_k: 返回最相关的条目数。 - - Returns: - List[str]: 相关文段列表。 - """ - try: - _, _, qa_mgr = await self._get_managers() - # 直接调用 QAManager 的检索接口 - knowledge = qa_mgr.get_knowledge(query, top_k=top_k) - # 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回 - return [knowledge] if knowledge else [] - except Exception as e: - logger.error(f"[Plugin API] 检索知识失败: {e}") - return [] - - async def delete(self, keyword: str, exact_match: bool = False) -> dict: - """ - 根据关键词或完整文段删除知识库内容。 - - Args: - keyword: 匹配关键词或完整文段。 - exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。 - - Returns: - dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"} - """ - try: - embed_mgr, kg_mgr, _ = await self._get_managers() - - # 1. 查找匹配的段落 - to_delete_keys = [] - to_delete_hashes = [] - - for key, item in embed_mgr.paragraphs_embedding_store.store.items(): - if exact_match: - # 完整文段匹配 - if item.str.strip() == keyword.strip(): - to_delete_keys.append(key) - to_delete_hashes.append(key.replace("paragraph-", "", 1)) - else: - # 关键词模糊匹配 - if keyword in item.str: - to_delete_keys.append(key) - to_delete_hashes.append(key.replace("paragraph-", "", 1)) - - if not to_delete_keys: - match_type = "完整文段" if exact_match else "关键词" - return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"} - - # 2. 执行删除 - # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - - # a. 从向量库删除 - deleted_count, _ = await self._run_cancellable_executor( - embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys - ) - embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys()) - - # b. 从知识图谱删除 - # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 - # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs - delete_func = partial( - kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True - ) - await self._run_cancellable_executor(delete_func) - - # 3. 持久化 - await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) - await self._run_cancellable_executor(embed_mgr.save_to_file) - await self._run_cancellable_executor(kg_mgr.save_to_file) - - match_type = "完整文段" if exact_match else "关键词" - return { - "status": "success", - "deleted_count": deleted_count, - "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)", - } - - except asyncio.CancelledError: - logger.warning("[Plugin API] 删除操作被用户中断") - return {"status": "cancelled", "message": "删除操作已被用户中断"} - except Exception as e: - logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True) - return {"status": "error", "message": str(e)} - - async def clear_all(self) -> dict: - """ - 清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。 - - Returns: - dict: {"status": "success/error", "message": "描述", "stats": {...}} - """ - try: - embed_mgr, kg_mgr, _ = await self._get_managers() - - # 记录清空前的统计信息 - before_stats = { - "paragraphs": len(embed_mgr.paragraphs_embedding_store.store), - "entities": len(embed_mgr.entities_embedding_store.store), - "relations": len(embed_mgr.relation_embedding_store.store), - "kg_nodes": len(kg_mgr.graph.get_node_list()), - "kg_edges": len(kg_mgr.graph.get_edge_list()), - } - - # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - - # 1. 清空所有向量库 - # 获取所有keys - para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys()) - ent_keys = list(embed_mgr.entities_embedding_store.store.keys()) - rel_keys = list(embed_mgr.relation_embedding_store.store.keys()) - - # 删除所有段落向量 - para_deleted, _ = await self._run_cancellable_executor( - embed_mgr.paragraphs_embedding_store.delete_items, para_keys - ) - embed_mgr.stored_pg_hashes.clear() - - # 删除所有实体向量 - if ent_keys: - ent_deleted, _ = await self._run_cancellable_executor( - embed_mgr.entities_embedding_store.delete_items, ent_keys - ) - else: - ent_deleted = 0 - - # 删除所有关系向量 - if rel_keys: - rel_deleted, _ = await self._run_cancellable_executor( - embed_mgr.relation_embedding_store.delete_items, rel_keys - ) - else: - rel_deleted = 0 - - # 2. 清空所有 embedding store 的索引和映射 - # 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件 - def _clear_embedding_indices(): - # 清空段落索引 - embed_mgr.paragraphs_embedding_store.faiss_index = None - embed_mgr.paragraphs_embedding_store.idx2hash = None - embed_mgr.paragraphs_embedding_store.dirty = False - # 删除旧的索引文件 - if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path): - os.remove(embed_mgr.paragraphs_embedding_store.index_file_path) - if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path): - os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path) - - # 清空实体索引 - embed_mgr.entities_embedding_store.faiss_index = None - embed_mgr.entities_embedding_store.idx2hash = None - embed_mgr.entities_embedding_store.dirty = False - # 删除旧的索引文件 - if os.path.exists(embed_mgr.entities_embedding_store.index_file_path): - os.remove(embed_mgr.entities_embedding_store.index_file_path) - if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path): - os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path) - - # 清空关系索引 - embed_mgr.relation_embedding_store.faiss_index = None - embed_mgr.relation_embedding_store.idx2hash = None - embed_mgr.relation_embedding_store.dirty = False - # 删除旧的索引文件 - if os.path.exists(embed_mgr.relation_embedding_store.index_file_path): - os.remove(embed_mgr.relation_embedding_store.index_file_path) - if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path): - os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path) - - await self._run_cancellable_executor(_clear_embedding_indices) - - # 3. 清空知识图谱 - # 获取所有段落hash - all_pg_hashes = list(kg_mgr.stored_paragraph_hashes) - if all_pg_hashes: - # 删除所有段落节点(这会自动清理相关的边和孤立实体) - # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 - # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs - delete_func = partial( - kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True - ) - await self._run_cancellable_executor(delete_func) - - # 完全清空KG:创建新的空图(无论是否有段落hash都要执行) - from quick_algo import di_graph - - kg_mgr.graph = di_graph.DiGraph() - kg_mgr.stored_paragraph_hashes.clear() - kg_mgr.ent_appear_cnt.clear() - - # 4. 保存所有数据(此时所有store都是空的,索引也是None) - # 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的 - await self._run_cancellable_executor(embed_mgr.save_to_file) - await self._run_cancellable_executor(kg_mgr.save_to_file) - - after_stats = { - "paragraphs": len(embed_mgr.paragraphs_embedding_store.store), - "entities": len(embed_mgr.entities_embedding_store.store), - "relations": len(embed_mgr.relation_embedding_store.store), - "kg_nodes": len(kg_mgr.graph.get_node_list()), - "kg_edges": len(kg_mgr.graph.get_edge_list()), - } - - return { - "status": "success", - "message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)", - "stats": { - "before": before_stats, - "after": after_stats, - }, - } - - except asyncio.CancelledError: - logger.warning("[Plugin API] 清空操作被用户中断") - return {"status": "cancelled", "message": "清空操作已被用户中断"} - except Exception as e: - logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True) - return {"status": "error", "message": str(e)} - - -# 内部使用的单例 -lpmm_ops = LPMMOperations() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 60586406..df7d28fc 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -360,6 +360,12 @@ class ChatBot: user_id = user_info.user_id group_id = group_info.group_id if group_info else None _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在 + try: + from src.services.memory_flow_service import memory_automation_service + + await memory_automation_service.on_incoming_message(message) + except Exception as exc: + logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}") # message.update_chat_stream(chat) diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 894af238..369c0c51 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -383,6 +383,13 @@ class UniversalMessageSender: with get_db_session() as db_session: db_session.add(message.to_db_instance()) + try: + from src.services.memory_flow_service import memory_automation_service + + await memory_automation_service.on_message_sent(message) + except Exception as exc: + logger.warning(f"[{chat_id}] 长期记忆人物事实写回注册失败: {exc}") + return sent_msg except Exception as e: diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 74b324be..003009b8 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1,7 +1,6 @@ import traceback import time import asyncio -import importlib import random import re @@ -36,6 +35,7 @@ from src.services import llm_service as llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt +from src.memory_system.retrieval_tools import get_tool_registry from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon from src.chat.utils.common_utils import TempMethodsExpression @@ -1164,29 +1164,14 @@ class DefaultReplyer: async def get_prompt_info(self, message: str, sender: str, target: str): related_info = "" start_time = time.time() - try: - knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge") - except ImportError: - logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容") - return "" - - search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None) + search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory") if search_knowledge_tool is None: - logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容") + logger.debug("长期记忆检索工具未注册,跳过获取知识内容") return "" - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - # 从LPMM知识库获取知识 + logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}") try: - # 检查LPMM知识库是否启用 - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用,跳过获取知识库内容") - return "" - - if global_config.lpmm_knowledge.lpmm_mode == "agent": - return "" - - template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge") + template_prompt = prompt_manager.get_prompt("memory_get_knowledge") template_prompt.add_context("bot_name", global_config.bot.nickname) template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) template_prompt.add_context("chat_history", message) @@ -1202,24 +1187,31 @@ class DefaultReplyer: # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") - if tool_calls: - result = await self.tool_executor.execute_tool_call(tool_calls[0]) - end_time = time.time() - if not result or not result.get("content"): - logger.debug("从LPMM知识库获取知识失败,返回空知识...") - return "" - found_knowledge_from_lpmm = result.get("content", "") - logger.info( - f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" - ) - related_info += found_knowledge_from_lpmm - logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - - return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" - else: - logger.debug("模型认为不需要使用LPMM知识库") + if not tool_calls: + logger.debug("模型认为不需要使用长期记忆") return "" + + related_chunks: List[str] = [] + for tool_call in tool_calls: + if tool_call.func_name != "search_long_term_memory": + continue + tool_args = dict(tool_call.args or {}) + tool_args.setdefault("chat_id", self.chat_stream.session_id) + result_text = await search_knowledge_tool.execute(**tool_args) + if result_text and "未找到" not in result_text: + related_chunks.append(result_text) + + if not related_chunks: + logger.debug("长期记忆未返回有效信息") + return "" + + related_info = "\n".join(related_chunks) + end_time = time.time() + logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") + logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + + return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "" diff --git a/src/config/config.py b/src/config/config.py index a3b81d2d..fcda4d01 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config" BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.1.0" +CONFIG_VERSION: str = "8.1.1" MODEL_CONFIG_VERSION: str = "1.12.0" logger = get_logger("config") diff --git a/src/config/legacy_migration.py b/src/config/legacy_migration.py index 7baaa03e..7b400f82 100644 --- a/src/config/legacy_migration.py +++ b/src/config/legacy_migration.py @@ -94,6 +94,11 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool: ["", "enable", "enable", "enable"], ["qq:1919810:group", "enable", "enable", "enable"], ] + 兼容旧旧格式: + learning_list = [ + ["qq:1919810:group", "enable", "enable", "0.5"], + ["", "disable", "disable", "0.1"], + ] 新: [[expression.learning_list]] platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true @@ -117,6 +122,16 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool: use_expression = _parse_enable_disable(r[1]) enable_learning = _parse_enable_disable(r[2]) enable_jargon_learning = _parse_enable_disable(r[3]) + if enable_jargon_learning is None: + # 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值, + # 当前 schema 已没有对应字段。这里按保守策略兼容迁移: + # 丢弃旧数值,并将 enable_jargon_learning 置为 False。 + try: + float(str(r[3])) + except (TypeError, ValueError): + pass + else: + enable_jargon_learning = False if use_expression is None or enable_learning is None or enable_jargon_learning is None: return False diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 2de01030..0b681748 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -416,6 +416,24 @@ class MemoryConfig(ConfigBase): ) """_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索""" + long_term_auto_summary_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "book-open", + }, + ) + """是否自动启动聊天总结并导入长期记忆""" + + person_fact_writeback_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "user-round-pen", + }, + ) + """是否在发送回复后自动提取并写回人物事实到长期记忆""" + chat_history_topic_check_message_threshold: int = Field( default=80, ge=1, diff --git a/src/main.py b/src/main.py index 91da2d83..059aee62 100644 --- a/src/main.py +++ b/src/main.py @@ -6,7 +6,6 @@ import time from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.knowledge import lpmm_start_up from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.chat_manager import chat_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask @@ -19,6 +18,7 @@ from src.config.config import config_manager, global_config from src.manager.async_task_manager import async_task_manager from src.plugin_runtime.integration import get_plugin_runtime_manager from src.prompt.prompt_manager import prompt_manager +from src.services.memory_flow_service import memory_automation_service # from src.api.main import start_api_server @@ -88,9 +88,6 @@ class MainSystem: # start_api_server() # logger.info("API服务器启动成功") - # 启动LPMM - lpmm_start_up() - # 启动插件运行时(内置插件 + 第三方插件双子进程) await get_plugin_runtime_manager().start() @@ -103,6 +100,7 @@ class MainSystem: asyncio.create_task(chat_manager.regularly_save_sessions()) logger.info(t("startup.chat_manager_initialized")) + await memory_automation_service.start() # await asyncio.sleep(0.5) #防止logger输出飞了 @@ -164,6 +162,10 @@ async def main(): system.schedule_tasks(), ) finally: + await memory_automation_service.shutdown() + await get_plugin_runtime_manager().bridge_event("on_stop") + await get_plugin_runtime_manager().stop() + await async_task_manager.stop_and_wait_all_tasks() await config_manager.stop_file_watcher() diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index b984c66d..cedf971f 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -931,12 +931,14 @@ class ChatHistorySummarizer: else: logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") - # 同时导入到LPMM知识库 - if global_config.lpmm_knowledge.enable: - await self._import_to_lpmm_knowledge( + if saved_record and saved_record.get("id") is not None: + await self._import_to_long_term_memory( + record_id=int(saved_record["id"]), theme=theme, summary=summary, participants=participants, + start_time=start_time, + end_time=end_time, original_text=original_text, ) @@ -947,76 +949,131 @@ class ChatHistorySummarizer: traceback.print_exc() raise - async def _import_to_lpmm_knowledge( + async def _import_to_long_term_memory( self, + record_id: int, theme: str, summary: str, participants: List[str], + start_time: float, + end_time: float, original_text: str, ): """ - 将聊天历史总结导入到LPMM知识库 + 将聊天历史总结导入到统一长期记忆 Args: + record_id: chat_history 主键 theme: 话题主题 summary: 概括内容 participants: 参与者列表 + start_time: 开始时间 + end_time: 结束时间 original_text: 原始文本(可能很长,需要截断) """ try: - from src.chat.knowledge.lpmm_ops import lpmm_ops + from src.services.memory_service import memory_service + session = _chat_manager.get_session_by_session_id(self.session_id) + session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" + session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" - # 构造要导入的文本内容 - # 格式:主题 + 概括 + 参与者信息 + 原始内容摘要 - # 注意:使用单换行符连接,确保整个内容作为一段导入,不被LPMM分段 content_parts = [] - - # 1. 话题主题 - # if theme: - # content_parts.append(f"话题:{theme}") - - # 2. 概括内容 + if theme: + content_parts.append(f"主题:{theme}") if summary: content_parts.append(f"概括:{summary}") - - # 3. 参与者信息 if participants: participants_text = "、".join(participants) content_parts.append(f"参与者:{participants_text}") - - # 4. 原始文本摘要(如果原始文本太长,只取前500字) - # if original_text: - # # 截断原始文本,避免过长 - # max_original_length = 500 - # if len(original_text) > max_original_length: - # truncated_text = original_text[:max_original_length] + "..." - # content_parts.append(f"原始内容摘要:{truncated_text}") - # else: - # content_parts.append(f"原始内容:{original_text}") - - # 将所有部分合并为一个完整段落(使用单换行符,避免被LPMM分段) - # LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段 content_to_import = "\n".join(content_parts) if not content_to_import.strip(): - logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库") + logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底") + await self._fallback_import_to_long_term_memory( + record_id=record_id, + theme=theme, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, + ) return - # 调用lpmm_ops导入 - result = await lpmm_ops.add_content(text=content_to_import, auto_split=False) - - if result["status"] == "success": - logger.info( - f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}" - ) + result = await memory_service.ingest_summary( + external_id=f"chat_history:{record_id}", + chat_id=self.session_id, + text=content_to_import, + participants=participants, + time_start=start_time, + time_end=end_time, + tags=[theme] if theme else [], + metadata={"theme": theme, "original_text_length": len(original_text or "")}, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}") + else: + logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}") else: - logger.warning( - f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}" + logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}") + await self._fallback_import_to_long_term_memory( + record_id=record_id, + theme=theme, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, ) except Exception as e: - # 导入失败不应该影响数据库存储,只记录错误 - logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True) + logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True) + + async def _fallback_import_to_long_term_memory( + self, + *, + record_id: int, + theme: str, + participants: List[str], + start_time: float, + end_time: float, + original_text: str, + ) -> None: + try: + from src.services.memory_service import memory_service + session = _chat_manager.get_session_by_session_id(self.session_id) + session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" + session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" + + result = await memory_service.ingest_summary( + external_id=f"chat_history:{record_id}", + chat_id=self.session_id, + text="", + participants=participants, + time_start=start_time, + time_end=end_time, + tags=[theme] if theme else [], + metadata={ + "theme": theme, + "original_text_length": len(original_text or ""), + "generate_from_chat": True, + "context_length": global_config.memory.chat_history_topic_check_message_threshold, + }, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}") + else: + logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}") + else: + logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}") + except Exception as exc: + logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True) async def start(self): """启动后台定期检查循环""" diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 4193a16a..49e5ca02 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -237,8 +237,8 @@ async def _react_agent_solve_question( if first_head_prompt is None: # 第一次构建,使用初始的collected_info(即initial_info) initial_collected_info = initial_info or "" - # 使用 LPMM 知识库检索 prompt - first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_lpmm") + # 使用统一长期记忆检索 prompt + first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory") first_head_prompt_template.add_context("bot_name", bot_name) first_head_prompt_template.add_context("time_now", time_now) first_head_prompt_template.add_context("chat_history", chat_history) diff --git a/src/memory_system/retrieval_tools/__init__.py b/src/memory_system/retrieval_tools/__init__.py index 9f2673b2..ba5f731f 100644 --- a/src/memory_system/retrieval_tools/__init__.py +++ b/src/memory_system/retrieval_tools/__init__.py @@ -10,21 +10,17 @@ from .tool_registry import ( get_tool_registry, ) -# 导入所有工具的注册函数 -from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge -from .query_words import register_tool as register_query_words -from .return_information import register_tool as register_return_information -from src.config.config import global_config - def init_all_tools(): """初始化并注册所有记忆检索工具""" + # 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。 + from .query_long_term_memory import register_tool as register_long_term_memory + from .query_words import register_tool as register_query_words + from .return_information import register_tool as register_return_information + register_query_words() register_return_information() - - # LPMM知识库检索工具 - if global_config.lpmm_knowledge.lpmm_mode == "agent": - register_lpmm_knowledge() + register_long_term_memory() __all__ = [ diff --git a/src/memory_system/retrieval_tools/query_long_term_memory.py b/src/memory_system/retrieval_tools/query_long_term_memory.py new file mode 100644 index 00000000..57202f34 --- /dev/null +++ b/src/memory_system/retrieval_tools/query_long_term_memory.py @@ -0,0 +1,304 @@ +"""通过统一长期记忆服务查询信息。""" + +from __future__ import annotations + +import re +from calendar import monthrange +from datetime import datetime, timedelta +from typing import Iterable, Literal, Tuple + +from src.common.logger import get_logger +from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service + +from .tool_registry import register_memory_retrieval_tool + +logger = get_logger("memory_retrieval_tools") + +_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"} +_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$") +_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") +_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$") +_TIME_EXPRESSION_HELP = ( + "请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、" + "2026/03/18、2026/03/18 09:30。" +) + + +def _format_query_datetime(dt: datetime) -> str: + return dt.strftime("%Y/%m/%d %H:%M") + + +def _resolve_time_expression( + expression: str, + *, + now: datetime | None = None, +) -> Tuple[float, float, str, str]: + clean = str(expression or "").strip() + if not clean: + raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}") + + current = now or datetime.now() + day_start = current.replace(hour=0, minute=0, second=0, microsecond=0) + + if clean == "今天": + start = day_start + end = day_start.replace(hour=23, minute=59) + elif clean == "昨天": + start = day_start - timedelta(days=1) + end = start.replace(hour=23, minute=59) + elif clean == "前天": + start = day_start - timedelta(days=2) + end = start.replace(hour=23, minute=59) + elif clean == "本周": + start = day_start - timedelta(days=day_start.weekday()) + end = start + timedelta(days=6, hours=23, minutes=59) + elif clean == "上周": + this_week_start = day_start - timedelta(days=day_start.weekday()) + start = this_week_start - timedelta(days=7) + end = start + timedelta(days=6, hours=23, minutes=59) + elif clean == "本月": + start = day_start.replace(day=1) + last_day = monthrange(start.year, start.month)[1] + end = start.replace(day=last_day, hour=23, minute=59) + elif clean == "上月": + year = day_start.year + month = day_start.month - 1 + if month == 0: + year -= 1 + month = 12 + start = day_start.replace(year=year, month=month, day=1) + last_day = monthrange(year, month)[1] + end = start.replace(day=last_day, hour=23, minute=59) + else: + relative_match = _RELATIVE_DAYS_RE.fullmatch(clean) + if relative_match: + days = max(1, int(relative_match.group(1))) + start = day_start - timedelta(days=max(0, days - 1)) + end = day_start.replace(hour=23, minute=59) + elif _DATE_RE.fullmatch(clean): + start = datetime.strptime(clean, "%Y/%m/%d") + end = start.replace(hour=23, minute=59) + elif _MINUTE_RE.fullmatch(clean): + start = datetime.strptime(clean, "%Y/%m/%d %H:%M") + end = start + else: + raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}") + + return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end) + + +def _extract_time_label(metadata: dict) -> str: + if not isinstance(metadata, dict): + return "" + start = metadata.get("event_time_start") + end = metadata.get("event_time_end") + event_time = metadata.get("event_time") + + def _fmt(value: object) -> str: + if value in {None, ""}: + return "" + try: + return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M") + except Exception: + return str(value) + + start_text = _fmt(start or event_time) + end_text = _fmt(end) + if start_text and end_text: + return f"{start_text} - {end_text}" + return start_text or end_text + + +def _truncate(text: str, limit: int = 160) -> str: + compact = str(text or "").strip().replace("\n", " ") + if len(compact) <= limit: + return compact + return compact[:limit] + "..." + + +def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str: + lines = [] + for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): + time_label = _extract_time_label(item.metadata) if include_time else "" + prefix = f"[{time_label}] " if time_label else "" + lines.append(f"{index}. {prefix}{_truncate(item.content)}") + return "\n".join(lines) + + +def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str: + lines = [] + for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): + metadata = item.metadata if isinstance(item.metadata, dict) else {} + title = str(item.title or "").strip() or "未命名事件" + summary = _truncate(item.content, limit=180) + participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()] + keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()] + extras = [] + if participants: + extras.append(f"参与者:{'、'.join(participants[:4])}") + if keywords: + extras.append(f"关键词:{'、'.join(keywords[:6])}") + time_label = _extract_time_label(metadata) + if time_label: + extras.append(f"时间:{time_label}") + suffix = f"({';'.join(extras)})" if extras else "" + lines.append(f"{index}. 事件《{title}》:{summary}{suffix}") + return "\n".join(lines) + + +def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str: + lines = [] + for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): + metadata = item.metadata if isinstance(item.metadata, dict) else {} + source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()] + branch_text = f"[{','.join(source_branches)}]" if source_branches else "" + item_type = str(item.hit_type or "").strip().lower() or "memory" + if item_type == "episode": + title = str(item.title or "").strip() or "未命名事件" + lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}") + else: + lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}") + return "\n".join(lines) + + +def _format_tool_result( + *, + result: MemorySearchResult, + mode: Literal["search", "time", "episode", "aggregate"], + limit: int, + query: str, + time_range_text: str = "", +) -> str: + if not result.hits: + if mode == "time": + return f"在指定时间范围内未找到相关的长期记忆{time_range_text}" + if mode == "episode": + return f"未找到与“{query}”相关的事件或情节记忆" + if mode == "aggregate": + return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}" + return f"在长期记忆中未找到与“{query}”相关的信息" + + if mode == "episode": + text = _format_episode_lines(result.hits, limit=limit) + return f"你从长期记忆的事件/情节中找到以下信息:\n{text}" + + if mode == "aggregate": + text = _format_aggregate_lines(result.hits, limit=limit) + return f"你从长期记忆中综合找到了以下线索:\n{text}" + + if mode == "time": + text = _format_search_lines(result.hits, limit=limit, include_time=True) + return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}" + + text = _format_search_lines(result.hits, limit=limit) + return f"你从长期记忆中找到以下信息:\n{text}" + + +async def query_long_term_memory( + query: str = "", + limit: int = 5, + chat_id: str = "", + person_id: str = "", + mode: str = "search", + time_expression: str = "", +) -> str: + content = str(query or "").strip() + safe_limit = max(1, int(limit or 5)) + normalized_mode = str(mode or "search").strip().lower() or "search" + if normalized_mode not in _SUPPORTED_MODES: + return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。" + + if normalized_mode == "search" and not content: + return "查询关键词为空,请提供你想查找的长期记忆内容。" + if normalized_mode == "time" and not str(time_expression or "").strip(): + return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}" + if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip(): + return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。" + + time_start = None + time_end = None + time_range_text = "" + if str(time_expression or "").strip(): + try: + time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression) + except ValueError as exc: + return str(exc) + time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})" + + backend_mode = "hybrid" if normalized_mode == "search" else normalized_mode + + try: + result = await memory_service.search( + content, + limit=safe_limit, + mode=backend_mode, + chat_id=str(chat_id or "").strip(), + person_id=str(person_id or "").strip(), + time_start=time_start, + time_end=time_end, + ) + text = _format_tool_result( + result=result, + mode=normalized_mode, # type: ignore[arg-type] + limit=safe_limit, + query=content, + time_range_text=time_range_text, + ) + logger.debug(f"长期记忆查询结果({normalized_mode}): {text}") + return text + except Exception as exc: + logger.error(f"长期记忆查询失败: {exc}") + return f"长期记忆查询失败:{exc}" + + +def register_tool(): + register_memory_retrieval_tool( + name="search_long_term_memory", + description=( + "从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、" + "episode(按事件/情节检索)、aggregate(综合检索)四种模式。" + ), + parameters=[ + { + "name": "query", + "type": "string", + "description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。", + "required": False, + }, + { + "name": "mode", + "type": "string", + "description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。", + "required": False, + "enum": ["search", "time", "episode", "aggregate"], + }, + { + "name": "limit", + "type": "integer", + "description": "希望返回的相关知识条数,默认为5", + "required": False, + }, + { + "name": "chat_id", + "type": "string", + "description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。", + "required": False, + }, + { + "name": "person_id", + "type": "string", + "description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。", + "required": False, + }, + { + "name": "time_expression", + "type": "string", + "description": ( + "时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、" + "最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。" + ), + "required": False, + }, + ], + execute_func=query_long_term_memory, + ) diff --git a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py b/src/memory_system/retrieval_tools/query_lpmm_knowledge.py deleted file mode 100644 index eed01af1..00000000 --- a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -通过LPMM知识库查询信息 - 工具实现 -""" - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.knowledge import get_qa_manager -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def query_lpmm_knowledge(query: str, limit: int = 5) -> str: - """在LPMM知识库中查询相关信息 - - Args: - query: 查询关键词 - - Returns: - str: 查询结果 - """ - try: - content = str(query).strip() - if not content: - return "查询关键词为空" - - try: - limit_value = int(limit) - except (TypeError, ValueError): - limit_value = 5 - limit_value = max(1, limit_value) - - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用") - return "LPMM知识库未启用" - - qa_manager = get_qa_manager() - if qa_manager is None: - logger.debug("LPMM知识库未初始化,跳过查询") - return "LPMM知识库未初始化" - - knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value) - logger.debug(f"LPMM知识库查询结果: {knowledge_info}") - - if knowledge_info: - return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}" - - return f"在LPMM知识库中未找到与“{content}”相关的信息" - - except Exception as e: - logger.error(f"LPMM知识库查询失败: {e}") - return f"LPMM知识库查询失败:{str(e)}" - - -def register_tool(): - """注册LPMM知识库查询工具""" - register_memory_retrieval_tool( - name="lpmm_search_knowledge", - description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索", - parameters=[ - { - "name": "query", - "type": "string", - "description": "需要查询的问题,使用一句疑问句提问,例如:什么是AI?", - "required": True, - }, - { - "name": "limit", - "type": "integer", - "description": "希望返回的相关知识条数,默认为5", - "required": False, - }, - ], - execute_func=query_lpmm_knowledge, - ) diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 799f56a0..960de4aa 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -6,9 +6,10 @@ import random import math from json_repair import repair_json -from typing import Union, Optional, Dict +from typing import Union, Optional, Dict, List from datetime import datetime +from sqlalchemy import or_ from sqlmodel import col, select from src.common.logger import get_logger @@ -17,6 +18,7 @@ from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.services.memory_service import memory_service logger = get_logger("person_info") @@ -37,16 +39,60 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" + clean_name = str(person_name or "").strip() + if not clean_name: + return "" try: with get_db_session() as session: - statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1) + statement = ( + select(PersonInfo) + .where( + or_( + col(PersonInfo.person_name) == clean_name, + col(PersonInfo.user_nickname) == clean_name, + ) + ) + .limit(1) + ) + record = session.exec(statement).first() + if record and record.person_id: + return record.person_id + + statement = ( + select(PersonInfo) + .where(PersonInfo.group_cardname.contains(clean_name)) + .limit(1) + ) record = session.exec(statement).first() return record.person_id if record else "" except Exception as e: - logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") + logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}") return "" +def resolve_person_id_for_memory( + *, + person_name: str = "", + platform: str = "", + user_id: Optional[Union[int, str]] = None, +) -> str: + """统一人物记忆链路中的 person_id 解析。 + + 优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。 + """ + name_token = str(person_name or "").strip() + if name_token: + resolved = get_person_id_by_person_name(name_token) + if resolved: + return resolved + + platform_token = str(platform or "").strip() + user_token = str(user_id or "").strip() + if platform_token and user_token: + return get_person_id(platform_token, user_token) + return "" + + def is_person_known( person_id: Optional[str] = None, user_id: Optional[str] = None, @@ -537,79 +583,79 @@ class Person: async def build_relationship(self, chat_content: str = "", info_type=""): if not self.is_known: return "" - # 构建points文本 - nickname_str = "" if self.person_name != self.nickname: nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})" - relation_info = "" + async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]: + clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()] + if not clean_traits: + return [] + if not query_text: + return clean_traits[:limit] - points_text = "" - category_list = self.get_all_category() + numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1)) + prompt = f"""当前关注内容: +{query_text} - if chat_content: - prompt = f"""当前聊天内容: -{chat_content} +候选人物信息: +{numbered_traits} -分类列表: -{category_list} -**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹: -例如: -<分类1><分类2><分类3>...... -如果没有相关的分类,请输出""" +请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。 +例如: +<1><3> +如果都不相关,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) - # print(prompt) - # print(response) - category_list = extract_categories_from_response(response) - if "none" not in category_list: - for category in category_list: - random_memory = self.get_random_memory_by_category(category, 2) - if random_memory: - random_memory_str = "\n".join( - [get_memory_content_from_memory(memory) for memory in random_memory] - ) - points_text = f"有关 {category} 的内容:{random_memory_str}" - break - elif info_type: - prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。 + try: + response, _ = await relation_selection_model.generate_response_async(prompt) + selected_traits: List[str] = [] + for raw_index in extract_categories_from_response(response): + if raw_index == "none": + return [] + try: + trait_index = int(raw_index) - 1 + except ValueError: + continue + if 0 <= trait_index < len(clean_traits): + trait = clean_traits[trait_index] + if trait not in selected_traits: + selected_traits.append(trait) + if selected_traits: + return selected_traits[:limit] + except Exception as e: + logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}") -现有信息类别列表: -{category_list} -**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹: -例如: -<分类1><分类2><分类3>...... -如果没有相关的分类,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) - # print(prompt) - # print(response) - category_list = extract_categories_from_response(response) - if "none" not in category_list: - for category in category_list: - random_memory = self.get_random_memory_by_category(category, 3) - if random_memory: - random_memory_str = "\n".join( - [get_memory_content_from_memory(memory) for memory in random_memory] - ) - points_text = f"有关 {category} 的内容:{random_memory_str}" - break - else: - for category in category_list: - random_memory = self.get_random_memory_by_category(category, 1)[0] - if random_memory: - points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}" - break + return clean_traits[:limit] + + profile = await memory_service.get_person_profile(self.person_id, limit=8) + relation_parts: List[str] = [] + if profile.summary.strip(): + relation_parts.append(profile.summary.strip()) + + query_text = str(chat_content or info_type or "").strip() + selected_traits = await _select_traits(query_text, profile.traits, limit=3) + if not selected_traits and not query_text: + selected_traits = [trait for trait in profile.traits if trait][:2] + + for trait in selected_traits: + clean_trait = str(trait).strip() + if clean_trait and clean_trait not in relation_parts: + relation_parts.append(clean_trait) + + for evidence in profile.evidence: + content = str(evidence.get("content", "") or "").strip() + if content and content not in relation_parts: + relation_parts.append(content) + if len(relation_parts) >= 4: + break points_info = "" - if points_text: - points_info = f"你还记得有关{self.person_name}的内容:{points_text}" + if relation_parts: + points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}" if not (nickname_str or points_info): return "" - relation_info = f"{self.person_name}:{nickname_str}{points_info}" - - return relation_info + return f"{self.person_name}:{nickname_str}{points_info}" class PersonInfoManager: @@ -776,7 +822,7 @@ person_info_manager = PersonInfoManager() async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None: - """将人物信息存入person_info的memory_points + """将人物事实写入统一长期记忆 Args: person_name: 人物名称 @@ -784,6 +830,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: 聊天ID """ try: + content = str(memory_content or "").strip() + if not content: + logger.debug("人物记忆内容为空,跳过写入") + return + # 从 chat_id 获取 session session = _chat_manager.get_session_by_session_id(chat_id) if not session: @@ -794,16 +845,14 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, # 尝试从person_name查找person_id # 首先尝试通过person_name查找 - person_id = get_person_id_by_person_name(person_name) - + person_id = resolve_person_id_for_memory( + person_name=person_name, + platform=platform, + user_id=session.user_id, + ) if not person_id: - # 如果通过person_name找不到,尝试从 session 获取 user_id - if platform and session.user_id: - user_id = session.user_id - person_id = get_person_id(platform, user_id) - else: - logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") - return + logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") + return # 创建或获取Person对象 person = Person(person_id=person_id) @@ -812,39 +861,34 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆") return - # 确定记忆分类(可以根据memory_content判断,这里使用通用分类) - category = "其他" # 默认分类,可以根据需要调整 + memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16] + result = await memory_service.ingest_text( + external_id=f"person_fact:{person_id}:{memory_hash}", + source_type="person_fact", + text=content, + chat_id=chat_id, + person_ids=[person_id], + participants=[person.person_name or person_name], + timestamp=time.time(), + tags=["person_fact"], + metadata={ + "person_id": person_id, + "person_name": person.person_name or person_name, + "platform": platform, + "source": "person_info.store_person_memory_from_answer", + }, + respect_filter=True, + user_id=str(session.user_id or "").strip(), + group_id=str(session.group_id or "").strip(), + ) - # 记忆点格式:category:content:weight - weight = "1.0" # 默认权重 - memory_point = f"{category}:{memory_content}:{weight}" - - # 添加到memory_points - if not person.memory_points: - person.memory_points = [] - - # 检查是否已存在相似的记忆点(避免重复) - is_duplicate = False - for existing_point in person.memory_points: - if existing_point and isinstance(existing_point, str): - parts = existing_point.split(":", 2) - if len(parts) >= 2: - existing_content = parts[1].strip() - # 简单相似度检查(如果内容相同或非常相似,则跳过) - if ( - existing_content == memory_content - or memory_content in existing_content - or existing_content in memory_content - ): - is_duplicate = True - break - - if not is_duplicate: - person.memory_points.append(memory_point) - person.sync_to_database() - logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}") + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})") + else: + logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})") else: - logger.debug(f"记忆点已存在,跳过: {memory_point}") + logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}") except Exception as e: logger.error(f"存储人物记忆失败: {e}") diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index c4ae0a56..06ddf5de 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -672,12 +672,10 @@ class RuntimeDataCapabilityMixin: limit_value = 5 try: - from src.chat.knowledge import qa_manager + from src.services.memory_service import memory_service - if qa_manager is None: - return {"success": True, "content": "LPMM知识库已禁用"} - - knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value) + result = await memory_service.search(query, limit=limit_value) + knowledge_info = result.to_text(limit=limit_value) content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识" return {"success": True, "content": content} except Exception as e: diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py new file mode 100644 index 00000000..96062eb6 --- /dev/null +++ b/src/services/memory_flow_service.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any, Dict, List, Optional + +from json_repair import repair_json + +from src.chat.utils.utils import is_bot_self +from src.common.message_repository import find_messages +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.memory_system.chat_history_summarizer import ChatHistorySummarizer +from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer + +logger = get_logger("memory_flow_service") + + +class LongTermMemorySessionManager: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._summarizers: Dict[str, ChatHistorySummarizer] = {} + + async def on_message(self, message: Any) -> None: + if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)): + return + session_id = str(getattr(message, "session_id", "") or "").strip() + if not session_id: + return + + created = False + async with self._lock: + summarizer = self._summarizers.get(session_id) + if summarizer is None: + summarizer = ChatHistorySummarizer(session_id=session_id) + self._summarizers[session_id] = summarizer + created = True + if created: + await summarizer.start() + + async def shutdown(self) -> None: + async with self._lock: + items = list(self._summarizers.items()) + self._summarizers.clear() + for session_id, summarizer in items: + try: + await summarizer.stop() + except Exception as exc: + logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc) + + +class PersonFactWritebackService: + def __init__(self) -> None: + self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256) + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + self._extractor = LLMRequest( + model_set=model_config.model_task_config.utils, + request_type="person_fact_writeback", + ) + + async def start(self) -> None: + if self._worker_task is not None and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_person_fact_writeback") + + async def shutdown(self) -> None: + self._stopping = True + worker = self._worker_task + self._worker_task = None + if worker is None: + return + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning("关闭人物事实写回 worker 失败: %s", exc) + + async def enqueue(self, message: Any) -> None: + if not bool(getattr(global_config.memory, "person_fact_writeback_enabled", True)): + return + if self._stopping: + return + try: + self._queue.put_nowait(message) + except asyncio.QueueFull: + logger.warning("人物事实写回队列已满,跳过本次回复") + + async def _worker_loop(self) -> None: + try: + while not self._stopping: + message = await self._queue.get() + try: + await self._handle_message(message) + except Exception as exc: + logger.warning("人物事实写回处理失败: %s", exc, exc_info=True) + finally: + self._queue.task_done() + except asyncio.CancelledError: + raise + + async def _handle_message(self, message: Any) -> None: + reply_text = str(getattr(message, "processed_plain_text", "") or "").strip() + if not reply_text: + return + if self._looks_ephemeral(reply_text): + return + + target_person = self._resolve_target_person(message) + if target_person is None or not target_person.is_known: + return + + facts = await self._extract_facts(target_person, reply_text) + if not facts: + return + + session_id = str( + getattr(message, "session_id", "") + or getattr(getattr(message, "session", None), "session_id", "") + or "" + ).strip() + if not session_id: + return + + person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip() + if not person_name: + return + + for fact in facts: + await store_person_memory_from_answer(person_name, fact, session_id) + + def _resolve_target_person(self, message: Any) -> Optional[Person]: + session = getattr(message, "session", None) + session_platform = str(getattr(session, "platform", "") or getattr(message, "platform", "") or "").strip() + session_user_id = str(getattr(session, "user_id", "") or "").strip() + group_id = str(getattr(session, "group_id", "") or "").strip() + + if session_platform and session_user_id and not group_id: + if is_bot_self(session_platform, session_user_id): + return None + person_id = get_person_id(session_platform, session_user_id) + person = Person(person_id=person_id) + return person if person.is_known else None + + reply_to = str(getattr(message, "reply_to", "") or "").strip() + if not reply_to: + return None + try: + replies = find_messages(message_id=reply_to, limit=1) + except Exception as exc: + logger.debug("查询 reply_to 目标失败: %s", exc) + return None + if not replies: + return None + reply_message = replies[0] + reply_platform = str(getattr(reply_message, "platform", "") or session_platform or "").strip() + reply_user_info = getattr(getattr(reply_message, "message_info", None), "user_info", None) + reply_user_id = str(getattr(reply_user_info, "user_id", "") or "").strip() + if not reply_platform or not reply_user_id or is_bot_self(reply_platform, reply_user_id): + return None + person_id = get_person_id(reply_platform, reply_user_id) + person = Person(person_id=person_id) + return person if person.is_known else None + + async def _extract_facts(self, person: Person, reply_text: str) -> List[str]: + person_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or person.person_id) + prompt = f"""你要从一条机器人刚刚发送的回复中,提取“关于{person_name}的稳定事实”。 + +目标人物:{person_name} +机器人回复: +{reply_text} + +请只提取满足以下条件的事实: +1. 明确是关于目标人物本人的信息。 +2. 具有相对稳定性,可以作为长期记忆保存。 +3. 用简洁中文陈述句表达。 + +不要提取: +- 机器人的情绪、计划、临时动作、客套话 +- 只适用于当前时刻的短期安排 +- 不确定、猜测、反问 +- 与目标人物无关的信息 + +严格输出 JSON 数组,例如: +["他喜欢深夜打游戏", "他养了一只猫"] +如果没有可写入的事实,输出 []""" + try: + response, _ = await self._extractor.generate_response_async(prompt) + except Exception as exc: + logger.debug("人物事实提取模型调用失败: %s", exc) + return [] + return self._parse_fact_list(response) + + @staticmethod + def _parse_fact_list(raw: str) -> List[str]: + text = str(raw or "").strip() + if not text: + return [] + try: + repaired = repair_json(text) + payload = json.loads(repaired) if isinstance(repaired, str) else repaired + except Exception: + payload = None + if not isinstance(payload, list): + return [] + + items: List[str] = [] + seen = set() + for item in payload: + fact = str(item or "").strip().strip("- ") + if not fact or len(fact) < 4: + continue + if fact in seen: + continue + seen.add(fact) + items.append(fact) + return items[:5] + + @staticmethod + def _looks_ephemeral(text: str) -> bool: + content = str(text or "").strip() + if not content: + return True + ephemeral_markers = ( + "哈哈", + "好的", + "收到", + "嗯嗯", + "晚安", + "早安", + "拜拜", + "谢谢", + "在吗", + "?", + ) + if len(content) <= 8 and any(marker in content for marker in ephemeral_markers): + return True + return False + + +class MemoryAutomationService: + def __init__(self) -> None: + self.session_manager = LongTermMemorySessionManager() + self.fact_writeback = PersonFactWritebackService() + self._started = False + + async def start(self) -> None: + if self._started: + return + await self.fact_writeback.start() + self._started = True + + async def shutdown(self) -> None: + if not self._started: + return + await self.session_manager.shutdown() + await self.fact_writeback.shutdown() + self._started = False + + async def on_incoming_message(self, message: Any) -> None: + if not self._started: + await self.start() + await self.session_manager.on_message(message) + + async def on_message_sent(self, message: Any) -> None: + if not self._started: + await self.start() + await self.fact_writeback.enqueue(message) + + +memory_automation_service = MemoryAutomationService() diff --git a/src/services/memory_service.py b/src/services/memory_service.py new file mode 100644 index 00000000..6cbecd63 --- /dev/null +++ b/src/services/memory_service.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger +from src.plugin_runtime.integration import get_plugin_runtime_manager + + +logger = get_logger("memory_service") + +PLUGIN_ID = "A_Memorix" + + +@dataclass +class MemoryHit: + content: str + score: float = 0.0 + hit_type: str = "" + source: str = "" + hash_value: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + episode_id: str = "" + title: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "content": self.content, + "score": self.score, + "type": self.hit_type, + "source": self.source, + "hash": self.hash_value, + "metadata": self.metadata, + "episode_id": self.episode_id, + "title": self.title, + } + + +@dataclass +class MemorySearchResult: + summary: str = "" + hits: List[MemoryHit] = field(default_factory=list) + filtered: bool = False + + def to_text(self, limit: int = 5) -> str: + if not self.hits: + return "" + lines = [] + for index, item in enumerate(self.hits[: max(1, int(limit))], start=1): + content = item.content.strip().replace("\n", " ") + if len(content) > 160: + content = content[:160] + "..." + lines.append(f"{index}. {content}") + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + return { + "summary": self.summary, + "hits": [item.to_dict() for item in self.hits], + "filtered": self.filtered, + } + + +@dataclass +class MemoryWriteResult: + success: bool + stored_ids: List[str] = field(default_factory=list) + skipped_ids: List[str] = field(default_factory=list) + detail: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "success": self.success, + "stored_ids": self.stored_ids, + "skipped_ids": self.skipped_ids, + "detail": self.detail, + } + + +@dataclass +class PersonProfileResult: + summary: str = "" + traits: List[str] = field(default_factory=list) + evidence: List[Dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return {"summary": self.summary, "traits": self.traits, "evidence": self.evidence} + + +class MemoryService: + async def _invoke(self, component_name: str, args: Optional[Dict[str, Any]] = None, *, timeout_ms: int = 30000) -> Any: + runtime = get_plugin_runtime_manager() + if not runtime.is_running: + raise RuntimeError("plugin_runtime 未启动") + return await runtime.invoke_plugin( + method="plugin.invoke_tool", + plugin_id=PLUGIN_ID, + component_name=component_name, + args=args or {}, + timeout_ms=max(1000, int(timeout_ms or 30000)), + ) + + async def _invoke_admin( + self, + component_name: str, + *, + action: str, + timeout_ms: int = 30000, + **kwargs, + ) -> Dict[str, Any]: + payload = await self._invoke(component_name, {"action": action, **kwargs}, timeout_ms=timeout_ms) + return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"} + + @staticmethod + def _coerce_write_result(payload: Any) -> MemoryWriteResult: + if not isinstance(payload, dict): + return MemoryWriteResult(success=False, detail="invalid_payload") + stored_ids = [str(item) for item in (payload.get("stored_ids") or []) if str(item).strip()] + skipped_ids = [str(item) for item in (payload.get("skipped_ids") or []) if str(item).strip()] + detail = str(payload.get("detail") or payload.get("reason") or "") + if stored_ids or skipped_ids: + success = True + elif "success" in payload: + success = bool(payload.get("success")) + else: + success = not bool(detail) + return MemoryWriteResult( + success=success, + stored_ids=stored_ids, + skipped_ids=skipped_ids, + detail=detail, + ) + + @staticmethod + def _coerce_search_result(payload: Any) -> MemorySearchResult: + if not isinstance(payload, dict): + return MemorySearchResult() + hits: List[MemoryHit] = [] + for item in payload.get("hits", []) or []: + if not isinstance(item, dict): + continue + metadata = item.get("metadata", {}) or {} + if not isinstance(metadata, dict): + metadata = {} + if "source_branches" in item and "source_branches" not in metadata: + metadata["source_branches"] = item.get("source_branches") or [] + if "rank" in item and "rank" not in metadata: + metadata["rank"] = item.get("rank") + hits.append( + MemoryHit( + content=str(item.get("content", "") or ""), + score=float(item.get("score", 0.0) or 0.0), + hit_type=str(item.get("type", "") or ""), + source=str(item.get("source", "") or ""), + hash_value=str(item.get("hash", "") or ""), + metadata=metadata, + episode_id=str(item.get("episode_id", "") or ""), + title=str(item.get("title", "") or ""), + ) + ) + return MemorySearchResult( + summary=str(payload.get("summary", "") or ""), + hits=hits, + filtered=bool(payload.get("filtered", False)), + ) + + @staticmethod + def _coerce_profile_result(payload: Any) -> PersonProfileResult: + if not isinstance(payload, dict): + return PersonProfileResult() + return PersonProfileResult( + summary=str(payload.get("summary", "") or ""), + traits=[str(item) for item in (payload.get("traits") or []) if str(item).strip()], + evidence=[item for item in (payload.get("evidence") or []) if isinstance(item, dict)], + ) + + async def search( + self, + query: str, + *, + limit: int = 5, + mode: str = "hybrid", + chat_id: str = "", + person_id: str = "", + time_start: str | float | None = None, + time_end: str | float | None = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", + ) -> MemorySearchResult: + clean_query = str(query or "").strip() + normalized_time_start = None if time_start in {None, ""} else time_start + normalized_time_end = None if time_end in {None, ""} else time_end + if not clean_query and normalized_time_start is None and normalized_time_end is None: + return MemorySearchResult() + try: + payload = await self._invoke( + "search_memory", + { + "query": clean_query, + "limit": max(1, int(limit)), + "mode": mode, + "chat_id": chat_id, + "person_id": person_id, + "time_start": normalized_time_start, + "time_end": normalized_time_end, + "respect_filter": bool(respect_filter), + "user_id": str(user_id or "").strip(), + "group_id": str(group_id or "").strip(), + }, + ) + return self._coerce_search_result(payload) + except Exception as exc: + logger.warning("长期记忆搜索失败: %s", exc) + return MemorySearchResult() + + async def ingest_summary( + self, + *, + external_id: str, + chat_id: str, + text: str, + participants: Optional[List[str]] = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", + ) -> MemoryWriteResult: + try: + payload = await self._invoke( + "ingest_summary", + { + "external_id": external_id, + "chat_id": chat_id, + "text": text, + "participants": participants or [], + "time_start": time_start, + "time_end": time_end, + "tags": tags or [], + "metadata": metadata or {}, + "respect_filter": bool(respect_filter), + "user_id": str(user_id or "").strip(), + "group_id": str(group_id or "").strip(), + }, + ) + return self._coerce_write_result(payload) + except Exception as exc: + logger.warning("长期记忆写入摘要失败: %s", exc) + return MemoryWriteResult(success=False, detail=str(exc)) + + async def ingest_text( + self, + *, + external_id: str, + source_type: str, + text: str, + chat_id: str = "", + person_ids: Optional[List[str]] = None, + participants: Optional[List[str]] = None, + timestamp: float | None = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + entities: Optional[List[str]] = None, + relations: Optional[List[Dict[str, Any]]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", + ) -> MemoryWriteResult: + try: + payload = await self._invoke( + "ingest_text", + { + "external_id": external_id, + "source_type": source_type, + "text": text, + "chat_id": chat_id, + "person_ids": person_ids or [], + "participants": participants or [], + "timestamp": timestamp, + "time_start": time_start, + "time_end": time_end, + "tags": tags or [], + "metadata": metadata or {}, + "entities": entities or [], + "relations": relations or [], + "respect_filter": bool(respect_filter), + "user_id": str(user_id or "").strip(), + "group_id": str(group_id or "").strip(), + }, + ) + return self._coerce_write_result(payload) + except Exception as exc: + logger.warning("长期记忆写入文本失败: %s", exc) + return MemoryWriteResult(success=False, detail=str(exc)) + + async def get_person_profile(self, person_id: str, *, chat_id: str = "", limit: int = 10) -> PersonProfileResult: + clean_person_id = str(person_id or "").strip() + if not clean_person_id: + return PersonProfileResult() + try: + payload = await self._invoke( + "get_person_profile", + {"person_id": clean_person_id, "chat_id": chat_id, "limit": max(1, int(limit))}, + ) + return self._coerce_profile_result(payload) + except Exception as exc: + logger.warning("获取人物画像失败: %s", exc) + return PersonProfileResult() + + async def maintain_memory( + self, + *, + action: str, + target: str = "", + hours: float | None = None, + reason: str = "", + limit: int = 50, + ) -> MemoryWriteResult: + try: + payload = await self._invoke( + "maintain_memory", + {"action": action, "target": target, "hours": hours, "reason": reason, "limit": limit}, + ) + if not isinstance(payload, dict): + return MemoryWriteResult(success=False, detail="invalid_payload") + return MemoryWriteResult(success=bool(payload.get("success")), detail=str(payload.get("detail", "") or "")) + except Exception as exc: + logger.warning("记忆维护失败: %s", exc) + return MemoryWriteResult(success=False, detail=str(exc)) + + async def memory_stats(self) -> Dict[str, Any]: + try: + payload = await self._invoke("memory_stats", {}) + return payload if isinstance(payload, dict) else {} + except Exception as exc: + logger.warning("获取记忆统计失败: %s", exc) + return {} + + async def graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_graph_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("图谱管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_source_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("来源管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_episode_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("Episode 管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_profile_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("画像管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("运行时管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def import_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_import_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("导入管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def tuning_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_tuning_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("调优管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def v5_admin(self, *, action: str, timeout_ms: int = 30000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_v5_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("V5 记忆管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def delete_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_delete_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("删除管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def get_recycle_bin(self, *, limit: int = 50) -> Dict[str, Any]: + try: + payload = await self._invoke("maintain_memory", {"action": "recycle_bin", "limit": max(1, int(limit or 50))}) + return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"} + except Exception as exc: + logger.warning("获取回收站失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def restore_memory(self, *, target: str) -> MemoryWriteResult: + return await self.maintain_memory(action="restore", target=target) + + async def reinforce_memory(self, *, target: str) -> MemoryWriteResult: + return await self.maintain_memory(action="reinforce", target=target) + + async def freeze_memory(self, *, target: str) -> MemoryWriteResult: + return await self.maintain_memory(action="freeze", target=target) + + async def protect_memory(self, *, target: str, hours: float | None = None) -> MemoryWriteResult: + return await self.maintain_memory(action="protect", target=target, hours=hours) + + +memory_service = MemoryService() diff --git a/src/webui/routers/__init__.py b/src/webui/routers/__init__.py index 65d63d02..687915d4 100644 --- a/src/webui/routers/__init__.py +++ b/src/webui/routers/__init__.py @@ -17,14 +17,14 @@ def get_all_routers() -> List[APIRouter]: from src.webui.api.planner import router as planner_router from src.webui.api.replier import router as replier_router from src.webui.routers.chat import router as chat_router - from src.webui.routers.knowledge import router as knowledge_router + from src.webui.routers.memory import compat_router as memory_compat_router from src.webui.routers.websocket.logs import router as logs_router from src.webui.routes import router as main_router return [ main_router, + memory_compat_router, logs_router, - knowledge_router, chat_router, planner_router, replier_router, diff --git a/src/webui/routers/memory.py b/src/webui/routers/memory.py new file mode 100644 index 00000000..d741affc --- /dev/null +++ b/src/webui/routers/memory.py @@ -0,0 +1,1395 @@ +from __future__ import annotations + +import json +import shutil +import uuid +from pathlib import Path +from typing import Any, Optional + +from fastapi import APIRouter, Body, Depends, File, Form, Query, UploadFile +from pydantic import BaseModel, Field + +from src.services.memory_service import MemorySearchResult, memory_service +from src.webui.dependencies import require_auth + + +router = APIRouter(prefix="/api/webui/memory", tags=["memory"], dependencies=[Depends(require_auth)]) +compat_router = APIRouter(prefix="/api", tags=["memory-compat"], dependencies=[Depends(require_auth)]) +STAGING_ROOT = Path(__file__).resolve().parents[3] / "data" / "memory_upload_staging" + + +class NodeRequest(BaseModel): + name: str = Field(..., min_length=1) + + +class NodeRenameRequest(BaseModel): + old_name: str = Field(..., min_length=1) + new_name: str = Field(..., min_length=1) + + +class EdgeCreateRequest(BaseModel): + subject: str = Field(..., min_length=1) + predicate: str = Field(..., min_length=1) + object: str = Field(..., min_length=1) + confidence: float = Field(1.0, ge=0.0) + + +class EdgeDeleteRequest(BaseModel): + hash: str = "" + subject: str = "" + object: str = "" + + +class EdgeWeightRequest(BaseModel): + hash: str = "" + subject: str = "" + object: str = "" + weight: float = Field(..., ge=0.0) + + +class SourceDeleteRequest(BaseModel): + source: str = Field(..., min_length=1) + + +class SourceBatchDeleteRequest(BaseModel): + sources: list[str] = Field(default_factory=list) + + +class EpisodeRebuildRequest(BaseModel): + source: str = "" + sources: list[str] = Field(default_factory=list) + all: bool = False + + +class EpisodeProcessPendingRequest(BaseModel): + limit: int = Field(20, ge=1, le=200) + max_retry: int = Field(3, ge=1, le=20) + + +class ProfileOverrideRequest(BaseModel): + person_id: str = Field(..., min_length=1) + override_text: str = "" + updated_by: str = "" + source: str = "webui" + + +class MaintainRequest(BaseModel): + target: str = Field(..., min_length=1) + hours: Optional[float] = None + + +class AutoSaveRequest(BaseModel): + enabled: bool + + +class TuningApplyProfileRequest(BaseModel): + profile: dict[str, Any] = Field(default_factory=dict) + reason: str = "manual" + + +class V5ActionRequest(BaseModel): + target: str = Field(..., min_length=1) + strength: Optional[float] = Field(default=None, ge=0.0) + reason: str = "" + updated_by: str = "webui" + + +class DeleteActionRequest(BaseModel): + mode: str = Field(..., min_length=1) + selector: dict[str, Any] | str = Field(default_factory=dict) + reason: str = "" + requested_by: str = "webui" + + +class DeleteRestoreRequest(BaseModel): + operation_id: str = "" + mode: str = "" + selector: dict[str, Any] | str = Field(default_factory=dict) + reason: str = "" + requested_by: str = "webui" + + +class DeletePurgeRequest(BaseModel): + grace_hours: Optional[float] = Field(default=None, ge=0.0) + limit: int = Field(1000, ge=1, le=5000) + + +def _build_import_guide_markdown(settings: dict[str, Any]) -> str: + path_aliases = settings.get("path_aliases") if isinstance(settings.get("path_aliases"), dict) else {} + alias_lines = [ + f"- `{name}` -> `{path}`" + for name, path in sorted(path_aliases.items()) + if str(name).strip() and str(path).strip() + ] + if not alias_lines: + alias_lines = ["- 当前未配置路径别名"] + return "\n".join( + [ + "# 长期记忆导入说明", + "", + "支持的导入方式:", + "- 上传文件:适合零散文档、日志、聊天导出文本。", + "- 粘贴文本:适合一次性导入少量整理好的内容。", + "- Raw Scan:扫描白名单目录内的原始文本文件。", + "- LPMM OpenIE / Convert:处理既有 LPMM 数据。", + "- Temporal Backfill:补回已有数据中的时间信息。", + "- MaiBot Migration:从宿主数据库迁移历史聊天记忆。", + "", + "当前路径别名:", + *alias_lines, + "", + "执行建议:", + "- 首次导入先小批量试跑,确认切分和抽取结果正常。", + "- 大批量导入时优先关注任务状态、失败块与重试结果。", + "- 若路径解析失败,请先检查路径别名与相对路径是否仍然有效。", + ] + ) + + +def _unwrap_payload(payload: dict[str, Any] | None) -> dict[str, Any]: + raw = payload if isinstance(payload, dict) else {} + nested = raw.get("payload") + if isinstance(nested, dict): + return dict(nested) + return dict(raw) + + +async def _graph_get(limit: int) -> dict: + return await memory_service.graph_admin(action="get_graph", limit=limit) + + +async def _graph_create_node(payload: NodeRequest) -> dict: + return await memory_service.graph_admin(action="create_node", name=payload.name) + + +async def _graph_delete_node(payload: NodeRequest) -> dict: + return await memory_service.graph_admin(action="delete_node", name=payload.name) + + +async def _graph_rename_node(payload: NodeRenameRequest) -> dict: + return await memory_service.graph_admin(action="rename_node", old_name=payload.old_name, new_name=payload.new_name) + + +async def _graph_create_edge(payload: EdgeCreateRequest) -> dict: + return await memory_service.graph_admin( + action="create_edge", + subject=payload.subject, + predicate=payload.predicate, + object=payload.object, + confidence=payload.confidence, + ) + + +async def _graph_delete_edge(payload: EdgeDeleteRequest) -> dict: + return await memory_service.graph_admin( + action="delete_edge", + hash=payload.hash, + subject=payload.subject, + object=payload.object, + ) + + +async def _graph_update_edge_weight(payload: EdgeWeightRequest) -> dict: + return await memory_service.graph_admin( + action="update_edge_weight", + hash=payload.hash, + subject=payload.subject, + object=payload.object, + weight=payload.weight, + ) + + +async def _source_list() -> dict: + return await memory_service.source_admin(action="list") + + +async def _source_delete(payload: SourceDeleteRequest) -> dict: + return await memory_service.source_admin(action="delete", source=payload.source) + + +async def _source_batch_delete(payload: SourceBatchDeleteRequest) -> dict: + return await memory_service.source_admin(action="batch_delete", sources=payload.sources) + + +async def _query_aggregate( + query: str, + *, + limit: int, + chat_id: str, + person_id: str, + time_start: float | None, + time_end: float | None, +) -> dict: + result: MemorySearchResult = await memory_service.search( + query, + limit=limit, + mode="aggregate", + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + respect_filter=False, + ) + return {"success": True, **result.to_dict()} + + +async def _episode_list( + *, + query: str, + limit: int, + source: str, + person_id: str, + time_start: float | None, + time_end: float | None, +) -> dict: + return await memory_service.episode_admin( + action="list", + query=query, + limit=limit, + source=source, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +async def _episode_get(episode_id: str) -> dict: + return await memory_service.episode_admin(action="get", episode_id=episode_id) + + +async def _episode_rebuild(payload: EpisodeRebuildRequest) -> dict: + return await memory_service.episode_admin( + action="rebuild", + source=payload.source, + sources=payload.sources, + all=payload.all, + ) + + +async def _episode_status(limit: int) -> dict: + return await memory_service.episode_admin(action="status", limit=limit) + + +async def _episode_process_pending(payload: EpisodeProcessPendingRequest) -> dict: + return await memory_service.episode_admin( + action="process_pending", + limit=payload.limit, + max_retry=payload.max_retry, + ) + + +async def _profile_query(*, person_id: str, person_keyword: str, limit: int, force_refresh: bool) -> dict: + return await memory_service.profile_admin( + action="query", + person_id=person_id, + person_keyword=person_keyword, + limit=limit, + force_refresh=force_refresh, + ) + + +async def _profile_list(limit: int) -> dict: + return await memory_service.profile_admin(action="list", limit=limit) + + +async def _profile_set_override(payload: ProfileOverrideRequest) -> dict: + return await memory_service.profile_admin( + action="set_override", + person_id=payload.person_id, + override_text=payload.override_text, + updated_by=payload.updated_by, + source=payload.source, + ) + + +async def _profile_delete_override(person_id: str) -> dict: + return await memory_service.profile_admin(action="delete_override", person_id=person_id) + + +async def _runtime_save() -> dict: + return await memory_service.runtime_admin(action="save") + + +async def _runtime_config() -> dict: + return await memory_service.runtime_admin(action="get_config") + + +async def _runtime_self_check(refresh: bool) -> dict: + return await memory_service.runtime_admin(action="refresh_self_check" if refresh else "self_check") + + +async def _runtime_auto_save(enabled: bool | None = None) -> dict: + if enabled is None: + config = await memory_service.runtime_admin(action="get_config") + return {"success": bool(config.get("success", False)), "auto_save": bool(config.get("auto_save", False))} + return await memory_service.runtime_admin(action="set_auto_save", enabled=enabled) + + +async def _maintenance_recycle_bin(limit: int) -> dict: + return await memory_service.get_recycle_bin(limit=limit) + + +async def _maintenance_restore(payload: MaintainRequest) -> dict: + return (await memory_service.restore_memory(target=payload.target)).to_dict() + + +async def _maintenance_reinforce(payload: MaintainRequest) -> dict: + return (await memory_service.reinforce_memory(target=payload.target)).to_dict() + + +async def _maintenance_freeze(payload: MaintainRequest) -> dict: + return (await memory_service.freeze_memory(target=payload.target)).to_dict() + + +async def _maintenance_protect(payload: MaintainRequest) -> dict: + return (await memory_service.protect_memory(target=payload.target, hours=payload.hours)).to_dict() + + +async def _v5_status(target: str, limit: int) -> dict: + return await memory_service.v5_admin(action="status", target=target, limit=limit) + + +async def _v5_recycle_bin(limit: int) -> dict: + return await memory_service.v5_admin(action="recycle_bin", limit=limit) + + +async def _v5_action(action: str, payload: V5ActionRequest) -> dict: + kwargs: dict[str, Any] = { + "target": payload.target, + "reason": payload.reason, + "updated_by": payload.updated_by, + } + if payload.strength is not None: + kwargs["strength"] = payload.strength + return await memory_service.v5_admin(action=action, **kwargs) + + +async def _delete_preview(payload: DeleteActionRequest) -> dict: + return await memory_service.delete_admin(action="preview", mode=payload.mode, selector=payload.selector) + + +async def _delete_execute(payload: DeleteActionRequest) -> dict: + return await memory_service.delete_admin( + action="execute", + mode=payload.mode, + selector=payload.selector, + reason=payload.reason, + requested_by=payload.requested_by, + ) + + +async def _delete_restore(payload: DeleteRestoreRequest) -> dict: + return await memory_service.delete_admin( + action="restore", + mode=payload.mode, + selector=payload.selector, + operation_id=payload.operation_id, + reason=payload.reason, + requested_by=payload.requested_by, + ) + + +async def _delete_list(limit: int, mode: str) -> dict: + return await memory_service.delete_admin(action="list_operations", limit=limit, mode=mode) + + +async def _delete_get(operation_id: str) -> dict: + return await memory_service.delete_admin(action="get_operation", operation_id=operation_id) + + +async def _delete_purge(payload: DeletePurgeRequest) -> dict: + return await memory_service.delete_admin( + action="purge", + grace_hours=payload.grace_hours, + limit=payload.limit, + ) + + +async def _import_settings() -> dict: + return await memory_service.import_admin(action="get_settings") + + +async def _import_path_aliases() -> dict: + return await memory_service.import_admin(action="get_path_aliases") + + +async def _import_guide() -> dict: + payload = await memory_service.import_admin(action="get_guide") + if not isinstance(payload, dict): + payload = {"success": False, "error": "invalid_payload"} + if isinstance(payload.get("content"), str): + return payload + + settings = payload.get("settings") if isinstance(payload.get("settings"), dict) else None + if settings is None: + settings_payload = await memory_service.import_admin(action="get_settings") + settings = settings_payload.get("settings") if isinstance(settings_payload.get("settings"), dict) else {} + + return { + "success": True, + "source": "local", + "path": "generated://memory_import_guide", + "content": _build_import_guide_markdown(settings or {}), + "settings": settings or {}, + } + + +async def _import_resolve_path(payload: dict[str, Any]) -> dict: + return await memory_service.import_admin(action="resolve_path", **_unwrap_payload(payload)) + + +async def _import_create(action: str, payload: dict[str, Any]) -> dict: + return await memory_service.import_admin(action=action, **_unwrap_payload(payload)) + + +async def _import_list(limit: int) -> dict: + listing = await memory_service.import_admin(action="list", limit=limit) + if not isinstance(listing, dict): + listing = {"success": False, "items": []} + settings_payload = await memory_service.import_admin(action="get_settings") + settings = settings_payload.get("settings") if isinstance(settings_payload.get("settings"), dict) else {} + listing.setdefault("success", True) + listing.setdefault("items", []) + listing["settings"] = settings + return listing + + +async def _import_get(task_id: str, include_chunks: bool) -> dict: + return await memory_service.import_admin(action="get", task_id=task_id, include_chunks=include_chunks) + + +async def _import_chunks(task_id: str, file_id: str, offset: int, limit: int) -> dict: + return await memory_service.import_admin( + action="get_chunks", + task_id=task_id, + file_id=file_id, + offset=offset, + limit=limit, + ) + + +async def _import_cancel(task_id: str) -> dict: + return await memory_service.import_admin(action="cancel", task_id=task_id) + + +async def _import_retry(task_id: str, payload: dict[str, Any]) -> dict: + raw = _unwrap_payload(payload) + overrides = raw.get("overrides") if isinstance(raw.get("overrides"), dict) else raw + return await memory_service.import_admin(action="retry_failed", task_id=task_id, overrides=overrides) + + +async def _tuning_settings() -> dict: + return await memory_service.tuning_admin(action="get_settings") + + +async def _tuning_profile() -> dict: + profile = await memory_service.tuning_admin(action="get_profile") + if not isinstance(profile, dict): + profile = {"success": False, "profile": {}} + if not isinstance(profile.get("settings"), dict): + settings = await memory_service.tuning_admin(action="get_settings") + profile["settings"] = settings.get("settings") if isinstance(settings.get("settings"), dict) else {} + return profile + + +async def _tuning_apply_profile(payload: TuningApplyProfileRequest) -> dict: + return await memory_service.tuning_admin(action="apply_profile", profile=payload.profile, reason=payload.reason) + + +async def _tuning_rollback_profile() -> dict: + return await memory_service.tuning_admin(action="rollback_profile") + + +async def _tuning_export_profile() -> dict: + return await memory_service.tuning_admin(action="export_profile") + + +async def _tuning_create_task(payload: dict[str, Any]) -> dict: + return await memory_service.tuning_admin(action="create_task", payload=_unwrap_payload(payload)) + + +async def _tuning_list_tasks(limit: int) -> dict: + return await memory_service.tuning_admin(action="list_tasks", limit=limit) + + +async def _tuning_get_task(task_id: str, include_rounds: bool) -> dict: + return await memory_service.tuning_admin(action="get_task", task_id=task_id, include_rounds=include_rounds) + + +async def _tuning_get_rounds(task_id: str, offset: int, limit: int) -> dict: + return await memory_service.tuning_admin(action="get_rounds", task_id=task_id, offset=offset, limit=limit) + + +async def _tuning_cancel(task_id: str) -> dict: + return await memory_service.tuning_admin(action="cancel", task_id=task_id) + + +async def _tuning_apply_best(task_id: str) -> dict: + return await memory_service.tuning_admin(action="apply_best", task_id=task_id) + + +async def _tuning_report(task_id: str, fmt: str) -> dict: + payload = await memory_service.tuning_admin(action="get_report", task_id=task_id, format=fmt) + report = payload.get("report") if isinstance(payload.get("report"), dict) else {} + return { + "success": bool(payload.get("success", False)), + "format": report.get("format", fmt), + "content": report.get("content", ""), + "path": report.get("path", ""), + "error": payload.get("error", ""), + } + + +async def _stage_upload_files(files: list[UploadFile]) -> tuple[Path, list[dict[str, Any]]]: + STAGING_ROOT.mkdir(parents=True, exist_ok=True) + staging_dir = STAGING_ROOT / uuid.uuid4().hex + staging_dir.mkdir(parents=True, exist_ok=True) + staged_files: list[dict[str, Any]] = [] + for index, upload in enumerate(files): + filename = Path(upload.filename or f"upload_{index}.txt").name + target = staging_dir / f"{index:03d}_{filename}" + content = await upload.read() + target.write_bytes(content) + staged_files.append( + { + "filename": filename, + "staged_path": str(target.resolve()), + "size": len(content), + } + ) + return staging_dir, staged_files + + +@router.get("/graph") +async def get_memory_graph(limit: int = Query(200, ge=1, le=5000)): + return await _graph_get(limit) + + +@router.post("/graph/node") +async def create_memory_node(payload: NodeRequest): + return await _graph_create_node(payload) + + +@router.delete("/graph/node") +async def delete_memory_node(payload: NodeRequest): + return await _graph_delete_node(payload) + + +@router.post("/graph/node/rename") +async def rename_memory_node(payload: NodeRenameRequest): + return await _graph_rename_node(payload) + + +@router.post("/graph/edge") +async def create_memory_edge(payload: EdgeCreateRequest): + return await _graph_create_edge(payload) + + +@router.delete("/graph/edge") +async def delete_memory_edge(payload: EdgeDeleteRequest): + return await _graph_delete_edge(payload) + + +@router.post("/graph/edge/weight") +async def update_memory_edge_weight(payload: EdgeWeightRequest): + return await _graph_update_edge_weight(payload) + + +@router.get("/sources") +async def list_memory_sources(): + return await _source_list() + + +@router.post("/sources/delete") +async def delete_memory_source(payload: SourceDeleteRequest): + return await _source_delete(payload) + + +@router.post("/sources/batch-delete") +async def batch_delete_memory_sources(payload: SourceBatchDeleteRequest): + return await _source_batch_delete(payload) + + +@router.get("/query/aggregate") +async def query_memory_aggregate( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + chat_id: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _query_aggregate( + query, + limit=limit, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@router.get("/episodes") +async def list_memory_episodes( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + source: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _episode_list( + query=query, + limit=limit, + source=source, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@router.get("/episodes/{episode_id}") +async def get_memory_episode(episode_id: str): + return await _episode_get(episode_id) + + +@router.post("/episodes/rebuild") +async def rebuild_memory_episodes(payload: EpisodeRebuildRequest): + return await _episode_rebuild(payload) + + +@router.get("/episodes/status") +async def get_memory_episode_status(limit: int = Query(20, ge=1, le=200)): + return await _episode_status(limit) + + +@router.post("/episodes/process-pending") +async def process_memory_episode_pending(payload: EpisodeProcessPendingRequest): + return await _episode_process_pending(payload) + + +@router.get("/profiles/query") +async def query_memory_profile( + person_id: str = Query(""), + person_keyword: str = Query(""), + limit: int = Query(12, ge=1, le=100), + force_refresh: bool = Query(False), +): + return await _profile_query( + person_id=person_id, + person_keyword=person_keyword, + limit=limit, + force_refresh=force_refresh, + ) + + +@router.get("/profiles") +async def list_memory_profiles(limit: int = Query(50, ge=1, le=200)): + return await _profile_list(limit) + + +@router.post("/profiles/override") +async def set_memory_profile_override(payload: ProfileOverrideRequest): + return await _profile_set_override(payload) + + +@router.delete("/profiles/override/{person_id}") +async def delete_memory_profile_override(person_id: str): + return await _profile_delete_override(person_id) + + +@router.post("/runtime/save") +async def save_memory_runtime(): + return await _runtime_save() + + +@router.get("/runtime/config") +async def get_memory_runtime_config(): + return await _runtime_config() + + +@router.get("/runtime/self-check") +async def get_memory_runtime_self_check(): + return await _runtime_self_check(False) + + +@router.post("/runtime/self-check/refresh") +async def refresh_memory_runtime_self_check(): + return await _runtime_self_check(True) + + +@router.get("/runtime/auto-save") +async def get_memory_runtime_auto_save(): + return await _runtime_auto_save(None) + + +@router.post("/runtime/auto-save") +async def set_memory_runtime_auto_save(payload: AutoSaveRequest): + return await _runtime_auto_save(payload.enabled) + + +@router.get("/maintenance/recycle-bin") +async def get_memory_recycle_bin(limit: int = Query(50, ge=1, le=200)): + return await _maintenance_recycle_bin(limit) + + +@router.post("/maintenance/restore") +async def restore_memory_relation(payload: MaintainRequest): + return await _maintenance_restore(payload) + + +@router.post("/maintenance/reinforce") +async def reinforce_memory_relation(payload: MaintainRequest): + return await _maintenance_reinforce(payload) + + +@router.post("/maintenance/freeze") +async def freeze_memory_relation(payload: MaintainRequest): + return await _maintenance_freeze(payload) + + +@router.post("/maintenance/protect") +async def protect_memory_relation(payload: MaintainRequest): + return await _maintenance_protect(payload) + + +@router.get("/v5/status") +async def get_memory_v5_status( + target: str = Query(""), + limit: int = Query(50, ge=1, le=200), +): + return await _v5_status(target, limit) + + +@router.get("/v5/recycle-bin") +async def get_memory_v5_recycle_bin(limit: int = Query(50, ge=1, le=200)): + return await _v5_recycle_bin(limit) + + +@router.post("/v5/reinforce") +async def reinforce_memory_v5(payload: V5ActionRequest): + return await _v5_action("reinforce", payload) + + +@router.post("/v5/weaken") +async def weaken_memory_v5(payload: V5ActionRequest): + return await _v5_action("weaken", payload) + + +@router.post("/v5/remember-forever") +async def remember_forever_memory_v5(payload: V5ActionRequest): + return await _v5_action("remember_forever", payload) + + +@router.post("/v5/forget") +async def forget_memory_v5(payload: V5ActionRequest): + return await _v5_action("forget", payload) + + +@router.post("/v5/restore") +async def restore_memory_v5(payload: V5ActionRequest): + return await _v5_action("restore", payload) + + +@router.post("/delete/preview") +async def preview_memory_delete(payload: DeleteActionRequest): + return await _delete_preview(payload) + + +@router.post("/delete/execute") +async def execute_memory_delete(payload: DeleteActionRequest): + return await _delete_execute(payload) + + +@router.post("/delete/restore") +async def restore_memory_delete(payload: DeleteRestoreRequest): + return await _delete_restore(payload) + + +@router.get("/delete/operations") +async def list_memory_delete_operations( + limit: int = Query(50, ge=1, le=200), + mode: str = Query(""), +): + return await _delete_list(limit, mode) + + +@router.get("/delete/operations/{operation_id}") +async def get_memory_delete_operation(operation_id: str): + return await _delete_get(operation_id) + + +@router.post("/delete/purge") +async def purge_memory_delete(payload: DeletePurgeRequest): + return await _delete_purge(payload) + + +@router.get("/import/settings") +async def get_memory_import_settings(): + return await _import_settings() + + +@router.get("/import/path-aliases") +async def get_memory_import_path_aliases(): + return await _import_path_aliases() + + +@router.get("/import/guide") +async def get_memory_import_guide(): + return await _import_guide() + + +@router.post("/import/resolve-path") +async def resolve_memory_import_path(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_resolve_path(payload) + + +@router.post("/import/upload") +async def create_memory_import_upload( + files: list[UploadFile] = File(...), + payload_json: str = Form("{}"), +): + staging_dir, staged_files = await _stage_upload_files(files) + try: + try: + payload = json.loads(payload_json or "{}") + except Exception: + payload = {} + if not isinstance(payload, dict): + payload = {} + payload["staged_files"] = staged_files + return await _import_create("create_upload", payload) + finally: + shutil.rmtree(staging_dir, ignore_errors=True) + + +@router.post("/import/paste") +async def create_memory_import_paste(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_paste", payload) + + +@router.post("/import/raw-scan") +async def create_memory_import_raw_scan(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_raw_scan", payload) + + +@router.post("/import/lpmm-openie") +async def create_memory_import_lpmm_openie(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_openie", payload) + + +@router.post("/import/lpmm-convert") +async def create_memory_import_lpmm_convert(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_convert", payload) + + +@router.post("/import/temporal-backfill") +async def create_memory_import_temporal_backfill(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_temporal_backfill", payload) + + +@router.post("/import/maibot-migration") +async def create_memory_import_maibot_migration(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_maibot_migration", payload) + + +@router.get("/import/tasks") +async def list_memory_import_tasks(limit: int = Query(50, ge=1, le=200)): + return await _import_list(limit) + + +@router.get("/import/tasks/{task_id}") +async def get_memory_import_task(task_id: str, include_chunks: bool = Query(False)): + return await _import_get(task_id, include_chunks) + + +@router.get("/import/tasks/{task_id}/chunks/{file_id}") +async def get_memory_import_chunks( + task_id: str, + file_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _import_chunks(task_id, file_id, offset, limit) + + +@router.post("/import/tasks/{task_id}/cancel") +async def cancel_memory_import_task(task_id: str): + return await _import_cancel(task_id) + + +@router.post("/import/tasks/{task_id}/retry") +async def retry_memory_import_task(task_id: str, payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_retry(task_id, payload) + + +@router.get("/retrieval_tuning/settings") +async def get_memory_tuning_settings(): + return await _tuning_settings() + + +@router.get("/retrieval_tuning/profile") +async def get_memory_tuning_profile(): + return await _tuning_profile() + + +@router.post("/retrieval_tuning/profile/apply") +async def apply_memory_tuning_profile(payload: TuningApplyProfileRequest): + return await _tuning_apply_profile(payload) + + +@router.post("/retrieval_tuning/profile/rollback") +async def rollback_memory_tuning_profile(): + return await _tuning_rollback_profile() + + +@router.get("/retrieval_tuning/profile/export") +async def export_memory_tuning_profile(): + return await _tuning_export_profile() + + +@router.post("/retrieval_tuning/tasks") +async def create_memory_tuning_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _tuning_create_task(payload) + + +@router.get("/retrieval_tuning/tasks") +async def list_memory_tuning_tasks(limit: int = Query(50, ge=1, le=200)): + return await _tuning_list_tasks(limit) + + +@router.get("/retrieval_tuning/tasks/{task_id}") +async def get_memory_tuning_task(task_id: str, include_rounds: bool = Query(False)): + return await _tuning_get_task(task_id, include_rounds) + + +@router.get("/retrieval_tuning/tasks/{task_id}/rounds") +async def get_memory_tuning_rounds( + task_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _tuning_get_rounds(task_id, offset, limit) + + +@router.post("/retrieval_tuning/tasks/{task_id}/cancel") +async def cancel_memory_tuning_task(task_id: str): + return await _tuning_cancel(task_id) + + +@router.post("/retrieval_tuning/tasks/{task_id}/apply-best") +async def apply_best_memory_tuning_profile(task_id: str): + return await _tuning_apply_best(task_id) + + +@router.get("/retrieval_tuning/tasks/{task_id}/report") +async def get_memory_tuning_report(task_id: str, format: str = Query("md")): + return await _tuning_report(task_id, format) + + +@compat_router.get("/graph") +async def compat_get_graph(limit: int = Query(200, ge=1, le=5000)): + return await _graph_get(limit) + + +@compat_router.post("/node") +async def compat_create_node(payload: NodeRequest): + return await _graph_create_node(payload) + + +@compat_router.delete("/node") +async def compat_delete_node(payload: NodeRequest): + return await _graph_delete_node(payload) + + +@compat_router.post("/node/rename") +async def compat_rename_node(payload: NodeRenameRequest): + return await _graph_rename_node(payload) + + +@compat_router.post("/edge") +async def compat_create_edge(payload: EdgeCreateRequest): + return await _graph_create_edge(payload) + + +@compat_router.delete("/edge") +async def compat_delete_edge(payload: EdgeDeleteRequest): + return await _graph_delete_edge(payload) + + +@compat_router.post("/edge/weight") +async def compat_update_edge_weight(payload: EdgeWeightRequest): + return await _graph_update_edge_weight(payload) + + +@compat_router.get("/source/list") +async def compat_list_sources(): + return await _source_list() + + +@compat_router.post("/source/delete") +async def compat_delete_source(payload: SourceDeleteRequest): + return await _source_delete(payload) + + +@compat_router.post("/source/batch_delete") +async def compat_batch_delete_sources(payload: SourceBatchDeleteRequest): + return await _source_batch_delete(payload) + + +@compat_router.get("/query/aggregate") +async def compat_query_aggregate( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + chat_id: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _query_aggregate( + query, + limit=limit, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@compat_router.get("/episodes") +async def compat_list_episodes( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + source: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _episode_list( + query=query, + limit=limit, + source=source, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@compat_router.get("/episodes/{episode_id}") +async def compat_get_episode(episode_id: str): + return await _episode_get(episode_id) + + +@compat_router.post("/episodes/rebuild") +async def compat_rebuild_episodes(payload: EpisodeRebuildRequest): + return await _episode_rebuild(payload) + + +@compat_router.get("/episodes/status") +async def compat_episode_status(limit: int = Query(20, ge=1, le=200)): + return await _episode_status(limit) + + +@compat_router.post("/episodes/process_pending") +async def compat_process_episode_pending(payload: EpisodeProcessPendingRequest): + return await _episode_process_pending(payload) + + +@compat_router.get("/person_profile/query") +async def compat_profile_query( + person_id: str = Query(""), + person_keyword: str = Query(""), + limit: int = Query(12, ge=1, le=100), + force_refresh: bool = Query(False), +): + return await _profile_query( + person_id=person_id, + person_keyword=person_keyword, + limit=limit, + force_refresh=force_refresh, + ) + + +@compat_router.get("/person_profile/list") +async def compat_profile_list(limit: int = Query(50, ge=1, le=200)): + return await _profile_list(limit) + + +@compat_router.post("/person_profile/override") +async def compat_set_profile_override(payload: ProfileOverrideRequest): + return await _profile_set_override(payload) + + +@compat_router.delete("/person_profile/override/{person_id}") +async def compat_delete_profile_override(person_id: str): + return await _profile_delete_override(person_id) + + +@compat_router.post("/save") +async def compat_runtime_save(): + return await _runtime_save() + + +@compat_router.get("/config") +async def compat_runtime_config(): + return await _runtime_config() + + +@compat_router.get("/runtime/self_check") +async def compat_runtime_self_check(): + return await _runtime_self_check(False) + + +@compat_router.post("/runtime/self_check/refresh") +async def compat_refresh_runtime_self_check(): + return await _runtime_self_check(True) + + +@compat_router.get("/config/auto_save") +async def compat_runtime_auto_save(): + return await _runtime_auto_save(None) + + +@compat_router.post("/config/auto_save") +async def compat_set_runtime_auto_save(payload: AutoSaveRequest): + return await _runtime_auto_save(payload.enabled) + + +@compat_router.get("/memory/recycle_bin") +async def compat_get_recycle_bin(limit: int = Query(50, ge=1, le=200)): + return await _maintenance_recycle_bin(limit) + + +@compat_router.post("/memory/restore") +async def compat_restore_memory(payload: MaintainRequest): + return await _maintenance_restore(payload) + + +@compat_router.post("/memory/reinforce") +async def compat_reinforce_memory(payload: MaintainRequest): + return await _maintenance_reinforce(payload) + + +@compat_router.post("/memory/freeze") +async def compat_freeze_memory(payload: MaintainRequest): + return await _maintenance_freeze(payload) + + +@compat_router.post("/memory/protect") +async def compat_protect_memory(payload: MaintainRequest): + return await _maintenance_protect(payload) + + +@compat_router.get("/import/settings") +async def compat_import_settings(): + return await _import_settings() + + +@compat_router.get("/import/path_aliases") +async def compat_import_path_aliases(): + return await _import_path_aliases() + + +@compat_router.get("/import/guide") +async def compat_import_guide(): + return await _import_guide() + + +@compat_router.post("/import/resolve_path") +async def compat_import_resolve_path(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_resolve_path(payload) + + +@compat_router.post("/import/upload") +async def compat_import_upload( + files: list[UploadFile] = File(...), + payload_json: str = Form("{}"), +): + return await create_memory_import_upload(files=files, payload_json=payload_json) + + +@compat_router.post("/import/tasks/upload") +async def compat_import_upload_task( + files: list[UploadFile] = File(...), + payload_json: str = Form("{}"), +): + return await create_memory_import_upload(files=files, payload_json=payload_json) + + +@compat_router.post("/import/paste") +async def compat_import_paste(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_paste", payload) + + +@compat_router.post("/import/tasks/paste") +async def compat_import_paste_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_paste", payload) + + +@compat_router.post("/import/raw_scan") +async def compat_import_raw_scan(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_raw_scan", payload) + + +@compat_router.post("/import/tasks/raw_scan") +async def compat_import_raw_scan_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_raw_scan", payload) + + +@compat_router.post("/import/lpmm_openie") +async def compat_import_lpmm_openie(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_openie", payload) + + +@compat_router.post("/import/tasks/lpmm_openie") +async def compat_import_lpmm_openie_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_openie", payload) + + +@compat_router.post("/import/lpmm_convert") +async def compat_import_lpmm_convert(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_convert", payload) + + +@compat_router.post("/import/tasks/lpmm_convert") +async def compat_import_lpmm_convert_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_convert", payload) + + +@compat_router.post("/import/temporal_backfill") +async def compat_import_temporal_backfill(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_temporal_backfill", payload) + + +@compat_router.post("/import/tasks/temporal_backfill") +async def compat_import_temporal_backfill_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_temporal_backfill", payload) + + +@compat_router.post("/import/maibot_migration") +async def compat_import_maibot_migration(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_maibot_migration", payload) + + +@compat_router.post("/import/tasks/maibot_migration") +async def compat_import_maibot_migration_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_maibot_migration", payload) + + +@compat_router.get("/import/tasks") +async def compat_import_list(limit: int = Query(50, ge=1, le=200)): + return await _import_list(limit) + + +@compat_router.get("/import/tasks/{task_id}") +async def compat_import_get(task_id: str, include_chunks: bool = Query(False)): + return await _import_get(task_id, include_chunks) + + +@compat_router.get("/import/tasks/{task_id}/chunks/{file_id}") +async def compat_import_chunks( + task_id: str, + file_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _import_chunks(task_id, file_id, offset, limit) + + +@compat_router.get("/import/tasks/{task_id}/files/{file_id}/chunks") +async def compat_import_file_chunks( + task_id: str, + file_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _import_chunks(task_id, file_id, offset, limit) + + +@compat_router.post("/import/tasks/{task_id}/cancel") +async def compat_import_cancel(task_id: str): + return await _import_cancel(task_id) + + +@compat_router.post("/import/tasks/{task_id}/retry") +async def compat_import_retry(task_id: str, payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_retry(task_id, payload) + + +@compat_router.post("/import/tasks/{task_id}/retry_failed") +async def compat_import_retry_failed(task_id: str, payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_retry(task_id, payload) + + +@compat_router.get("/retrieval_tuning/settings") +async def compat_tuning_settings(): + return await _tuning_settings() + + +@compat_router.get("/retrieval_tuning/profile") +async def compat_tuning_profile(): + return await _tuning_profile() + + +@compat_router.post("/retrieval_tuning/profile/apply") +async def compat_apply_tuning_profile(payload: TuningApplyProfileRequest): + return await _tuning_apply_profile(payload) + + +@compat_router.post("/retrieval_tuning/profile/rollback") +async def compat_rollback_tuning_profile(): + return await _tuning_rollback_profile() + + +@compat_router.get("/retrieval_tuning/profile/export") +async def compat_export_tuning_profile(): + return await _tuning_export_profile() + + +@compat_router.get("/retrieval_tuning/profile/export_toml") +async def compat_export_tuning_profile_toml(): + return await _tuning_export_profile() + + +@compat_router.post("/retrieval_tuning/tasks") +async def compat_create_tuning_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _tuning_create_task(payload) + + +@compat_router.get("/retrieval_tuning/tasks") +async def compat_list_tuning_tasks(limit: int = Query(50, ge=1, le=200)): + return await _tuning_list_tasks(limit) + + +@compat_router.get("/retrieval_tuning/tasks/{task_id}") +async def compat_get_tuning_task(task_id: str, include_rounds: bool = Query(False)): + return await _tuning_get_task(task_id, include_rounds) + + +@compat_router.get("/retrieval_tuning/tasks/{task_id}/rounds") +async def compat_get_tuning_rounds( + task_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _tuning_get_rounds(task_id, offset, limit) + + +@compat_router.post("/retrieval_tuning/tasks/{task_id}/cancel") +async def compat_cancel_tuning_task(task_id: str): + return await _tuning_cancel(task_id) + + +@compat_router.post("/retrieval_tuning/tasks/{task_id}/apply_best") +async def compat_apply_best_tuning_profile(task_id: str): + return await _tuning_apply_best(task_id) + + +@compat_router.post("/retrieval_tuning/tasks/{task_id}/apply-best") +async def compat_apply_best_tuning_profile_kebab(task_id: str): + return await _tuning_apply_best(task_id) + + +@compat_router.get("/retrieval_tuning/tasks/{task_id}/report") +async def compat_get_tuning_report(task_id: str, format: str = Query("md")): + return await _tuning_report(task_id, format) diff --git a/src/webui/routes.py b/src/webui/routes.py index c1a7e446..5e33e78b 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -16,6 +16,7 @@ from src.webui.routers.config import router as config_router from src.webui.routers.emoji import router as emoji_router from src.webui.routers.expression import router as expression_router from src.webui.routers.jargon import router as jargon_router +from src.webui.routers.memory import router as memory_router from src.webui.routers.model import router as model_router from src.webui.routers.person import router as person_router from src.webui.routers.plugin import get_progress_router @@ -49,6 +50,8 @@ router.include_router(get_progress_router()) router.include_router(system_router) # 注册模型列表获取路由 router.include_router(model_router) +# 注册长期记忆管理路由 +router.include_router(memory_router) # 注册 WebSocket 认证路由 router.include_router(ws_auth_router)