feat:新增记忆测试、检索工具与服务

新增完整的长期记忆支持及测试:引入中文记忆检索提示词、query_long_term_memory 检索工具、记忆服务与记忆流程服务,以及 WebUI 的记忆路由。新增大规模测试套件(包括单元测试与基准/在线测试),覆盖聊天历史摘要、知识获取器、事件(episode)生成、写回机制以及用户画像检索等功能。

更新多个模块以集成记忆检索能力(包括 knowledge fetcher、chat summarizer、memory_retrieval、person_info、config/legacy 迁移以及 WebUI 路由),并移除遗留的 lpmm 知识模块。这些变更完成了记忆运行时的接入,同时为基准测试提供嵌入适配器的 mock,并支持新测试与工具所需的导入与 episode 处理流程。
This commit is contained in:
DawnARC
2026-03-18 21:35:17 +08:00
parent 999e7246e2
commit bd84e500e1
39 changed files with 5849 additions and 764 deletions

View File

@@ -0,0 +1,26 @@
你是一个专门获取长期记忆的助手。你的名字是{bot_name}。现在是{time_now}。
群里正在进行的聊天内容:
{chat_history}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 内容中是否包含需要查询历史知识或长期记忆的问题
2. 是否有明确的知识获取指令
如果需要使用长期记忆工具,请直接调用函数 `search_long_term_memory`;如果不需要任何工具,直接输出 `No tool needed`。
工具模式说明:
- `mode="search"`:普通长期记忆检索,适合查具体事实、偏好、历史对话内容
- `mode="time"`:按时间范围检索,必须同时提供 `time_expression`
- `mode="episode"`:按事件/情节检索,适合查“那次经历”“那件事的经过”
- `mode="aggregate"`:综合检索,适合“整体回忆一下”“把相关线索综合找出来”
优先规则:
- 问“某段时间发生了什么”:优先 `time`
- 问“某次事件/某段经历”:优先 `episode`
- 问“整体情况/最近发生过什么”:优先 `aggregate`
- 问单点事实:优先 `search`
`time_expression` 可用表达:
- `今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
- 或绝对时间:`2026/03/18`、`2026/03/18 09:30`

View File

@@ -0,0 +1,34 @@
你的名字是{bot_name}。现在是{time_now}。
你正在参与聊天,你需要搜集信息来帮助你进行回复。
重要,这是当前聊天记录:
{chat_history}
聊天记录结束
已收集的信息:
{collected_info}
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
- 思考完毕后,使用工具
**工具说明:**
- 如果涉及过往事件、历史对话、用户长期偏好或某段时间发生的事件,可以使用长期记忆查询工具
- 如果遇到不熟悉的词语、缩写、黑话或网络用语可以使用query_words工具查询其含义
- 你必须使用tool如果需要查询你必须给出使用什么工具进行查询
- 当你决定结束查询时必须调用return_information工具返回总结信息并结束查询
长期记忆工具 `search_long_term_memory` 支持以下模式:
- `mode="search"`:普通事实/偏好/历史内容检索。适合问“她喜欢什么”“我们之前讨论过什么”。
- `mode="time"`按时间范围检索。适合问“昨天发生了什么”“最近7天有哪些相关记忆”。
- `mode="episode"`:按事件/情节检索。适合问“那次灯塔停电的经过是什么”“关于某次经历还有什么”。
- `mode="aggregate"`:综合检索。适合问“帮我整体回忆一下这个人最近的情况”“把相关线索综合找出来”。
模式选择建议:
- 问单点事实、偏好、人设、具体信息:优先 `search`
- 问某段时间发生了什么:优先 `time`
- 问某次事件、某段经历、某个剧情片段:优先 `episode`
- 问整体回忆、综合找线索、总结最近发生的事:优先 `aggregate`
时间模式要求:
- 使用 `mode="time"` 时,必须填写 `time_expression`
- 可用时间表达包括:`今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
- 也可以使用绝对时间:`2026/03/18`、`2026/03/18 09:30`

View File

@@ -0,0 +1,148 @@
from types import SimpleNamespace
import pytest
from src.memory_system import chat_history_summarizer as summarizer_module
def _build_summarizer() -> summarizer_module.ChatHistorySummarizer:
summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer)
summarizer.session_id = "session-1"
summarizer.log_prefix = "[session-1]"
return summarizer
@pytest.mark.asyncio
async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch):
calls = []
summarizer = _build_summarizer()
async def fake_ingest_summary(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
)
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8)))
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._import_to_long_term_memory(
record_id=1,
theme="旅行计划",
summary="我们讨论了春游安排",
participants=["Alice", "Bob"],
start_time=1.0,
end_time=2.0,
original_text="long text",
)
assert len(calls) == 1
payload = calls[0]
assert payload["external_id"] == "chat_history:1"
assert payload["chat_id"] == "session-1"
assert payload["participants"] == ["Alice", "Bob"]
assert payload["respect_filter"] is True
assert payload["user_id"] == "user-1"
assert payload["group_id"] == ""
assert "主题:旅行计划" in payload["text"]
assert "概括:我们讨论了春游安排" in payload["text"]
@pytest.mark.asyncio
async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch):
summarizer = _build_summarizer()
fallback_calls = []
async def fake_fallback(**kwargs):
fallback_calls.append(kwargs)
summarizer._fallback_import_to_long_term_memory = fake_fallback
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
)
await summarizer._import_to_long_term_memory(
record_id=2,
theme="",
summary="",
participants=[],
start_time=10.0,
end_time=20.0,
original_text="raw chat",
)
assert len(fallback_calls) == 1
assert fallback_calls[0]["record_id"] == 2
assert fallback_calls[0]["original_text"] == "raw chat"
@pytest.mark.asyncio
async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch):
summarizer = _build_summarizer()
fallback_calls = []
async def fake_ingest_summary(**kwargs):
return SimpleNamespace(success=False, detail="boom", stored_ids=[])
async def fake_fallback(**kwargs):
fallback_calls.append(kwargs)
summarizer._fallback_import_to_long_term_memory = fake_fallback
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")),
)
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._import_to_long_term_memory(
record_id=3,
theme="电影",
summary="聊了电影推荐",
participants=["Alice"],
start_time=3.0,
end_time=4.0,
original_text="raw",
)
assert len(fallback_calls) == 1
assert fallback_calls[0]["theme"] == "电影"
@pytest.mark.asyncio
async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch):
calls = []
summarizer = _build_summarizer()
async def fake_ingest_summary(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[])
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")),
)
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12)))
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._fallback_import_to_long_term_memory(
record_id=4,
theme="工作",
participants=["Alice"],
start_time=5.0,
end_time=6.0,
original_text="a" * 128,
)
assert len(calls) == 1
metadata = calls[0]["metadata"]
assert metadata["generate_from_chat"] is True
assert metadata["context_length"] == 12
assert calls[0]["respect_filter"] is True

View File

@@ -0,0 +1,127 @@
from types import SimpleNamespace
import pytest
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.services.memory_service import MemoryHit, MemorySearchResult
def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch):
monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object())
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
assert fetcher._resolve_private_memory_context() == {
"chat_id": "stream-1",
"person_id": "Alice:qq:42",
"user_id": "42",
"group_id": "",
}
@pytest.mark.asyncio
async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch):
monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object())
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
)
calls = []
async def fake_search(query: str, **kwargs):
calls.append((query, kwargs))
return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False)
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
result = await fetcher._memory_get_knowledge("她喜欢什么")
assert "1. 她喜欢猫" in result
assert calls == [
(
"她喜欢什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "Alice:qq:42",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
)
]
@pytest.mark.asyncio
async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch):
monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object())
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: "person-1",
)
calls = []
async def fake_search(query: str, **kwargs):
calls.append((query, kwargs))
if kwargs.get("person_id"):
return MemorySearchResult(summary="", hits=[], filtered=False)
return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False)
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
result = await fetcher._memory_get_knowledge("Alice 最近在忙什么")
assert "杭州音乐节" in result
assert calls == [
(
"Alice 最近在忙什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
),
(
"Alice 最近在忙什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
),
]

View File

@@ -0,0 +1,35 @@
from src.config.legacy_migration import try_migrate_legacy_bot_config_dict
def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
payload = {
"expression": {
"learning_list": [
["qq:123456:group", "enable", "disable", "0.5"],
["", "disable", "enable", "0.1"],
]
}
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "expression.learning_list" in result.reason
assert result.data["expression"]["learning_list"] == [
{
"platform": "qq",
"item_id": "123456",
"rule_type": "group",
"use_expression": True,
"enable_learning": False,
"enable_jargon_learning": False,
},
{
"platform": "",
"item_id": "",
"rule_type": "group",
"use_expression": False,
"enable_learning": True,
"enable_jargon_learning": False,
},
]

View File

@@ -0,0 +1,691 @@
from __future__ import annotations
import asyncio
import inspect
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import numpy as np
import pytest
import pytest_asyncio
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import MemorySearchResult, memory_service
DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json"
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json"
def _load_benchmark_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _FakeEmbeddingAdapter:
def __init__(self, dimension: int = 32) -> None:
self.dimension = dimension
async def _detect_dimension(self) -> int:
return self.dimension
async def encode(self, texts, dimensions=None):
dim = int(dimensions or self.dimension)
if isinstance(texts, str):
sequence = [texts]
single = True
else:
sequence = list(texts)
single = False
rows = []
for text in sequence:
vec = np.zeros(dim, dtype=np.float32)
for ch in str(text or ""):
code = ord(ch)
vec[code % dim] += 1.0
vec[(code * 7) % dim] += 0.5
if not vec.any():
vec[0] = 1.0
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
rows.append(vec)
payload = np.vstack(rows)
return payload[0] if single else payload
class _KnownPerson:
def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None:
self.person_id = person_id
self.is_known = person_id in reverse_registry
self.person_name = reverse_registry.get(person_id, "")
self._registry = registry
class _KernelBackedRuntimeManager:
is_running = True
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke_plugin(
self,
*,
method: str,
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None,
timeout_ms: int,
):
del method, plugin_id, timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]:
for _ in range(max_rounds):
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(max(0.01, float(sleep_seconds)))
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result: MemorySearchResult) -> str:
return "\n".join(hit.content for hit in search_result.hits)
def _keyword_hits(text: str, keywords: List[str]) -> int:
haystack = str(text or "")
return sum(1 for keyword in keywords if keyword in haystack)
def _keyword_recall(text: str, keywords: List[str]) -> float:
if not keywords:
return 1.0
return _keyword_hits(text, keywords) / float(len(keywords))
def _hit_blob(hit) -> str:
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
return "\n".join(
[
str(hit.content or ""),
str(hit.title or ""),
str(hit.source or ""),
json.dumps(meta, ensure_ascii=False),
]
)
def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int:
for index, hit in enumerate(search_result.hits[:5], start=1):
if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))):
return index
return 0
def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str:
return "\n".join(
(
f"{item.get('title', '')}\n"
f"{item.get('summary', '')}\n"
f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n"
f"{json.dumps(item.get('participants', []), ensure_ascii=False)}"
)
for item in items
)
def _episode_blob_from_hits(search_result: MemorySearchResult) -> str:
chunks = []
for hit in search_result.hits:
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
chunks.append(
"\n".join(
[
str(hit.title or ""),
str(hit.content or ""),
json.dumps(meta.get("keywords", []) or [], ensure_ascii=False),
json.dumps(meta.get("participants", []) or [], ensure_ascii=False),
]
)
)
return "\n".join(chunks)
async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
episode_source = f"chat_summary:{session_id}"
payload = await memory_service.episode_admin(
action="query",
source=episode_source,
limit=20,
)
items = payload.get("items") or []
blob = _episode_blob_from_items(items)
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in episode_cases:
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(items),
"top_episode": items[0] if items else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"episode_count": len(items),
"reports": reports,
}
async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
episode_source = f"chat_summary:{session_id}"
for case in episode_cases:
payload = await memory_service.episode_admin(
action="query",
source=episode_source,
query=case["query"],
limit=5,
)
items = payload.get("items") or []
blob = "\n".join(
f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}"
for item in items
)
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(items),
"top_episode": items[0] if items else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in episode_cases:
result = await memory_service.search(
case["query"],
mode="episode",
chat_id=session_id,
respect_filter=False,
limit=5,
)
blob = _episode_blob_from_hits(result)
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(result.hits),
"top_episode": result.hits[0].to_dict() if result.hits else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]:
search_case = dataset["search_cases"][0]
episode_case = dataset["episode_cases"][0]
aggregate_case = dataset["knowledge_fetcher_cases"][0]
tool_cases = [
{
"name": "search",
"kwargs": {
"query": "蓝漆铁盒 北塔木梯",
"mode": "search",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"],
"minimum_keyword_recall": 0.67,
},
{
"name": "time",
"kwargs": {
"query": "蓝漆铁盒 北塔",
"mode": "time",
"chat_id": session_id,
"limit": 5,
"time_expression": "最近7天",
},
"expected_keywords": ["蓝漆铁盒", "北塔木梯"],
"minimum_keyword_recall": 0.67,
},
{
"name": "episode",
"kwargs": {
"query": episode_case["query"],
"mode": "episode",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": episode_case["expected_keywords"],
"minimum_keyword_recall": 0.67,
},
{
"name": "aggregate",
"kwargs": {
"query": aggregate_case["query"],
"mode": "aggregate",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": aggregate_case["expected_keywords"],
"minimum_keyword_recall": 0.67,
},
]
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in tool_cases:
text = await query_long_term_memory(**case["kwargs"])
recall = _keyword_recall(text, case["expected_keywords"])
success = (
"失败" not in text
and "无法解析" not in text
and "未找到" not in text
and recall >= float(case["minimum_keyword_recall"])
)
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"name": case["name"],
"success": success,
"keyword_recall": recall,
"preview": text[:320],
}
)
total = max(1, len(tool_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
@pytest_asyncio.fixture
async def benchmark_env(monkeypatch, tmp_path):
dataset = _load_benchmark_fixture()
session_cfg = dataset["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
reverse_registry = {value: key for key, value in registry.items()}
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
async def fake_self_check(**kwargs):
return {"ok": True, "message": "ok", "encoded_dimension": 32}
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
monkeypatch.setattr(
person_info_module,
"Person",
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
)
data_dir = (tmp_path / "a_memorix_benchmark_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
await kernel.initialize()
try:
yield {
"dataset": dataset,
"kernel": kernel,
"session": session,
"person_registry": registry,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_long_novel_memory_benchmark(benchmark_env):
dataset = benchmark_env["dataset"]
session_id = benchmark_env["session"].session_id
created = await memory_service.import_admin(
action="create_paste",
name="long_novel_memory_benchmark.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
import_detail = await _wait_for_import_task(created["task"]["task_id"])
assert import_detail["task"]["status"] == "completed"
for record in dataset["chat_history_records"]:
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
for payload in dataset["person_writebacks"]:
await person_info_module.store_person_memory_from_answer(
payload["person_name"],
payload["memory_content"],
session_id,
)
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
search_case_reports: List[Dict[str, Any]] = []
search_accuracy_at_1 = 0.0
search_recall_at_5 = 0.0
search_precision_at_5 = 0.0
search_mrr = 0.0
search_keyword_recall = 0.0
for case in dataset["search_cases"]:
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
joined = _join_hit_content(result)
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
relevant_hits = sum(
1
for hit in result.hits[:5]
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
)
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
search_recall_at_5 += 1.0 if rank > 0 else 0.0
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
search_keyword_recall += keyword_recall
search_case_reports.append(
{
"query": case["query"],
"rank_of_first_relevant": rank,
"relevant_hits_top5": relevant_hits,
"keyword_recall_top5": keyword_recall,
"top_hit": result.hits[0].to_dict() if result.hits else None,
}
)
search_total = max(1, len(dataset["search_cases"]))
writeback_reports: List[Dict[str, Any]] = []
writeback_success_rate = 0.0
writeback_keyword_recall = 0.0
for payload in dataset["person_writebacks"]:
query = " ".join(payload["expected_keywords"])
result = await memory_service.search(
query,
mode="search",
chat_id=session_id,
person_id=payload["person_id"],
respect_filter=False,
limit=5,
)
joined = _join_hit_content(result)
recall = _keyword_recall(joined, payload["expected_keywords"])
success = bool(result.hits) and recall >= 0.67
writeback_success_rate += 1.0 if success else 0.0
writeback_keyword_recall += recall
writeback_reports.append(
{
"person_id": payload["person_id"],
"success": success,
"keyword_recall": recall,
"hit_count": len(result.hits),
}
)
writeback_total = max(1, len(dataset["person_writebacks"]))
knowledge_reports: List[Dict[str, Any]] = []
knowledge_success_rate = 0.0
knowledge_keyword_recall = 0.0
fetcher = knowledge_module.KnowledgeFetcher(
private_name=dataset["session"]["display_name"],
stream_id=session_id,
)
for case in dataset["knowledge_fetcher_cases"]:
knowledge_text, _ = await fetcher.fetch(case["query"], [])
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
knowledge_success_rate += 1.0 if success else 0.0
knowledge_keyword_recall += recall
knowledge_reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"preview": knowledge_text[:300],
}
)
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
profile_reports: List[Dict[str, Any]] = []
profile_success_rate = 0.0
profile_keyword_recall = 0.0
profile_evidence_rate = 0.0
for case in dataset["profile_cases"]:
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
recall = _keyword_recall(profile.summary, case["expected_keywords"])
has_evidence = bool(profile.evidence)
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
profile_success_rate += 1.0 if success else 0.0
profile_keyword_recall += recall
profile_evidence_rate += 1.0 if has_evidence else 0.0
profile_reports.append(
{
"person_id": case["person_id"],
"success": success,
"keyword_recall": recall,
"evidence_count": len(profile.evidence),
"summary_preview": profile.summary[:240],
}
)
profile_total = max(1, len(dataset["profile_cases"]))
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_rebuild = await memory_service.episode_admin(
action="rebuild",
source=f"chat_summary:{session_id}",
)
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
report = {
"dataset": dataset["meta"],
"import": {
"task_id": created["task"]["task_id"],
"status": import_detail["task"]["status"],
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
},
"metrics": {
"search": {
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
"recall_at_5": round(search_recall_at_5 / search_total, 4),
"precision_at_5": round(search_precision_at_5 / search_total, 4),
"mrr": round(search_mrr / search_total, 4),
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
},
"writeback": {
"success_rate": round(writeback_success_rate / writeback_total, 4),
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
},
"knowledge_fetcher": {
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
},
"profile": {
"success_rate": round(profile_success_rate / profile_total, 4),
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
},
"tool_modes": {
"success_rate": tool_modes["success_rate"],
"keyword_recall": tool_modes["keyword_recall"],
},
"episode_generation_auto": {
"success_rate": episode_generation_auto["success_rate"],
"keyword_recall": episode_generation_auto["keyword_recall"],
"episode_count": episode_generation_auto["episode_count"],
},
"episode_generation_after_rebuild": {
"success_rate": episode_generation_after_rebuild["success_rate"],
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
"episode_count": episode_generation_after_rebuild["episode_count"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_admin_query_auto": {
"success_rate": episode_admin_query_auto["success_rate"],
"keyword_recall": episode_admin_query_auto["keyword_recall"],
},
"episode_admin_query_after_rebuild": {
"success_rate": episode_admin_query_after_rebuild["success_rate"],
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_search_mode_auto": {
"success_rate": episode_search_mode_auto["success_rate"],
"keyword_recall": episode_search_mode_auto["keyword_recall"],
},
"episode_search_mode_after_rebuild": {
"success_rate": episode_search_mode_after_rebuild["success_rate"],
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
},
"cases": {
"search": search_case_reports,
"writeback": writeback_reports,
"knowledge_fetcher": knowledge_reports,
"profile": profile_reports,
"tool_modes": tool_modes["reports"],
"episode_generation_auto": episode_generation_auto["reports"],
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
"episode_admin_query_auto": episode_admin_query_auto["reports"],
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
"episode_search_mode_auto": episode_search_mode_auto["reports"],
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
},
}
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
assert report["import"]["status"] == "completed"
assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35
assert report["metrics"]["search"]["recall_at_5"] >= 0.6
assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8
assert report["metrics"]["writeback"]["success_rate"] >= 0.66
assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66
assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75
assert report["metrics"]["profile"]["success_rate"] >= 0.66
assert report["metrics"]["profile"]["evidence_rate"] >= 1.0
assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75
assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True
assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"]

View File

@@ -0,0 +1,343 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
import pytest_asyncio
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from pytests.test_long_novel_memory_benchmark import (
_evaluate_episode_admin_query,
_evaluate_episode_generation,
_evaluate_episode_search_mode,
_evaluate_tool_modes,
_KernelBackedRuntimeManager,
_KnownPerson,
_first_relevant_rank,
_hit_blob,
_join_hit_content,
_keyword_hits,
_keyword_recall,
_load_benchmark_fixture,
_wait_for_import_task,
)
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
pytestmark = pytest.mark.skipif(
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
reason="需要显式开启真实 external embedding benchmark",
)
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json"
@pytest_asyncio.fixture
async def benchmark_live_env(monkeypatch, tmp_path):
dataset = _load_benchmark_fixture()
session_cfg = dataset["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
reverse_registry = {value: key for key, value in registry.items()}
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
monkeypatch.setattr(
person_info_module,
"Person",
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
)
data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
await kernel.initialize()
try:
yield {
"dataset": dataset,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_long_novel_memory_benchmark_live(benchmark_live_env):
dataset = benchmark_live_env["dataset"]
session_id = benchmark_live_env["session"].session_id
self_check = await memory_service.runtime_admin(action="refresh_self_check")
assert self_check["success"] is True
assert self_check["report"]["ok"] is True
created = await memory_service.import_admin(
action="create_paste",
name="long_novel_memory_benchmark.live.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
import_detail = await _wait_for_import_task(
created["task"]["task_id"],
max_rounds=2400,
sleep_seconds=0.25,
)
assert import_detail["task"]["status"] == "completed"
for record in dataset["chat_history_records"]:
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
for payload in dataset["person_writebacks"]:
await person_info_module.store_person_memory_from_answer(
payload["person_name"],
payload["memory_content"],
session_id,
)
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
search_case_reports: List[Dict[str, Any]] = []
search_accuracy_at_1 = 0.0
search_recall_at_5 = 0.0
search_precision_at_5 = 0.0
search_mrr = 0.0
search_keyword_recall = 0.0
for case in dataset["search_cases"]:
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
joined = _join_hit_content(result)
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
relevant_hits = sum(
1
for hit in result.hits[:5]
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
)
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
search_recall_at_5 += 1.0 if rank > 0 else 0.0
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
search_keyword_recall += keyword_recall
search_case_reports.append(
{
"query": case["query"],
"rank_of_first_relevant": rank,
"relevant_hits_top5": relevant_hits,
"keyword_recall_top5": keyword_recall,
"top_hit": result.hits[0].to_dict() if result.hits else None,
}
)
search_total = max(1, len(dataset["search_cases"]))
writeback_reports: List[Dict[str, Any]] = []
writeback_success_rate = 0.0
writeback_keyword_recall = 0.0
for payload in dataset["person_writebacks"]:
query = " ".join(payload["expected_keywords"])
result = await memory_service.search(
query,
mode="search",
chat_id=session_id,
person_id=payload["person_id"],
respect_filter=False,
limit=5,
)
joined = _join_hit_content(result)
recall = _keyword_recall(joined, payload["expected_keywords"])
success = bool(result.hits) and recall >= 0.67
writeback_success_rate += 1.0 if success else 0.0
writeback_keyword_recall += recall
writeback_reports.append(
{
"person_id": payload["person_id"],
"success": success,
"keyword_recall": recall,
"hit_count": len(result.hits),
}
)
writeback_total = max(1, len(dataset["person_writebacks"]))
knowledge_reports: List[Dict[str, Any]] = []
knowledge_success_rate = 0.0
knowledge_keyword_recall = 0.0
fetcher = knowledge_module.KnowledgeFetcher(
private_name=dataset["session"]["display_name"],
stream_id=session_id,
)
for case in dataset["knowledge_fetcher_cases"]:
knowledge_text, _ = await fetcher.fetch(case["query"], [])
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
knowledge_success_rate += 1.0 if success else 0.0
knowledge_keyword_recall += recall
knowledge_reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"preview": knowledge_text[:300],
}
)
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
profile_reports: List[Dict[str, Any]] = []
profile_success_rate = 0.0
profile_keyword_recall = 0.0
profile_evidence_rate = 0.0
for case in dataset["profile_cases"]:
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
recall = _keyword_recall(profile.summary, case["expected_keywords"])
has_evidence = bool(profile.evidence)
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
profile_success_rate += 1.0 if success else 0.0
profile_keyword_recall += recall
profile_evidence_rate += 1.0 if has_evidence else 0.0
profile_reports.append(
{
"person_id": case["person_id"],
"success": success,
"keyword_recall": recall,
"evidence_count": len(profile.evidence),
"summary_preview": profile.summary[:240],
}
)
profile_total = max(1, len(dataset["profile_cases"]))
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_rebuild = await memory_service.episode_admin(
action="rebuild",
source=f"chat_summary:{session_id}",
)
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
report = {
"dataset": dataset["meta"],
"runtime_self_check": self_check["report"],
"import": {
"task_id": created["task"]["task_id"],
"status": import_detail["task"]["status"],
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
},
"metrics": {
"search": {
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
"recall_at_5": round(search_recall_at_5 / search_total, 4),
"precision_at_5": round(search_precision_at_5 / search_total, 4),
"mrr": round(search_mrr / search_total, 4),
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
},
"writeback": {
"success_rate": round(writeback_success_rate / writeback_total, 4),
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
},
"knowledge_fetcher": {
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
},
"profile": {
"success_rate": round(profile_success_rate / profile_total, 4),
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
},
"tool_modes": {
"success_rate": tool_modes["success_rate"],
"keyword_recall": tool_modes["keyword_recall"],
},
"episode_generation_auto": {
"success_rate": episode_generation_auto["success_rate"],
"keyword_recall": episode_generation_auto["keyword_recall"],
"episode_count": episode_generation_auto["episode_count"],
},
"episode_generation_after_rebuild": {
"success_rate": episode_generation_after_rebuild["success_rate"],
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
"episode_count": episode_generation_after_rebuild["episode_count"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_admin_query_auto": {
"success_rate": episode_admin_query_auto["success_rate"],
"keyword_recall": episode_admin_query_auto["keyword_recall"],
},
"episode_admin_query_after_rebuild": {
"success_rate": episode_admin_query_after_rebuild["success_rate"],
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_search_mode_auto": {
"success_rate": episode_search_mode_auto["success_rate"],
"keyword_recall": episode_search_mode_auto["keyword_recall"],
},
"episode_search_mode_after_rebuild": {
"success_rate": episode_search_mode_after_rebuild["success_rate"],
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
},
"cases": {
"search": search_case_reports,
"writeback": writeback_reports,
"knowledge_fetcher": knowledge_reports,
"profile": profile_reports,
"tool_modes": tool_modes["reports"],
"episode_generation_auto": episode_generation_auto["reports"],
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
"episode_admin_query_auto": episode_admin_query_auto["reports"],
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
"episode_search_mode_auto": episode_search_mode_auto["reports"],
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
},
}
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
assert report["import"]["status"] == "completed"
assert report["runtime_self_check"]["ok"] is True

View File

@@ -0,0 +1,138 @@
from types import SimpleNamespace
import pytest
from src.services import memory_flow_service as memory_flow_module
@pytest.mark.asyncio
async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch):
starts: list[str] = []
summarizers: list[object] = []
class FakeSummarizer:
def __init__(self, session_id: str):
self.session_id = session_id
summarizers.append(self)
async def start(self):
starts.append(self.session_id)
async def stop(self):
starts.append(f"stop:{self.session_id}")
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
)
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
manager = memory_flow_module.LongTermMemorySessionManager()
message = SimpleNamespace(session_id="session-1")
await manager.on_message(message)
await manager.on_message(message)
assert len(summarizers) == 1
assert starts == ["session-1"]
@pytest.mark.asyncio
async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch):
stopped: list[str] = []
class FakeSummarizer:
def __init__(self, session_id: str):
self.session_id = session_id
async def start(self):
return None
async def stop(self):
stopped.append(self.session_id)
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
)
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
manager = memory_flow_module.LongTermMemorySessionManager()
await manager.on_message(SimpleNamespace(session_id="session-a"))
await manager.on_message(SimpleNamespace(session_id="session-b"))
await manager.shutdown()
assert stopped == ["session-a", "session-b"]
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
raw = '["他喜欢猫", "他喜欢猫", "", "", "他会弹吉他"]'
result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw)
assert result == ["他喜欢猫", "他会弹吉他"]
def test_person_fact_looks_ephemeral_detects_short_chitchat():
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈")
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?")
assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴")
def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.is_known = True
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False)
monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}")
monkeypatch.setattr(memory_flow_module, "Person", FakePerson)
message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id=""))
person = service._resolve_target_person(message)
assert person is not None
assert person.person_id == "qq:123"
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
events: list[tuple[str, str]] = []
class FakeSessionManager:
async def on_message(self, message):
events.append(("incoming", message.session_id))
async def shutdown(self):
events.append(("shutdown", "session"))
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
service = memory_flow_module.MemoryAutomationService()
service.session_manager = FakeSessionManager()
service.fact_writeback = FakeFactWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("incoming", "session-1"),
("sent", "session-1"),
("shutdown", "session"),
("shutdown", "fact"),
]

View File

@@ -0,0 +1,281 @@
import pytest
from src.services.memory_service import MemorySearchResult, MemoryService
def test_coerce_write_result_treats_skipped_payload_as_success():
result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"})
assert result.success is True
assert result.stored_ids == []
assert result.skipped_ids == ["p1"]
assert result.detail == "chat_filtered"
@pytest.mark.asyncio
async def test_graph_admin_invokes_plugin(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "nodes": [], "edges": []}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.graph_admin(action="get_graph", limit=12)
assert result["success"] is True
assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})]
@pytest.mark.asyncio
async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"success": True, "items": [{"hash": "abc"}], "count": 1}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.get_recycle_bin(limit=5)
assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1}
assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})]
@pytest.mark.asyncio
async def test_search_respects_filter_by_default(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"summary": "ok", "hits": [], "filtered": True}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.search(
"mai",
chat_id="stream-1",
person_id="person-1",
user_id="user-1",
group_id="",
)
assert isinstance(result, MemorySearchResult)
assert result.filtered is True
assert calls == [
(
"search_memory",
{
"query": "mai",
"limit": 5,
"mode": "hybrid",
"chat_id": "stream-1",
"person_id": "person-1",
"time_start": None,
"time_end": None,
"respect_filter": True,
"user_id": "user-1",
"group_id": "",
},
)
]
@pytest.mark.asyncio
async def test_ingest_summary_can_bypass_filter(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"success": True, "stored_ids": ["p1"], "detail": ""}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.ingest_summary(
external_id="chat_history:1",
chat_id="stream-1",
text="summary",
respect_filter=False,
user_id="user-1",
)
assert result.success is True
assert calls == [
(
"ingest_summary",
{
"external_id": "chat_history:1",
"chat_id": "stream-1",
"text": "summary",
"participants": [],
"time_start": None,
"time_end": None,
"tags": [],
"metadata": {},
"respect_filter": False,
"user_id": "user-1",
"group_id": "",
},
)
]
@pytest.mark.asyncio
async def test_v5_admin_invokes_plugin(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "count": 1}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.v5_admin(action="status", target="mai", limit=5)
assert result["success"] is True
assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})]
@pytest.mark.asyncio
async def test_delete_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "operation_id": "del-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"})
assert result["success"] is True
assert calls == [
(
"memory_delete_admin",
{"action": "execute", "mode": "relation", "selector": {"query": "mai"}},
{"timeout_ms": 120000},
)
]
@pytest.mark.asyncio
async def test_search_returns_empty_when_query_and_time_missing_async():
service = MemoryService()
result = await service.search("", time_start=None, time_end=None)
assert isinstance(result, MemorySearchResult)
assert result.summary == ""
assert result.hits == []
assert result.filtered is False
@pytest.mark.asyncio
async def test_search_accepts_string_time_bounds(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"summary": "ok", "hits": [], "filtered": False}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.search(
"广播站",
mode="time",
time_start="2026/03/18",
time_end="2026/03/18 09:30",
)
assert isinstance(result, MemorySearchResult)
assert calls == [
(
"search_memory",
{
"query": "广播站",
"limit": 5,
"mode": "time",
"chat_id": "",
"person_id": "",
"time_start": "2026/03/18",
"time_end": "2026/03/18 09:30",
"respect_filter": True,
"user_id": "",
"group_id": "",
},
)
]
def test_coerce_search_result_preserves_aggregate_source_branches():
result = MemoryService._coerce_search_result(
{
"hits": [
{
"content": "广播站值夜班",
"type": "paragraph",
"metadata": {"event_time_start": 1.0},
"source_branches": ["search", "time"],
"rank": 1,
}
]
}
)
assert result.hits[0].metadata["source_branches"] == ["search", "time"]
assert result.hits[0].metadata["rank"] == 1
@pytest.mark.asyncio
async def test_import_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "task_id": "import-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.import_admin(action="create_lpmm_openie", alias="lpmm")
assert result["success"] is True
assert calls == [
(
"memory_import_admin",
{"action": "create_lpmm_openie", "alias": "lpmm"},
{"timeout_ms": 120000},
)
]
@pytest.mark.asyncio
async def test_tuning_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "task_id": "tuning-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.tuning_admin(action="create_task", payload={"query": "mai"})
assert result["success"] is True
assert calls == [
(
"memory_tuning_admin",
{"action": "create_task", "payload": {"query": "mai"}},
{"timeout_ms": 120000},
)
]

View File

@@ -0,0 +1,81 @@
from types import SimpleNamespace
import pytest
from src.person_info import person_info as person_info_module
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch):
calls = []
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.person_name = "Alice"
self.is_known = True
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
monkeypatch.setattr(person_info_module, "Person", FakePerson)
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
assert len(calls) == 1
payload = calls[0]
assert payload["external_id"].startswith("person_fact:person-1:")
assert payload["source_type"] == "person_fact"
assert payload["chat_id"] == "session-1"
assert payload["person_ids"] == ["person-1"]
assert payload["participants"] == ["Alice"]
assert payload["respect_filter"] is True
assert payload["user_id"] == "10001"
assert payload["group_id"] == ""
assert payload["metadata"]["person_id"] == "person-1"
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch):
calls = []
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.person_name = "Unknown"
self.is_known = False
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
monkeypatch.setattr(person_info_module, "Person", FakePerson)
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
assert calls == []
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch):
calls = []
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1")
assert calls == []

View File

@@ -0,0 +1,184 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
import pytest
from src.memory_system.retrieval_tools import query_long_term_memory as tool_module
from src.memory_system.retrieval_tools import init_all_tools
from src.memory_system.retrieval_tools.query_long_term_memory import (
_resolve_time_expression,
query_long_term_memory,
register_tool,
)
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.services.memory_service import MemoryHit, MemorySearchResult
def test_resolve_time_expression_supports_relative_and_absolute_inputs():
now = datetime(2026, 3, 18, 15, 30)
start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/18 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/12 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/18 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30)
assert start_text == "2026/03/18 09:30"
assert end_text == "2026/03/18 09:30"
def test_register_tool_exposes_mode_and_time_expression():
register_tool()
tool = get_tool_registry().get_tool("search_long_term_memory")
assert tool is not None
params = {item["name"]: item for item in tool.parameters}
assert "mode" in params
assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"]
assert "time_expression" in params
assert params["query"]["required"] is False
def test_init_all_tools_registers_long_term_memory_tool():
init_all_tools()
tool = get_tool_registry().get_tool("search_long_term_memory")
assert tool is not None
@pytest.mark.asyncio
async def test_query_long_term_memory_search_mode_maps_to_hybrid(monkeypatch):
captured = {}
async def fake_search(query, **kwargs):
captured["query"] = query
captured["kwargs"] = kwargs
return MemorySearchResult(
hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")],
)
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1")
assert "Alice 喜欢猫" in text
assert captured == {
"query": "Alice 喜欢什么",
"kwargs": {
"limit": 5,
"mode": "hybrid",
"chat_id": "stream-1",
"person_id": "person-1",
"time_start": None,
"time_end": None,
},
}
@pytest.mark.asyncio
async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch):
captured = {}
async def fake_search(query, **kwargs):
captured["query"] = query
captured["kwargs"] = kwargs
return MemorySearchResult(
hits=[
MemoryHit(
content="昨天晚上广播站停播了十分钟。",
score=0.8,
hit_type="paragraph",
metadata={"event_time_start": 1773797400.0},
)
]
)
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
monkeypatch.setattr(
tool_module,
"_resolve_time_expression",
lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"),
)
text = await query_long_term_memory(
query="广播站",
mode="time",
time_expression="昨天",
chat_id="stream-1",
)
assert "指定时间范围" in text
assert "广播站停播" in text
assert captured == {
"query": "广播站",
"kwargs": {
"limit": 5,
"mode": "time",
"chat_id": "stream-1",
"person_id": "",
"time_start": 1773795600.0,
"time_end": 1773881940.0,
},
}
@pytest.mark.asyncio
async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch):
responses = {
"episode": MemorySearchResult(
hits=[
MemoryHit(
content="苏弦在灯塔拆开了那封冬信。",
title="冬信重见天日",
hit_type="episode",
metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]},
)
]
),
"aggregate": MemorySearchResult(
hits=[
MemoryHit(
content="唐未在广播站值夜班时带着黑狗墨点。",
hit_type="paragraph",
metadata={"source_branches": ["search", "time"]},
)
]
),
}
async def fake_search(query, **kwargs):
return responses[kwargs["mode"]]
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode")
aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate")
assert "事件《冬信重见天日》" in episode_text
assert "参与者:苏弦" in episode_text
assert "[search,time][paragraph]" in aggregate_text
@pytest.mark.asyncio
async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message():
text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周")
assert "无法解析" in text
assert "最近7天" in text

View File

@@ -0,0 +1,335 @@
from __future__ import annotations
import asyncio
import inspect
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import numpy as np
import pytest
import pytest_asyncio
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
def _load_dialogue_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _FakeEmbeddingAdapter:
def __init__(self, dimension: int = 16) -> None:
self.dimension = dimension
async def _detect_dimension(self) -> int:
return self.dimension
async def encode(self, texts, dimensions=None):
dim = int(dimensions or self.dimension)
if isinstance(texts, str):
sequence = [texts]
single = True
else:
sequence = list(texts)
single = False
rows = []
for text in sequence:
vec = np.zeros(dim, dtype=np.float32)
for ch in str(text or ""):
vec[ord(ch) % dim] += 1.0
if not vec.any():
vec[0] = 1.0
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
rows.append(vec)
payload = np.vstack(rows)
return payload[0] if single else payload
class _KernelBackedRuntimeManager:
is_running = True
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke_plugin(
self,
*,
method: str,
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None,
timeout_ms: int,
):
del method, plugin_id, timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]:
for _ in range(max_rounds):
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(0.05)
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result) -> str:
return "\n".join(hit.content for hit in search_result.hits)
@pytest_asyncio.fixture
async def real_dialogue_env(monkeypatch, tmp_path):
scenario = _load_dialogue_fixture()
session_cfg = scenario["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
async def fake_self_check(**kwargs):
return {"ok": True, "message": "ok"}
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
data_dir = (tmp_path / "a_memorix_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
await kernel.initialize()
try:
yield {
"scenario": scenario,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
created = await memory_service.import_admin(
action="create_paste",
name="private_alice.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
detail = await _wait_for_import_task(created["task"]["task_id"])
assert detail["task"]["status"] == "completed"
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
respect_filter=False,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=real_dialogue_env["session"].session_id,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch):
scenario = real_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
real_dialogue_env["session"].session_id,
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=real_dialogue_env["session"].session_id,
person_id=scenario["person"]["person_id"],
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
await memory_service.ingest_text(
external_id="fixture:knowledge_fetcher",
source_type="dialogue_note",
text=scenario["person_fact"]["memory_content"],
chat_id=real_dialogue_env["session"].session_id,
person_ids=[scenario["person"]["person_id"]],
participants=[scenario["person"]["person_name"]],
respect_filter=False,
)
fetcher = knowledge_module.KnowledgeFetcher(
private_name=scenario["session"]["display_name"],
stream_id=real_dialogue_env["session"].session_id,
)
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in knowledge_text
@pytest.mark.asyncio
async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch):
scenario = real_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
real_dialogue_env["session"].session_id,
)
profile = await memory_service.get_person_profile(
scenario["person"]["person_id"],
chat_id=real_dialogue_env["session"].session_id,
)
assert profile.evidence
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
@pytest.mark.asyncio
async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
episodes = await memory_service.episode_admin(
action="query",
source=scenario["expectations"]["episode_source"],
limit=5,
)
assert episodes["success"] is True
assert int(episodes["count"]) >= 1

View File

@@ -0,0 +1,312 @@
from __future__ import annotations
import asyncio
import inspect
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
import pytest_asyncio
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
pytestmark = pytest.mark.skipif(
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
reason="需要显式开启真实 embedding / self-check 集成测试",
)
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
def _load_dialogue_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _KernelBackedRuntimeManager:
is_running = True
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke_plugin(
self,
*,
method: str,
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None,
timeout_ms: int,
):
del method, plugin_id, timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]:
deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(0.2)
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result) -> str:
return "\n".join(hit.content for hit in search_result.hits)
@pytest_asyncio.fixture
async def live_dialogue_env(monkeypatch, tmp_path):
scenario = _load_dialogue_fixture()
session_cfg = scenario["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
data_dir = (tmp_path / "a_memorix_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
await kernel.initialize()
try:
yield {
"scenario": scenario,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_live_runtime_self_check_passes(live_dialogue_env):
report = await memory_service.runtime_admin(action="refresh_self_check")
assert report["success"] is True
assert report["report"]["ok"] is True
assert report["report"]["encoded_dimension"] > 0
@pytest.mark.asyncio
async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
created = await memory_service.import_admin(
action="create_paste",
name="private_alice.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
detail = await _wait_for_import_task(created["task"]["task_id"])
assert detail["task"]["status"] == "completed"
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
respect_filter=False,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=live_dialogue_env["session"].session_id,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch):
scenario = live_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
live_dialogue_env["session"].session_id,
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=live_dialogue_env["session"].session_id,
person_id=scenario["person"]["person_id"],
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
await memory_service.ingest_text(
external_id="fixture:knowledge_fetcher",
source_type="dialogue_note",
text=scenario["person_fact"]["memory_content"],
chat_id=live_dialogue_env["session"].session_id,
person_ids=[scenario["person"]["person_id"]],
participants=[scenario["person"]["person_name"]],
respect_filter=False,
)
fetcher = knowledge_module.KnowledgeFetcher(
private_name=scenario["session"]["display_name"],
stream_id=live_dialogue_env["session"].session_id,
)
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in knowledge_text
@pytest.mark.asyncio
async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch):
scenario = live_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
live_dialogue_env["session"].session_id,
)
profile = await memory_service.get_person_profile(
scenario["person"]["person_id"],
chat_id=live_dialogue_env["session"].session_id,
)
assert profile.evidence
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
@pytest.mark.asyncio
async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
episodes = await memory_service.episode_admin(
action="query",
source=scenario["expectations"]["episode_source"],
limit=5,
)
assert episodes["success"] is True
assert int(episodes["count"]) >= 1

View File

@@ -0,0 +1,279 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
from src.services.memory_service import MemorySearchResult
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
from src.webui.routers.memory import compat_router, router
@pytest.fixture
def client() -> TestClient:
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(router)
app.include_router(compat_router)
return TestClient(app)
def test_webui_memory_graph_route(client: TestClient, monkeypatch):
async def fake_graph_admin(*, action: str, **kwargs):
assert action == "get_graph"
return {"success": True, "nodes": [], "edges": [], "total_nodes": 0, "limit": kwargs.get("limit")}
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
response = client.get("/api/webui/memory/graph", params={"limit": 77})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["limit"] == 77
def test_compat_aggregate_route(client: TestClient, monkeypatch):
async def fake_search(query: str, **kwargs):
assert kwargs["mode"] == "aggregate"
assert kwargs["respect_filter"] is False
return MemorySearchResult(summary=f"summary:{query}", hits=[])
monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search)
response = client.get("/api/query/aggregate", params={"query": "mai"})
assert response.status_code == 200
assert response.json() == {"success": True, "summary": "summary:mai", "hits": [], "filtered": False}
def test_auto_save_routes(client: TestClient, monkeypatch):
async def fake_runtime_admin(*, action: str, **kwargs):
if action == "get_config":
return {"success": True, "auto_save": True}
if action == "set_auto_save":
return {"success": True, "auto_save": kwargs["enabled"]}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin)
get_response = client.get("/api/config/auto_save")
post_response = client.post("/api/config/auto_save", json={"enabled": False})
assert get_response.status_code == 200
assert get_response.json() == {"success": True, "auto_save": True}
assert post_response.status_code == 200
assert post_response.json() == {"success": True, "auto_save": False}
def test_recycle_bin_route(client: TestClient, monkeypatch):
async def fake_get_recycle_bin(*, limit: int):
return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit}
monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin)
response = client.get("/api/memory/recycle_bin", params={"limit": 10})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["count"] == 1
assert response.json()["limit"] == 10
def test_import_guide_route(client: TestClient, monkeypatch):
async def fake_import_admin(*, action: str, **kwargs):
assert kwargs == {}
if action == "get_guide":
return {"success": True}
if action == "get_settings":
return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.get("/api/webui/memory/import/guide")
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["source"] == "local"
assert "长期记忆导入说明" in response.json()["content"]
def test_import_upload_route(client: TestClient, monkeypatch, tmp_path):
monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path)
async def fake_import_admin(*, action: str, **kwargs):
assert action == "create_upload"
staged_files = kwargs["staged_files"]
assert len(staged_files) == 1
assert staged_files[0]["filename"] == "demo.txt"
assert memory_router_module.Path(staged_files[0]["staged_path"]).exists()
return {"success": True, "task_id": "task-1"}
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.post(
"/api/import/upload",
data={"payload_json": "{\"source\": \"upload\"}"},
files=[("files", ("demo.txt", b"hello world", "text/plain"))],
)
assert response.status_code == 200
assert response.json() == {"success": True, "task_id": "task-1"}
assert list(tmp_path.iterdir()) == []
def test_v5_status_route(client: TestClient, monkeypatch):
async def fake_v5_admin(*, action: str, **kwargs):
assert action == "status"
assert kwargs["target"] == "mai"
return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3}
monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin)
response = client.get("/api/webui/memory/v5/status", params={"target": "mai"})
assert response.status_code == 200
assert response.json()["success"] is True
assert response.json()["deleted_count"] == 3
def test_delete_preview_route(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "preview"
assert kwargs["mode"] == "paragraph"
assert kwargs["selector"] == {"query": "demo"}
return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/preview",
json={"mode": "paragraph", "selector": {"query": "demo"}},
)
assert response.status_code == 200
assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
def test_episode_process_pending_route(client: TestClient, monkeypatch):
async def fake_episode_admin(*, action: str, **kwargs):
assert action == "process_pending"
assert kwargs == {"limit": 7, "max_retry": 4}
return {"success": True, "processed": 3}
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4})
assert response.status_code == 200
assert response.json() == {"success": True, "processed": 3}
def test_import_list_route_includes_settings(client: TestClient, monkeypatch):
calls = []
async def fake_import_admin(*, action: str, **kwargs):
calls.append((action, kwargs))
if action == "list":
return {"success": True, "items": [{"task_id": "task-1"}]}
if action == "get_settings":
return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
response = client.get("/api/webui/memory/import/tasks", params={"limit": 9})
assert response.status_code == 200
assert response.json()["items"] == [{"task_id": "task-1"}]
assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}}
assert calls == [("list", {"limit": 9}), ("get_settings", {})]
def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch):
calls = []
async def fake_tuning_admin(*, action: str, **kwargs):
calls.append((action, kwargs))
if action == "get_profile":
return {"success": True, "profile": {"retrieval": {"top_k": 8}}}
if action == "get_settings":
return {"success": True, "settings": {"profiles": ["default"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
response = client.get("/api/webui/memory/retrieval_tuning/profile")
assert response.status_code == 200
assert response.json()["profile"] == {"retrieval": {"top_k": 8}}
assert response.json()["settings"] == {"profiles": ["default"]}
assert calls == [("get_profile", {}), ("get_settings", {})]
def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch):
async def fake_tuning_admin(*, action: str, **kwargs):
assert action == "get_report"
assert kwargs == {"task_id": "task-1", "format": "json"}
return {
"success": True,
"report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"},
}
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"})
assert response.status_code == 200
assert response.json() == {
"success": True,
"format": "json",
"content": "{\"ok\": true}",
"path": "/tmp/report.json",
"error": "",
}
def test_delete_execute_route(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
assert action == "execute"
assert kwargs["mode"] == "source"
assert kwargs["selector"] == {"source": "chat_summary:stream-1"}
assert kwargs["reason"] == "cleanup"
assert kwargs["requested_by"] == "tester"
return {"success": True, "operation_id": "del-1"}
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
response = client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"source": "chat_summary:stream-1"},
"reason": "cleanup",
"requested_by": "tester",
},
)
assert response.status_code == 200
assert response.json() == {"success": True, "operation_id": "del-1"}
def test_delete_operation_routes(client: TestClient, monkeypatch):
async def fake_delete_admin(*, action: str, **kwargs):
if action == "list_operations":
assert kwargs == {"limit": 5, "mode": "paragraph"}
return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1}
if action == "get_operation":
assert kwargs == {"operation_id": "del-1"}
return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"})
get_response = client.get("/api/webui/memory/delete/operations/del-1")
assert list_response.status_code == 200
assert list_response.json()["count"] == 1
assert get_response.status_code == 200
assert get_response.json()["operation"]["operation_id"] == "del-1"

View File

@@ -7,7 +7,7 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.jargon_miner_old import search_jargon
from src.bw_learner.jargon_explainer import search_jargon
from src.bw_learner.learner_utils_old import (
is_bot_message,
contains_bot_self_name,

View File

@@ -196,6 +196,32 @@ def contains_bot_self_name(content: str) -> bool:
return any(name in target for name in candidates)
def is_bot_message(msg: Any) -> bool:
"""判断消息是否来自机器人自身。"""
if msg is None:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
if not user_id:
return False
known_accounts = {
str(getattr(bot_config, "qq_account", "") or "").strip(),
str(getattr(bot_config, "telegram_account", "") or "").strip(),
}
for platform in getattr(bot_config, "platforms", []) or []:
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
if account:
known_accounts.add(account)
return user_id in {account for account in known_accounts if account}
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
# """
# 构建包含中心消息上下文的段落前3条+后3条使用标准的 readable builder 输出

View File

@@ -55,7 +55,7 @@ class Conversation:
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id)
self.waiter = Waiter(self.stream_id, self.private_name)
self.direct_sender = DirectMessageSender(self.private_name)

View File

@@ -1,11 +1,14 @@
from typing import List, Tuple, Dict, Any
from typing import Any, Dict, List, Tuple
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
# from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.knowledge import qa_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import resolve_person_id_for_memory
from src.services.memory_service import memory_service
logger = get_logger("knowledge_fetcher")
@@ -13,11 +16,39 @@ logger = get_logger("knowledge_fetcher")
class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
def __init__(self, private_name: str, stream_id: str):
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
self.private_name = private_name
self.stream_id = stream_id
def _lpmm_get_knowledge(self, query: str) -> str:
def _resolve_private_memory_context(self) -> Dict[str, str]:
session = _chat_manager.get_session_by_session_id(self.stream_id)
if session is None:
return {"chat_id": self.stream_id}
group_id = str(getattr(session, "group_id", "") or "").strip()
user_id = str(getattr(session, "user_id", "") or "").strip()
platform = str(getattr(session, "platform", "") or "").strip()
person_id = ""
if not group_id:
try:
person_id = resolve_person_id_for_memory(
person_name=self.private_name,
platform=platform,
user_id=user_id,
)
except Exception as exc:
logger.debug(f"[私聊][{self.private_name}]解析人物ID失败: {exc}")
return {
"chat_id": self.stream_id,
"person_id": person_id,
"user_id": user_id,
"group_id": group_id,
}
async def _memory_get_knowledge(self, query: str) -> str:
"""获取相关知识
Args:
@@ -27,13 +58,32 @@ class KnowledgeFetcher:
str: 构造好的,带相关度的知识
"""
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
logger.debug(f"[私聊][{self.private_name}]正在从长期记忆中获取知识")
try:
knowledge_info = qa_manager.get_knowledge(query)
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
return knowledge_info
context = self._resolve_private_memory_context()
search_kwargs = {
"limit": 5,
"mode": "search",
"chat_id": context.get("chat_id", ""),
"person_id": context.get("person_id", ""),
"user_id": context.get("user_id", ""),
"group_id": context.get("group_id", ""),
"respect_filter": True,
}
result = await memory_service.search(query, **search_kwargs)
if not result.filtered and not result.hits and search_kwargs["person_id"]:
fallback_kwargs = dict(search_kwargs)
fallback_kwargs["person_id"] = ""
logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆")
result = await memory_service.search(query, **fallback_kwargs)
knowledge_info = result.to_text(limit=5)
if result.filtered:
logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过")
else:
logger.debug(f"[私聊][{self.private_name}]长期记忆查询结果: {knowledge_info[:150]}")
return knowledge_info or "未找到匹配的知识"
except Exception as e:
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
logger.error(f"[私聊][{self.private_name}]长期记忆搜索工具执行失败: {str(e)}")
return "未找到匹配的知识"
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
@@ -72,7 +122,7 @@ class KnowledgeFetcher:
# sources_text = "".join(sources)
knowledge_text += "\n现在有以下**知识**可供参考:\n "
knowledge_text += self._lpmm_get_knowledge(query)
knowledge_text += await self._memory_get_knowledge(query)
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"

View File

@@ -1,90 +0,0 @@
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
from src.config.config import global_config
import os
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def get_qa_manager():
return qa_manager
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端")
# 初始化Embedding库
embed_manager = EmbeddingManager(
max_workers=global_config.lpmm_knowledge.max_embedding_workers,
chunk_size=global_config.lpmm_knowledge.embedding_chunk_size,
)
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
except Exception as e:
logger.warning(f"此消息不会影响正常使用从文件加载Embedding库时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager()
logger.info("正在从文件加载KG")
try:
kg_manager.load_from_file()
except Exception as e:
logger.warning(f"此消息不会影响正常使用从文件加载KG时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
logger.info(f"KG边数量{len(kg_manager.graph.get_edge_list())}")
# 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes:
# 使用与EmbeddingStore中一致的命名空间格式
key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
global qa_manager
# 问答系统(用于知识库)
qa_manager = QAManager(
embed_manager,
kg_manager,
)
# # 记忆激活(用于记忆库)
# global inspire_manager
# inspire_manager = MemoryActiveManager(
# embed_manager,
# llm_client_list[global_config["embedding"]["provider"]],
# )
else:
logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误

View File

@@ -1,380 +0,0 @@
import asyncio
import os
from functools import partial
from typing import List, Callable, Any
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.qa_manager import QAManager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API")
class LPMMOperations:
"""
LPMM 内部操作接口。
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
"""
def __init__(self):
self._initialized = False
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
"""
在线程池中执行可取消的同步操作。
当任务被取消时(如 Ctrl+C会立即响应并抛出 CancelledError。
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
Args:
func: 要执行的同步函数
*args: 函数的位置参数
**kwargs: 函数的关键字参数
Returns:
函数的返回值
Raises:
asyncio.CancelledError: 当任务被取消时
"""
loop = asyncio.get_event_loop()
# 在线程池中执行,当协程被取消时会立即响应
# 虽然线程池中的操作可能仍在运行,但协程不会阻塞
return await loop.run_in_executor(None, func, *args, **kwargs)
async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]:
"""获取并确保 LPMM 管理器已初始化"""
qa_mgr = get_qa_manager()
if qa_mgr is None:
# 如果全局没初始化,尝试初始化
if not global_config.lpmm_knowledge.enable:
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
lpmm_start_up()
qa_mgr = get_qa_manager()
if qa_mgr is None:
raise RuntimeError("无法获取 LPMM QAManager请检查 LPMM 是否已正确安装和配置。")
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
async def add_content(self, text: str, auto_split: bool = True) -> dict:
"""
向知识库添加新内容。
Args:
text: 原始文本。
auto_split: 是否自动按双换行符分割段落。
- True: 自动分割(默认),支持多段文本(用双换行分隔)
- False: 不分割,将整个文本作为完整一段处理
Returns:
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 分段处理
if auto_split:
# 自动按双换行符分割
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
else:
# 不分割,作为完整一段
text_stripped = text.strip()
if not text_stripped:
return {"status": "error", "message": "文本内容为空"}
paragraphs = [text_stripped]
if not paragraphs:
return {"status": "error", "message": "文本内容为空"}
# 2. 实体与三元组抽取 (内部调用大模型)
from src.chat.knowledge.ie_process import IEProcess
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_ner = LLMRequest(
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
)
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf)
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
extracted_docs = await ie_process.process_paragraphs(paragraphs)
# 3. 构造并导入数据
# 这里我们手动实现导入逻辑,不依赖外部脚本
# a. 准备段落
raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs}
# b. 准备三元组
triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs}
# 向量化并入库
# 注意:此处模仿 import_openie.py 的核心逻辑
# 1. 先进行去重检查,只处理新段落
# store_new_data_set 期望的格式raw_paragraphs 的键是段落hash不带前缀值是段落文本
new_raw_paragraphs = {}
new_triple_list_data = {}
for pg_hash, passage in raw_paragraphs.items():
key = f"paragraph-{pg_hash}"
if key not in embed_mgr.stored_pg_hashes:
new_raw_paragraphs[pg_hash] = passage
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
if not new_raw_paragraphs:
return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"}
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
# 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
# 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
return {
"status": "success",
"count": len(new_raw_paragraphs),
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断")
return {"status": "cancelled", "message": "导入操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
async def search(self, query: str, top_k: int = 3) -> List[str]:
"""
检索知识库。
Args:
query: 查询问题。
top_k: 返回最相关的条目数。
Returns:
List[str]: 相关文段列表。
"""
try:
_, _, qa_mgr = await self._get_managers()
# 直接调用 QAManager 的检索接口
knowledge = qa_mgr.get_knowledge(query, top_k=top_k)
# 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回
return [knowledge] if knowledge else []
except Exception as e:
logger.error(f"[Plugin API] 检索知识失败: {e}")
return []
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
"""
根据关键词或完整文段删除知识库内容。
Args:
keyword: 匹配关键词或完整文段。
exact_match: 是否使用完整文段匹配True=完全匹配False=关键词模糊匹配)。
Returns:
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 查找匹配的段落
to_delete_keys = []
to_delete_hashes = []
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
if exact_match:
# 完整文段匹配
if item.str.strip() == keyword.strip():
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
else:
# 关键词模糊匹配
if keyword in item.str:
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
if not to_delete_keys:
match_type = "完整文段" if exact_match else "关键词"
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
# 2. 执行删除
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
)
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
# b. 从知识图谱删除
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 3. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词"
return {
"status": "success",
"deleted_count": deleted_count,
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断")
return {"status": "cancelled", "message": "删除操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
async def clear_all(self) -> dict:
"""
清空整个LPMM知识库删除所有段落、实体、关系和知识图谱数据
Returns:
dict: {"status": "success/error", "message": "描述", "stats": {...}}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 记录清空前的统计信息
before_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
"relations": len(embed_mgr.relation_embedding_store.store),
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# 1. 清空所有向量库
# 获取所有keys
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
# 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
)
embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量
if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items, ent_keys
)
else:
ent_deleted = 0
# 删除所有关系向量
if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items, rel_keys
)
else:
rel_deleted = 0
# 2. 清空所有 embedding store 的索引和映射
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
def _clear_embedding_indices():
# 清空段落索引
embed_mgr.paragraphs_embedding_store.faiss_index = None
embed_mgr.paragraphs_embedding_store.idx2hash = None
embed_mgr.paragraphs_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
# 清空实体索引
embed_mgr.entities_embedding_store.faiss_index = None
embed_mgr.entities_embedding_store.idx2hash = None
embed_mgr.entities_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.entities_embedding_store.index_file_path):
os.remove(embed_mgr.entities_embedding_store.index_file_path)
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
# 清空关系索引
embed_mgr.relation_embedding_store.faiss_index = None
embed_mgr.relation_embedding_store.idx2hash = None
embed_mgr.relation_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.relation_embedding_store.index_file_path):
os.remove(embed_mgr.relation_embedding_store.index_file_path)
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
await self._run_cancellable_executor(_clear_embedding_indices)
# 3. 清空知识图谱
# 获取所有段落hash
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
if all_pg_hashes:
# 删除所有段落节点(这会自动清理相关的边和孤立实体)
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear()
# 4. 保存所有数据此时所有store都是空的索引也是None
# 注意即使store为空save_to_file也会保存空的DataFrame这是正确的
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
after_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
"relations": len(embed_mgr.relation_embedding_store.store),
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
return {
"status": "success",
"message": f"已成功清空LPMM知识库删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
"stats": {
"before": before_stats,
"after": after_stats,
},
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 清空操作被用户中断")
return {"status": "cancelled", "message": "清空操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
# 内部使用的单例
lpmm_ops = LPMMOperations()

View File

@@ -360,6 +360,12 @@ class ChatBot:
user_id = user_info.user_id
group_id = group_info.group_id if group_info else None
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
try:
from src.services.memory_flow_service import memory_automation_service
await memory_automation_service.on_incoming_message(message)
except Exception as exc:
logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}")
# message.update_chat_stream(chat)

View File

@@ -383,6 +383,13 @@ class UniversalMessageSender:
with get_db_session() as db_session:
db_session.add(message.to_db_instance())
try:
from src.services.memory_flow_service import memory_automation_service
await memory_automation_service.on_message_sent(message)
except Exception as exc:
logger.warning(f"[{chat_id}] 长期记忆人物事实写回注册失败: {exc}")
return sent_msg
except Exception as e:

View File

@@ -1,7 +1,6 @@
import traceback
import time
import asyncio
import importlib
import random
import re
@@ -36,6 +35,7 @@ from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.memory_system.retrieval_tools import get_tool_registry
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.chat.utils.common_utils import TempMethodsExpression
@@ -1164,29 +1164,14 @@ class DefaultReplyer:
async def get_prompt_info(self, message: str, sender: str, target: str):
related_info = ""
start_time = time.time()
try:
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
except ImportError:
logger.debug("LPMM知识库工具模块不存在跳过获取知识库内容")
return ""
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory")
if search_knowledge_tool is None:
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识内容")
logger.debug("长期记忆检索工具未注册,跳过获取知识内容")
return ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}")
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge")
template_prompt = prompt_manager.get_prompt("memory_get_knowledge")
template_prompt.add_context("bot_name", global_config.bot.nickname)
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
template_prompt.add_context("chat_history", message)
@@ -1202,24 +1187,31 @@ class DefaultReplyer:
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0])
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return ""
found_knowledge_from_lpmm = result.get("content", "")
logger.info(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("模型认为不需要使用LPMM知识库")
if not tool_calls:
logger.debug("模型认为不需要使用长期记忆")
return ""
related_chunks: List[str] = []
for tool_call in tool_calls:
if tool_call.func_name != "search_long_term_memory":
continue
tool_args = dict(tool_call.args or {})
tool_args.setdefault("chat_id", self.chat_stream.session_id)
result_text = await search_knowledge_tool.execute(**tool_args)
if result_text and "未找到" not in result_text:
related_chunks.append(result_text)
if not related_chunks:
logger.debug("长期记忆未返回有效信息")
return ""
related_info = "\n".join(related_chunks)
end_time = time.time()
logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""

View File

@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.1.0"
CONFIG_VERSION: str = "8.1.1"
MODEL_CONFIG_VERSION: str = "1.12.0"
logger = get_logger("config")

View File

@@ -94,6 +94,11 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
["", "enable", "enable", "enable"],
["qq:1919810:group", "enable", "enable", "enable"],
]
兼容旧旧格式:
learning_list = [
["qq:1919810:group", "enable", "enable", "0.5"],
["", "disable", "disable", "0.1"],
]
新:
[[expression.learning_list]]
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
@@ -117,6 +122,16 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
use_expression = _parse_enable_disable(r[1])
enable_learning = _parse_enable_disable(r[2])
enable_jargon_learning = _parse_enable_disable(r[3])
if enable_jargon_learning is None:
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
try:
float(str(r[3]))
except (TypeError, ValueError):
pass
else:
enable_jargon_learning = False
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
return False

View File

@@ -416,6 +416,24 @@ class MemoryConfig(ConfigBase):
)
"""_wrap_全局记忆黑名单当启用全局记忆时不将特定聊天流纳入检索"""
long_term_auto_summary_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "book-open",
},
)
"""是否自动启动聊天总结并导入长期记忆"""
person_fact_writeback_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "user-round-pen",
},
)
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
chat_history_topic_check_message_threshold: int = Field(
default=80,
ge=1,

View File

@@ -6,7 +6,6 @@ import time
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.knowledge import lpmm_start_up
from src.chat.message_receive.bot import chat_bot
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
@@ -19,6 +18,7 @@ from src.config.config import config_manager, global_config
from src.manager.async_task_manager import async_task_manager
from src.plugin_runtime.integration import get_plugin_runtime_manager
from src.prompt.prompt_manager import prompt_manager
from src.services.memory_flow_service import memory_automation_service
# from src.api.main import start_api_server
@@ -88,9 +88,6 @@ class MainSystem:
# start_api_server()
# logger.info("API服务器启动成功")
# 启动LPMM
lpmm_start_up()
# 启动插件运行时(内置插件 + 第三方插件双子进程)
await get_plugin_runtime_manager().start()
@@ -103,6 +100,7 @@ class MainSystem:
asyncio.create_task(chat_manager.regularly_save_sessions())
logger.info(t("startup.chat_manager_initialized"))
await memory_automation_service.start()
# await asyncio.sleep(0.5) #防止logger输出飞了
@@ -164,6 +162,10 @@ async def main():
system.schedule_tasks(),
)
finally:
await memory_automation_service.shutdown()
await get_plugin_runtime_manager().bridge_event("on_stop")
await get_plugin_runtime_manager().stop()
await async_task_manager.stop_and_wait_all_tasks()
await config_manager.stop_file_watcher()

View File

@@ -931,12 +931,14 @@ class ChatHistorySummarizer:
else:
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
# 同时导入到LPMM知识库
if global_config.lpmm_knowledge.enable:
await self._import_to_lpmm_knowledge(
if saved_record and saved_record.get("id") is not None:
await self._import_to_long_term_memory(
record_id=int(saved_record["id"]),
theme=theme,
summary=summary,
participants=participants,
start_time=start_time,
end_time=end_time,
original_text=original_text,
)
@@ -947,76 +949,131 @@ class ChatHistorySummarizer:
traceback.print_exc()
raise
async def _import_to_lpmm_knowledge(
async def _import_to_long_term_memory(
self,
record_id: int,
theme: str,
summary: str,
participants: List[str],
start_time: float,
end_time: float,
original_text: str,
):
"""
将聊天历史总结导入到LPMM知识库
将聊天历史总结导入到统一长期记忆
Args:
record_id: chat_history 主键
theme: 话题主题
summary: 概括内容
participants: 参与者列表
start_time: 开始时间
end_time: 结束时间
original_text: 原始文本(可能很长,需要截断)
"""
try:
from src.chat.knowledge.lpmm_ops import lpmm_ops
from src.services.memory_service import memory_service
session = _chat_manager.get_session_by_session_id(self.session_id)
session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else ""
session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else ""
# 构造要导入的文本内容
# 格式:主题 + 概括 + 参与者信息 + 原始内容摘要
# 注意使用单换行符连接确保整个内容作为一段导入不被LPMM分段
content_parts = []
# 1. 话题主题
# if theme:
# content_parts.append(f"话题:{theme}")
# 2. 概括内容
if theme:
content_parts.append(f"主题:{theme}")
if summary:
content_parts.append(f"概括:{summary}")
# 3. 参与者信息
if participants:
participants_text = "".join(participants)
content_parts.append(f"参与者:{participants_text}")
# 4. 原始文本摘要如果原始文本太长只取前500字
# if original_text:
# # 截断原始文本,避免过长
# max_original_length = 500
# if len(original_text) > max_original_length:
# truncated_text = original_text[:max_original_length] + "..."
# content_parts.append(f"原始内容摘要:{truncated_text}")
# else:
# content_parts.append(f"原始内容:{original_text}")
# 将所有部分合并为一个完整段落使用单换行符避免被LPMM分段
# LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段
content_to_import = "\n".join(content_parts)
if not content_to_import.strip():
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库")
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底")
await self._fallback_import_to_long_term_memory(
record_id=record_id,
theme=theme,
participants=participants,
start_time=start_time,
end_time=end_time,
original_text=original_text,
)
return
# 调用lpmm_ops导入
result = await lpmm_ops.add_content(text=content_to_import, auto_split=False)
if result["status"] == "success":
logger.info(
f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}"
)
result = await memory_service.ingest_summary(
external_id=f"chat_history:{record_id}",
chat_id=self.session_id,
text=content_to_import,
participants=participants,
time_start=start_time,
time_end=end_time,
tags=[theme] if theme else [],
metadata={"theme": theme, "original_text_length": len(original_text or "")},
respect_filter=True,
user_id=session_user_id,
group_id=session_group_id,
)
if result.success:
if result.detail == "chat_filtered":
logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}")
else:
logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}")
else:
logger.warning(
f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}"
logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}")
await self._fallback_import_to_long_term_memory(
record_id=record_id,
theme=theme,
participants=participants,
start_time=start_time,
end_time=end_time,
original_text=original_text,
)
except Exception as e:
# 导入失败不应该影响数据库存储,只记录错误
logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True)
logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True)
async def _fallback_import_to_long_term_memory(
self,
*,
record_id: int,
theme: str,
participants: List[str],
start_time: float,
end_time: float,
original_text: str,
) -> None:
try:
from src.services.memory_service import memory_service
session = _chat_manager.get_session_by_session_id(self.session_id)
session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else ""
session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else ""
result = await memory_service.ingest_summary(
external_id=f"chat_history:{record_id}",
chat_id=self.session_id,
text="",
participants=participants,
time_start=start_time,
time_end=end_time,
tags=[theme] if theme else [],
metadata={
"theme": theme,
"original_text_length": len(original_text or ""),
"generate_from_chat": True,
"context_length": global_config.memory.chat_history_topic_check_message_threshold,
},
respect_filter=True,
user_id=session_user_id,
group_id=session_group_id,
)
if result.success:
if result.detail == "chat_filtered":
logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}")
else:
logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}")
else:
logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}")
except Exception as exc:
logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True)
async def start(self):
"""启动后台定期检查循环"""

View File

@@ -237,8 +237,8 @@ async def _react_agent_solve_question(
if first_head_prompt is None:
# 第一次构建使用初始的collected_info即initial_info
initial_collected_info = initial_info or ""
# 使用 LPMM 知识库检索 prompt
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_lpmm")
# 使用统一长期记忆检索 prompt
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory")
first_head_prompt_template.add_context("bot_name", bot_name)
first_head_prompt_template.add_context("time_now", time_now)
first_head_prompt_template.add_context("chat_history", chat_history)

View File

@@ -10,21 +10,17 @@ from .tool_registry import (
get_tool_registry,
)
# 导入所有工具的注册函数
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_words import register_tool as register_query_words
from .return_information import register_tool as register_return_information
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
# 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。
from .query_long_term_memory import register_tool as register_long_term_memory
from .query_words import register_tool as register_query_words
from .return_information import register_tool as register_return_information
register_query_words()
register_return_information()
# LPMM知识库检索工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge()
register_long_term_memory()
__all__ = [

View File

@@ -0,0 +1,304 @@
"""通过统一长期记忆服务查询信息。"""
from __future__ import annotations
import re
from calendar import monthrange
from datetime import datetime, timedelta
from typing import Iterable, Literal, Tuple
from src.common.logger import get_logger
from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"}
_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$")
_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$")
_TIME_EXPRESSION_HELP = (
"请改用更具体的时间表达例如今天、昨天、前天、本周、上周、本月、上月、最近7天、"
"2026/03/18、2026/03/18 09:30。"
)
def _format_query_datetime(dt: datetime) -> str:
return dt.strftime("%Y/%m/%d %H:%M")
def _resolve_time_expression(
expression: str,
*,
now: datetime | None = None,
) -> Tuple[float, float, str, str]:
clean = str(expression or "").strip()
if not clean:
raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}")
current = now or datetime.now()
day_start = current.replace(hour=0, minute=0, second=0, microsecond=0)
if clean == "今天":
start = day_start
end = day_start.replace(hour=23, minute=59)
elif clean == "昨天":
start = day_start - timedelta(days=1)
end = start.replace(hour=23, minute=59)
elif clean == "前天":
start = day_start - timedelta(days=2)
end = start.replace(hour=23, minute=59)
elif clean == "本周":
start = day_start - timedelta(days=day_start.weekday())
end = start + timedelta(days=6, hours=23, minutes=59)
elif clean == "上周":
this_week_start = day_start - timedelta(days=day_start.weekday())
start = this_week_start - timedelta(days=7)
end = start + timedelta(days=6, hours=23, minutes=59)
elif clean == "本月":
start = day_start.replace(day=1)
last_day = monthrange(start.year, start.month)[1]
end = start.replace(day=last_day, hour=23, minute=59)
elif clean == "上月":
year = day_start.year
month = day_start.month - 1
if month == 0:
year -= 1
month = 12
start = day_start.replace(year=year, month=month, day=1)
last_day = monthrange(year, month)[1]
end = start.replace(day=last_day, hour=23, minute=59)
else:
relative_match = _RELATIVE_DAYS_RE.fullmatch(clean)
if relative_match:
days = max(1, int(relative_match.group(1)))
start = day_start - timedelta(days=max(0, days - 1))
end = day_start.replace(hour=23, minute=59)
elif _DATE_RE.fullmatch(clean):
start = datetime.strptime(clean, "%Y/%m/%d")
end = start.replace(hour=23, minute=59)
elif _MINUTE_RE.fullmatch(clean):
start = datetime.strptime(clean, "%Y/%m/%d %H:%M")
end = start
else:
raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}")
return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end)
def _extract_time_label(metadata: dict) -> str:
if not isinstance(metadata, dict):
return ""
start = metadata.get("event_time_start")
end = metadata.get("event_time_end")
event_time = metadata.get("event_time")
def _fmt(value: object) -> str:
if value in {None, ""}:
return ""
try:
return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M")
except Exception:
return str(value)
start_text = _fmt(start or event_time)
end_text = _fmt(end)
if start_text and end_text:
return f"{start_text} - {end_text}"
return start_text or end_text
def _truncate(text: str, limit: int = 160) -> str:
compact = str(text or "").strip().replace("\n", " ")
if len(compact) <= limit:
return compact
return compact[:limit] + "..."
def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str:
lines = []
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
time_label = _extract_time_label(item.metadata) if include_time else ""
prefix = f"[{time_label}] " if time_label else ""
lines.append(f"{index}. {prefix}{_truncate(item.content)}")
return "\n".join(lines)
def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
lines = []
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
metadata = item.metadata if isinstance(item.metadata, dict) else {}
title = str(item.title or "").strip() or "未命名事件"
summary = _truncate(item.content, limit=180)
participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()]
keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()]
extras = []
if participants:
extras.append(f"参与者:{''.join(participants[:4])}")
if keywords:
extras.append(f"关键词:{''.join(keywords[:6])}")
time_label = _extract_time_label(metadata)
if time_label:
extras.append(f"时间:{time_label}")
suffix = f"{''.join(extras)}" if extras else ""
lines.append(f"{index}. 事件《{title}》:{summary}{suffix}")
return "\n".join(lines)
def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
lines = []
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
metadata = item.metadata if isinstance(item.metadata, dict) else {}
source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()]
branch_text = f"[{','.join(source_branches)}]" if source_branches else ""
item_type = str(item.hit_type or "").strip().lower() or "memory"
if item_type == "episode":
title = str(item.title or "").strip() or "未命名事件"
lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}")
else:
lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}")
return "\n".join(lines)
def _format_tool_result(
*,
result: MemorySearchResult,
mode: Literal["search", "time", "episode", "aggregate"],
limit: int,
query: str,
time_range_text: str = "",
) -> str:
if not result.hits:
if mode == "time":
return f"在指定时间范围内未找到相关的长期记忆{time_range_text}"
if mode == "episode":
return f"未找到与“{query}”相关的事件或情节记忆"
if mode == "aggregate":
return f"未找到可用于综合回忆的长期记忆线索{f'query{query}' if query else ''}"
return f"在长期记忆中未找到与“{query}”相关的信息"
if mode == "episode":
text = _format_episode_lines(result.hits, limit=limit)
return f"你从长期记忆的事件/情节中找到以下信息:\n{text}"
if mode == "aggregate":
text = _format_aggregate_lines(result.hits, limit=limit)
return f"你从长期记忆中综合找到了以下线索:\n{text}"
if mode == "time":
text = _format_search_lines(result.hits, limit=limit, include_time=True)
return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}\n{text}"
text = _format_search_lines(result.hits, limit=limit)
return f"你从长期记忆中找到以下信息:\n{text}"
async def query_long_term_memory(
query: str = "",
limit: int = 5,
chat_id: str = "",
person_id: str = "",
mode: str = "search",
time_expression: str = "",
) -> str:
content = str(query or "").strip()
safe_limit = max(1, int(limit or 5))
normalized_mode = str(mode or "search").strip().lower() or "search"
if normalized_mode not in _SUPPORTED_MODES:
return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式search、time、episode、aggregate。"
if normalized_mode == "search" and not content:
return "查询关键词为空,请提供你想查找的长期记忆内容。"
if normalized_mode == "time" and not str(time_expression or "").strip():
return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}"
if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip():
return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。"
time_start = None
time_end = None
time_range_text = ""
if str(time_expression or "").strip():
try:
time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression)
except ValueError as exc:
return str(exc)
time_range_text = f"(时间范围:{time_start_text}{time_end_text}"
backend_mode = "hybrid" if normalized_mode == "search" else normalized_mode
try:
result = await memory_service.search(
content,
limit=safe_limit,
mode=backend_mode,
chat_id=str(chat_id or "").strip(),
person_id=str(person_id or "").strip(),
time_start=time_start,
time_end=time_end,
)
text = _format_tool_result(
result=result,
mode=normalized_mode, # type: ignore[arg-type]
limit=safe_limit,
query=content,
time_range_text=time_range_text,
)
logger.debug(f"长期记忆查询结果({normalized_mode}): {text}")
return text
except Exception as exc:
logger.error(f"长期记忆查询失败: {exc}")
return f"长期记忆查询失败:{exc}"
def register_tool():
register_memory_retrieval_tool(
name="search_long_term_memory",
description=(
"从长期记忆中检索信息。支持 search普通事实检索、time按时间范围检索"
"episode按事件/情节检索、aggregate综合检索四种模式。"
),
parameters=[
{
"name": "query",
"type": "string",
"description": "需要查询的问题。search 模式建议用自然语言问句time/episode/aggregate 模式也可用关键词短语。",
"required": False,
},
{
"name": "mode",
"type": "string",
"description": "检索模式search普通长期记忆、time按时间窗口、episode事件/情节、aggregate综合检索",
"required": False,
"enum": ["search", "time", "episode", "aggregate"],
},
{
"name": "limit",
"type": "integer",
"description": "希望返回的相关知识条数默认为5",
"required": False,
},
{
"name": "chat_id",
"type": "string",
"description": "当前聊天流ID可选。提供后优先检索当前聊天上下文相关的长期记忆。",
"required": False,
},
{
"name": "person_id",
"type": "string",
"description": "相关人物ID可选。提供后优先检索该人物相关的长期记忆。",
"required": False,
},
{
"name": "time_expression",
"type": "string",
"description": (
"时间表达可选。time 模式必填episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、"
"最近N天以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。"
),
"required": False,
},
],
execute_func=query_long_term_memory,
)

View File

@@ -1,75 +0,0 @@
"""
通过LPMM知识库查询信息 - 工具实现
"""
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
"""在LPMM知识库中查询相关信息
Args:
query: 查询关键词
Returns:
str: 查询结果
"""
try:
content = str(query).strip()
if not content:
return "查询关键词为空"
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用")
return "LPMM知识库未启用"
qa_manager = get_qa_manager()
if qa_manager is None:
logger.debug("LPMM知识库未初始化跳过查询")
return "LPMM知识库未初始化"
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
if knowledge_info:
return f"你从LPMM知识库中找到以下信息\n{knowledge_info}"
return f"在LPMM知识库中未找到与“{content}”相关的信息"
except Exception as e:
logger.error(f"LPMM知识库查询失败: {e}")
return f"LPMM知识库查询失败{str(e)}"
def register_tool():
"""注册LPMM知识库查询工具"""
register_memory_retrieval_tool(
name="lpmm_search_knowledge",
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
parameters=[
{
"name": "query",
"type": "string",
"description": "需要查询的问题使用一句疑问句提问例如什么是AI",
"required": True,
},
{
"name": "limit",
"type": "integer",
"description": "希望返回的相关知识条数默认为5",
"required": False,
},
],
execute_func=query_lpmm_knowledge,
)

View File

@@ -6,9 +6,10 @@ import random
import math
from json_repair import repair_json
from typing import Union, Optional, Dict
from typing import Union, Optional, Dict, List
from datetime import datetime
from sqlalchemy import or_
from sqlmodel import col, select
from src.common.logger import get_logger
@@ -17,6 +18,7 @@ from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.services.memory_service import memory_service
logger = get_logger("person_info")
@@ -37,16 +39,60 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
clean_name = str(person_name or "").strip()
if not clean_name:
return ""
try:
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
statement = (
select(PersonInfo)
.where(
or_(
col(PersonInfo.person_name) == clean_name,
col(PersonInfo.user_nickname) == clean_name,
)
)
.limit(1)
)
record = session.exec(statement).first()
if record and record.person_id:
return record.person_id
statement = (
select(PersonInfo)
.where(PersonInfo.group_cardname.contains(clean_name))
.limit(1)
)
record = session.exec(statement).first()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}")
return ""
def resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Optional[Union[int, str]] = None,
) -> str:
"""统一人物记忆链路中的 person_id 解析。
优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。
"""
name_token = str(person_name or "").strip()
if name_token:
resolved = get_person_id_by_person_name(name_token)
if resolved:
return resolved
platform_token = str(platform or "").strip()
user_token = str(user_id or "").strip()
if platform_token and user_token:
return get_person_id(platform_token, user_token)
return ""
def is_person_known(
person_id: Optional[str] = None,
user_id: Optional[str] = None,
@@ -537,79 +583,79 @@ class Person:
async def build_relationship(self, chat_content: str = "", info_type=""):
if not self.is_known:
return ""
# 构建points文本
nickname_str = ""
if self.person_name != self.nickname:
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
relation_info = ""
async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]:
clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()]
if not clean_traits:
return []
if not query_text:
return clean_traits[:limit]
points_text = ""
category_list = self.get_all_category()
numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1))
prompt = f"""当前关注内容:
{query_text}
if chat_content:
prompt = f"""当前聊天内容:
{chat_content}
候选人物信息:
{numbered_traits}
分类列表:
{category_list}
**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
例如:
<分类1><分类2><分类3>......
如果没有相关的分类,请输出<none>"""
请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。
例如:
<1><3>
如果都不相关,请输出<none>"""
response, _ = await relation_selection_model.generate_response_async(prompt)
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
if "none" not in category_list:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 2)
if random_memory:
random_memory_str = "\n".join(
[get_memory_content_from_memory(memory) for memory in random_memory]
)
points_text = f"有关 {category} 的内容:{random_memory_str}"
break
elif info_type:
prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。
try:
response, _ = await relation_selection_model.generate_response_async(prompt)
selected_traits: List[str] = []
for raw_index in extract_categories_from_response(response):
if raw_index == "none":
return []
try:
trait_index = int(raw_index) - 1
except ValueError:
continue
if 0 <= trait_index < len(clean_traits):
trait = clean_traits[trait_index]
if trait not in selected_traits:
selected_traits.append(trait)
if selected_traits:
return selected_traits[:limit]
except Exception as e:
logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}")
现有信息类别列表:
{category_list}
**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
例如:
<分类1><分类2><分类3>......
如果没有相关的分类,请输出<none>"""
response, _ = await relation_selection_model.generate_response_async(prompt)
# print(prompt)
# print(response)
category_list = extract_categories_from_response(response)
if "none" not in category_list:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 3)
if random_memory:
random_memory_str = "\n".join(
[get_memory_content_from_memory(memory) for memory in random_memory]
)
points_text = f"有关 {category} 的内容:{random_memory_str}"
break
else:
for category in category_list:
random_memory = self.get_random_memory_by_category(category, 1)[0]
if random_memory:
points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}"
break
return clean_traits[:limit]
profile = await memory_service.get_person_profile(self.person_id, limit=8)
relation_parts: List[str] = []
if profile.summary.strip():
relation_parts.append(profile.summary.strip())
query_text = str(chat_content or info_type or "").strip()
selected_traits = await _select_traits(query_text, profile.traits, limit=3)
if not selected_traits and not query_text:
selected_traits = [trait for trait in profile.traits if trait][:2]
for trait in selected_traits:
clean_trait = str(trait).strip()
if clean_trait and clean_trait not in relation_parts:
relation_parts.append(clean_trait)
for evidence in profile.evidence:
content = str(evidence.get("content", "") or "").strip()
if content and content not in relation_parts:
relation_parts.append(content)
if len(relation_parts) >= 4:
break
points_info = ""
if points_text:
points_info = f"你还记得有关{self.person_name}的内容:{points_text}"
if relation_parts:
points_info = f"你还记得有关{self.person_name}的内容:{''.join(relation_parts[:3])}"
if not (nickname_str or points_info):
return ""
relation_info = f"{self.person_name}:{nickname_str}{points_info}"
return relation_info
return f"{self.person_name}:{nickname_str}{points_info}"
class PersonInfoManager:
@@ -776,7 +822,7 @@ person_info_manager = PersonInfoManager()
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
"""将人物信息存入person_info的memory_points
"""将人物事实写入统一长期记忆
Args:
person_name: 人物名称
@@ -784,6 +830,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
chat_id: 聊天ID
"""
try:
content = str(memory_content or "").strip()
if not content:
logger.debug("人物记忆内容为空,跳过写入")
return
# 从 chat_id 获取 session
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
@@ -794,16 +845,14 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
# 尝试从person_name查找person_id
# 首先尝试通过person_name查找
person_id = get_person_id_by_person_name(person_name)
person_id = resolve_person_id_for_memory(
person_name=person_name,
platform=platform,
user_id=session.user_id,
)
if not person_id:
# 如果通过person_name找不到尝试从 session 获取 user_id
if platform and session.user_id:
user_id = session.user_id
person_id = get_person_id(platform, user_id)
else:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
return
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
return
# 创建或获取Person对象
person = Person(person_id=person_id)
@@ -812,39 +861,34 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
return
# 确定记忆分类可以根据memory_content判断这里使用通用分类
category = "其他" # 默认分类,可以根据需要调整
memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16]
result = await memory_service.ingest_text(
external_id=f"person_fact:{person_id}:{memory_hash}",
source_type="person_fact",
text=content,
chat_id=chat_id,
person_ids=[person_id],
participants=[person.person_name or person_name],
timestamp=time.time(),
tags=["person_fact"],
metadata={
"person_id": person_id,
"person_name": person.person_name or person_name,
"platform": platform,
"source": "person_info.store_person_memory_from_answer",
},
respect_filter=True,
user_id=str(session.user_id or "").strip(),
group_id=str(session.group_id or "").strip(),
)
# 记忆点格式category:content:weight
weight = "1.0" # 默认权重
memory_point = f"{category}:{memory_content}:{weight}"
# 添加到memory_points
if not person.memory_points:
person.memory_points = []
# 检查是否已存在相似的记忆点(避免重复)
is_duplicate = False
for existing_point in person.memory_points:
if existing_point and isinstance(existing_point, str):
parts = existing_point.split(":", 2)
if len(parts) >= 2:
existing_content = parts[1].strip()
# 简单相似度检查(如果内容相同或非常相似,则跳过)
if (
existing_content == memory_content
or memory_content in existing_content
or existing_content in memory_content
):
is_duplicate = True
break
if not is_duplicate:
person.memory_points.append(memory_point)
person.sync_to_database()
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
if result.success:
if result.detail == "chat_filtered":
logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})")
else:
logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})")
else:
logger.debug(f"记忆点已存在,跳过: {memory_point}")
logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}")
except Exception as e:
logger.error(f"存储人物记忆失败: {e}")

View File

@@ -672,12 +672,10 @@ class RuntimeDataCapabilityMixin:
limit_value = 5
try:
from src.chat.knowledge import qa_manager
from src.services.memory_service import memory_service
if qa_manager is None:
return {"success": True, "content": "LPMM知识库已禁用"}
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
result = await memory_service.search(query, limit=limit_value)
knowledge_info = result.to_text(limit=limit_value)
content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识"
return {"success": True, "content": content}
except Exception as e:

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
import asyncio
import json
from typing import Any, Dict, List, Optional
from json_repair import repair_json
from src.chat.utils.utils import is_bot_self
from src.common.message_repository import find_messages
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
logger = get_logger("memory_flow_service")
class LongTermMemorySessionManager:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._summarizers: Dict[str, ChatHistorySummarizer] = {}
async def on_message(self, message: Any) -> None:
if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)):
return
session_id = str(getattr(message, "session_id", "") or "").strip()
if not session_id:
return
created = False
async with self._lock:
summarizer = self._summarizers.get(session_id)
if summarizer is None:
summarizer = ChatHistorySummarizer(session_id=session_id)
self._summarizers[session_id] = summarizer
created = True
if created:
await summarizer.start()
async def shutdown(self) -> None:
async with self._lock:
items = list(self._summarizers.items())
self._summarizers.clear()
for session_id, summarizer in items:
try:
await summarizer.stop()
except Exception as exc:
logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc)
class PersonFactWritebackService:
def __init__(self) -> None:
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
self._worker_task: Optional[asyncio.Task] = None
self._stopping = False
self._extractor = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="person_fact_writeback",
)
async def start(self) -> None:
if self._worker_task is not None and not self._worker_task.done():
return
self._stopping = False
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_person_fact_writeback")
async def shutdown(self) -> None:
self._stopping = True
worker = self._worker_task
self._worker_task = None
if worker is None:
return
worker.cancel()
try:
await worker
except asyncio.CancelledError:
pass
except Exception as exc:
logger.warning("关闭人物事实写回 worker 失败: %s", exc)
async def enqueue(self, message: Any) -> None:
if not bool(getattr(global_config.memory, "person_fact_writeback_enabled", True)):
return
if self._stopping:
return
try:
self._queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning("人物事实写回队列已满,跳过本次回复")
async def _worker_loop(self) -> None:
try:
while not self._stopping:
message = await self._queue.get()
try:
await self._handle_message(message)
except Exception as exc:
logger.warning("人物事实写回处理失败: %s", exc, exc_info=True)
finally:
self._queue.task_done()
except asyncio.CancelledError:
raise
async def _handle_message(self, message: Any) -> None:
reply_text = str(getattr(message, "processed_plain_text", "") or "").strip()
if not reply_text:
return
if self._looks_ephemeral(reply_text):
return
target_person = self._resolve_target_person(message)
if target_person is None or not target_person.is_known:
return
facts = await self._extract_facts(target_person, reply_text)
if not facts:
return
session_id = str(
getattr(message, "session_id", "")
or getattr(getattr(message, "session", None), "session_id", "")
or ""
).strip()
if not session_id:
return
person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip()
if not person_name:
return
for fact in facts:
await store_person_memory_from_answer(person_name, fact, session_id)
def _resolve_target_person(self, message: Any) -> Optional[Person]:
session = getattr(message, "session", None)
session_platform = str(getattr(session, "platform", "") or getattr(message, "platform", "") or "").strip()
session_user_id = str(getattr(session, "user_id", "") or "").strip()
group_id = str(getattr(session, "group_id", "") or "").strip()
if session_platform and session_user_id and not group_id:
if is_bot_self(session_platform, session_user_id):
return None
person_id = get_person_id(session_platform, session_user_id)
person = Person(person_id=person_id)
return person if person.is_known else None
reply_to = str(getattr(message, "reply_to", "") or "").strip()
if not reply_to:
return None
try:
replies = find_messages(message_id=reply_to, limit=1)
except Exception as exc:
logger.debug("查询 reply_to 目标失败: %s", exc)
return None
if not replies:
return None
reply_message = replies[0]
reply_platform = str(getattr(reply_message, "platform", "") or session_platform or "").strip()
reply_user_info = getattr(getattr(reply_message, "message_info", None), "user_info", None)
reply_user_id = str(getattr(reply_user_info, "user_id", "") or "").strip()
if not reply_platform or not reply_user_id or is_bot_self(reply_platform, reply_user_id):
return None
person_id = get_person_id(reply_platform, reply_user_id)
person = Person(person_id=person_id)
return person if person.is_known else None
async def _extract_facts(self, person: Person, reply_text: str) -> List[str]:
person_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or person.person_id)
prompt = f"""你要从一条机器人刚刚发送的回复中,提取“关于{person_name}的稳定事实”。
目标人物:{person_name}
机器人回复:
{reply_text}
请只提取满足以下条件的事实:
1. 明确是关于目标人物本人的信息。
2. 具有相对稳定性,可以作为长期记忆保存。
3. 用简洁中文陈述句表达。
不要提取:
- 机器人的情绪、计划、临时动作、客套话
- 只适用于当前时刻的短期安排
- 不确定、猜测、反问
- 与目标人物无关的信息
严格输出 JSON 数组,例如:
["他喜欢深夜打游戏", "他养了一只猫"]
如果没有可写入的事实,输出 []"""
try:
response, _ = await self._extractor.generate_response_async(prompt)
except Exception as exc:
logger.debug("人物事实提取模型调用失败: %s", exc)
return []
return self._parse_fact_list(response)
@staticmethod
def _parse_fact_list(raw: str) -> List[str]:
text = str(raw or "").strip()
if not text:
return []
try:
repaired = repair_json(text)
payload = json.loads(repaired) if isinstance(repaired, str) else repaired
except Exception:
payload = None
if not isinstance(payload, list):
return []
items: List[str] = []
seen = set()
for item in payload:
fact = str(item or "").strip().strip("- ")
if not fact or len(fact) < 4:
continue
if fact in seen:
continue
seen.add(fact)
items.append(fact)
return items[:5]
@staticmethod
def _looks_ephemeral(text: str) -> bool:
content = str(text or "").strip()
if not content:
return True
ephemeral_markers = (
"哈哈",
"好的",
"收到",
"嗯嗯",
"晚安",
"早安",
"拜拜",
"谢谢",
"在吗",
"",
)
if len(content) <= 8 and any(marker in content for marker in ephemeral_markers):
return True
return False
class MemoryAutomationService:
def __init__(self) -> None:
self.session_manager = LongTermMemorySessionManager()
self.fact_writeback = PersonFactWritebackService()
self._started = False
async def start(self) -> None:
if self._started:
return
await self.fact_writeback.start()
self._started = True
async def shutdown(self) -> None:
if not self._started:
return
await self.session_manager.shutdown()
await self.fact_writeback.shutdown()
self._started = False
async def on_incoming_message(self, message: Any) -> None:
if not self._started:
await self.start()
await self.session_manager.on_message(message)
async def on_message_sent(self, message: Any) -> None:
if not self._started:
await self.start()
await self.fact_writeback.enqueue(message)
memory_automation_service = MemoryAutomationService()

View File

@@ -0,0 +1,428 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from src.plugin_runtime.integration import get_plugin_runtime_manager
logger = get_logger("memory_service")
PLUGIN_ID = "A_Memorix"
@dataclass
class MemoryHit:
content: str
score: float = 0.0
hit_type: str = ""
source: str = ""
hash_value: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
episode_id: str = ""
title: str = ""
def to_dict(self) -> Dict[str, Any]:
return {
"content": self.content,
"score": self.score,
"type": self.hit_type,
"source": self.source,
"hash": self.hash_value,
"metadata": self.metadata,
"episode_id": self.episode_id,
"title": self.title,
}
@dataclass
class MemorySearchResult:
summary: str = ""
hits: List[MemoryHit] = field(default_factory=list)
filtered: bool = False
def to_text(self, limit: int = 5) -> str:
if not self.hits:
return ""
lines = []
for index, item in enumerate(self.hits[: max(1, int(limit))], start=1):
content = item.content.strip().replace("\n", " ")
if len(content) > 160:
content = content[:160] + "..."
lines.append(f"{index}. {content}")
return "\n".join(lines)
def to_dict(self) -> Dict[str, Any]:
return {
"summary": self.summary,
"hits": [item.to_dict() for item in self.hits],
"filtered": self.filtered,
}
@dataclass
class MemoryWriteResult:
success: bool
stored_ids: List[str] = field(default_factory=list)
skipped_ids: List[str] = field(default_factory=list)
detail: str = ""
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
"stored_ids": self.stored_ids,
"skipped_ids": self.skipped_ids,
"detail": self.detail,
}
@dataclass
class PersonProfileResult:
summary: str = ""
traits: List[str] = field(default_factory=list)
evidence: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {"summary": self.summary, "traits": self.traits, "evidence": self.evidence}
class MemoryService:
async def _invoke(self, component_name: str, args: Optional[Dict[str, Any]] = None, *, timeout_ms: int = 30000) -> Any:
runtime = get_plugin_runtime_manager()
if not runtime.is_running:
raise RuntimeError("plugin_runtime 未启动")
return await runtime.invoke_plugin(
method="plugin.invoke_tool",
plugin_id=PLUGIN_ID,
component_name=component_name,
args=args or {},
timeout_ms=max(1000, int(timeout_ms or 30000)),
)
async def _invoke_admin(
self,
component_name: str,
*,
action: str,
timeout_ms: int = 30000,
**kwargs,
) -> Dict[str, Any]:
payload = await self._invoke(component_name, {"action": action, **kwargs}, timeout_ms=timeout_ms)
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
@staticmethod
def _coerce_write_result(payload: Any) -> MemoryWriteResult:
if not isinstance(payload, dict):
return MemoryWriteResult(success=False, detail="invalid_payload")
stored_ids = [str(item) for item in (payload.get("stored_ids") or []) if str(item).strip()]
skipped_ids = [str(item) for item in (payload.get("skipped_ids") or []) if str(item).strip()]
detail = str(payload.get("detail") or payload.get("reason") or "")
if stored_ids or skipped_ids:
success = True
elif "success" in payload:
success = bool(payload.get("success"))
else:
success = not bool(detail)
return MemoryWriteResult(
success=success,
stored_ids=stored_ids,
skipped_ids=skipped_ids,
detail=detail,
)
@staticmethod
def _coerce_search_result(payload: Any) -> MemorySearchResult:
if not isinstance(payload, dict):
return MemorySearchResult()
hits: List[MemoryHit] = []
for item in payload.get("hits", []) or []:
if not isinstance(item, dict):
continue
metadata = item.get("metadata", {}) or {}
if not isinstance(metadata, dict):
metadata = {}
if "source_branches" in item and "source_branches" not in metadata:
metadata["source_branches"] = item.get("source_branches") or []
if "rank" in item and "rank" not in metadata:
metadata["rank"] = item.get("rank")
hits.append(
MemoryHit(
content=str(item.get("content", "") or ""),
score=float(item.get("score", 0.0) or 0.0),
hit_type=str(item.get("type", "") or ""),
source=str(item.get("source", "") or ""),
hash_value=str(item.get("hash", "") or ""),
metadata=metadata,
episode_id=str(item.get("episode_id", "") or ""),
title=str(item.get("title", "") or ""),
)
)
return MemorySearchResult(
summary=str(payload.get("summary", "") or ""),
hits=hits,
filtered=bool(payload.get("filtered", False)),
)
@staticmethod
def _coerce_profile_result(payload: Any) -> PersonProfileResult:
if not isinstance(payload, dict):
return PersonProfileResult()
return PersonProfileResult(
summary=str(payload.get("summary", "") or ""),
traits=[str(item) for item in (payload.get("traits") or []) if str(item).strip()],
evidence=[item for item in (payload.get("evidence") or []) if isinstance(item, dict)],
)
async def search(
self,
query: str,
*,
limit: int = 5,
mode: str = "hybrid",
chat_id: str = "",
person_id: str = "",
time_start: str | float | None = None,
time_end: str | float | None = None,
respect_filter: bool = True,
user_id: str = "",
group_id: str = "",
) -> MemorySearchResult:
clean_query = str(query or "").strip()
normalized_time_start = None if time_start in {None, ""} else time_start
normalized_time_end = None if time_end in {None, ""} else time_end
if not clean_query and normalized_time_start is None and normalized_time_end is None:
return MemorySearchResult()
try:
payload = await self._invoke(
"search_memory",
{
"query": clean_query,
"limit": max(1, int(limit)),
"mode": mode,
"chat_id": chat_id,
"person_id": person_id,
"time_start": normalized_time_start,
"time_end": normalized_time_end,
"respect_filter": bool(respect_filter),
"user_id": str(user_id or "").strip(),
"group_id": str(group_id or "").strip(),
},
)
return self._coerce_search_result(payload)
except Exception as exc:
logger.warning("长期记忆搜索失败: %s", exc)
return MemorySearchResult()
async def ingest_summary(
self,
*,
external_id: str,
chat_id: str,
text: str,
participants: Optional[List[str]] = None,
time_start: float | None = None,
time_end: float | None = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
respect_filter: bool = True,
user_id: str = "",
group_id: str = "",
) -> MemoryWriteResult:
try:
payload = await self._invoke(
"ingest_summary",
{
"external_id": external_id,
"chat_id": chat_id,
"text": text,
"participants": participants or [],
"time_start": time_start,
"time_end": time_end,
"tags": tags or [],
"metadata": metadata or {},
"respect_filter": bool(respect_filter),
"user_id": str(user_id or "").strip(),
"group_id": str(group_id or "").strip(),
},
)
return self._coerce_write_result(payload)
except Exception as exc:
logger.warning("长期记忆写入摘要失败: %s", exc)
return MemoryWriteResult(success=False, detail=str(exc))
async def ingest_text(
self,
*,
external_id: str,
source_type: str,
text: str,
chat_id: str = "",
person_ids: Optional[List[str]] = None,
participants: Optional[List[str]] = None,
timestamp: float | None = None,
time_start: float | None = None,
time_end: float | None = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
entities: Optional[List[str]] = None,
relations: Optional[List[Dict[str, Any]]] = None,
respect_filter: bool = True,
user_id: str = "",
group_id: str = "",
) -> MemoryWriteResult:
try:
payload = await self._invoke(
"ingest_text",
{
"external_id": external_id,
"source_type": source_type,
"text": text,
"chat_id": chat_id,
"person_ids": person_ids or [],
"participants": participants or [],
"timestamp": timestamp,
"time_start": time_start,
"time_end": time_end,
"tags": tags or [],
"metadata": metadata or {},
"entities": entities or [],
"relations": relations or [],
"respect_filter": bool(respect_filter),
"user_id": str(user_id or "").strip(),
"group_id": str(group_id or "").strip(),
},
)
return self._coerce_write_result(payload)
except Exception as exc:
logger.warning("长期记忆写入文本失败: %s", exc)
return MemoryWriteResult(success=False, detail=str(exc))
async def get_person_profile(self, person_id: str, *, chat_id: str = "", limit: int = 10) -> PersonProfileResult:
clean_person_id = str(person_id or "").strip()
if not clean_person_id:
return PersonProfileResult()
try:
payload = await self._invoke(
"get_person_profile",
{"person_id": clean_person_id, "chat_id": chat_id, "limit": max(1, int(limit))},
)
return self._coerce_profile_result(payload)
except Exception as exc:
logger.warning("获取人物画像失败: %s", exc)
return PersonProfileResult()
async def maintain_memory(
self,
*,
action: str,
target: str = "",
hours: float | None = None,
reason: str = "",
limit: int = 50,
) -> MemoryWriteResult:
try:
payload = await self._invoke(
"maintain_memory",
{"action": action, "target": target, "hours": hours, "reason": reason, "limit": limit},
)
if not isinstance(payload, dict):
return MemoryWriteResult(success=False, detail="invalid_payload")
return MemoryWriteResult(success=bool(payload.get("success")), detail=str(payload.get("detail", "") or ""))
except Exception as exc:
logger.warning("记忆维护失败: %s", exc)
return MemoryWriteResult(success=False, detail=str(exc))
async def memory_stats(self) -> Dict[str, Any]:
try:
payload = await self._invoke("memory_stats", {})
return payload if isinstance(payload, dict) else {}
except Exception as exc:
logger.warning("获取记忆统计失败: %s", exc)
return {}
async def graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_graph_admin", action=action, **kwargs)
except Exception as exc:
logger.warning("图谱管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_source_admin", action=action, **kwargs)
except Exception as exc:
logger.warning("来源管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_episode_admin", action=action, **kwargs)
except Exception as exc:
logger.warning("Episode 管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_profile_admin", action=action, **kwargs)
except Exception as exc:
logger.warning("画像管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs)
except Exception as exc:
logger.warning("运行时管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def import_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_import_admin", action=action, timeout_ms=timeout_ms, **kwargs)
except Exception as exc:
logger.warning("导入管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def tuning_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_tuning_admin", action=action, timeout_ms=timeout_ms, **kwargs)
except Exception as exc:
logger.warning("调优管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def v5_admin(self, *, action: str, timeout_ms: int = 30000, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_v5_admin", action=action, timeout_ms=timeout_ms, **kwargs)
except Exception as exc:
logger.warning("V5 记忆管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def delete_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
try:
return await self._invoke_admin("memory_delete_admin", action=action, timeout_ms=timeout_ms, **kwargs)
except Exception as exc:
logger.warning("删除管理调用失败: %s", exc)
return {"success": False, "error": str(exc)}
async def get_recycle_bin(self, *, limit: int = 50) -> Dict[str, Any]:
try:
payload = await self._invoke("maintain_memory", {"action": "recycle_bin", "limit": max(1, int(limit or 50))})
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
except Exception as exc:
logger.warning("获取回收站失败: %s", exc)
return {"success": False, "error": str(exc)}
async def restore_memory(self, *, target: str) -> MemoryWriteResult:
return await self.maintain_memory(action="restore", target=target)
async def reinforce_memory(self, *, target: str) -> MemoryWriteResult:
return await self.maintain_memory(action="reinforce", target=target)
async def freeze_memory(self, *, target: str) -> MemoryWriteResult:
return await self.maintain_memory(action="freeze", target=target)
async def protect_memory(self, *, target: str, hours: float | None = None) -> MemoryWriteResult:
return await self.maintain_memory(action="protect", target=target, hours=hours)
memory_service = MemoryService()

View File

@@ -17,14 +17,14 @@ def get_all_routers() -> List[APIRouter]:
from src.webui.api.planner import router as planner_router
from src.webui.api.replier import router as replier_router
from src.webui.routers.chat import router as chat_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.memory import compat_router as memory_compat_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routes import router as main_router
return [
main_router,
memory_compat_router,
logs_router,
knowledge_router,
chat_router,
planner_router,
replier_router,

1395
src/webui/routers/memory.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -16,6 +16,7 @@ from src.webui.routers.config import router as config_router
from src.webui.routers.emoji import router as emoji_router
from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.memory import router as memory_router
from src.webui.routers.model import router as model_router
from src.webui.routers.person import router as person_router
from src.webui.routers.plugin import get_progress_router
@@ -49,6 +50,8 @@ router.include_router(get_progress_router())
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册长期记忆管理路由
router.include_router(memory_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)