feat:新增记忆测试、检索工具与服务
新增完整的长期记忆支持及测试:引入中文记忆检索提示词、query_long_term_memory 检索工具、记忆服务与记忆流程服务,以及 WebUI 的记忆路由。新增大规模测试套件(包括单元测试与基准/在线测试),覆盖聊天历史摘要、知识获取器、事件(episode)生成、写回机制以及用户画像检索等功能。 更新多个模块以集成记忆检索能力(包括 knowledge fetcher、chat summarizer、memory_retrieval、person_info、config/legacy 迁移以及 WebUI 路由),并移除遗留的 lpmm 知识模块。这些变更完成了记忆运行时的接入,同时为基准测试提供嵌入适配器的 mock,并支持新测试与工具所需的导入与 episode 处理流程。
This commit is contained in:
26
prompts/zh-CN/memory_get_knowledge.prompt
Normal file
26
prompts/zh-CN/memory_get_knowledge.prompt
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
你是一个专门获取长期记忆的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||||
|
群里正在进行的聊天内容:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||||
|
请仔细分析聊天内容,考虑以下几点:
|
||||||
|
1. 内容中是否包含需要查询历史知识或长期记忆的问题
|
||||||
|
2. 是否有明确的知识获取指令
|
||||||
|
|
||||||
|
如果需要使用长期记忆工具,请直接调用函数 `search_long_term_memory`;如果不需要任何工具,直接输出 `No tool needed`。
|
||||||
|
|
||||||
|
工具模式说明:
|
||||||
|
- `mode="search"`:普通长期记忆检索,适合查具体事实、偏好、历史对话内容
|
||||||
|
- `mode="time"`:按时间范围检索,必须同时提供 `time_expression`
|
||||||
|
- `mode="episode"`:按事件/情节检索,适合查“那次经历”“那件事的经过”
|
||||||
|
- `mode="aggregate"`:综合检索,适合“整体回忆一下”“把相关线索综合找出来”
|
||||||
|
|
||||||
|
优先规则:
|
||||||
|
- 问“某段时间发生了什么”:优先 `time`
|
||||||
|
- 问“某次事件/某段经历”:优先 `episode`
|
||||||
|
- 问“整体情况/最近发生过什么”:优先 `aggregate`
|
||||||
|
- 问单点事实:优先 `search`
|
||||||
|
|
||||||
|
`time_expression` 可用表达:
|
||||||
|
- `今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
|
||||||
|
- 或绝对时间:`2026/03/18`、`2026/03/18 09:30`
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
你的名字是{bot_name}。现在是{time_now}。
|
||||||
|
你正在参与聊天,你需要搜集信息来帮助你进行回复。
|
||||||
|
重要,这是当前聊天记录:
|
||||||
|
{chat_history}
|
||||||
|
聊天记录结束
|
||||||
|
|
||||||
|
已收集的信息:
|
||||||
|
{collected_info}
|
||||||
|
|
||||||
|
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
|
||||||
|
- 思考完毕后,使用工具
|
||||||
|
|
||||||
|
**工具说明:**
|
||||||
|
- 如果涉及过往事件、历史对话、用户长期偏好或某段时间发生的事件,可以使用长期记忆查询工具
|
||||||
|
- 如果遇到不熟悉的词语、缩写、黑话或网络用语,可以使用query_words工具查询其含义
|
||||||
|
- 你必须使用tool,如果需要查询你必须给出使用什么工具进行查询
|
||||||
|
- 当你决定结束查询时,必须调用return_information工具返回总结信息并结束查询
|
||||||
|
|
||||||
|
长期记忆工具 `search_long_term_memory` 支持以下模式:
|
||||||
|
- `mode="search"`:普通事实/偏好/历史内容检索。适合问“她喜欢什么”“我们之前讨论过什么”。
|
||||||
|
- `mode="time"`:按时间范围检索。适合问“昨天发生了什么”“最近7天有哪些相关记忆”。
|
||||||
|
- `mode="episode"`:按事件/情节检索。适合问“那次灯塔停电的经过是什么”“关于某次经历还有什么”。
|
||||||
|
- `mode="aggregate"`:综合检索。适合问“帮我整体回忆一下这个人最近的情况”“把相关线索综合找出来”。
|
||||||
|
|
||||||
|
模式选择建议:
|
||||||
|
- 问单点事实、偏好、人设、具体信息:优先 `search`
|
||||||
|
- 问某段时间发生了什么:优先 `time`
|
||||||
|
- 问某次事件、某段经历、某个剧情片段:优先 `episode`
|
||||||
|
- 问整体回忆、综合找线索、总结最近发生的事:优先 `aggregate`
|
||||||
|
|
||||||
|
时间模式要求:
|
||||||
|
- 使用 `mode="time"` 时,必须填写 `time_expression`
|
||||||
|
- 可用时间表达包括:`今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天`
|
||||||
|
- 也可以使用绝对时间:`2026/03/18`、`2026/03/18 09:30`
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||||
|
|
||||||
|
|
||||||
|
def _build_summarizer() -> summarizer_module.ChatHistorySummarizer:
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer)
|
||||||
|
summarizer.session_id = "session-1"
|
||||||
|
summarizer.log_prefix = "[session-1]"
|
||||||
|
return summarizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
summarizer = _build_summarizer()
|
||||||
|
|
||||||
|
async def fake_ingest_summary(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
summarizer_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8)))
|
||||||
|
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||||
|
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=1,
|
||||||
|
theme="旅行计划",
|
||||||
|
summary="我们讨论了春游安排",
|
||||||
|
participants=["Alice", "Bob"],
|
||||||
|
start_time=1.0,
|
||||||
|
end_time=2.0,
|
||||||
|
original_text="long text",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
payload = calls[0]
|
||||||
|
assert payload["external_id"] == "chat_history:1"
|
||||||
|
assert payload["chat_id"] == "session-1"
|
||||||
|
assert payload["participants"] == ["Alice", "Bob"]
|
||||||
|
assert payload["respect_filter"] is True
|
||||||
|
assert payload["user_id"] == "user-1"
|
||||||
|
assert payload["group_id"] == ""
|
||||||
|
assert "主题:旅行计划" in payload["text"]
|
||||||
|
assert "概括:我们讨论了春游安排" in payload["text"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch):
|
||||||
|
summarizer = _build_summarizer()
|
||||||
|
fallback_calls = []
|
||||||
|
|
||||||
|
async def fake_fallback(**kwargs):
|
||||||
|
fallback_calls.append(kwargs)
|
||||||
|
|
||||||
|
summarizer._fallback_import_to_long_term_memory = fake_fallback
|
||||||
|
monkeypatch.setattr(
|
||||||
|
summarizer_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
|
||||||
|
)
|
||||||
|
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=2,
|
||||||
|
theme="",
|
||||||
|
summary="",
|
||||||
|
participants=[],
|
||||||
|
start_time=10.0,
|
||||||
|
end_time=20.0,
|
||||||
|
original_text="raw chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(fallback_calls) == 1
|
||||||
|
assert fallback_calls[0]["record_id"] == 2
|
||||||
|
assert fallback_calls[0]["original_text"] == "raw chat"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch):
|
||||||
|
summarizer = _build_summarizer()
|
||||||
|
fallback_calls = []
|
||||||
|
|
||||||
|
async def fake_ingest_summary(**kwargs):
|
||||||
|
return SimpleNamespace(success=False, detail="boom", stored_ids=[])
|
||||||
|
|
||||||
|
async def fake_fallback(**kwargs):
|
||||||
|
fallback_calls.append(kwargs)
|
||||||
|
|
||||||
|
summarizer._fallback_import_to_long_term_memory = fake_fallback
|
||||||
|
monkeypatch.setattr(
|
||||||
|
summarizer_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||||
|
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=3,
|
||||||
|
theme="电影",
|
||||||
|
summary="聊了电影推荐",
|
||||||
|
participants=["Alice"],
|
||||||
|
start_time=3.0,
|
||||||
|
end_time=4.0,
|
||||||
|
original_text="raw",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(fallback_calls) == 1
|
||||||
|
assert fallback_calls[0]["theme"] == "电影"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
summarizer = _build_summarizer()
|
||||||
|
|
||||||
|
async def fake_ingest_summary(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
summarizer_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12)))
|
||||||
|
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||||
|
|
||||||
|
await summarizer._fallback_import_to_long_term_memory(
|
||||||
|
record_id=4,
|
||||||
|
theme="工作",
|
||||||
|
participants=["Alice"],
|
||||||
|
start_time=5.0,
|
||||||
|
end_time=6.0,
|
||||||
|
original_text="a" * 128,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
metadata = calls[0]["metadata"]
|
||||||
|
assert metadata["generate_from_chat"] is True
|
||||||
|
assert metadata["context_length"] == 12
|
||||||
|
assert calls[0]["respect_filter"] is True
|
||||||
|
|
||||||
127
pytests/A_memorix_test/test_knowledge_fetcher.py
Normal file
127
pytests/A_memorix_test/test_knowledge_fetcher.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||||
|
from src.services.memory_service import MemoryHit, MemorySearchResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch):
|
||||||
|
monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
knowledge_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
knowledge_module,
|
||||||
|
"resolve_person_id_for_memory",
|
||||||
|
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
||||||
|
|
||||||
|
assert fetcher._resolve_private_memory_context() == {
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "Alice:qq:42",
|
||||||
|
"user_id": "42",
|
||||||
|
"group_id": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch):
|
||||||
|
monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
knowledge_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
knowledge_module,
|
||||||
|
"resolve_person_id_for_memory",
|
||||||
|
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_search(query: str, **kwargs):
|
||||||
|
calls.append((query, kwargs))
|
||||||
|
return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
||||||
|
result = await fetcher._memory_get_knowledge("她喜欢什么")
|
||||||
|
|
||||||
|
assert "1. 她喜欢猫" in result
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"她喜欢什么",
|
||||||
|
{
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "search",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "Alice:qq:42",
|
||||||
|
"user_id": "42",
|
||||||
|
"group_id": "",
|
||||||
|
"respect_filter": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch):
|
||||||
|
monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
knowledge_module,
|
||||||
|
"_chat_manager",
|
||||||
|
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
knowledge_module,
|
||||||
|
"resolve_person_id_for_memory",
|
||||||
|
lambda *, person_name, platform, user_id: "person-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_search(query: str, **kwargs):
|
||||||
|
calls.append((query, kwargs))
|
||||||
|
if kwargs.get("person_id"):
|
||||||
|
return MemorySearchResult(summary="", hits=[], filtered=False)
|
||||||
|
return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
||||||
|
result = await fetcher._memory_get_knowledge("Alice 最近在忙什么")
|
||||||
|
|
||||||
|
assert "杭州音乐节" in result
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"Alice 最近在忙什么",
|
||||||
|
{
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "search",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "person-1",
|
||||||
|
"user_id": "42",
|
||||||
|
"group_id": "",
|
||||||
|
"respect_filter": True,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Alice 最近在忙什么",
|
||||||
|
{
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "search",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "",
|
||||||
|
"user_id": "42",
|
||||||
|
"group_id": "",
|
||||||
|
"respect_filter": True,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
35
pytests/A_memorix_test/test_legacy_config_migration.py
Normal file
35
pytests/A_memorix_test/test_legacy_config_migration.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from src.config.legacy_migration import try_migrate_legacy_bot_config_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
|
||||||
|
payload = {
|
||||||
|
"expression": {
|
||||||
|
"learning_list": [
|
||||||
|
["qq:123456:group", "enable", "disable", "0.5"],
|
||||||
|
["", "disable", "enable", "0.1"],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = try_migrate_legacy_bot_config_dict(payload)
|
||||||
|
|
||||||
|
assert result.migrated is True
|
||||||
|
assert "expression.learning_list" in result.reason
|
||||||
|
assert result.data["expression"]["learning_list"] == [
|
||||||
|
{
|
||||||
|
"platform": "qq",
|
||||||
|
"item_id": "123456",
|
||||||
|
"rule_type": "group",
|
||||||
|
"use_expression": True,
|
||||||
|
"enable_learning": False,
|
||||||
|
"enable_jargon_learning": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"platform": "",
|
||||||
|
"item_id": "",
|
||||||
|
"rule_type": "group",
|
||||||
|
"use_expression": False,
|
||||||
|
"enable_learning": True,
|
||||||
|
"enable_jargon_learning": False,
|
||||||
|
},
|
||||||
|
]
|
||||||
691
pytests/A_memorix_test/test_long_novel_memory_benchmark.py
Normal file
691
pytests/A_memorix_test/test_long_novel_memory_benchmark.py
Normal file
@@ -0,0 +1,691 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||||
|
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||||
|
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||||
|
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||||
|
from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory
|
||||||
|
from src.person_info import person_info as person_info_module
|
||||||
|
from src.services import memory_service as memory_service_module
|
||||||
|
from src.services.memory_service import MemorySearchResult, memory_service
|
||||||
|
|
||||||
|
|
||||||
|
DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json"
|
||||||
|
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_benchmark_fixture() -> Dict[str, Any]:
|
||||||
|
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEmbeddingAdapter:
|
||||||
|
def __init__(self, dimension: int = 32) -> None:
|
||||||
|
self.dimension = dimension
|
||||||
|
|
||||||
|
async def _detect_dimension(self) -> int:
|
||||||
|
return self.dimension
|
||||||
|
|
||||||
|
async def encode(self, texts, dimensions=None):
|
||||||
|
dim = int(dimensions or self.dimension)
|
||||||
|
if isinstance(texts, str):
|
||||||
|
sequence = [texts]
|
||||||
|
single = True
|
||||||
|
else:
|
||||||
|
sequence = list(texts)
|
||||||
|
single = False
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for text in sequence:
|
||||||
|
vec = np.zeros(dim, dtype=np.float32)
|
||||||
|
for ch in str(text or ""):
|
||||||
|
code = ord(ch)
|
||||||
|
vec[code % dim] += 1.0
|
||||||
|
vec[(code * 7) % dim] += 0.5
|
||||||
|
if not vec.any():
|
||||||
|
vec[0] = 1.0
|
||||||
|
norm = np.linalg.norm(vec)
|
||||||
|
if norm > 0:
|
||||||
|
vec = vec / norm
|
||||||
|
rows.append(vec)
|
||||||
|
payload = np.vstack(rows)
|
||||||
|
return payload[0] if single else payload
|
||||||
|
|
||||||
|
|
||||||
|
class _KnownPerson:
|
||||||
|
def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None:
|
||||||
|
self.person_id = person_id
|
||||||
|
self.is_known = person_id in reverse_registry
|
||||||
|
self.person_name = reverse_registry.get(person_id, "")
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
|
||||||
|
class _KernelBackedRuntimeManager:
|
||||||
|
is_running = True
|
||||||
|
|
||||||
|
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||||
|
self.kernel = kernel
|
||||||
|
|
||||||
|
async def invoke_plugin(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
plugin_id: str,
|
||||||
|
component_name: str,
|
||||||
|
args: Dict[str, Any] | None,
|
||||||
|
timeout_ms: int,
|
||||||
|
):
|
||||||
|
del method, plugin_id, timeout_ms
|
||||||
|
payload = args or {}
|
||||||
|
if component_name == "search_memory":
|
||||||
|
return await self.kernel.search_memory(
|
||||||
|
KernelSearchRequest(
|
||||||
|
query=str(payload.get("query", "") or ""),
|
||||||
|
limit=int(payload.get("limit", 5) or 5),
|
||||||
|
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||||
|
chat_id=str(payload.get("chat_id", "") or ""),
|
||||||
|
person_id=str(payload.get("person_id", "") or ""),
|
||||||
|
time_start=payload.get("time_start"),
|
||||||
|
time_end=payload.get("time_end"),
|
||||||
|
respect_filter=bool(payload.get("respect_filter", True)),
|
||||||
|
user_id=str(payload.get("user_id", "") or ""),
|
||||||
|
group_id=str(payload.get("group_id", "") or ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = getattr(self.kernel, component_name)
|
||||||
|
result = handler(**payload)
|
||||||
|
return await result if inspect.isawaitable(result) else result
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]:
|
||||||
|
for _ in range(max_rounds):
|
||||||
|
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||||
|
task = detail.get("task") or {}
|
||||||
|
status = str(task.get("status", "") or "")
|
||||||
|
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||||
|
return detail
|
||||||
|
await asyncio.sleep(max(0.01, float(sleep_seconds)))
|
||||||
|
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _join_hit_content(search_result: MemorySearchResult) -> str:
|
||||||
|
return "\n".join(hit.content for hit in search_result.hits)
|
||||||
|
|
||||||
|
|
||||||
|
def _keyword_hits(text: str, keywords: List[str]) -> int:
|
||||||
|
haystack = str(text or "")
|
||||||
|
return sum(1 for keyword in keywords if keyword in haystack)
|
||||||
|
|
||||||
|
|
||||||
|
def _keyword_recall(text: str, keywords: List[str]) -> float:
|
||||||
|
if not keywords:
|
||||||
|
return 1.0
|
||||||
|
return _keyword_hits(text, keywords) / float(len(keywords))
|
||||||
|
|
||||||
|
|
||||||
|
def _hit_blob(hit) -> str:
|
||||||
|
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
|
||||||
|
return "\n".join(
|
||||||
|
[
|
||||||
|
str(hit.content or ""),
|
||||||
|
str(hit.title or ""),
|
||||||
|
str(hit.source or ""),
|
||||||
|
json.dumps(meta, ensure_ascii=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int:
|
||||||
|
for index, hit in enumerate(search_result.hits[:5], start=1):
|
||||||
|
if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))):
|
||||||
|
return index
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str:
|
||||||
|
return "\n".join(
|
||||||
|
(
|
||||||
|
f"{item.get('title', '')}\n"
|
||||||
|
f"{item.get('summary', '')}\n"
|
||||||
|
f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n"
|
||||||
|
f"{json.dumps(item.get('participants', []), ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
for item in items
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _episode_blob_from_hits(search_result: MemorySearchResult) -> str:
|
||||||
|
chunks = []
|
||||||
|
for hit in search_result.hits:
|
||||||
|
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
|
||||||
|
chunks.append(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
str(hit.title or ""),
|
||||||
|
str(hit.content or ""),
|
||||||
|
json.dumps(meta.get("keywords", []) or [], ensure_ascii=False),
|
||||||
|
json.dumps(meta.get("participants", []) or [], ensure_ascii=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return "\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
episode_source = f"chat_summary:{session_id}"
|
||||||
|
payload = await memory_service.episode_admin(
|
||||||
|
action="query",
|
||||||
|
source=episode_source,
|
||||||
|
limit=20,
|
||||||
|
)
|
||||||
|
items = payload.get("items") or []
|
||||||
|
blob = _episode_blob_from_items(items)
|
||||||
|
reports: List[Dict[str, Any]] = []
|
||||||
|
success_rate = 0.0
|
||||||
|
keyword_recall = 0.0
|
||||||
|
|
||||||
|
for case in episode_cases:
|
||||||
|
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||||
|
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||||
|
success_rate += 1.0 if success else 0.0
|
||||||
|
keyword_recall += recall
|
||||||
|
reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"episode_count": len(items),
|
||||||
|
"top_episode": items[0] if items else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
total = max(1, len(episode_cases))
|
||||||
|
return {
|
||||||
|
"success_rate": round(success_rate / total, 4),
|
||||||
|
"keyword_recall": round(keyword_recall / total, 4),
|
||||||
|
"episode_count": len(items),
|
||||||
|
"reports": reports,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
reports: List[Dict[str, Any]] = []
|
||||||
|
success_rate = 0.0
|
||||||
|
keyword_recall = 0.0
|
||||||
|
episode_source = f"chat_summary:{session_id}"
|
||||||
|
|
||||||
|
for case in episode_cases:
|
||||||
|
payload = await memory_service.episode_admin(
|
||||||
|
action="query",
|
||||||
|
source=episode_source,
|
||||||
|
query=case["query"],
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
items = payload.get("items") or []
|
||||||
|
blob = "\n".join(
|
||||||
|
f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}"
|
||||||
|
for item in items
|
||||||
|
)
|
||||||
|
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||||
|
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||||
|
success_rate += 1.0 if success else 0.0
|
||||||
|
keyword_recall += recall
|
||||||
|
reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"episode_count": len(items),
|
||||||
|
"top_episode": items[0] if items else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
total = max(1, len(episode_cases))
|
||||||
|
return {
|
||||||
|
"success_rate": round(success_rate / total, 4),
|
||||||
|
"keyword_recall": round(keyword_recall / total, 4),
|
||||||
|
"reports": reports,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
reports: List[Dict[str, Any]] = []
|
||||||
|
success_rate = 0.0
|
||||||
|
keyword_recall = 0.0
|
||||||
|
|
||||||
|
for case in episode_cases:
|
||||||
|
result = await memory_service.search(
|
||||||
|
case["query"],
|
||||||
|
mode="episode",
|
||||||
|
chat_id=session_id,
|
||||||
|
respect_filter=False,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
blob = _episode_blob_from_hits(result)
|
||||||
|
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||||
|
success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||||
|
success_rate += 1.0 if success else 0.0
|
||||||
|
keyword_recall += recall
|
||||||
|
reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"episode_count": len(result.hits),
|
||||||
|
"top_episode": result.hits[0].to_dict() if result.hits else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
total = max(1, len(episode_cases))
|
||||||
|
return {
|
||||||
|
"success_rate": round(success_rate / total, 4),
|
||||||
|
"keyword_recall": round(keyword_recall / total, 4),
|
||||||
|
"reports": reports,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
search_case = dataset["search_cases"][0]
|
||||||
|
episode_case = dataset["episode_cases"][0]
|
||||||
|
aggregate_case = dataset["knowledge_fetcher_cases"][0]
|
||||||
|
tool_cases = [
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"kwargs": {
|
||||||
|
"query": "蓝漆铁盒 北塔木梯",
|
||||||
|
"mode": "search",
|
||||||
|
"chat_id": session_id,
|
||||||
|
"limit": 5,
|
||||||
|
},
|
||||||
|
"expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"],
|
||||||
|
"minimum_keyword_recall": 0.67,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "time",
|
||||||
|
"kwargs": {
|
||||||
|
"query": "蓝漆铁盒 北塔",
|
||||||
|
"mode": "time",
|
||||||
|
"chat_id": session_id,
|
||||||
|
"limit": 5,
|
||||||
|
"time_expression": "最近7天",
|
||||||
|
},
|
||||||
|
"expected_keywords": ["蓝漆铁盒", "北塔木梯"],
|
||||||
|
"minimum_keyword_recall": 0.67,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "episode",
|
||||||
|
"kwargs": {
|
||||||
|
"query": episode_case["query"],
|
||||||
|
"mode": "episode",
|
||||||
|
"chat_id": session_id,
|
||||||
|
"limit": 5,
|
||||||
|
},
|
||||||
|
"expected_keywords": episode_case["expected_keywords"],
|
||||||
|
"minimum_keyword_recall": 0.67,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "aggregate",
|
||||||
|
"kwargs": {
|
||||||
|
"query": aggregate_case["query"],
|
||||||
|
"mode": "aggregate",
|
||||||
|
"chat_id": session_id,
|
||||||
|
"limit": 5,
|
||||||
|
},
|
||||||
|
"expected_keywords": aggregate_case["expected_keywords"],
|
||||||
|
"minimum_keyword_recall": 0.67,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
reports: List[Dict[str, Any]] = []
|
||||||
|
success_rate = 0.0
|
||||||
|
keyword_recall = 0.0
|
||||||
|
|
||||||
|
for case in tool_cases:
|
||||||
|
text = await query_long_term_memory(**case["kwargs"])
|
||||||
|
recall = _keyword_recall(text, case["expected_keywords"])
|
||||||
|
success = (
|
||||||
|
"失败" not in text
|
||||||
|
and "无法解析" not in text
|
||||||
|
and "未找到" not in text
|
||||||
|
and recall >= float(case["minimum_keyword_recall"])
|
||||||
|
)
|
||||||
|
success_rate += 1.0 if success else 0.0
|
||||||
|
keyword_recall += recall
|
||||||
|
reports.append(
|
||||||
|
{
|
||||||
|
"name": case["name"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"preview": text[:320],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
total = max(1, len(tool_cases))
|
||||||
|
return {
|
||||||
|
"success_rate": round(success_rate / total, 4),
|
||||||
|
"keyword_recall": round(keyword_recall / total, 4),
|
||||||
|
"reports": reports,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def benchmark_env(monkeypatch, tmp_path):
|
||||||
|
dataset = _load_benchmark_fixture()
|
||||||
|
session_cfg = dataset["session"]
|
||||||
|
session = SimpleNamespace(
|
||||||
|
session_id=session_cfg["session_id"],
|
||||||
|
platform=session_cfg["platform"],
|
||||||
|
user_id=session_cfg["user_id"],
|
||||||
|
group_id=session_cfg["group_id"],
|
||||||
|
)
|
||||||
|
fake_chat_manager = SimpleNamespace(
|
||||||
|
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||||
|
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
|
||||||
|
reverse_registry = {value: key for key, value in registry.items()}
|
||||||
|
|
||||||
|
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
|
||||||
|
|
||||||
|
async def fake_self_check(**kwargs):
|
||||||
|
return {"ok": True, "message": "ok", "encoded_dimension": 32}
|
||||||
|
|
||||||
|
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
|
||||||
|
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
person_info_module,
|
||||||
|
"Person",
|
||||||
|
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
|
||||||
|
)
|
||||||
|
|
||||||
|
data_dir = (tmp_path / "a_memorix_benchmark_data").resolve()
|
||||||
|
kernel = SDKMemoryKernel(
|
||||||
|
plugin_root=tmp_path / "plugin_root",
|
||||||
|
config={
|
||||||
|
"storage": {"data_dir": str(data_dir)},
|
||||||
|
"advanced": {"enable_auto_save": False},
|
||||||
|
"memory": {"base_decay_interval_hours": 24},
|
||||||
|
"person_profile": {"refresh_interval_minutes": 5},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = _KernelBackedRuntimeManager(kernel)
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
|
||||||
|
|
||||||
|
await kernel.initialize()
|
||||||
|
try:
|
||||||
|
yield {
|
||||||
|
"dataset": dataset,
|
||||||
|
"kernel": kernel,
|
||||||
|
"session": session,
|
||||||
|
"person_registry": registry,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await kernel.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_novel_memory_benchmark(benchmark_env):
|
||||||
|
dataset = benchmark_env["dataset"]
|
||||||
|
session_id = benchmark_env["session"].session_id
|
||||||
|
|
||||||
|
created = await memory_service.import_admin(
|
||||||
|
action="create_paste",
|
||||||
|
name="long_novel_memory_benchmark.json",
|
||||||
|
input_mode="json",
|
||||||
|
llm_enabled=False,
|
||||||
|
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
assert created["success"] is True
|
||||||
|
|
||||||
|
import_detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||||
|
assert import_detail["task"]["status"] == "completed"
|
||||||
|
|
||||||
|
for record in dataset["chat_history_records"]:
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=record["record_id"],
|
||||||
|
theme=record["theme"],
|
||||||
|
summary=record["summary"],
|
||||||
|
participants=record["participants"],
|
||||||
|
start_time=record["start_time"],
|
||||||
|
end_time=record["end_time"],
|
||||||
|
original_text=record["original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for payload in dataset["person_writebacks"]:
|
||||||
|
await person_info_module.store_person_memory_from_answer(
|
||||||
|
payload["person_name"],
|
||||||
|
payload["memory_content"],
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
|
||||||
|
|
||||||
|
search_case_reports: List[Dict[str, Any]] = []
|
||||||
|
search_accuracy_at_1 = 0.0
|
||||||
|
search_recall_at_5 = 0.0
|
||||||
|
search_precision_at_5 = 0.0
|
||||||
|
search_mrr = 0.0
|
||||||
|
search_keyword_recall = 0.0
|
||||||
|
|
||||||
|
for case in dataset["search_cases"]:
|
||||||
|
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
|
||||||
|
joined = _join_hit_content(result)
|
||||||
|
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
|
||||||
|
relevant_hits = sum(
|
||||||
|
1
|
||||||
|
for hit in result.hits[:5]
|
||||||
|
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
|
||||||
|
)
|
||||||
|
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
|
||||||
|
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
|
||||||
|
search_recall_at_5 += 1.0 if rank > 0 else 0.0
|
||||||
|
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
|
||||||
|
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
|
||||||
|
search_keyword_recall += keyword_recall
|
||||||
|
search_case_reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"rank_of_first_relevant": rank,
|
||||||
|
"relevant_hits_top5": relevant_hits,
|
||||||
|
"keyword_recall_top5": keyword_recall,
|
||||||
|
"top_hit": result.hits[0].to_dict() if result.hits else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
search_total = max(1, len(dataset["search_cases"]))
|
||||||
|
|
||||||
|
writeback_reports: List[Dict[str, Any]] = []
|
||||||
|
writeback_success_rate = 0.0
|
||||||
|
writeback_keyword_recall = 0.0
|
||||||
|
for payload in dataset["person_writebacks"]:
|
||||||
|
query = " ".join(payload["expected_keywords"])
|
||||||
|
result = await memory_service.search(
|
||||||
|
query,
|
||||||
|
mode="search",
|
||||||
|
chat_id=session_id,
|
||||||
|
person_id=payload["person_id"],
|
||||||
|
respect_filter=False,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
joined = _join_hit_content(result)
|
||||||
|
recall = _keyword_recall(joined, payload["expected_keywords"])
|
||||||
|
success = bool(result.hits) and recall >= 0.67
|
||||||
|
writeback_success_rate += 1.0 if success else 0.0
|
||||||
|
writeback_keyword_recall += recall
|
||||||
|
writeback_reports.append(
|
||||||
|
{
|
||||||
|
"person_id": payload["person_id"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"hit_count": len(result.hits),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
writeback_total = max(1, len(dataset["person_writebacks"]))
|
||||||
|
|
||||||
|
knowledge_reports: List[Dict[str, Any]] = []
|
||||||
|
knowledge_success_rate = 0.0
|
||||||
|
knowledge_keyword_recall = 0.0
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(
|
||||||
|
private_name=dataset["session"]["display_name"],
|
||||||
|
stream_id=session_id,
|
||||||
|
)
|
||||||
|
for case in dataset["knowledge_fetcher_cases"]:
|
||||||
|
knowledge_text, _ = await fetcher.fetch(case["query"], [])
|
||||||
|
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
|
||||||
|
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||||
|
knowledge_success_rate += 1.0 if success else 0.0
|
||||||
|
knowledge_keyword_recall += recall
|
||||||
|
knowledge_reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"preview": knowledge_text[:300],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
|
||||||
|
|
||||||
|
profile_reports: List[Dict[str, Any]] = []
|
||||||
|
profile_success_rate = 0.0
|
||||||
|
profile_keyword_recall = 0.0
|
||||||
|
profile_evidence_rate = 0.0
|
||||||
|
for case in dataset["profile_cases"]:
|
||||||
|
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
|
||||||
|
recall = _keyword_recall(profile.summary, case["expected_keywords"])
|
||||||
|
has_evidence = bool(profile.evidence)
|
||||||
|
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
|
||||||
|
profile_success_rate += 1.0 if success else 0.0
|
||||||
|
profile_keyword_recall += recall
|
||||||
|
profile_evidence_rate += 1.0 if has_evidence else 0.0
|
||||||
|
profile_reports.append(
|
||||||
|
{
|
||||||
|
"person_id": case["person_id"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"evidence_count": len(profile.evidence),
|
||||||
|
"summary_preview": profile.summary[:240],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
profile_total = max(1, len(dataset["profile_cases"]))
|
||||||
|
|
||||||
|
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_rebuild = await memory_service.episode_admin(
|
||||||
|
action="rebuild",
|
||||||
|
source=f"chat_summary:{session_id}",
|
||||||
|
)
|
||||||
|
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
|
||||||
|
|
||||||
|
report = {
|
||||||
|
"dataset": dataset["meta"],
|
||||||
|
"import": {
|
||||||
|
"task_id": created["task"]["task_id"],
|
||||||
|
"status": import_detail["task"]["status"],
|
||||||
|
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"search": {
|
||||||
|
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
|
||||||
|
"recall_at_5": round(search_recall_at_5 / search_total, 4),
|
||||||
|
"precision_at_5": round(search_precision_at_5 / search_total, 4),
|
||||||
|
"mrr": round(search_mrr / search_total, 4),
|
||||||
|
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
|
||||||
|
},
|
||||||
|
"writeback": {
|
||||||
|
"success_rate": round(writeback_success_rate / writeback_total, 4),
|
||||||
|
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
|
||||||
|
},
|
||||||
|
"knowledge_fetcher": {
|
||||||
|
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
|
||||||
|
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
|
||||||
|
},
|
||||||
|
"profile": {
|
||||||
|
"success_rate": round(profile_success_rate / profile_total, 4),
|
||||||
|
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
|
||||||
|
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
|
||||||
|
},
|
||||||
|
"tool_modes": {
|
||||||
|
"success_rate": tool_modes["success_rate"],
|
||||||
|
"keyword_recall": tool_modes["keyword_recall"],
|
||||||
|
},
|
||||||
|
"episode_generation_auto": {
|
||||||
|
"success_rate": episode_generation_auto["success_rate"],
|
||||||
|
"keyword_recall": episode_generation_auto["keyword_recall"],
|
||||||
|
"episode_count": episode_generation_auto["episode_count"],
|
||||||
|
},
|
||||||
|
"episode_generation_after_rebuild": {
|
||||||
|
"success_rate": episode_generation_after_rebuild["success_rate"],
|
||||||
|
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
|
||||||
|
"episode_count": episode_generation_after_rebuild["episode_count"],
|
||||||
|
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||||
|
},
|
||||||
|
"episode_admin_query_auto": {
|
||||||
|
"success_rate": episode_admin_query_auto["success_rate"],
|
||||||
|
"keyword_recall": episode_admin_query_auto["keyword_recall"],
|
||||||
|
},
|
||||||
|
"episode_admin_query_after_rebuild": {
|
||||||
|
"success_rate": episode_admin_query_after_rebuild["success_rate"],
|
||||||
|
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
|
||||||
|
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||||
|
},
|
||||||
|
"episode_search_mode_auto": {
|
||||||
|
"success_rate": episode_search_mode_auto["success_rate"],
|
||||||
|
"keyword_recall": episode_search_mode_auto["keyword_recall"],
|
||||||
|
},
|
||||||
|
"episode_search_mode_after_rebuild": {
|
||||||
|
"success_rate": episode_search_mode_after_rebuild["success_rate"],
|
||||||
|
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
|
||||||
|
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cases": {
|
||||||
|
"search": search_case_reports,
|
||||||
|
"writeback": writeback_reports,
|
||||||
|
"knowledge_fetcher": knowledge_reports,
|
||||||
|
"profile": profile_reports,
|
||||||
|
"tool_modes": tool_modes["reports"],
|
||||||
|
"episode_generation_auto": episode_generation_auto["reports"],
|
||||||
|
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
|
||||||
|
"episode_admin_query_auto": episode_admin_query_auto["reports"],
|
||||||
|
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
|
||||||
|
"episode_search_mode_auto": episode_search_mode_auto["reports"],
|
||||||
|
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
assert report["import"]["status"] == "completed"
|
||||||
|
assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35
|
||||||
|
assert report["metrics"]["search"]["recall_at_5"] >= 0.6
|
||||||
|
assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8
|
||||||
|
assert report["metrics"]["writeback"]["success_rate"] >= 0.66
|
||||||
|
assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66
|
||||||
|
assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75
|
||||||
|
assert report["metrics"]["profile"]["success_rate"] >= 0.66
|
||||||
|
assert report["metrics"]["profile"]["evidence_rate"] >= 1.0
|
||||||
|
assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75
|
||||||
|
assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True
|
||||||
|
assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"]
|
||||||
343
pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py
Normal file
343
pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||||
|
from pytests.test_long_novel_memory_benchmark import (
|
||||||
|
_evaluate_episode_admin_query,
|
||||||
|
_evaluate_episode_generation,
|
||||||
|
_evaluate_episode_search_mode,
|
||||||
|
_evaluate_tool_modes,
|
||||||
|
_KernelBackedRuntimeManager,
|
||||||
|
_KnownPerson,
|
||||||
|
_first_relevant_rank,
|
||||||
|
_hit_blob,
|
||||||
|
_join_hit_content,
|
||||||
|
_keyword_hits,
|
||||||
|
_keyword_recall,
|
||||||
|
_load_benchmark_fixture,
|
||||||
|
_wait_for_import_task,
|
||||||
|
)
|
||||||
|
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||||
|
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||||
|
from src.person_info import person_info as person_info_module
|
||||||
|
from src.services import memory_service as memory_service_module
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
|
||||||
|
reason="需要显式开启真实 external embedding benchmark",
|
||||||
|
)
|
||||||
|
|
||||||
|
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def benchmark_live_env(monkeypatch, tmp_path):
|
||||||
|
dataset = _load_benchmark_fixture()
|
||||||
|
session_cfg = dataset["session"]
|
||||||
|
session = SimpleNamespace(
|
||||||
|
session_id=session_cfg["session_id"],
|
||||||
|
platform=session_cfg["platform"],
|
||||||
|
user_id=session_cfg["user_id"],
|
||||||
|
group_id=session_cfg["group_id"],
|
||||||
|
)
|
||||||
|
fake_chat_manager = SimpleNamespace(
|
||||||
|
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||||
|
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
|
||||||
|
reverse_registry = {value: key for key, value in registry.items()}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
|
||||||
|
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
person_info_module,
|
||||||
|
"Person",
|
||||||
|
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
|
||||||
|
)
|
||||||
|
|
||||||
|
data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve()
|
||||||
|
kernel = SDKMemoryKernel(
|
||||||
|
plugin_root=tmp_path / "plugin_root",
|
||||||
|
config={
|
||||||
|
"storage": {"data_dir": str(data_dir)},
|
||||||
|
"advanced": {"enable_auto_save": False},
|
||||||
|
"memory": {"base_decay_interval_hours": 24},
|
||||||
|
"person_profile": {"refresh_interval_minutes": 5},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = _KernelBackedRuntimeManager(kernel)
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
|
||||||
|
|
||||||
|
await kernel.initialize()
|
||||||
|
try:
|
||||||
|
yield {
|
||||||
|
"dataset": dataset,
|
||||||
|
"kernel": kernel,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await kernel.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_novel_memory_benchmark_live(benchmark_live_env):
|
||||||
|
dataset = benchmark_live_env["dataset"]
|
||||||
|
session_id = benchmark_live_env["session"].session_id
|
||||||
|
|
||||||
|
self_check = await memory_service.runtime_admin(action="refresh_self_check")
|
||||||
|
assert self_check["success"] is True
|
||||||
|
assert self_check["report"]["ok"] is True
|
||||||
|
|
||||||
|
created = await memory_service.import_admin(
|
||||||
|
action="create_paste",
|
||||||
|
name="long_novel_memory_benchmark.live.json",
|
||||||
|
input_mode="json",
|
||||||
|
llm_enabled=False,
|
||||||
|
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
assert created["success"] is True
|
||||||
|
|
||||||
|
import_detail = await _wait_for_import_task(
|
||||||
|
created["task"]["task_id"],
|
||||||
|
max_rounds=2400,
|
||||||
|
sleep_seconds=0.25,
|
||||||
|
)
|
||||||
|
assert import_detail["task"]["status"] == "completed"
|
||||||
|
|
||||||
|
for record in dataset["chat_history_records"]:
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=record["record_id"],
|
||||||
|
theme=record["theme"],
|
||||||
|
summary=record["summary"],
|
||||||
|
participants=record["participants"],
|
||||||
|
start_time=record["start_time"],
|
||||||
|
end_time=record["end_time"],
|
||||||
|
original_text=record["original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for payload in dataset["person_writebacks"]:
|
||||||
|
await person_info_module.store_person_memory_from_answer(
|
||||||
|
payload["person_name"],
|
||||||
|
payload["memory_content"],
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
|
||||||
|
|
||||||
|
search_case_reports: List[Dict[str, Any]] = []
|
||||||
|
search_accuracy_at_1 = 0.0
|
||||||
|
search_recall_at_5 = 0.0
|
||||||
|
search_precision_at_5 = 0.0
|
||||||
|
search_mrr = 0.0
|
||||||
|
search_keyword_recall = 0.0
|
||||||
|
for case in dataset["search_cases"]:
|
||||||
|
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
|
||||||
|
joined = _join_hit_content(result)
|
||||||
|
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
|
||||||
|
relevant_hits = sum(
|
||||||
|
1
|
||||||
|
for hit in result.hits[:5]
|
||||||
|
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
|
||||||
|
)
|
||||||
|
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
|
||||||
|
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
|
||||||
|
search_recall_at_5 += 1.0 if rank > 0 else 0.0
|
||||||
|
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
|
||||||
|
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
|
||||||
|
search_keyword_recall += keyword_recall
|
||||||
|
search_case_reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"rank_of_first_relevant": rank,
|
||||||
|
"relevant_hits_top5": relevant_hits,
|
||||||
|
"keyword_recall_top5": keyword_recall,
|
||||||
|
"top_hit": result.hits[0].to_dict() if result.hits else None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
search_total = max(1, len(dataset["search_cases"]))
|
||||||
|
|
||||||
|
writeback_reports: List[Dict[str, Any]] = []
|
||||||
|
writeback_success_rate = 0.0
|
||||||
|
writeback_keyword_recall = 0.0
|
||||||
|
for payload in dataset["person_writebacks"]:
|
||||||
|
query = " ".join(payload["expected_keywords"])
|
||||||
|
result = await memory_service.search(
|
||||||
|
query,
|
||||||
|
mode="search",
|
||||||
|
chat_id=session_id,
|
||||||
|
person_id=payload["person_id"],
|
||||||
|
respect_filter=False,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
joined = _join_hit_content(result)
|
||||||
|
recall = _keyword_recall(joined, payload["expected_keywords"])
|
||||||
|
success = bool(result.hits) and recall >= 0.67
|
||||||
|
writeback_success_rate += 1.0 if success else 0.0
|
||||||
|
writeback_keyword_recall += recall
|
||||||
|
writeback_reports.append(
|
||||||
|
{
|
||||||
|
"person_id": payload["person_id"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"hit_count": len(result.hits),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
writeback_total = max(1, len(dataset["person_writebacks"]))
|
||||||
|
|
||||||
|
knowledge_reports: List[Dict[str, Any]] = []
|
||||||
|
knowledge_success_rate = 0.0
|
||||||
|
knowledge_keyword_recall = 0.0
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(
|
||||||
|
private_name=dataset["session"]["display_name"],
|
||||||
|
stream_id=session_id,
|
||||||
|
)
|
||||||
|
for case in dataset["knowledge_fetcher_cases"]:
|
||||||
|
knowledge_text, _ = await fetcher.fetch(case["query"], [])
|
||||||
|
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
|
||||||
|
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||||
|
knowledge_success_rate += 1.0 if success else 0.0
|
||||||
|
knowledge_keyword_recall += recall
|
||||||
|
knowledge_reports.append(
|
||||||
|
{
|
||||||
|
"query": case["query"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"preview": knowledge_text[:300],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
|
||||||
|
|
||||||
|
profile_reports: List[Dict[str, Any]] = []
|
||||||
|
profile_success_rate = 0.0
|
||||||
|
profile_keyword_recall = 0.0
|
||||||
|
profile_evidence_rate = 0.0
|
||||||
|
for case in dataset["profile_cases"]:
|
||||||
|
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
|
||||||
|
recall = _keyword_recall(profile.summary, case["expected_keywords"])
|
||||||
|
has_evidence = bool(profile.evidence)
|
||||||
|
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
|
||||||
|
profile_success_rate += 1.0 if success else 0.0
|
||||||
|
profile_keyword_recall += recall
|
||||||
|
profile_evidence_rate += 1.0 if has_evidence else 0.0
|
||||||
|
profile_reports.append(
|
||||||
|
{
|
||||||
|
"person_id": case["person_id"],
|
||||||
|
"success": success,
|
||||||
|
"keyword_recall": recall,
|
||||||
|
"evidence_count": len(profile.evidence),
|
||||||
|
"summary_preview": profile.summary[:240],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
profile_total = max(1, len(dataset["profile_cases"]))
|
||||||
|
|
||||||
|
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_rebuild = await memory_service.episode_admin(
|
||||||
|
action="rebuild",
|
||||||
|
source=f"chat_summary:{session_id}",
|
||||||
|
)
|
||||||
|
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||||
|
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
|
||||||
|
|
||||||
|
report = {
|
||||||
|
"dataset": dataset["meta"],
|
||||||
|
"runtime_self_check": self_check["report"],
|
||||||
|
"import": {
|
||||||
|
"task_id": created["task"]["task_id"],
|
||||||
|
"status": import_detail["task"]["status"],
|
||||||
|
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"search": {
|
||||||
|
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
|
||||||
|
"recall_at_5": round(search_recall_at_5 / search_total, 4),
|
||||||
|
"precision_at_5": round(search_precision_at_5 / search_total, 4),
|
||||||
|
"mrr": round(search_mrr / search_total, 4),
|
||||||
|
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
|
||||||
|
},
|
||||||
|
"writeback": {
|
||||||
|
"success_rate": round(writeback_success_rate / writeback_total, 4),
|
||||||
|
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
|
||||||
|
},
|
||||||
|
"knowledge_fetcher": {
|
||||||
|
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
|
||||||
|
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
|
||||||
|
},
|
||||||
|
"profile": {
|
||||||
|
"success_rate": round(profile_success_rate / profile_total, 4),
|
||||||
|
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
|
||||||
|
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
|
||||||
|
},
|
||||||
|
"tool_modes": {
|
||||||
|
"success_rate": tool_modes["success_rate"],
|
||||||
|
"keyword_recall": tool_modes["keyword_recall"],
|
||||||
|
},
|
||||||
|
"episode_generation_auto": {
|
||||||
|
"success_rate": episode_generation_auto["success_rate"],
|
||||||
|
"keyword_recall": episode_generation_auto["keyword_recall"],
|
||||||
|
"episode_count": episode_generation_auto["episode_count"],
|
||||||
|
},
|
||||||
|
"episode_generation_after_rebuild": {
|
||||||
|
"success_rate": episode_generation_after_rebuild["success_rate"],
|
||||||
|
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
|
||||||
|
"episode_count": episode_generation_after_rebuild["episode_count"],
|
||||||
|
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||||
|
},
|
||||||
|
"episode_admin_query_auto": {
|
||||||
|
"success_rate": episode_admin_query_auto["success_rate"],
|
||||||
|
"keyword_recall": episode_admin_query_auto["keyword_recall"],
|
||||||
|
},
|
||||||
|
"episode_admin_query_after_rebuild": {
|
||||||
|
"success_rate": episode_admin_query_after_rebuild["success_rate"],
|
||||||
|
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
|
||||||
|
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||||
|
},
|
||||||
|
"episode_search_mode_auto": {
|
||||||
|
"success_rate": episode_search_mode_auto["success_rate"],
|
||||||
|
"keyword_recall": episode_search_mode_auto["keyword_recall"],
|
||||||
|
},
|
||||||
|
"episode_search_mode_after_rebuild": {
|
||||||
|
"success_rate": episode_search_mode_after_rebuild["success_rate"],
|
||||||
|
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
|
||||||
|
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cases": {
|
||||||
|
"search": search_case_reports,
|
||||||
|
"writeback": writeback_reports,
|
||||||
|
"knowledge_fetcher": knowledge_reports,
|
||||||
|
"profile": profile_reports,
|
||||||
|
"tool_modes": tool_modes["reports"],
|
||||||
|
"episode_generation_auto": episode_generation_auto["reports"],
|
||||||
|
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
|
||||||
|
"episode_admin_query_auto": episode_admin_query_auto["reports"],
|
||||||
|
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
|
||||||
|
"episode_search_mode_auto": episode_search_mode_auto["reports"],
|
||||||
|
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||||
|
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
assert report["import"]["status"] == "completed"
|
||||||
|
assert report["runtime_self_check"]["ok"] is True
|
||||||
138
pytests/A_memorix_test/test_memory_flow_service.py
Normal file
138
pytests/A_memorix_test/test_memory_flow_service.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.services import memory_flow_service as memory_flow_module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch):
|
||||||
|
starts: list[str] = []
|
||||||
|
summarizers: list[object] = []
|
||||||
|
|
||||||
|
class FakeSummarizer:
|
||||||
|
def __init__(self, session_id: str):
|
||||||
|
self.session_id = session_id
|
||||||
|
summarizers.append(self)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
starts.append(self.session_id)
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
starts.append(f"stop:{self.session_id}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
|
||||||
|
|
||||||
|
manager = memory_flow_module.LongTermMemorySessionManager()
|
||||||
|
message = SimpleNamespace(session_id="session-1")
|
||||||
|
|
||||||
|
await manager.on_message(message)
|
||||||
|
await manager.on_message(message)
|
||||||
|
|
||||||
|
assert len(summarizers) == 1
|
||||||
|
assert starts == ["session-1"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch):
|
||||||
|
stopped: list[str] = []
|
||||||
|
|
||||||
|
class FakeSummarizer:
|
||||||
|
def __init__(self, session_id: str):
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
stopped.append(self.session_id)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
|
||||||
|
|
||||||
|
manager = memory_flow_module.LongTermMemorySessionManager()
|
||||||
|
await manager.on_message(SimpleNamespace(session_id="session-a"))
|
||||||
|
await manager.on_message(SimpleNamespace(session_id="session-b"))
|
||||||
|
await manager.shutdown()
|
||||||
|
|
||||||
|
assert stopped == ["session-a", "session-b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
|
||||||
|
raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]'
|
||||||
|
|
||||||
|
result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw)
|
||||||
|
|
||||||
|
assert result == ["他喜欢猫", "他会弹吉他"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_fact_looks_ephemeral_detects_short_chitchat():
|
||||||
|
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈")
|
||||||
|
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?")
|
||||||
|
assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴")
|
||||||
|
|
||||||
|
|
||||||
|
def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
||||||
|
class FakePerson:
|
||||||
|
def __init__(self, person_id: str):
|
||||||
|
self.person_id = person_id
|
||||||
|
self.is_known = True
|
||||||
|
|
||||||
|
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
|
||||||
|
monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False)
|
||||||
|
monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}")
|
||||||
|
monkeypatch.setattr(memory_flow_module, "Person", FakePerson)
|
||||||
|
|
||||||
|
message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id=""))
|
||||||
|
|
||||||
|
person = service._resolve_target_person(message)
|
||||||
|
|
||||||
|
assert person is not None
|
||||||
|
assert person.person_id == "qq:123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
|
||||||
|
events: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
class FakeSessionManager:
|
||||||
|
async def on_message(self, message):
|
||||||
|
events.append(("incoming", message.session_id))
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
events.append(("shutdown", "session"))
|
||||||
|
|
||||||
|
class FakeFactWriteback:
|
||||||
|
async def start(self):
|
||||||
|
events.append(("start", "fact"))
|
||||||
|
|
||||||
|
async def enqueue(self, message):
|
||||||
|
events.append(("sent", message.session_id))
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
events.append(("shutdown", "fact"))
|
||||||
|
|
||||||
|
service = memory_flow_module.MemoryAutomationService()
|
||||||
|
service.session_manager = FakeSessionManager()
|
||||||
|
service.fact_writeback = FakeFactWriteback()
|
||||||
|
|
||||||
|
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||||
|
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
||||||
|
await service.shutdown()
|
||||||
|
|
||||||
|
assert events == [
|
||||||
|
("start", "fact"),
|
||||||
|
("incoming", "session-1"),
|
||||||
|
("sent", "session-1"),
|
||||||
|
("shutdown", "session"),
|
||||||
|
("shutdown", "fact"),
|
||||||
|
]
|
||||||
281
pytests/A_memorix_test/test_memory_service.py
Normal file
281
pytests/A_memorix_test/test_memory_service.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.services.memory_service import MemorySearchResult, MemoryService
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_write_result_treats_skipped_payload_as_success():
|
||||||
|
result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"})
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.stored_ids == []
|
||||||
|
assert result.skipped_ids == ["p1"]
|
||||||
|
assert result.detail == "chat_filtered"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_graph_admin_invokes_plugin(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args, kwargs))
|
||||||
|
return {"success": True, "nodes": [], "edges": []}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.graph_admin(action="get_graph", limit=12)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args))
|
||||||
|
return {"success": True, "items": [{"hash": "abc"}], "count": 1}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.get_recycle_bin(limit=5)
|
||||||
|
|
||||||
|
assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1}
|
||||||
|
assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_respects_filter_by_default(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args))
|
||||||
|
return {"summary": "ok", "hits": [], "filtered": True}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.search(
|
||||||
|
"mai",
|
||||||
|
chat_id="stream-1",
|
||||||
|
person_id="person-1",
|
||||||
|
user_id="user-1",
|
||||||
|
group_id="",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, MemorySearchResult)
|
||||||
|
assert result.filtered is True
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"search_memory",
|
||||||
|
{
|
||||||
|
"query": "mai",
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "hybrid",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "person-1",
|
||||||
|
"time_start": None,
|
||||||
|
"time_end": None,
|
||||||
|
"respect_filter": True,
|
||||||
|
"user_id": "user-1",
|
||||||
|
"group_id": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ingest_summary_can_bypass_filter(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args))
|
||||||
|
return {"success": True, "stored_ids": ["p1"], "detail": ""}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.ingest_summary(
|
||||||
|
external_id="chat_history:1",
|
||||||
|
chat_id="stream-1",
|
||||||
|
text="summary",
|
||||||
|
respect_filter=False,
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"ingest_summary",
|
||||||
|
{
|
||||||
|
"external_id": "chat_history:1",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"text": "summary",
|
||||||
|
"participants": [],
|
||||||
|
"time_start": None,
|
||||||
|
"time_end": None,
|
||||||
|
"tags": [],
|
||||||
|
"metadata": {},
|
||||||
|
"respect_filter": False,
|
||||||
|
"user_id": "user-1",
|
||||||
|
"group_id": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_v5_admin_invokes_plugin(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args, kwargs))
|
||||||
|
return {"success": True, "count": 1}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.v5_admin(action="status", target="mai", limit=5)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_admin_uses_long_timeout(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args, kwargs))
|
||||||
|
return {"success": True, "operation_id": "del-1"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"})
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"memory_delete_admin",
|
||||||
|
{"action": "execute", "mode": "relation", "selector": {"query": "mai"}},
|
||||||
|
{"timeout_ms": 120000},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_empty_when_query_and_time_missing_async():
|
||||||
|
service = MemoryService()
|
||||||
|
|
||||||
|
result = await service.search("", time_start=None, time_end=None)
|
||||||
|
|
||||||
|
assert isinstance(result, MemorySearchResult)
|
||||||
|
assert result.summary == ""
|
||||||
|
assert result.hits == []
|
||||||
|
assert result.filtered is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_accepts_string_time_bounds(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args))
|
||||||
|
return {"summary": "ok", "hits": [], "filtered": False}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.search(
|
||||||
|
"广播站",
|
||||||
|
mode="time",
|
||||||
|
time_start="2026/03/18",
|
||||||
|
time_end="2026/03/18 09:30",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, MemorySearchResult)
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"search_memory",
|
||||||
|
{
|
||||||
|
"query": "广播站",
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "time",
|
||||||
|
"chat_id": "",
|
||||||
|
"person_id": "",
|
||||||
|
"time_start": "2026/03/18",
|
||||||
|
"time_end": "2026/03/18 09:30",
|
||||||
|
"respect_filter": True,
|
||||||
|
"user_id": "",
|
||||||
|
"group_id": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_coerce_search_result_preserves_aggregate_source_branches():
|
||||||
|
result = MemoryService._coerce_search_result(
|
||||||
|
{
|
||||||
|
"hits": [
|
||||||
|
{
|
||||||
|
"content": "广播站值夜班",
|
||||||
|
"type": "paragraph",
|
||||||
|
"metadata": {"event_time_start": 1.0},
|
||||||
|
"source_branches": ["search", "time"],
|
||||||
|
"rank": 1,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.hits[0].metadata["source_branches"] == ["search", "time"]
|
||||||
|
assert result.hits[0].metadata["rank"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_import_admin_uses_long_timeout(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args, kwargs))
|
||||||
|
return {"success": True, "task_id": "import-1"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.import_admin(action="create_lpmm_openie", alias="lpmm")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"memory_import_admin",
|
||||||
|
{"action": "create_lpmm_openie", "alias": "lpmm"},
|
||||||
|
{"timeout_ms": 120000},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tuning_admin_uses_long_timeout(monkeypatch):
|
||||||
|
service = MemoryService()
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_invoke(component_name, args=None, **kwargs):
|
||||||
|
calls.append((component_name, args, kwargs))
|
||||||
|
return {"success": True, "task_id": "tuning-1"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||||
|
|
||||||
|
result = await service.tuning_admin(action="create_task", payload={"query": "mai"})
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert calls == [
|
||||||
|
(
|
||||||
|
"memory_tuning_admin",
|
||||||
|
{"action": "create_task", "payload": {"query": "mai"}},
|
||||||
|
{"timeout_ms": 120000},
|
||||||
|
)
|
||||||
|
]
|
||||||
81
pytests/A_memorix_test/test_person_memory_writeback.py
Normal file
81
pytests/A_memorix_test/test_person_memory_writeback.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.person_info import person_info as person_info_module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
class FakePerson:
|
||||||
|
def __init__(self, person_id: str):
|
||||||
|
self.person_id = person_id
|
||||||
|
self.person_name = "Alice"
|
||||||
|
self.is_known = True
|
||||||
|
|
||||||
|
async def fake_ingest_text(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||||
|
|
||||||
|
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
|
||||||
|
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
|
||||||
|
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
|
||||||
|
monkeypatch.setattr(person_info_module, "Person", FakePerson)
|
||||||
|
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
|
||||||
|
|
||||||
|
assert len(calls) == 1
|
||||||
|
payload = calls[0]
|
||||||
|
assert payload["external_id"].startswith("person_fact:person-1:")
|
||||||
|
assert payload["source_type"] == "person_fact"
|
||||||
|
assert payload["chat_id"] == "session-1"
|
||||||
|
assert payload["person_ids"] == ["person-1"]
|
||||||
|
assert payload["participants"] == ["Alice"]
|
||||||
|
assert payload["respect_filter"] is True
|
||||||
|
assert payload["user_id"] == "10001"
|
||||||
|
assert payload["group_id"] == ""
|
||||||
|
assert payload["metadata"]["person_id"] == "person-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
class FakePerson:
|
||||||
|
def __init__(self, person_id: str):
|
||||||
|
self.person_id = person_id
|
||||||
|
self.person_name = "Unknown"
|
||||||
|
self.is_known = False
|
||||||
|
|
||||||
|
async def fake_ingest_text(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||||
|
|
||||||
|
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
|
||||||
|
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
|
||||||
|
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
|
||||||
|
monkeypatch.setattr(person_info_module, "Person", FakePerson)
|
||||||
|
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
|
||||||
|
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_ingest_text(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||||
|
|
||||||
|
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1")
|
||||||
|
|
||||||
|
assert calls == []
|
||||||
|
|
||||||
184
pytests/A_memorix_test/test_query_long_term_memory_tool.py
Normal file
184
pytests/A_memorix_test/test_query_long_term_memory_tool.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.memory_system.retrieval_tools import query_long_term_memory as tool_module
|
||||||
|
from src.memory_system.retrieval_tools import init_all_tools
|
||||||
|
from src.memory_system.retrieval_tools.query_long_term_memory import (
|
||||||
|
_resolve_time_expression,
|
||||||
|
query_long_term_memory,
|
||||||
|
register_tool,
|
||||||
|
)
|
||||||
|
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||||
|
from src.services.memory_service import MemoryHit, MemorySearchResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_time_expression_supports_relative_and_absolute_inputs():
|
||||||
|
now = datetime(2026, 3, 18, 15, 30)
|
||||||
|
|
||||||
|
start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now)
|
||||||
|
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
|
||||||
|
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
||||||
|
assert start_text == "2026/03/18 00:00"
|
||||||
|
assert end_text == "2026/03/18 23:59"
|
||||||
|
|
||||||
|
start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now)
|
||||||
|
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0)
|
||||||
|
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
||||||
|
assert start_text == "2026/03/12 00:00"
|
||||||
|
assert end_text == "2026/03/18 23:59"
|
||||||
|
|
||||||
|
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now)
|
||||||
|
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
|
||||||
|
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
||||||
|
assert start_text == "2026/03/18 00:00"
|
||||||
|
assert end_text == "2026/03/18 23:59"
|
||||||
|
|
||||||
|
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now)
|
||||||
|
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30)
|
||||||
|
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30)
|
||||||
|
assert start_text == "2026/03/18 09:30"
|
||||||
|
assert end_text == "2026/03/18 09:30"
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_tool_exposes_mode_and_time_expression():
|
||||||
|
register_tool()
|
||||||
|
tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
params = {item["name"]: item for item in tool.parameters}
|
||||||
|
assert "mode" in params
|
||||||
|
assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"]
|
||||||
|
assert "time_expression" in params
|
||||||
|
assert params["query"]["required"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_all_tools_registers_long_term_memory_tool():
|
||||||
|
init_all_tools()
|
||||||
|
|
||||||
|
tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||||
|
assert tool is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_long_term_memory_search_mode_maps_to_hybrid(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_search(query, **kwargs):
|
||||||
|
captured["query"] = query
|
||||||
|
captured["kwargs"] = kwargs
|
||||||
|
return MemorySearchResult(
|
||||||
|
hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
||||||
|
|
||||||
|
text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1")
|
||||||
|
|
||||||
|
assert "Alice 喜欢猫" in text
|
||||||
|
assert captured == {
|
||||||
|
"query": "Alice 喜欢什么",
|
||||||
|
"kwargs": {
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "hybrid",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "person-1",
|
||||||
|
"time_start": None,
|
||||||
|
"time_end": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_search(query, **kwargs):
|
||||||
|
captured["query"] = query
|
||||||
|
captured["kwargs"] = kwargs
|
||||||
|
return MemorySearchResult(
|
||||||
|
hits=[
|
||||||
|
MemoryHit(
|
||||||
|
content="昨天晚上广播站停播了十分钟。",
|
||||||
|
score=0.8,
|
||||||
|
hit_type="paragraph",
|
||||||
|
metadata={"event_time_start": 1773797400.0},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tool_module,
|
||||||
|
"_resolve_time_expression",
|
||||||
|
lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"),
|
||||||
|
)
|
||||||
|
|
||||||
|
text = await query_long_term_memory(
|
||||||
|
query="广播站",
|
||||||
|
mode="time",
|
||||||
|
time_expression="昨天",
|
||||||
|
chat_id="stream-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "指定时间范围" in text
|
||||||
|
assert "广播站停播" in text
|
||||||
|
assert captured == {
|
||||||
|
"query": "广播站",
|
||||||
|
"kwargs": {
|
||||||
|
"limit": 5,
|
||||||
|
"mode": "time",
|
||||||
|
"chat_id": "stream-1",
|
||||||
|
"person_id": "",
|
||||||
|
"time_start": 1773795600.0,
|
||||||
|
"time_end": 1773881940.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch):
|
||||||
|
responses = {
|
||||||
|
"episode": MemorySearchResult(
|
||||||
|
hits=[
|
||||||
|
MemoryHit(
|
||||||
|
content="苏弦在灯塔拆开了那封冬信。",
|
||||||
|
title="冬信重见天日",
|
||||||
|
hit_type="episode",
|
||||||
|
metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"aggregate": MemorySearchResult(
|
||||||
|
hits=[
|
||||||
|
MemoryHit(
|
||||||
|
content="唐未在广播站值夜班时带着黑狗墨点。",
|
||||||
|
hit_type="paragraph",
|
||||||
|
metadata={"source_branches": ["search", "time"]},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_search(query, **kwargs):
|
||||||
|
return responses[kwargs["mode"]]
|
||||||
|
|
||||||
|
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
||||||
|
|
||||||
|
episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode")
|
||||||
|
aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate")
|
||||||
|
|
||||||
|
assert "事件《冬信重见天日》" in episode_text
|
||||||
|
assert "参与者:苏弦" in episode_text
|
||||||
|
assert "[search,time][paragraph]" in aggregate_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message():
|
||||||
|
text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周")
|
||||||
|
|
||||||
|
assert "无法解析" in text
|
||||||
|
assert "最近7天" in text
|
||||||
@@ -0,0 +1,335 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||||
|
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||||
|
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||||
|
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||||
|
from src.person_info import person_info as person_info_module
|
||||||
|
from src.services import memory_service as memory_service_module
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
|
|
||||||
|
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dialogue_fixture() -> Dict[str, Any]:
|
||||||
|
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeEmbeddingAdapter:
|
||||||
|
def __init__(self, dimension: int = 16) -> None:
|
||||||
|
self.dimension = dimension
|
||||||
|
|
||||||
|
async def _detect_dimension(self) -> int:
|
||||||
|
return self.dimension
|
||||||
|
|
||||||
|
async def encode(self, texts, dimensions=None):
|
||||||
|
dim = int(dimensions or self.dimension)
|
||||||
|
if isinstance(texts, str):
|
||||||
|
sequence = [texts]
|
||||||
|
single = True
|
||||||
|
else:
|
||||||
|
sequence = list(texts)
|
||||||
|
single = False
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for text in sequence:
|
||||||
|
vec = np.zeros(dim, dtype=np.float32)
|
||||||
|
for ch in str(text or ""):
|
||||||
|
vec[ord(ch) % dim] += 1.0
|
||||||
|
if not vec.any():
|
||||||
|
vec[0] = 1.0
|
||||||
|
norm = np.linalg.norm(vec)
|
||||||
|
if norm > 0:
|
||||||
|
vec = vec / norm
|
||||||
|
rows.append(vec)
|
||||||
|
payload = np.vstack(rows)
|
||||||
|
return payload[0] if single else payload
|
||||||
|
|
||||||
|
|
||||||
|
class _KernelBackedRuntimeManager:
|
||||||
|
is_running = True
|
||||||
|
|
||||||
|
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||||
|
self.kernel = kernel
|
||||||
|
|
||||||
|
async def invoke_plugin(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
plugin_id: str,
|
||||||
|
component_name: str,
|
||||||
|
args: Dict[str, Any] | None,
|
||||||
|
timeout_ms: int,
|
||||||
|
):
|
||||||
|
del method, plugin_id, timeout_ms
|
||||||
|
payload = args or {}
|
||||||
|
if component_name == "search_memory":
|
||||||
|
return await self.kernel.search_memory(
|
||||||
|
KernelSearchRequest(
|
||||||
|
query=str(payload.get("query", "") or ""),
|
||||||
|
limit=int(payload.get("limit", 5) or 5),
|
||||||
|
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||||
|
chat_id=str(payload.get("chat_id", "") or ""),
|
||||||
|
person_id=str(payload.get("person_id", "") or ""),
|
||||||
|
time_start=payload.get("time_start"),
|
||||||
|
time_end=payload.get("time_end"),
|
||||||
|
respect_filter=bool(payload.get("respect_filter", True)),
|
||||||
|
user_id=str(payload.get("user_id", "") or ""),
|
||||||
|
group_id=str(payload.get("group_id", "") or ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = getattr(self.kernel, component_name)
|
||||||
|
result = handler(**payload)
|
||||||
|
return await result if inspect.isawaitable(result) else result
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]:
|
||||||
|
for _ in range(max_rounds):
|
||||||
|
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||||
|
task = detail.get("task") or {}
|
||||||
|
status = str(task.get("status", "") or "")
|
||||||
|
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||||
|
return detail
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _join_hit_content(search_result) -> str:
|
||||||
|
return "\n".join(hit.content for hit in search_result.hits)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def real_dialogue_env(monkeypatch, tmp_path):
|
||||||
|
scenario = _load_dialogue_fixture()
|
||||||
|
session_cfg = scenario["session"]
|
||||||
|
session = SimpleNamespace(
|
||||||
|
session_id=session_cfg["session_id"],
|
||||||
|
platform=session_cfg["platform"],
|
||||||
|
user_id=session_cfg["user_id"],
|
||||||
|
group_id=session_cfg["group_id"],
|
||||||
|
)
|
||||||
|
fake_chat_manager = SimpleNamespace(
|
||||||
|
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||||
|
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
|
||||||
|
|
||||||
|
async def fake_self_check(**kwargs):
|
||||||
|
return {"ok": True, "message": "ok"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
|
||||||
|
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||||
|
|
||||||
|
data_dir = (tmp_path / "a_memorix_data").resolve()
|
||||||
|
kernel = SDKMemoryKernel(
|
||||||
|
plugin_root=tmp_path / "plugin_root",
|
||||||
|
config={
|
||||||
|
"storage": {"data_dir": str(data_dir)},
|
||||||
|
"advanced": {"enable_auto_save": False},
|
||||||
|
"memory": {"base_decay_interval_hours": 24},
|
||||||
|
"person_profile": {"refresh_interval_minutes": 5},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = _KernelBackedRuntimeManager(kernel)
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
|
||||||
|
|
||||||
|
await kernel.initialize()
|
||||||
|
try:
|
||||||
|
yield {
|
||||||
|
"scenario": scenario,
|
||||||
|
"kernel": kernel,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await kernel.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env):
|
||||||
|
scenario = real_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
created = await memory_service.import_admin(
|
||||||
|
action="create_paste",
|
||||||
|
name="private_alice.json",
|
||||||
|
input_mode="json",
|
||||||
|
llm_enabled=False,
|
||||||
|
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created["success"] is True
|
||||||
|
detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||||
|
assert detail["task"]["status"] == "completed"
|
||||||
|
|
||||||
|
search = await memory_service.search(
|
||||||
|
scenario["search_queries"]["direct"],
|
||||||
|
mode="search",
|
||||||
|
respect_filter=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search.hits
|
||||||
|
joined = _join_hit_content(search)
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in joined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env):
|
||||||
|
scenario = real_dialogue_env["scenario"]
|
||||||
|
record = scenario["chat_history_record"]
|
||||||
|
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=record["record_id"],
|
||||||
|
theme=record["theme"],
|
||||||
|
summary=record["summary"],
|
||||||
|
participants=record["participants"],
|
||||||
|
start_time=record["start_time"],
|
||||||
|
end_time=record["end_time"],
|
||||||
|
original_text=record["original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
search = await memory_service.search(
|
||||||
|
scenario["search_queries"]["direct"],
|
||||||
|
mode="search",
|
||||||
|
chat_id=real_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search.hits
|
||||||
|
joined = _join_hit_content(search)
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in joined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch):
|
||||||
|
scenario = real_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
class _KnownPerson:
|
||||||
|
def __init__(self, person_id: str) -> None:
|
||||||
|
self.person_id = person_id
|
||||||
|
self.is_known = True
|
||||||
|
self.person_name = scenario["person"]["person_name"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
person_info_module,
|
||||||
|
"get_person_id_by_person_name",
|
||||||
|
lambda person_name: scenario["person"]["person_id"],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer(
|
||||||
|
scenario["person"]["person_name"],
|
||||||
|
scenario["person_fact"]["memory_content"],
|
||||||
|
real_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
search = await memory_service.search(
|
||||||
|
scenario["search_queries"]["direct"],
|
||||||
|
mode="search",
|
||||||
|
chat_id=real_dialogue_env["session"].session_id,
|
||||||
|
person_id=scenario["person"]["person_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search.hits
|
||||||
|
joined = _join_hit_content(search)
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in joined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env):
|
||||||
|
scenario = real_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
await memory_service.ingest_text(
|
||||||
|
external_id="fixture:knowledge_fetcher",
|
||||||
|
source_type="dialogue_note",
|
||||||
|
text=scenario["person_fact"]["memory_content"],
|
||||||
|
chat_id=real_dialogue_env["session"].session_id,
|
||||||
|
person_ids=[scenario["person"]["person_id"]],
|
||||||
|
participants=[scenario["person"]["person_name"]],
|
||||||
|
respect_filter=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(
|
||||||
|
private_name=scenario["session"]["display_name"],
|
||||||
|
stream_id=real_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
|
||||||
|
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in knowledge_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch):
|
||||||
|
scenario = real_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
class _KnownPerson:
|
||||||
|
def __init__(self, person_id: str) -> None:
|
||||||
|
self.person_id = person_id
|
||||||
|
self.is_known = True
|
||||||
|
self.person_name = scenario["person"]["person_name"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
person_info_module,
|
||||||
|
"get_person_id_by_person_name",
|
||||||
|
lambda person_name: scenario["person"]["person_id"],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer(
|
||||||
|
scenario["person"]["person_name"],
|
||||||
|
scenario["person_fact"]["memory_content"],
|
||||||
|
real_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
profile = await memory_service.get_person_profile(
|
||||||
|
scenario["person"]["person_id"],
|
||||||
|
chat_id=real_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert profile.evidence
|
||||||
|
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env):
|
||||||
|
scenario = real_dialogue_env["scenario"]
|
||||||
|
record = scenario["chat_history_record"]
|
||||||
|
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=record["record_id"],
|
||||||
|
theme=record["theme"],
|
||||||
|
summary=record["summary"],
|
||||||
|
participants=record["participants"],
|
||||||
|
start_time=record["start_time"],
|
||||||
|
end_time=record["end_time"],
|
||||||
|
original_text=record["original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = await memory_service.episode_admin(
|
||||||
|
action="query",
|
||||||
|
source=scenario["expectations"]["episode_source"],
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert episodes["success"] is True
|
||||||
|
assert int(episodes["count"]) >= 1
|
||||||
312
pytests/A_memorix_test/test_real_dialogue_business_flow_live.py
Normal file
312
pytests/A_memorix_test/test_real_dialogue_business_flow_live.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||||
|
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||||
|
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||||
|
from src.person_info import person_info as person_info_module
|
||||||
|
from src.services import memory_service as memory_service_module
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
|
||||||
|
reason="需要显式开启真实 embedding / self-check 集成测试",
|
||||||
|
)
|
||||||
|
|
||||||
|
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dialogue_fixture() -> Dict[str, Any]:
|
||||||
|
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
class _KernelBackedRuntimeManager:
|
||||||
|
is_running = True
|
||||||
|
|
||||||
|
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||||
|
self.kernel = kernel
|
||||||
|
|
||||||
|
async def invoke_plugin(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
method: str,
|
||||||
|
plugin_id: str,
|
||||||
|
component_name: str,
|
||||||
|
args: Dict[str, Any] | None,
|
||||||
|
timeout_ms: int,
|
||||||
|
):
|
||||||
|
del method, plugin_id, timeout_ms
|
||||||
|
payload = args or {}
|
||||||
|
if component_name == "search_memory":
|
||||||
|
return await self.kernel.search_memory(
|
||||||
|
KernelSearchRequest(
|
||||||
|
query=str(payload.get("query", "") or ""),
|
||||||
|
limit=int(payload.get("limit", 5) or 5),
|
||||||
|
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||||
|
chat_id=str(payload.get("chat_id", "") or ""),
|
||||||
|
person_id=str(payload.get("person_id", "") or ""),
|
||||||
|
time_start=payload.get("time_start"),
|
||||||
|
time_end=payload.get("time_end"),
|
||||||
|
respect_filter=bool(payload.get("respect_filter", True)),
|
||||||
|
user_id=str(payload.get("user_id", "") or ""),
|
||||||
|
group_id=str(payload.get("group_id", "") or ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
handler = getattr(self.kernel, component_name)
|
||||||
|
result = handler(**payload)
|
||||||
|
return await result if inspect.isawaitable(result) else result
|
||||||
|
|
||||||
|
|
||||||
|
async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]:
|
||||||
|
deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds))
|
||||||
|
while asyncio.get_running_loop().time() < deadline:
|
||||||
|
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||||
|
task = detail.get("task") or {}
|
||||||
|
status = str(task.get("status", "") or "")
|
||||||
|
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||||
|
return detail
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _join_hit_content(search_result) -> str:
|
||||||
|
return "\n".join(hit.content for hit in search_result.hits)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def live_dialogue_env(monkeypatch, tmp_path):
|
||||||
|
scenario = _load_dialogue_fixture()
|
||||||
|
session_cfg = scenario["session"]
|
||||||
|
session = SimpleNamespace(
|
||||||
|
session_id=session_cfg["session_id"],
|
||||||
|
platform=session_cfg["platform"],
|
||||||
|
user_id=session_cfg["user_id"],
|
||||||
|
group_id=session_cfg["group_id"],
|
||||||
|
)
|
||||||
|
fake_chat_manager = SimpleNamespace(
|
||||||
|
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||||
|
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None)
|
||||||
|
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||||
|
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||||
|
|
||||||
|
data_dir = (tmp_path / "a_memorix_data").resolve()
|
||||||
|
kernel = SDKMemoryKernel(
|
||||||
|
plugin_root=tmp_path / "plugin_root",
|
||||||
|
config={
|
||||||
|
"storage": {"data_dir": str(data_dir)},
|
||||||
|
"advanced": {"enable_auto_save": False},
|
||||||
|
"memory": {"base_decay_interval_hours": 24},
|
||||||
|
"person_profile": {"refresh_interval_minutes": 5},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = _KernelBackedRuntimeManager(kernel)
|
||||||
|
monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager)
|
||||||
|
|
||||||
|
await kernel.initialize()
|
||||||
|
try:
|
||||||
|
yield {
|
||||||
|
"scenario": scenario,
|
||||||
|
"kernel": kernel,
|
||||||
|
"session": session,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await kernel.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_runtime_self_check_passes(live_dialogue_env):
|
||||||
|
report = await memory_service.runtime_admin(action="refresh_self_check")
|
||||||
|
|
||||||
|
assert report["success"] is True
|
||||||
|
assert report["report"]["ok"] is True
|
||||||
|
assert report["report"]["encoded_dimension"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env):
|
||||||
|
scenario = live_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
created = await memory_service.import_admin(
|
||||||
|
action="create_paste",
|
||||||
|
name="private_alice.json",
|
||||||
|
input_mode="json",
|
||||||
|
llm_enabled=False,
|
||||||
|
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created["success"] is True
|
||||||
|
detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||||
|
assert detail["task"]["status"] == "completed"
|
||||||
|
|
||||||
|
search = await memory_service.search(
|
||||||
|
scenario["search_queries"]["direct"],
|
||||||
|
mode="search",
|
||||||
|
respect_filter=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search.hits
|
||||||
|
joined = _join_hit_content(search)
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in joined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env):
|
||||||
|
scenario = live_dialogue_env["scenario"]
|
||||||
|
record = scenario["chat_history_record"]
|
||||||
|
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=record["record_id"],
|
||||||
|
theme=record["theme"],
|
||||||
|
summary=record["summary"],
|
||||||
|
participants=record["participants"],
|
||||||
|
start_time=record["start_time"],
|
||||||
|
end_time=record["end_time"],
|
||||||
|
original_text=record["original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
search = await memory_service.search(
|
||||||
|
scenario["search_queries"]["direct"],
|
||||||
|
mode="search",
|
||||||
|
chat_id=live_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search.hits
|
||||||
|
joined = _join_hit_content(search)
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in joined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch):
|
||||||
|
scenario = live_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
class _KnownPerson:
|
||||||
|
def __init__(self, person_id: str) -> None:
|
||||||
|
self.person_id = person_id
|
||||||
|
self.is_known = True
|
||||||
|
self.person_name = scenario["person"]["person_name"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
person_info_module,
|
||||||
|
"get_person_id_by_person_name",
|
||||||
|
lambda person_name: scenario["person"]["person_id"],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer(
|
||||||
|
scenario["person"]["person_name"],
|
||||||
|
scenario["person_fact"]["memory_content"],
|
||||||
|
live_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
search = await memory_service.search(
|
||||||
|
scenario["search_queries"]["direct"],
|
||||||
|
mode="search",
|
||||||
|
chat_id=live_dialogue_env["session"].session_id,
|
||||||
|
person_id=scenario["person"]["person_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search.hits
|
||||||
|
joined = _join_hit_content(search)
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in joined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env):
|
||||||
|
scenario = live_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
await memory_service.ingest_text(
|
||||||
|
external_id="fixture:knowledge_fetcher",
|
||||||
|
source_type="dialogue_note",
|
||||||
|
text=scenario["person_fact"]["memory_content"],
|
||||||
|
chat_id=live_dialogue_env["session"].session_id,
|
||||||
|
person_ids=[scenario["person"]["person_id"]],
|
||||||
|
participants=[scenario["person"]["person_name"]],
|
||||||
|
respect_filter=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
fetcher = knowledge_module.KnowledgeFetcher(
|
||||||
|
private_name=scenario["session"]["display_name"],
|
||||||
|
stream_id=live_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
|
||||||
|
|
||||||
|
for keyword in scenario["expectations"]["search_keywords"]:
|
||||||
|
assert keyword in knowledge_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch):
|
||||||
|
scenario = live_dialogue_env["scenario"]
|
||||||
|
|
||||||
|
class _KnownPerson:
|
||||||
|
def __init__(self, person_id: str) -> None:
|
||||||
|
self.person_id = person_id
|
||||||
|
self.is_known = True
|
||||||
|
self.person_name = scenario["person"]["person_name"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
person_info_module,
|
||||||
|
"get_person_id_by_person_name",
|
||||||
|
lambda person_name: scenario["person"]["person_id"],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||||
|
|
||||||
|
await person_info_module.store_person_memory_from_answer(
|
||||||
|
scenario["person"]["person_name"],
|
||||||
|
scenario["person_fact"]["memory_content"],
|
||||||
|
live_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
profile = await memory_service.get_person_profile(
|
||||||
|
scenario["person"]["person_id"],
|
||||||
|
chat_id=live_dialogue_env["session"].session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert profile.evidence
|
||||||
|
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env):
|
||||||
|
scenario = live_dialogue_env["scenario"]
|
||||||
|
record = scenario["chat_history_record"]
|
||||||
|
|
||||||
|
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
|
||||||
|
await summarizer._import_to_long_term_memory(
|
||||||
|
record_id=record["record_id"],
|
||||||
|
theme=record["theme"],
|
||||||
|
summary=record["summary"],
|
||||||
|
participants=record["participants"],
|
||||||
|
start_time=record["start_time"],
|
||||||
|
end_time=record["end_time"],
|
||||||
|
original_text=record["original_text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
episodes = await memory_service.episode_admin(
|
||||||
|
action="query",
|
||||||
|
source=scenario["expectations"]["episode_source"],
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert episodes["success"] is True
|
||||||
|
assert int(episodes["count"]) >= 1
|
||||||
279
pytests/webui/test_memory_routes.py
Normal file
279
pytests/webui/test_memory_routes.py
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.services.memory_service import MemorySearchResult
|
||||||
|
from src.webui.dependencies import require_auth
|
||||||
|
from src.webui.routers import memory as memory_router_module
|
||||||
|
from src.webui.routers.memory import compat_router, router
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client() -> TestClient:
|
||||||
|
app = FastAPI()
|
||||||
|
app.dependency_overrides[require_auth] = lambda: "ok"
|
||||||
|
app.include_router(router)
|
||||||
|
app.include_router(compat_router)
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_webui_memory_graph_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_graph_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "get_graph"
|
||||||
|
return {"success": True, "nodes": [], "edges": [], "total_nodes": 0, "limit": kwargs.get("limit")}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin)
|
||||||
|
|
||||||
|
response = client.get("/api/webui/memory/graph", params={"limit": 77})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["success"] is True
|
||||||
|
assert response.json()["limit"] == 77
|
||||||
|
|
||||||
|
|
||||||
|
def test_compat_aggregate_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_search(query: str, **kwargs):
|
||||||
|
assert kwargs["mode"] == "aggregate"
|
||||||
|
assert kwargs["respect_filter"] is False
|
||||||
|
return MemorySearchResult(summary=f"summary:{query}", hits=[])
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
response = client.get("/api/query/aggregate", params={"query": "mai"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"success": True, "summary": "summary:mai", "hits": [], "filtered": False}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_save_routes(client: TestClient, monkeypatch):
|
||||||
|
async def fake_runtime_admin(*, action: str, **kwargs):
|
||||||
|
if action == "get_config":
|
||||||
|
return {"success": True, "auto_save": True}
|
||||||
|
if action == "set_auto_save":
|
||||||
|
return {"success": True, "auto_save": kwargs["enabled"]}
|
||||||
|
raise AssertionError(action)
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin)
|
||||||
|
|
||||||
|
get_response = client.get("/api/config/auto_save")
|
||||||
|
post_response = client.post("/api/config/auto_save", json={"enabled": False})
|
||||||
|
|
||||||
|
assert get_response.status_code == 200
|
||||||
|
assert get_response.json() == {"success": True, "auto_save": True}
|
||||||
|
assert post_response.status_code == 200
|
||||||
|
assert post_response.json() == {"success": True, "auto_save": False}
|
||||||
|
|
||||||
|
|
||||||
|
def test_recycle_bin_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_get_recycle_bin(*, limit: int):
|
||||||
|
return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin)
|
||||||
|
|
||||||
|
response = client.get("/api/memory/recycle_bin", params={"limit": 10})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["success"] is True
|
||||||
|
assert response.json()["count"] == 1
|
||||||
|
assert response.json()["limit"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_guide_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_import_admin(*, action: str, **kwargs):
|
||||||
|
assert kwargs == {}
|
||||||
|
if action == "get_guide":
|
||||||
|
return {"success": True}
|
||||||
|
if action == "get_settings":
|
||||||
|
return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}}
|
||||||
|
raise AssertionError(action)
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
||||||
|
|
||||||
|
response = client.get("/api/webui/memory/import/guide")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["success"] is True
|
||||||
|
assert response.json()["source"] == "local"
|
||||||
|
assert "长期记忆导入说明" in response.json()["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_upload_route(client: TestClient, monkeypatch, tmp_path):
|
||||||
|
monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path)
|
||||||
|
|
||||||
|
async def fake_import_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "create_upload"
|
||||||
|
staged_files = kwargs["staged_files"]
|
||||||
|
assert len(staged_files) == 1
|
||||||
|
assert staged_files[0]["filename"] == "demo.txt"
|
||||||
|
assert memory_router_module.Path(staged_files[0]["staged_path"]).exists()
|
||||||
|
return {"success": True, "task_id": "task-1"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/import/upload",
|
||||||
|
data={"payload_json": "{\"source\": \"upload\"}"},
|
||||||
|
files=[("files", ("demo.txt", b"hello world", "text/plain"))],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"success": True, "task_id": "task-1"}
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_v5_status_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_v5_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "status"
|
||||||
|
assert kwargs["target"] == "mai"
|
||||||
|
return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin)
|
||||||
|
|
||||||
|
response = client.get("/api/webui/memory/v5/status", params={"target": "mai"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["success"] is True
|
||||||
|
assert response.json()["deleted_count"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_preview_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_delete_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "preview"
|
||||||
|
assert kwargs["mode"] == "paragraph"
|
||||||
|
assert kwargs["selector"] == {"query": "demo"}
|
||||||
|
return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/webui/memory/delete/preview",
|
||||||
|
json={"mode": "paragraph", "selector": {"query": "demo"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_process_pending_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_episode_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "process_pending"
|
||||||
|
assert kwargs == {"limit": 7, "max_retry": 4}
|
||||||
|
return {"success": True, "processed": 3}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin)
|
||||||
|
|
||||||
|
response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"success": True, "processed": 3}
|
||||||
|
|
||||||
|
|
||||||
|
def test_import_list_route_includes_settings(client: TestClient, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_import_admin(*, action: str, **kwargs):
|
||||||
|
calls.append((action, kwargs))
|
||||||
|
if action == "list":
|
||||||
|
return {"success": True, "items": [{"task_id": "task-1"}]}
|
||||||
|
if action == "get_settings":
|
||||||
|
return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}}
|
||||||
|
raise AssertionError(action)
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin)
|
||||||
|
|
||||||
|
response = client.get("/api/webui/memory/import/tasks", params={"limit": 9})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["items"] == [{"task_id": "task-1"}]
|
||||||
|
assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}}
|
||||||
|
assert calls == [("list", {"limit": 9}), ("get_settings", {})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_tuning_admin(*, action: str, **kwargs):
|
||||||
|
calls.append((action, kwargs))
|
||||||
|
if action == "get_profile":
|
||||||
|
return {"success": True, "profile": {"retrieval": {"top_k": 8}}}
|
||||||
|
if action == "get_settings":
|
||||||
|
return {"success": True, "settings": {"profiles": ["default"]}}
|
||||||
|
raise AssertionError(action)
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
|
||||||
|
|
||||||
|
response = client.get("/api/webui/memory/retrieval_tuning/profile")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["profile"] == {"retrieval": {"top_k": 8}}
|
||||||
|
assert response.json()["settings"] == {"profiles": ["default"]}
|
||||||
|
assert calls == [("get_profile", {}), ("get_settings", {})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch):
|
||||||
|
async def fake_tuning_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "get_report"
|
||||||
|
assert kwargs == {"task_id": "task-1", "format": "json"}
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"},
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin)
|
||||||
|
|
||||||
|
response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {
|
||||||
|
"success": True,
|
||||||
|
"format": "json",
|
||||||
|
"content": "{\"ok\": true}",
|
||||||
|
"path": "/tmp/report.json",
|
||||||
|
"error": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_execute_route(client: TestClient, monkeypatch):
|
||||||
|
async def fake_delete_admin(*, action: str, **kwargs):
|
||||||
|
assert action == "execute"
|
||||||
|
assert kwargs["mode"] == "source"
|
||||||
|
assert kwargs["selector"] == {"source": "chat_summary:stream-1"}
|
||||||
|
assert kwargs["reason"] == "cleanup"
|
||||||
|
assert kwargs["requested_by"] == "tester"
|
||||||
|
return {"success": True, "operation_id": "del-1"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/webui/memory/delete/execute",
|
||||||
|
json={
|
||||||
|
"mode": "source",
|
||||||
|
"selector": {"source": "chat_summary:stream-1"},
|
||||||
|
"reason": "cleanup",
|
||||||
|
"requested_by": "tester",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"success": True, "operation_id": "del-1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_operation_routes(client: TestClient, monkeypatch):
|
||||||
|
async def fake_delete_admin(*, action: str, **kwargs):
|
||||||
|
if action == "list_operations":
|
||||||
|
assert kwargs == {"limit": 5, "mode": "paragraph"}
|
||||||
|
return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1}
|
||||||
|
if action == "get_operation":
|
||||||
|
assert kwargs == {"operation_id": "del-1"}
|
||||||
|
return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}}
|
||||||
|
raise AssertionError(action)
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin)
|
||||||
|
|
||||||
|
list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"})
|
||||||
|
get_response = client.get("/api/webui/memory/delete/operations/del-1")
|
||||||
|
|
||||||
|
assert list_response.status_code == 200
|
||||||
|
assert list_response.json()["count"] == 1
|
||||||
|
assert get_response.status_code == 200
|
||||||
|
assert get_response.json()["operation"]["operation_id"] == "del-1"
|
||||||
@@ -7,7 +7,7 @@ from src.common.database.database_model import Jargon
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.bw_learner.jargon_miner_old import search_jargon
|
from src.bw_learner.jargon_explainer import search_jargon
|
||||||
from src.bw_learner.learner_utils_old import (
|
from src.bw_learner.learner_utils_old import (
|
||||||
is_bot_message,
|
is_bot_message,
|
||||||
contains_bot_self_name,
|
contains_bot_self_name,
|
||||||
|
|||||||
@@ -196,6 +196,32 @@ def contains_bot_self_name(content: str) -> bool:
|
|||||||
return any(name in target for name in candidates)
|
return any(name in target for name in candidates)
|
||||||
|
|
||||||
|
|
||||||
|
def is_bot_message(msg: Any) -> bool:
|
||||||
|
"""判断消息是否来自机器人自身。"""
|
||||||
|
if msg is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
bot_config = getattr(global_config, "bot", None)
|
||||||
|
if not bot_config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||||
|
if not user_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
known_accounts = {
|
||||||
|
str(getattr(bot_config, "qq_account", "") or "").strip(),
|
||||||
|
str(getattr(bot_config, "telegram_account", "") or "").strip(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for platform in getattr(bot_config, "platforms", []) or []:
|
||||||
|
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
|
||||||
|
if account:
|
||||||
|
known_accounts.add(account)
|
||||||
|
|
||||||
|
return user_id in {account for account in known_accounts if account}
|
||||||
|
|
||||||
|
|
||||||
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||||
# """
|
# """
|
||||||
# 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
# 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class Conversation:
|
|||||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
|
self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id)
|
||||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||||
self.direct_sender = DirectMessageSender(self.private_name)
|
self.direct_sender = DirectMessageSender(self.private_name)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
from typing import List, Tuple, Dict, Any
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.chat.knowledge import qa_manager
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.person_info.person_info import resolve_person_id_for_memory
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
logger = get_logger("knowledge_fetcher")
|
logger = get_logger("knowledge_fetcher")
|
||||||
|
|
||||||
@@ -13,11 +16,39 @@ logger = get_logger("knowledge_fetcher")
|
|||||||
class KnowledgeFetcher:
|
class KnowledgeFetcher:
|
||||||
"""知识调取器"""
|
"""知识调取器"""
|
||||||
|
|
||||||
def __init__(self, private_name: str):
|
def __init__(self, private_name: str, stream_id: str):
|
||||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
|
self.stream_id = stream_id
|
||||||
|
|
||||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
def _resolve_private_memory_context(self) -> Dict[str, str]:
|
||||||
|
session = _chat_manager.get_session_by_session_id(self.stream_id)
|
||||||
|
if session is None:
|
||||||
|
return {"chat_id": self.stream_id}
|
||||||
|
|
||||||
|
group_id = str(getattr(session, "group_id", "") or "").strip()
|
||||||
|
user_id = str(getattr(session, "user_id", "") or "").strip()
|
||||||
|
platform = str(getattr(session, "platform", "") or "").strip()
|
||||||
|
|
||||||
|
person_id = ""
|
||||||
|
if not group_id:
|
||||||
|
try:
|
||||||
|
person_id = resolve_person_id_for_memory(
|
||||||
|
person_name=self.private_name,
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"[私聊][{self.private_name}]解析人物ID失败: {exc}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"chat_id": self.stream_id,
|
||||||
|
"person_id": person_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"group_id": group_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _memory_get_knowledge(self, query: str) -> str:
|
||||||
"""获取相关知识
|
"""获取相关知识
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -27,13 +58,32 @@ class KnowledgeFetcher:
|
|||||||
str: 构造好的,带相关度的知识
|
str: 构造好的,带相关度的知识
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
|
logger.debug(f"[私聊][{self.private_name}]正在从长期记忆中获取知识")
|
||||||
try:
|
try:
|
||||||
knowledge_info = qa_manager.get_knowledge(query)
|
context = self._resolve_private_memory_context()
|
||||||
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
|
search_kwargs = {
|
||||||
return knowledge_info
|
"limit": 5,
|
||||||
|
"mode": "search",
|
||||||
|
"chat_id": context.get("chat_id", ""),
|
||||||
|
"person_id": context.get("person_id", ""),
|
||||||
|
"user_id": context.get("user_id", ""),
|
||||||
|
"group_id": context.get("group_id", ""),
|
||||||
|
"respect_filter": True,
|
||||||
|
}
|
||||||
|
result = await memory_service.search(query, **search_kwargs)
|
||||||
|
if not result.filtered and not result.hits and search_kwargs["person_id"]:
|
||||||
|
fallback_kwargs = dict(search_kwargs)
|
||||||
|
fallback_kwargs["person_id"] = ""
|
||||||
|
logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆")
|
||||||
|
result = await memory_service.search(query, **fallback_kwargs)
|
||||||
|
knowledge_info = result.to_text(limit=5)
|
||||||
|
if result.filtered:
|
||||||
|
logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过")
|
||||||
|
else:
|
||||||
|
logger.debug(f"[私聊][{self.private_name}]长期记忆查询结果: {knowledge_info[:150]}")
|
||||||
|
return knowledge_info or "未找到匹配的知识"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
|
logger.error(f"[私聊][{self.private_name}]长期记忆搜索工具执行失败: {str(e)}")
|
||||||
return "未找到匹配的知识"
|
return "未找到匹配的知识"
|
||||||
|
|
||||||
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
|
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
|
||||||
@@ -72,7 +122,7 @@ class KnowledgeFetcher:
|
|||||||
# sources_text = ",".join(sources)
|
# sources_text = ",".join(sources)
|
||||||
|
|
||||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||||
knowledge_text += self._lpmm_get_knowledge(query)
|
knowledge_text += await self._memory_get_knowledge(query)
|
||||||
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||||
|
|
||||||
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
||||||
|
|||||||
@@ -1,90 +0,0 @@
|
|||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
|
||||||
from src.chat.knowledge.qa_manager import QAManager
|
|
||||||
from src.chat.knowledge.kg_manager import KGManager
|
|
||||||
from src.chat.knowledge.global_logger import logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
import os
|
|
||||||
|
|
||||||
INVALID_ENTITY = [
|
|
||||||
"",
|
|
||||||
"你",
|
|
||||||
"他",
|
|
||||||
"她",
|
|
||||||
"它",
|
|
||||||
"我们",
|
|
||||||
"你们",
|
|
||||||
"他们",
|
|
||||||
"她们",
|
|
||||||
"它们",
|
|
||||||
]
|
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
|
||||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
|
||||||
|
|
||||||
|
|
||||||
qa_manager = None
|
|
||||||
inspire_manager = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_qa_manager():
|
|
||||||
return qa_manager
|
|
||||||
|
|
||||||
|
|
||||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
|
||||||
# 检查LPMM知识库是否启用
|
|
||||||
if global_config.lpmm_knowledge.enable:
|
|
||||||
logger.info("正在初始化Mai-LPMM")
|
|
||||||
logger.info("创建LLM客户端")
|
|
||||||
|
|
||||||
# 初始化Embedding库
|
|
||||||
embed_manager = EmbeddingManager(
|
|
||||||
max_workers=global_config.lpmm_knowledge.max_embedding_workers,
|
|
||||||
chunk_size=global_config.lpmm_knowledge.embedding_chunk_size,
|
|
||||||
)
|
|
||||||
logger.info("正在从文件加载Embedding库")
|
|
||||||
try:
|
|
||||||
embed_manager.load_from_file()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
|
||||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
|
||||||
logger.info("Embedding库加载完成")
|
|
||||||
# 初始化KG
|
|
||||||
kg_manager = KGManager()
|
|
||||||
logger.info("正在从文件加载KG")
|
|
||||||
try:
|
|
||||||
kg_manager.load_from_file()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
|
||||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
|
||||||
logger.info("KG加载完成")
|
|
||||||
|
|
||||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
|
||||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
|
||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
|
||||||
# 使用与EmbeddingStore中一致的命名空间格式
|
|
||||||
key = f"paragraph-{pg_hash}"
|
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
|
||||||
global qa_manager
|
|
||||||
# 问答系统(用于知识库)
|
|
||||||
qa_manager = QAManager(
|
|
||||||
embed_manager,
|
|
||||||
kg_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # 记忆激活(用于记忆库)
|
|
||||||
# global inspire_manager
|
|
||||||
# inspire_manager = MemoryActiveManager(
|
|
||||||
# embed_manager,
|
|
||||||
# llm_client_list[global_config["embedding"]["provider"]],
|
|
||||||
# )
|
|
||||||
else:
|
|
||||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
|
||||||
# 创建空的占位符对象,避免导入错误
|
|
||||||
@@ -1,380 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
from typing import List, Callable, Any
|
|
||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
|
||||||
from src.chat.knowledge.kg_manager import KGManager
|
|
||||||
from src.chat.knowledge.qa_manager import QAManager
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
|
||||||
|
|
||||||
logger = get_logger("LPMM-Plugin-API")
|
|
||||||
|
|
||||||
|
|
||||||
class LPMMOperations:
|
|
||||||
"""
|
|
||||||
LPMM 内部操作接口。
|
|
||||||
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._initialized = False
|
|
||||||
|
|
||||||
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
|
||||||
"""
|
|
||||||
在线程池中执行可取消的同步操作。
|
|
||||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
|
||||||
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: 要执行的同步函数
|
|
||||||
*args: 函数的位置参数
|
|
||||||
**kwargs: 函数的关键字参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
函数的返回值
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
asyncio.CancelledError: 当任务被取消时
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
# 在线程池中执行,当协程被取消时会立即响应
|
|
||||||
# 虽然线程池中的操作可能仍在运行,但协程不会阻塞
|
|
||||||
return await loop.run_in_executor(None, func, *args, **kwargs)
|
|
||||||
|
|
||||||
async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]:
|
|
||||||
"""获取并确保 LPMM 管理器已初始化"""
|
|
||||||
qa_mgr = get_qa_manager()
|
|
||||||
if qa_mgr is None:
|
|
||||||
# 如果全局没初始化,尝试初始化
|
|
||||||
if not global_config.lpmm_knowledge.enable:
|
|
||||||
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
|
|
||||||
|
|
||||||
lpmm_start_up()
|
|
||||||
qa_mgr = get_qa_manager()
|
|
||||||
|
|
||||||
if qa_mgr is None:
|
|
||||||
raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。")
|
|
||||||
|
|
||||||
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
|
|
||||||
|
|
||||||
async def add_content(self, text: str, auto_split: bool = True) -> dict:
|
|
||||||
"""
|
|
||||||
向知识库添加新内容。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 原始文本。
|
|
||||||
auto_split: 是否自动按双换行符分割段落。
|
|
||||||
- True: 自动分割(默认),支持多段文本(用双换行分隔)
|
|
||||||
- False: 不分割,将整个文本作为完整一段处理
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
|
||||||
|
|
||||||
# 1. 分段处理
|
|
||||||
if auto_split:
|
|
||||||
# 自动按双换行符分割
|
|
||||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
|
||||||
else:
|
|
||||||
# 不分割,作为完整一段
|
|
||||||
text_stripped = text.strip()
|
|
||||||
if not text_stripped:
|
|
||||||
return {"status": "error", "message": "文本内容为空"}
|
|
||||||
paragraphs = [text_stripped]
|
|
||||||
|
|
||||||
if not paragraphs:
|
|
||||||
return {"status": "error", "message": "文本内容为空"}
|
|
||||||
|
|
||||||
# 2. 实体与三元组抽取 (内部调用大模型)
|
|
||||||
from src.chat.knowledge.ie_process import IEProcess
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import model_config
|
|
||||||
|
|
||||||
llm_ner = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
|
||||||
)
|
|
||||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
|
||||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
|
||||||
|
|
||||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
|
||||||
extracted_docs = await ie_process.process_paragraphs(paragraphs)
|
|
||||||
|
|
||||||
# 3. 构造并导入数据
|
|
||||||
# 这里我们手动实现导入逻辑,不依赖外部脚本
|
|
||||||
# a. 准备段落
|
|
||||||
raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs}
|
|
||||||
# b. 准备三元组
|
|
||||||
triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs}
|
|
||||||
|
|
||||||
# 向量化并入库
|
|
||||||
# 注意:此处模仿 import_openie.py 的核心逻辑
|
|
||||||
# 1. 先进行去重检查,只处理新段落
|
|
||||||
# store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本
|
|
||||||
new_raw_paragraphs = {}
|
|
||||||
new_triple_list_data = {}
|
|
||||||
|
|
||||||
for pg_hash, passage in raw_paragraphs.items():
|
|
||||||
key = f"paragraph-{pg_hash}"
|
|
||||||
if key not in embed_mgr.stored_pg_hashes:
|
|
||||||
new_raw_paragraphs[pg_hash] = passage
|
|
||||||
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
|
|
||||||
|
|
||||||
if not new_raw_paragraphs:
|
|
||||||
return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"}
|
|
||||||
|
|
||||||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
|
||||||
# store_new_data_set 会自动处理嵌入生成和存储
|
|
||||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
|
||||||
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
|
||||||
|
|
||||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
|
||||||
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
|
||||||
|
|
||||||
# 4. 持久化
|
|
||||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
|
||||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
|
||||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"count": len(new_raw_paragraphs),
|
|
||||||
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
|
||||||
}
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
|
||||||
return {"status": "cancelled", "message": "导入操作已被用户中断"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True)
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
async def search(self, query: str, top_k: int = 3) -> List[str]:
|
|
||||||
"""
|
|
||||||
检索知识库。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询问题。
|
|
||||||
top_k: 返回最相关的条目数。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 相关文段列表。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
_, _, qa_mgr = await self._get_managers()
|
|
||||||
# 直接调用 QAManager 的检索接口
|
|
||||||
knowledge = qa_mgr.get_knowledge(query, top_k=top_k)
|
|
||||||
# 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回
|
|
||||||
return [knowledge] if knowledge else []
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Plugin API] 检索知识失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
|
|
||||||
"""
|
|
||||||
根据关键词或完整文段删除知识库内容。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyword: 匹配关键词或完整文段。
|
|
||||||
exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
|
||||||
|
|
||||||
# 1. 查找匹配的段落
|
|
||||||
to_delete_keys = []
|
|
||||||
to_delete_hashes = []
|
|
||||||
|
|
||||||
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
|
|
||||||
if exact_match:
|
|
||||||
# 完整文段匹配
|
|
||||||
if item.str.strip() == keyword.strip():
|
|
||||||
to_delete_keys.append(key)
|
|
||||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
|
||||||
else:
|
|
||||||
# 关键词模糊匹配
|
|
||||||
if keyword in item.str:
|
|
||||||
to_delete_keys.append(key)
|
|
||||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
|
||||||
|
|
||||||
if not to_delete_keys:
|
|
||||||
match_type = "完整文段" if exact_match else "关键词"
|
|
||||||
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
|
|
||||||
|
|
||||||
# 2. 执行删除
|
|
||||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
|
||||||
|
|
||||||
# a. 从向量库删除
|
|
||||||
deleted_count, _ = await self._run_cancellable_executor(
|
|
||||||
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
|
||||||
)
|
|
||||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
|
||||||
|
|
||||||
# b. 从知识图谱删除
|
|
||||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
|
||||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
|
||||||
delete_func = partial(
|
|
||||||
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
|
||||||
)
|
|
||||||
await self._run_cancellable_executor(delete_func)
|
|
||||||
|
|
||||||
# 3. 持久化
|
|
||||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
|
||||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
|
||||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
|
||||||
|
|
||||||
match_type = "完整文段" if exact_match else "关键词"
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"deleted_count": deleted_count,
|
|
||||||
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
|
||||||
}
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
|
||||||
return {"status": "cancelled", "message": "删除操作已被用户中断"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True)
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
async def clear_all(self) -> dict:
|
|
||||||
"""
|
|
||||||
清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: {"status": "success/error", "message": "描述", "stats": {...}}
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
|
||||||
|
|
||||||
# 记录清空前的统计信息
|
|
||||||
before_stats = {
|
|
||||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
|
||||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
|
||||||
"relations": len(embed_mgr.relation_embedding_store.store),
|
|
||||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
|
||||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
|
||||||
|
|
||||||
# 1. 清空所有向量库
|
|
||||||
# 获取所有keys
|
|
||||||
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
|
|
||||||
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
|
|
||||||
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
|
|
||||||
|
|
||||||
# 删除所有段落向量
|
|
||||||
para_deleted, _ = await self._run_cancellable_executor(
|
|
||||||
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
|
||||||
)
|
|
||||||
embed_mgr.stored_pg_hashes.clear()
|
|
||||||
|
|
||||||
# 删除所有实体向量
|
|
||||||
if ent_keys:
|
|
||||||
ent_deleted, _ = await self._run_cancellable_executor(
|
|
||||||
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ent_deleted = 0
|
|
||||||
|
|
||||||
# 删除所有关系向量
|
|
||||||
if rel_keys:
|
|
||||||
rel_deleted, _ = await self._run_cancellable_executor(
|
|
||||||
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rel_deleted = 0
|
|
||||||
|
|
||||||
# 2. 清空所有 embedding store 的索引和映射
|
|
||||||
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
|
|
||||||
def _clear_embedding_indices():
|
|
||||||
# 清空段落索引
|
|
||||||
embed_mgr.paragraphs_embedding_store.faiss_index = None
|
|
||||||
embed_mgr.paragraphs_embedding_store.idx2hash = None
|
|
||||||
embed_mgr.paragraphs_embedding_store.dirty = False
|
|
||||||
# 删除旧的索引文件
|
|
||||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path):
|
|
||||||
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
|
|
||||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
|
|
||||||
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
|
|
||||||
|
|
||||||
# 清空实体索引
|
|
||||||
embed_mgr.entities_embedding_store.faiss_index = None
|
|
||||||
embed_mgr.entities_embedding_store.idx2hash = None
|
|
||||||
embed_mgr.entities_embedding_store.dirty = False
|
|
||||||
# 删除旧的索引文件
|
|
||||||
if os.path.exists(embed_mgr.entities_embedding_store.index_file_path):
|
|
||||||
os.remove(embed_mgr.entities_embedding_store.index_file_path)
|
|
||||||
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
|
|
||||||
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
|
|
||||||
|
|
||||||
# 清空关系索引
|
|
||||||
embed_mgr.relation_embedding_store.faiss_index = None
|
|
||||||
embed_mgr.relation_embedding_store.idx2hash = None
|
|
||||||
embed_mgr.relation_embedding_store.dirty = False
|
|
||||||
# 删除旧的索引文件
|
|
||||||
if os.path.exists(embed_mgr.relation_embedding_store.index_file_path):
|
|
||||||
os.remove(embed_mgr.relation_embedding_store.index_file_path)
|
|
||||||
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
|
|
||||||
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
|
|
||||||
|
|
||||||
await self._run_cancellable_executor(_clear_embedding_indices)
|
|
||||||
|
|
||||||
# 3. 清空知识图谱
|
|
||||||
# 获取所有段落hash
|
|
||||||
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
|
|
||||||
if all_pg_hashes:
|
|
||||||
# 删除所有段落节点(这会自动清理相关的边和孤立实体)
|
|
||||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
|
||||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
|
||||||
delete_func = partial(
|
|
||||||
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
|
||||||
)
|
|
||||||
await self._run_cancellable_executor(delete_func)
|
|
||||||
|
|
||||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
|
||||||
from quick_algo import di_graph
|
|
||||||
|
|
||||||
kg_mgr.graph = di_graph.DiGraph()
|
|
||||||
kg_mgr.stored_paragraph_hashes.clear()
|
|
||||||
kg_mgr.ent_appear_cnt.clear()
|
|
||||||
|
|
||||||
# 4. 保存所有数据(此时所有store都是空的,索引也是None)
|
|
||||||
# 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的
|
|
||||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
|
||||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
|
||||||
|
|
||||||
after_stats = {
|
|
||||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
|
||||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
|
||||||
"relations": len(embed_mgr.relation_embedding_store.store),
|
|
||||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
|
||||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
|
|
||||||
"stats": {
|
|
||||||
"before": before_stats,
|
|
||||||
"after": after_stats,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.warning("[Plugin API] 清空操作被用户中断")
|
|
||||||
return {"status": "cancelled", "message": "清空操作已被用户中断"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
# 内部使用的单例
|
|
||||||
lpmm_ops = LPMMOperations()
|
|
||||||
@@ -360,6 +360,12 @@ class ChatBot:
|
|||||||
user_id = user_info.user_id
|
user_id = user_info.user_id
|
||||||
group_id = group_info.group_id if group_info else None
|
group_id = group_info.group_id if group_info else None
|
||||||
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
||||||
|
try:
|
||||||
|
from src.services.memory_flow_service import memory_automation_service
|
||||||
|
|
||||||
|
await memory_automation_service.on_incoming_message(message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}")
|
||||||
|
|
||||||
# message.update_chat_stream(chat)
|
# message.update_chat_stream(chat)
|
||||||
|
|
||||||
|
|||||||
@@ -383,6 +383,13 @@ class UniversalMessageSender:
|
|||||||
with get_db_session() as db_session:
|
with get_db_session() as db_session:
|
||||||
db_session.add(message.to_db_instance())
|
db_session.add(message.to_db_instance())
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.services.memory_flow_service import memory_automation_service
|
||||||
|
|
||||||
|
await memory_automation_service.on_message_sent(message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"[{chat_id}] 长期记忆人物事实写回注册失败: {exc}")
|
||||||
|
|
||||||
return sent_msg
|
return sent_msg
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -36,6 +35,7 @@ from src.services import llm_service as llm_api
|
|||||||
|
|
||||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||||
|
from src.memory_system.retrieval_tools import get_tool_registry
|
||||||
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||||
from src.chat.utils.common_utils import TempMethodsExpression
|
from src.chat.utils.common_utils import TempMethodsExpression
|
||||||
|
|
||||||
@@ -1164,29 +1164,14 @@ class DefaultReplyer:
|
|||||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||||
related_info = ""
|
related_info = ""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||||
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
|
|
||||||
except ImportError:
|
|
||||||
logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
|
|
||||||
if search_knowledge_tool is None:
|
if search_knowledge_tool is None:
|
||||||
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容")
|
logger.debug("长期记忆检索工具未注册,跳过获取知识内容")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
# 从LPMM知识库获取知识
|
|
||||||
try:
|
try:
|
||||||
# 检查LPMM知识库是否启用
|
template_prompt = prompt_manager.get_prompt("memory_get_knowledge")
|
||||||
if not global_config.lpmm_knowledge.enable:
|
|
||||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
|
||||||
return ""
|
|
||||||
|
|
||||||
template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge")
|
|
||||||
template_prompt.add_context("bot_name", global_config.bot.nickname)
|
template_prompt.add_context("bot_name", global_config.bot.nickname)
|
||||||
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||||
template_prompt.add_context("chat_history", message)
|
template_prompt.add_context("chat_history", message)
|
||||||
@@ -1202,24 +1187,31 @@ class DefaultReplyer:
|
|||||||
# logger.info(f"工具调用提示词: {prompt}")
|
# logger.info(f"工具调用提示词: {prompt}")
|
||||||
# logger.info(f"工具调用: {tool_calls}")
|
# logger.info(f"工具调用: {tool_calls}")
|
||||||
|
|
||||||
if tool_calls:
|
if not tool_calls:
|
||||||
result = await self.tool_executor.execute_tool_call(tool_calls[0])
|
logger.debug("模型认为不需要使用长期记忆")
|
||||||
end_time = time.time()
|
|
||||||
if not result or not result.get("content"):
|
|
||||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
|
||||||
return ""
|
|
||||||
found_knowledge_from_lpmm = result.get("content", "")
|
|
||||||
logger.info(
|
|
||||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
|
||||||
)
|
|
||||||
related_info += found_knowledge_from_lpmm
|
|
||||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
|
||||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
|
||||||
|
|
||||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
|
||||||
else:
|
|
||||||
logger.debug("模型认为不需要使用LPMM知识库")
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
related_chunks: List[str] = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if tool_call.func_name != "search_long_term_memory":
|
||||||
|
continue
|
||||||
|
tool_args = dict(tool_call.args or {})
|
||||||
|
tool_args.setdefault("chat_id", self.chat_stream.session_id)
|
||||||
|
result_text = await search_knowledge_tool.execute(**tool_args)
|
||||||
|
if result_text and "未找到" not in result_text:
|
||||||
|
related_chunks.append(result_text)
|
||||||
|
|
||||||
|
if not related_chunks:
|
||||||
|
logger.debug("长期记忆未返回有效信息")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
related_info = "\n".join(related_chunks)
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||||
|
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||||
|
|
||||||
|
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
|||||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||||
MMC_VERSION: str = "1.0.0"
|
MMC_VERSION: str = "1.0.0"
|
||||||
CONFIG_VERSION: str = "8.1.0"
|
CONFIG_VERSION: str = "8.1.1"
|
||||||
MODEL_CONFIG_VERSION: str = "1.12.0"
|
MODEL_CONFIG_VERSION: str = "1.12.0"
|
||||||
|
|
||||||
logger = get_logger("config")
|
logger = get_logger("config")
|
||||||
|
|||||||
@@ -94,6 +94,11 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
|||||||
["", "enable", "enable", "enable"],
|
["", "enable", "enable", "enable"],
|
||||||
["qq:1919810:group", "enable", "enable", "enable"],
|
["qq:1919810:group", "enable", "enable", "enable"],
|
||||||
]
|
]
|
||||||
|
兼容旧旧格式:
|
||||||
|
learning_list = [
|
||||||
|
["qq:1919810:group", "enable", "enable", "0.5"],
|
||||||
|
["", "disable", "disable", "0.1"],
|
||||||
|
]
|
||||||
新:
|
新:
|
||||||
[[expression.learning_list]]
|
[[expression.learning_list]]
|
||||||
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
||||||
@@ -117,6 +122,16 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
|||||||
use_expression = _parse_enable_disable(r[1])
|
use_expression = _parse_enable_disable(r[1])
|
||||||
enable_learning = _parse_enable_disable(r[2])
|
enable_learning = _parse_enable_disable(r[2])
|
||||||
enable_jargon_learning = _parse_enable_disable(r[3])
|
enable_jargon_learning = _parse_enable_disable(r[3])
|
||||||
|
if enable_jargon_learning is None:
|
||||||
|
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
|
||||||
|
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
|
||||||
|
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
|
||||||
|
try:
|
||||||
|
float(str(r[3]))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
enable_jargon_learning = False
|
||||||
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -416,6 +416,24 @@ class MemoryConfig(ConfigBase):
|
|||||||
)
|
)
|
||||||
"""_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
|
"""_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
|
||||||
|
|
||||||
|
long_term_auto_summary_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"x-widget": "switch",
|
||||||
|
"x-icon": "book-open",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
"""是否自动启动聊天总结并导入长期记忆"""
|
||||||
|
|
||||||
|
person_fact_writeback_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"x-widget": "switch",
|
||||||
|
"x-icon": "user-round-pen",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||||
|
|
||||||
chat_history_topic_check_message_threshold: int = Field(
|
chat_history_topic_check_message_threshold: int = Field(
|
||||||
default=80,
|
default=80,
|
||||||
ge=1,
|
ge=1,
|
||||||
|
|||||||
10
src/main.py
10
src/main.py
@@ -6,7 +6,6 @@ import time
|
|||||||
|
|
||||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from src.chat.knowledge import lpmm_start_up
|
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
from src.chat.message_receive.chat_manager import chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
@@ -19,6 +18,7 @@ from src.config.config import config_manager, global_config
|
|||||||
from src.manager.async_task_manager import async_task_manager
|
from src.manager.async_task_manager import async_task_manager
|
||||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
|
from src.services.memory_flow_service import memory_automation_service
|
||||||
|
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
@@ -88,9 +88,6 @@ class MainSystem:
|
|||||||
# start_api_server()
|
# start_api_server()
|
||||||
# logger.info("API服务器启动成功")
|
# logger.info("API服务器启动成功")
|
||||||
|
|
||||||
# 启动LPMM
|
|
||||||
lpmm_start_up()
|
|
||||||
|
|
||||||
# 启动插件运行时(内置插件 + 第三方插件双子进程)
|
# 启动插件运行时(内置插件 + 第三方插件双子进程)
|
||||||
await get_plugin_runtime_manager().start()
|
await get_plugin_runtime_manager().start()
|
||||||
|
|
||||||
@@ -103,6 +100,7 @@ class MainSystem:
|
|||||||
asyncio.create_task(chat_manager.regularly_save_sessions())
|
asyncio.create_task(chat_manager.regularly_save_sessions())
|
||||||
|
|
||||||
logger.info(t("startup.chat_manager_initialized"))
|
logger.info(t("startup.chat_manager_initialized"))
|
||||||
|
await memory_automation_service.start()
|
||||||
|
|
||||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||||
|
|
||||||
@@ -164,6 +162,10 @@ async def main():
|
|||||||
system.schedule_tasks(),
|
system.schedule_tasks(),
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
await memory_automation_service.shutdown()
|
||||||
|
await get_plugin_runtime_manager().bridge_event("on_stop")
|
||||||
|
await get_plugin_runtime_manager().stop()
|
||||||
|
await async_task_manager.stop_and_wait_all_tasks()
|
||||||
await config_manager.stop_file_watcher()
|
await config_manager.stop_file_watcher()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -931,12 +931,14 @@ class ChatHistorySummarizer:
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
|
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
|
||||||
|
|
||||||
# 同时导入到LPMM知识库
|
if saved_record and saved_record.get("id") is not None:
|
||||||
if global_config.lpmm_knowledge.enable:
|
await self._import_to_long_term_memory(
|
||||||
await self._import_to_lpmm_knowledge(
|
record_id=int(saved_record["id"]),
|
||||||
theme=theme,
|
theme=theme,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
participants=participants,
|
participants=participants,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
original_text=original_text,
|
original_text=original_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -947,76 +949,131 @@ class ChatHistorySummarizer:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _import_to_lpmm_knowledge(
|
async def _import_to_long_term_memory(
|
||||||
self,
|
self,
|
||||||
|
record_id: int,
|
||||||
theme: str,
|
theme: str,
|
||||||
summary: str,
|
summary: str,
|
||||||
participants: List[str],
|
participants: List[str],
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
original_text: str,
|
original_text: str,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
将聊天历史总结导入到LPMM知识库
|
将聊天历史总结导入到统一长期记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
record_id: chat_history 主键
|
||||||
theme: 话题主题
|
theme: 话题主题
|
||||||
summary: 概括内容
|
summary: 概括内容
|
||||||
participants: 参与者列表
|
participants: 参与者列表
|
||||||
|
start_time: 开始时间
|
||||||
|
end_time: 结束时间
|
||||||
original_text: 原始文本(可能很长,需要截断)
|
original_text: 原始文本(可能很长,需要截断)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from src.chat.knowledge.lpmm_ops import lpmm_ops
|
from src.services.memory_service import memory_service
|
||||||
|
session = _chat_manager.get_session_by_session_id(self.session_id)
|
||||||
|
session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else ""
|
||||||
|
session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else ""
|
||||||
|
|
||||||
# 构造要导入的文本内容
|
|
||||||
# 格式:主题 + 概括 + 参与者信息 + 原始内容摘要
|
|
||||||
# 注意:使用单换行符连接,确保整个内容作为一段导入,不被LPMM分段
|
|
||||||
content_parts = []
|
content_parts = []
|
||||||
|
if theme:
|
||||||
# 1. 话题主题
|
content_parts.append(f"主题:{theme}")
|
||||||
# if theme:
|
|
||||||
# content_parts.append(f"话题:{theme}")
|
|
||||||
|
|
||||||
# 2. 概括内容
|
|
||||||
if summary:
|
if summary:
|
||||||
content_parts.append(f"概括:{summary}")
|
content_parts.append(f"概括:{summary}")
|
||||||
|
|
||||||
# 3. 参与者信息
|
|
||||||
if participants:
|
if participants:
|
||||||
participants_text = "、".join(participants)
|
participants_text = "、".join(participants)
|
||||||
content_parts.append(f"参与者:{participants_text}")
|
content_parts.append(f"参与者:{participants_text}")
|
||||||
|
|
||||||
# 4. 原始文本摘要(如果原始文本太长,只取前500字)
|
|
||||||
# if original_text:
|
|
||||||
# # 截断原始文本,避免过长
|
|
||||||
# max_original_length = 500
|
|
||||||
# if len(original_text) > max_original_length:
|
|
||||||
# truncated_text = original_text[:max_original_length] + "..."
|
|
||||||
# content_parts.append(f"原始内容摘要:{truncated_text}")
|
|
||||||
# else:
|
|
||||||
# content_parts.append(f"原始内容:{original_text}")
|
|
||||||
|
|
||||||
# 将所有部分合并为一个完整段落(使用单换行符,避免被LPMM分段)
|
|
||||||
# LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段
|
|
||||||
content_to_import = "\n".join(content_parts)
|
content_to_import = "\n".join(content_parts)
|
||||||
|
|
||||||
if not content_to_import.strip():
|
if not content_to_import.strip():
|
||||||
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库")
|
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底")
|
||||||
|
await self._fallback_import_to_long_term_memory(
|
||||||
|
record_id=record_id,
|
||||||
|
theme=theme,
|
||||||
|
participants=participants,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
original_text=original_text,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 调用lpmm_ops导入
|
result = await memory_service.ingest_summary(
|
||||||
result = await lpmm_ops.add_content(text=content_to_import, auto_split=False)
|
external_id=f"chat_history:{record_id}",
|
||||||
|
chat_id=self.session_id,
|
||||||
if result["status"] == "success":
|
text=content_to_import,
|
||||||
logger.info(
|
participants=participants,
|
||||||
f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}"
|
time_start=start_time,
|
||||||
)
|
time_end=end_time,
|
||||||
|
tags=[theme] if theme else [],
|
||||||
|
metadata={"theme": theme, "original_text_length": len(original_text or "")},
|
||||||
|
respect_filter=True,
|
||||||
|
user_id=session_user_id,
|
||||||
|
group_id=session_group_id,
|
||||||
|
)
|
||||||
|
if result.success:
|
||||||
|
if result.detail == "chat_filtered":
|
||||||
|
logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}")
|
||||||
|
else:
|
||||||
|
logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}")
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}")
|
||||||
f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}"
|
await self._fallback_import_to_long_term_memory(
|
||||||
|
record_id=record_id,
|
||||||
|
theme=theme,
|
||||||
|
participants=participants,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
original_text=original_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 导入失败不应该影响数据库存储,只记录错误
|
logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True)
|
||||||
logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True)
|
|
||||||
|
async def _fallback_import_to_long_term_memory(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
record_id: int,
|
||||||
|
theme: str,
|
||||||
|
participants: List[str],
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
original_text: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
|
session = _chat_manager.get_session_by_session_id(self.session_id)
|
||||||
|
session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else ""
|
||||||
|
session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else ""
|
||||||
|
|
||||||
|
result = await memory_service.ingest_summary(
|
||||||
|
external_id=f"chat_history:{record_id}",
|
||||||
|
chat_id=self.session_id,
|
||||||
|
text="",
|
||||||
|
participants=participants,
|
||||||
|
time_start=start_time,
|
||||||
|
time_end=end_time,
|
||||||
|
tags=[theme] if theme else [],
|
||||||
|
metadata={
|
||||||
|
"theme": theme,
|
||||||
|
"original_text_length": len(original_text or ""),
|
||||||
|
"generate_from_chat": True,
|
||||||
|
"context_length": global_config.memory.chat_history_topic_check_message_threshold,
|
||||||
|
},
|
||||||
|
respect_filter=True,
|
||||||
|
user_id=session_user_id,
|
||||||
|
group_id=session_group_id,
|
||||||
|
)
|
||||||
|
if result.success:
|
||||||
|
if result.detail == "chat_filtered":
|
||||||
|
logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}")
|
||||||
|
else:
|
||||||
|
logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动后台定期检查循环"""
|
"""启动后台定期检查循环"""
|
||||||
|
|||||||
@@ -237,8 +237,8 @@ async def _react_agent_solve_question(
|
|||||||
if first_head_prompt is None:
|
if first_head_prompt is None:
|
||||||
# 第一次构建,使用初始的collected_info(即initial_info)
|
# 第一次构建,使用初始的collected_info(即initial_info)
|
||||||
initial_collected_info = initial_info or ""
|
initial_collected_info = initial_info or ""
|
||||||
# 使用 LPMM 知识库检索 prompt
|
# 使用统一长期记忆检索 prompt
|
||||||
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_lpmm")
|
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory")
|
||||||
first_head_prompt_template.add_context("bot_name", bot_name)
|
first_head_prompt_template.add_context("bot_name", bot_name)
|
||||||
first_head_prompt_template.add_context("time_now", time_now)
|
first_head_prompt_template.add_context("time_now", time_now)
|
||||||
first_head_prompt_template.add_context("chat_history", chat_history)
|
first_head_prompt_template.add_context("chat_history", chat_history)
|
||||||
|
|||||||
@@ -10,21 +10,17 @@ from .tool_registry import (
|
|||||||
get_tool_registry,
|
get_tool_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入所有工具的注册函数
|
|
||||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
|
||||||
from .query_words import register_tool as register_query_words
|
|
||||||
from .return_information import register_tool as register_return_information
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
def init_all_tools():
|
def init_all_tools():
|
||||||
"""初始化并注册所有记忆检索工具"""
|
"""初始化并注册所有记忆检索工具"""
|
||||||
|
# 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。
|
||||||
|
from .query_long_term_memory import register_tool as register_long_term_memory
|
||||||
|
from .query_words import register_tool as register_query_words
|
||||||
|
from .return_information import register_tool as register_return_information
|
||||||
|
|
||||||
register_query_words()
|
register_query_words()
|
||||||
register_return_information()
|
register_return_information()
|
||||||
|
register_long_term_memory()
|
||||||
# LPMM知识库检索工具
|
|
||||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
|
||||||
register_lpmm_knowledge()
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
304
src/memory_system/retrieval_tools/query_long_term_memory.py
Normal file
304
src/memory_system/retrieval_tools/query_long_term_memory.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""通过统一长期记忆服务查询信息。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from calendar import monthrange
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Iterable, Literal, Tuple
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service
|
||||||
|
|
||||||
|
from .tool_registry import register_memory_retrieval_tool
|
||||||
|
|
||||||
|
logger = get_logger("memory_retrieval_tools")
|
||||||
|
|
||||||
|
_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"}
|
||||||
|
_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$")
|
||||||
|
_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||||
|
_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$")
|
||||||
|
_TIME_EXPRESSION_HELP = (
|
||||||
|
"请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、"
|
||||||
|
"2026/03/18、2026/03/18 09:30。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_query_datetime(dt: datetime) -> str:
|
||||||
|
return dt.strftime("%Y/%m/%d %H:%M")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_time_expression(
|
||||||
|
expression: str,
|
||||||
|
*,
|
||||||
|
now: datetime | None = None,
|
||||||
|
) -> Tuple[float, float, str, str]:
|
||||||
|
clean = str(expression or "").strip()
|
||||||
|
if not clean:
|
||||||
|
raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}")
|
||||||
|
|
||||||
|
current = now or datetime.now()
|
||||||
|
day_start = current.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
if clean == "今天":
|
||||||
|
start = day_start
|
||||||
|
end = day_start.replace(hour=23, minute=59)
|
||||||
|
elif clean == "昨天":
|
||||||
|
start = day_start - timedelta(days=1)
|
||||||
|
end = start.replace(hour=23, minute=59)
|
||||||
|
elif clean == "前天":
|
||||||
|
start = day_start - timedelta(days=2)
|
||||||
|
end = start.replace(hour=23, minute=59)
|
||||||
|
elif clean == "本周":
|
||||||
|
start = day_start - timedelta(days=day_start.weekday())
|
||||||
|
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||||
|
elif clean == "上周":
|
||||||
|
this_week_start = day_start - timedelta(days=day_start.weekday())
|
||||||
|
start = this_week_start - timedelta(days=7)
|
||||||
|
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||||
|
elif clean == "本月":
|
||||||
|
start = day_start.replace(day=1)
|
||||||
|
last_day = monthrange(start.year, start.month)[1]
|
||||||
|
end = start.replace(day=last_day, hour=23, minute=59)
|
||||||
|
elif clean == "上月":
|
||||||
|
year = day_start.year
|
||||||
|
month = day_start.month - 1
|
||||||
|
if month == 0:
|
||||||
|
year -= 1
|
||||||
|
month = 12
|
||||||
|
start = day_start.replace(year=year, month=month, day=1)
|
||||||
|
last_day = monthrange(year, month)[1]
|
||||||
|
end = start.replace(day=last_day, hour=23, minute=59)
|
||||||
|
else:
|
||||||
|
relative_match = _RELATIVE_DAYS_RE.fullmatch(clean)
|
||||||
|
if relative_match:
|
||||||
|
days = max(1, int(relative_match.group(1)))
|
||||||
|
start = day_start - timedelta(days=max(0, days - 1))
|
||||||
|
end = day_start.replace(hour=23, minute=59)
|
||||||
|
elif _DATE_RE.fullmatch(clean):
|
||||||
|
start = datetime.strptime(clean, "%Y/%m/%d")
|
||||||
|
end = start.replace(hour=23, minute=59)
|
||||||
|
elif _MINUTE_RE.fullmatch(clean):
|
||||||
|
start = datetime.strptime(clean, "%Y/%m/%d %H:%M")
|
||||||
|
end = start
|
||||||
|
else:
|
||||||
|
raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}")
|
||||||
|
|
||||||
|
return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_time_label(metadata: dict) -> str:
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
return ""
|
||||||
|
start = metadata.get("event_time_start")
|
||||||
|
end = metadata.get("event_time_end")
|
||||||
|
event_time = metadata.get("event_time")
|
||||||
|
|
||||||
|
def _fmt(value: object) -> str:
|
||||||
|
if value in {None, ""}:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M")
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
start_text = _fmt(start or event_time)
|
||||||
|
end_text = _fmt(end)
|
||||||
|
if start_text and end_text:
|
||||||
|
return f"{start_text} - {end_text}"
|
||||||
|
return start_text or end_text
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(text: str, limit: int = 160) -> str:
|
||||||
|
compact = str(text or "").strip().replace("\n", " ")
|
||||||
|
if len(compact) <= limit:
|
||||||
|
return compact
|
||||||
|
return compact[:limit] + "..."
|
||||||
|
|
||||||
|
|
||||||
|
def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str:
|
||||||
|
lines = []
|
||||||
|
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||||
|
time_label = _extract_time_label(item.metadata) if include_time else ""
|
||||||
|
prefix = f"[{time_label}] " if time_label else ""
|
||||||
|
lines.append(f"{index}. {prefix}{_truncate(item.content)}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||||
|
lines = []
|
||||||
|
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||||
|
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||||
|
title = str(item.title or "").strip() or "未命名事件"
|
||||||
|
summary = _truncate(item.content, limit=180)
|
||||||
|
participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()]
|
||||||
|
keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()]
|
||||||
|
extras = []
|
||||||
|
if participants:
|
||||||
|
extras.append(f"参与者:{'、'.join(participants[:4])}")
|
||||||
|
if keywords:
|
||||||
|
extras.append(f"关键词:{'、'.join(keywords[:6])}")
|
||||||
|
time_label = _extract_time_label(metadata)
|
||||||
|
if time_label:
|
||||||
|
extras.append(f"时间:{time_label}")
|
||||||
|
suffix = f"({';'.join(extras)})" if extras else ""
|
||||||
|
lines.append(f"{index}. 事件《{title}》:{summary}{suffix}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||||
|
lines = []
|
||||||
|
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||||
|
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||||
|
source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()]
|
||||||
|
branch_text = f"[{','.join(source_branches)}]" if source_branches else ""
|
||||||
|
item_type = str(item.hit_type or "").strip().lower() or "memory"
|
||||||
|
if item_type == "episode":
|
||||||
|
title = str(item.title or "").strip() or "未命名事件"
|
||||||
|
lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}")
|
||||||
|
else:
|
||||||
|
lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_tool_result(
|
||||||
|
*,
|
||||||
|
result: MemorySearchResult,
|
||||||
|
mode: Literal["search", "time", "episode", "aggregate"],
|
||||||
|
limit: int,
|
||||||
|
query: str,
|
||||||
|
time_range_text: str = "",
|
||||||
|
) -> str:
|
||||||
|
if not result.hits:
|
||||||
|
if mode == "time":
|
||||||
|
return f"在指定时间范围内未找到相关的长期记忆{time_range_text}"
|
||||||
|
if mode == "episode":
|
||||||
|
return f"未找到与“{query}”相关的事件或情节记忆"
|
||||||
|
if mode == "aggregate":
|
||||||
|
return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}"
|
||||||
|
return f"在长期记忆中未找到与“{query}”相关的信息"
|
||||||
|
|
||||||
|
if mode == "episode":
|
||||||
|
text = _format_episode_lines(result.hits, limit=limit)
|
||||||
|
return f"你从长期记忆的事件/情节中找到以下信息:\n{text}"
|
||||||
|
|
||||||
|
if mode == "aggregate":
|
||||||
|
text = _format_aggregate_lines(result.hits, limit=limit)
|
||||||
|
return f"你从长期记忆中综合找到了以下线索:\n{text}"
|
||||||
|
|
||||||
|
if mode == "time":
|
||||||
|
text = _format_search_lines(result.hits, limit=limit, include_time=True)
|
||||||
|
return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}"
|
||||||
|
|
||||||
|
text = _format_search_lines(result.hits, limit=limit)
|
||||||
|
return f"你从长期记忆中找到以下信息:\n{text}"
|
||||||
|
|
||||||
|
|
||||||
|
async def query_long_term_memory(
|
||||||
|
query: str = "",
|
||||||
|
limit: int = 5,
|
||||||
|
chat_id: str = "",
|
||||||
|
person_id: str = "",
|
||||||
|
mode: str = "search",
|
||||||
|
time_expression: str = "",
|
||||||
|
) -> str:
|
||||||
|
content = str(query or "").strip()
|
||||||
|
safe_limit = max(1, int(limit or 5))
|
||||||
|
normalized_mode = str(mode or "search").strip().lower() or "search"
|
||||||
|
if normalized_mode not in _SUPPORTED_MODES:
|
||||||
|
return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。"
|
||||||
|
|
||||||
|
if normalized_mode == "search" and not content:
|
||||||
|
return "查询关键词为空,请提供你想查找的长期记忆内容。"
|
||||||
|
if normalized_mode == "time" and not str(time_expression or "").strip():
|
||||||
|
return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}"
|
||||||
|
if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip():
|
||||||
|
return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。"
|
||||||
|
|
||||||
|
time_start = None
|
||||||
|
time_end = None
|
||||||
|
time_range_text = ""
|
||||||
|
if str(time_expression or "").strip():
|
||||||
|
try:
|
||||||
|
time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression)
|
||||||
|
except ValueError as exc:
|
||||||
|
return str(exc)
|
||||||
|
time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})"
|
||||||
|
|
||||||
|
backend_mode = "hybrid" if normalized_mode == "search" else normalized_mode
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await memory_service.search(
|
||||||
|
content,
|
||||||
|
limit=safe_limit,
|
||||||
|
mode=backend_mode,
|
||||||
|
chat_id=str(chat_id or "").strip(),
|
||||||
|
person_id=str(person_id or "").strip(),
|
||||||
|
time_start=time_start,
|
||||||
|
time_end=time_end,
|
||||||
|
)
|
||||||
|
text = _format_tool_result(
|
||||||
|
result=result,
|
||||||
|
mode=normalized_mode, # type: ignore[arg-type]
|
||||||
|
limit=safe_limit,
|
||||||
|
query=content,
|
||||||
|
time_range_text=time_range_text,
|
||||||
|
)
|
||||||
|
logger.debug(f"长期记忆查询结果({normalized_mode}): {text}")
|
||||||
|
return text
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"长期记忆查询失败: {exc}")
|
||||||
|
return f"长期记忆查询失败:{exc}"
|
||||||
|
|
||||||
|
|
||||||
|
def register_tool():
|
||||||
|
register_memory_retrieval_tool(
|
||||||
|
name="search_long_term_memory",
|
||||||
|
description=(
|
||||||
|
"从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、"
|
||||||
|
"episode(按事件/情节检索)、aggregate(综合检索)四种模式。"
|
||||||
|
),
|
||||||
|
parameters=[
|
||||||
|
{
|
||||||
|
"name": "query",
|
||||||
|
"type": "string",
|
||||||
|
"description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mode",
|
||||||
|
"type": "string",
|
||||||
|
"description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。",
|
||||||
|
"required": False,
|
||||||
|
"enum": ["search", "time", "episode", "aggregate"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "limit",
|
||||||
|
"type": "integer",
|
||||||
|
"description": "希望返回的相关知识条数,默认为5",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "chat_id",
|
||||||
|
"type": "string",
|
||||||
|
"description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "person_id",
|
||||||
|
"type": "string",
|
||||||
|
"description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。",
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "time_expression",
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、"
|
||||||
|
"最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。"
|
||||||
|
),
|
||||||
|
"required": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
execute_func=query_long_term_memory,
|
||||||
|
)
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
"""
|
|
||||||
通过LPMM知识库查询信息 - 工具实现
|
|
||||||
"""
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.knowledge import get_qa_manager
|
|
||||||
from .tool_registry import register_memory_retrieval_tool
|
|
||||||
|
|
||||||
logger = get_logger("memory_retrieval_tools")
|
|
||||||
|
|
||||||
|
|
||||||
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
|
|
||||||
"""在LPMM知识库中查询相关信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询关键词
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 查询结果
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
content = str(query).strip()
|
|
||||||
if not content:
|
|
||||||
return "查询关键词为空"
|
|
||||||
|
|
||||||
try:
|
|
||||||
limit_value = int(limit)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
limit_value = 5
|
|
||||||
limit_value = max(1, limit_value)
|
|
||||||
|
|
||||||
if not global_config.lpmm_knowledge.enable:
|
|
||||||
logger.debug("LPMM知识库未启用")
|
|
||||||
return "LPMM知识库未启用"
|
|
||||||
|
|
||||||
qa_manager = get_qa_manager()
|
|
||||||
if qa_manager is None:
|
|
||||||
logger.debug("LPMM知识库未初始化,跳过查询")
|
|
||||||
return "LPMM知识库未初始化"
|
|
||||||
|
|
||||||
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
|
|
||||||
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
|
|
||||||
|
|
||||||
if knowledge_info:
|
|
||||||
return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}"
|
|
||||||
|
|
||||||
return f"在LPMM知识库中未找到与“{content}”相关的信息"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"LPMM知识库查询失败: {e}")
|
|
||||||
return f"LPMM知识库查询失败:{str(e)}"
|
|
||||||
|
|
||||||
|
|
||||||
def register_tool():
|
|
||||||
"""注册LPMM知识库查询工具"""
|
|
||||||
register_memory_retrieval_tool(
|
|
||||||
name="lpmm_search_knowledge",
|
|
||||||
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
|
|
||||||
parameters=[
|
|
||||||
{
|
|
||||||
"name": "query",
|
|
||||||
"type": "string",
|
|
||||||
"description": "需要查询的问题,使用一句疑问句提问,例如:什么是AI?",
|
|
||||||
"required": True,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "limit",
|
|
||||||
"type": "integer",
|
|
||||||
"description": "希望返回的相关知识条数,默认为5",
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
execute_func=query_lpmm_knowledge,
|
|
||||||
)
|
|
||||||
@@ -6,9 +6,10 @@ import random
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from typing import Union, Optional, Dict
|
from typing import Union, Optional, Dict, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import or_
|
||||||
from sqlmodel import col, select
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -17,6 +18,7 @@ from src.common.database.database_model import PersonInfo
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("person_info")
|
logger = get_logger("person_info")
|
||||||
@@ -37,16 +39,60 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
|||||||
|
|
||||||
def get_person_id_by_person_name(person_name: str) -> str:
|
def get_person_id_by_person_name(person_name: str) -> str:
|
||||||
"""根据用户名获取用户ID"""
|
"""根据用户名获取用户ID"""
|
||||||
|
clean_name = str(person_name or "").strip()
|
||||||
|
if not clean_name:
|
||||||
|
return ""
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
|
statement = (
|
||||||
|
select(PersonInfo)
|
||||||
|
.where(
|
||||||
|
or_(
|
||||||
|
col(PersonInfo.person_name) == clean_name,
|
||||||
|
col(PersonInfo.user_nickname) == clean_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
record = session.exec(statement).first()
|
||||||
|
if record and record.person_id:
|
||||||
|
return record.person_id
|
||||||
|
|
||||||
|
statement = (
|
||||||
|
select(PersonInfo)
|
||||||
|
.where(PersonInfo.group_cardname.contains(clean_name))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
record = session.exec(statement).first()
|
record = session.exec(statement).first()
|
||||||
return record.person_id if record else ""
|
return record.person_id if record else ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_person_id_for_memory(
|
||||||
|
*,
|
||||||
|
person_name: str = "",
|
||||||
|
platform: str = "",
|
||||||
|
user_id: Optional[Union[int, str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""统一人物记忆链路中的 person_id 解析。
|
||||||
|
|
||||||
|
优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。
|
||||||
|
"""
|
||||||
|
name_token = str(person_name or "").strip()
|
||||||
|
if name_token:
|
||||||
|
resolved = get_person_id_by_person_name(name_token)
|
||||||
|
if resolved:
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
platform_token = str(platform or "").strip()
|
||||||
|
user_token = str(user_id or "").strip()
|
||||||
|
if platform_token and user_token:
|
||||||
|
return get_person_id(platform_token, user_token)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def is_person_known(
|
def is_person_known(
|
||||||
person_id: Optional[str] = None,
|
person_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
@@ -537,79 +583,79 @@ class Person:
|
|||||||
async def build_relationship(self, chat_content: str = "", info_type=""):
|
async def build_relationship(self, chat_content: str = "", info_type=""):
|
||||||
if not self.is_known:
|
if not self.is_known:
|
||||||
return ""
|
return ""
|
||||||
# 构建points文本
|
|
||||||
|
|
||||||
nickname_str = ""
|
nickname_str = ""
|
||||||
if self.person_name != self.nickname:
|
if self.person_name != self.nickname:
|
||||||
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
||||||
|
|
||||||
relation_info = ""
|
async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]:
|
||||||
|
clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()]
|
||||||
|
if not clean_traits:
|
||||||
|
return []
|
||||||
|
if not query_text:
|
||||||
|
return clean_traits[:limit]
|
||||||
|
|
||||||
points_text = ""
|
numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1))
|
||||||
category_list = self.get_all_category()
|
prompt = f"""当前关注内容:
|
||||||
|
{query_text}
|
||||||
|
|
||||||
if chat_content:
|
候选人物信息:
|
||||||
prompt = f"""当前聊天内容:
|
{numbered_traits}
|
||||||
{chat_content}
|
|
||||||
|
|
||||||
分类列表:
|
请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。
|
||||||
{category_list}
|
例如:
|
||||||
**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
|
<1><3>
|
||||||
例如:
|
如果都不相关,请输出<none>"""
|
||||||
<分类1><分类2><分类3>......
|
|
||||||
如果没有相关的分类,请输出<none>"""
|
|
||||||
|
|
||||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
try:
|
||||||
# print(prompt)
|
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||||
# print(response)
|
selected_traits: List[str] = []
|
||||||
category_list = extract_categories_from_response(response)
|
for raw_index in extract_categories_from_response(response):
|
||||||
if "none" not in category_list:
|
if raw_index == "none":
|
||||||
for category in category_list:
|
return []
|
||||||
random_memory = self.get_random_memory_by_category(category, 2)
|
try:
|
||||||
if random_memory:
|
trait_index = int(raw_index) - 1
|
||||||
random_memory_str = "\n".join(
|
except ValueError:
|
||||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
continue
|
||||||
)
|
if 0 <= trait_index < len(clean_traits):
|
||||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
trait = clean_traits[trait_index]
|
||||||
break
|
if trait not in selected_traits:
|
||||||
elif info_type:
|
selected_traits.append(trait)
|
||||||
prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。
|
if selected_traits:
|
||||||
|
return selected_traits[:limit]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}")
|
||||||
|
|
||||||
现有信息类别列表:
|
return clean_traits[:limit]
|
||||||
{category_list}
|
|
||||||
**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
|
profile = await memory_service.get_person_profile(self.person_id, limit=8)
|
||||||
例如:
|
relation_parts: List[str] = []
|
||||||
<分类1><分类2><分类3>......
|
if profile.summary.strip():
|
||||||
如果没有相关的分类,请输出<none>"""
|
relation_parts.append(profile.summary.strip())
|
||||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
|
||||||
# print(prompt)
|
query_text = str(chat_content or info_type or "").strip()
|
||||||
# print(response)
|
selected_traits = await _select_traits(query_text, profile.traits, limit=3)
|
||||||
category_list = extract_categories_from_response(response)
|
if not selected_traits and not query_text:
|
||||||
if "none" not in category_list:
|
selected_traits = [trait for trait in profile.traits if trait][:2]
|
||||||
for category in category_list:
|
|
||||||
random_memory = self.get_random_memory_by_category(category, 3)
|
for trait in selected_traits:
|
||||||
if random_memory:
|
clean_trait = str(trait).strip()
|
||||||
random_memory_str = "\n".join(
|
if clean_trait and clean_trait not in relation_parts:
|
||||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
relation_parts.append(clean_trait)
|
||||||
)
|
|
||||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
for evidence in profile.evidence:
|
||||||
break
|
content = str(evidence.get("content", "") or "").strip()
|
||||||
else:
|
if content and content not in relation_parts:
|
||||||
for category in category_list:
|
relation_parts.append(content)
|
||||||
random_memory = self.get_random_memory_by_category(category, 1)[0]
|
if len(relation_parts) >= 4:
|
||||||
if random_memory:
|
break
|
||||||
points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}"
|
|
||||||
break
|
|
||||||
|
|
||||||
points_info = ""
|
points_info = ""
|
||||||
if points_text:
|
if relation_parts:
|
||||||
points_info = f"你还记得有关{self.person_name}的内容:{points_text}"
|
points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}"
|
||||||
|
|
||||||
if not (nickname_str or points_info):
|
if not (nickname_str or points_info):
|
||||||
return ""
|
return ""
|
||||||
relation_info = f"{self.person_name}:{nickname_str}{points_info}"
|
return f"{self.person_name}:{nickname_str}{points_info}"
|
||||||
|
|
||||||
return relation_info
|
|
||||||
|
|
||||||
|
|
||||||
class PersonInfoManager:
|
class PersonInfoManager:
|
||||||
@@ -776,7 +822,7 @@ person_info_manager = PersonInfoManager()
|
|||||||
|
|
||||||
|
|
||||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||||
"""将人物信息存入person_info的memory_points
|
"""将人物事实写入统一长期记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
person_name: 人物名称
|
person_name: 人物名称
|
||||||
@@ -784,6 +830,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
|||||||
chat_id: 聊天ID
|
chat_id: 聊天ID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
content = str(memory_content or "").strip()
|
||||||
|
if not content:
|
||||||
|
logger.debug("人物记忆内容为空,跳过写入")
|
||||||
|
return
|
||||||
|
|
||||||
# 从 chat_id 获取 session
|
# 从 chat_id 获取 session
|
||||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||||
if not session:
|
if not session:
|
||||||
@@ -794,16 +845,14 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
|||||||
|
|
||||||
# 尝试从person_name查找person_id
|
# 尝试从person_name查找person_id
|
||||||
# 首先尝试通过person_name查找
|
# 首先尝试通过person_name查找
|
||||||
person_id = get_person_id_by_person_name(person_name)
|
person_id = resolve_person_id_for_memory(
|
||||||
|
person_name=person_name,
|
||||||
|
platform=platform,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
if not person_id:
|
if not person_id:
|
||||||
# 如果通过person_name找不到,尝试从 session 获取 user_id
|
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||||
if platform and session.user_id:
|
return
|
||||||
user_id = session.user_id
|
|
||||||
person_id = get_person_id(platform, user_id)
|
|
||||||
else:
|
|
||||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 创建或获取Person对象
|
# 创建或获取Person对象
|
||||||
person = Person(person_id=person_id)
|
person = Person(person_id=person_id)
|
||||||
@@ -812,39 +861,34 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
|||||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16]
|
||||||
category = "其他" # 默认分类,可以根据需要调整
|
result = await memory_service.ingest_text(
|
||||||
|
external_id=f"person_fact:{person_id}:{memory_hash}",
|
||||||
|
source_type="person_fact",
|
||||||
|
text=content,
|
||||||
|
chat_id=chat_id,
|
||||||
|
person_ids=[person_id],
|
||||||
|
participants=[person.person_name or person_name],
|
||||||
|
timestamp=time.time(),
|
||||||
|
tags=["person_fact"],
|
||||||
|
metadata={
|
||||||
|
"person_id": person_id,
|
||||||
|
"person_name": person.person_name or person_name,
|
||||||
|
"platform": platform,
|
||||||
|
"source": "person_info.store_person_memory_from_answer",
|
||||||
|
},
|
||||||
|
respect_filter=True,
|
||||||
|
user_id=str(session.user_id or "").strip(),
|
||||||
|
group_id=str(session.group_id or "").strip(),
|
||||||
|
)
|
||||||
|
|
||||||
# 记忆点格式:category:content:weight
|
if result.success:
|
||||||
weight = "1.0" # 默认权重
|
if result.detail == "chat_filtered":
|
||||||
memory_point = f"{category}:{memory_content}:{weight}"
|
logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})")
|
||||||
|
else:
|
||||||
# 添加到memory_points
|
logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})")
|
||||||
if not person.memory_points:
|
|
||||||
person.memory_points = []
|
|
||||||
|
|
||||||
# 检查是否已存在相似的记忆点(避免重复)
|
|
||||||
is_duplicate = False
|
|
||||||
for existing_point in person.memory_points:
|
|
||||||
if existing_point and isinstance(existing_point, str):
|
|
||||||
parts = existing_point.split(":", 2)
|
|
||||||
if len(parts) >= 2:
|
|
||||||
existing_content = parts[1].strip()
|
|
||||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
|
||||||
if (
|
|
||||||
existing_content == memory_content
|
|
||||||
or memory_content in existing_content
|
|
||||||
or existing_content in memory_content
|
|
||||||
):
|
|
||||||
is_duplicate = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not is_duplicate:
|
|
||||||
person.memory_points.append(memory_point)
|
|
||||||
person.sync_to_database()
|
|
||||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储人物记忆失败: {e}")
|
logger.error(f"存储人物记忆失败: {e}")
|
||||||
|
|||||||
@@ -672,12 +672,10 @@ class RuntimeDataCapabilityMixin:
|
|||||||
limit_value = 5
|
limit_value = 5
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.chat.knowledge import qa_manager
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
if qa_manager is None:
|
result = await memory_service.search(query, limit=limit_value)
|
||||||
return {"success": True, "content": "LPMM知识库已禁用"}
|
knowledge_info = result.to_text(limit=limit_value)
|
||||||
|
|
||||||
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
|
|
||||||
content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识"
|
content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识"
|
||||||
return {"success": True, "content": content}
|
return {"success": True, "content": content}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
275
src/services/memory_flow_service.py
Normal file
275
src/services/memory_flow_service.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
from src.chat.utils.utils import is_bot_self
|
||||||
|
from src.common.message_repository import find_messages
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
|
||||||
|
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||||
|
|
||||||
|
logger = get_logger("memory_flow_service")
|
||||||
|
|
||||||
|
|
||||||
|
class LongTermMemorySessionManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._summarizers: Dict[str, ChatHistorySummarizer] = {}
|
||||||
|
|
||||||
|
async def on_message(self, message: Any) -> None:
|
||||||
|
if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)):
|
||||||
|
return
|
||||||
|
session_id = str(getattr(message, "session_id", "") or "").strip()
|
||||||
|
if not session_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
created = False
|
||||||
|
async with self._lock:
|
||||||
|
summarizer = self._summarizers.get(session_id)
|
||||||
|
if summarizer is None:
|
||||||
|
summarizer = ChatHistorySummarizer(session_id=session_id)
|
||||||
|
self._summarizers[session_id] = summarizer
|
||||||
|
created = True
|
||||||
|
if created:
|
||||||
|
await summarizer.start()
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
items = list(self._summarizers.items())
|
||||||
|
self._summarizers.clear()
|
||||||
|
for session_id, summarizer in items:
|
||||||
|
try:
|
||||||
|
await summarizer.stop()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonFactWritebackService:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||||
|
self._worker_task: Optional[asyncio.Task] = None
|
||||||
|
self._stopping = False
|
||||||
|
self._extractor = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils,
|
||||||
|
request_type="person_fact_writeback",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._worker_task is not None and not self._worker_task.done():
|
||||||
|
return
|
||||||
|
self._stopping = False
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_person_fact_writeback")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
self._stopping = True
|
||||||
|
worker = self._worker_task
|
||||||
|
self._worker_task = None
|
||||||
|
if worker is None:
|
||||||
|
return
|
||||||
|
worker.cancel()
|
||||||
|
try:
|
||||||
|
await worker
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("关闭人物事实写回 worker 失败: %s", exc)
|
||||||
|
|
||||||
|
async def enqueue(self, message: Any) -> None:
|
||||||
|
if not bool(getattr(global_config.memory, "person_fact_writeback_enabled", True)):
|
||||||
|
return
|
||||||
|
if self._stopping:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning("人物事实写回队列已满,跳过本次回复")
|
||||||
|
|
||||||
|
async def _worker_loop(self) -> None:
|
||||||
|
try:
|
||||||
|
while not self._stopping:
|
||||||
|
message = await self._queue.get()
|
||||||
|
try:
|
||||||
|
await self._handle_message(message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("人物事实写回处理失败: %s", exc, exc_info=True)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _handle_message(self, message: Any) -> None:
|
||||||
|
reply_text = str(getattr(message, "processed_plain_text", "") or "").strip()
|
||||||
|
if not reply_text:
|
||||||
|
return
|
||||||
|
if self._looks_ephemeral(reply_text):
|
||||||
|
return
|
||||||
|
|
||||||
|
target_person = self._resolve_target_person(message)
|
||||||
|
if target_person is None or not target_person.is_known:
|
||||||
|
return
|
||||||
|
|
||||||
|
facts = await self._extract_facts(target_person, reply_text)
|
||||||
|
if not facts:
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = str(
|
||||||
|
getattr(message, "session_id", "")
|
||||||
|
or getattr(getattr(message, "session", None), "session_id", "")
|
||||||
|
or ""
|
||||||
|
).strip()
|
||||||
|
if not session_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip()
|
||||||
|
if not person_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
for fact in facts:
|
||||||
|
await store_person_memory_from_answer(person_name, fact, session_id)
|
||||||
|
|
||||||
|
def _resolve_target_person(self, message: Any) -> Optional[Person]:
|
||||||
|
session = getattr(message, "session", None)
|
||||||
|
session_platform = str(getattr(session, "platform", "") or getattr(message, "platform", "") or "").strip()
|
||||||
|
session_user_id = str(getattr(session, "user_id", "") or "").strip()
|
||||||
|
group_id = str(getattr(session, "group_id", "") or "").strip()
|
||||||
|
|
||||||
|
if session_platform and session_user_id and not group_id:
|
||||||
|
if is_bot_self(session_platform, session_user_id):
|
||||||
|
return None
|
||||||
|
person_id = get_person_id(session_platform, session_user_id)
|
||||||
|
person = Person(person_id=person_id)
|
||||||
|
return person if person.is_known else None
|
||||||
|
|
||||||
|
reply_to = str(getattr(message, "reply_to", "") or "").strip()
|
||||||
|
if not reply_to:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
replies = find_messages(message_id=reply_to, limit=1)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("查询 reply_to 目标失败: %s", exc)
|
||||||
|
return None
|
||||||
|
if not replies:
|
||||||
|
return None
|
||||||
|
reply_message = replies[0]
|
||||||
|
reply_platform = str(getattr(reply_message, "platform", "") or session_platform or "").strip()
|
||||||
|
reply_user_info = getattr(getattr(reply_message, "message_info", None), "user_info", None)
|
||||||
|
reply_user_id = str(getattr(reply_user_info, "user_id", "") or "").strip()
|
||||||
|
if not reply_platform or not reply_user_id or is_bot_self(reply_platform, reply_user_id):
|
||||||
|
return None
|
||||||
|
person_id = get_person_id(reply_platform, reply_user_id)
|
||||||
|
person = Person(person_id=person_id)
|
||||||
|
return person if person.is_known else None
|
||||||
|
|
||||||
|
async def _extract_facts(self, person: Person, reply_text: str) -> List[str]:
|
||||||
|
person_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or person.person_id)
|
||||||
|
prompt = f"""你要从一条机器人刚刚发送的回复中,提取“关于{person_name}的稳定事实”。
|
||||||
|
|
||||||
|
目标人物:{person_name}
|
||||||
|
机器人回复:
|
||||||
|
{reply_text}
|
||||||
|
|
||||||
|
请只提取满足以下条件的事实:
|
||||||
|
1. 明确是关于目标人物本人的信息。
|
||||||
|
2. 具有相对稳定性,可以作为长期记忆保存。
|
||||||
|
3. 用简洁中文陈述句表达。
|
||||||
|
|
||||||
|
不要提取:
|
||||||
|
- 机器人的情绪、计划、临时动作、客套话
|
||||||
|
- 只适用于当前时刻的短期安排
|
||||||
|
- 不确定、猜测、反问
|
||||||
|
- 与目标人物无关的信息
|
||||||
|
|
||||||
|
严格输出 JSON 数组,例如:
|
||||||
|
["他喜欢深夜打游戏", "他养了一只猫"]
|
||||||
|
如果没有可写入的事实,输出 []"""
|
||||||
|
try:
|
||||||
|
response, _ = await self._extractor.generate_response_async(prompt)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("人物事实提取模型调用失败: %s", exc)
|
||||||
|
return []
|
||||||
|
return self._parse_fact_list(response)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_fact_list(raw: str) -> List[str]:
|
||||||
|
text = str(raw or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
repaired = repair_json(text)
|
||||||
|
payload = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||||
|
except Exception:
|
||||||
|
payload = None
|
||||||
|
if not isinstance(payload, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
items: List[str] = []
|
||||||
|
seen = set()
|
||||||
|
for item in payload:
|
||||||
|
fact = str(item or "").strip().strip("- ")
|
||||||
|
if not fact or len(fact) < 4:
|
||||||
|
continue
|
||||||
|
if fact in seen:
|
||||||
|
continue
|
||||||
|
seen.add(fact)
|
||||||
|
items.append(fact)
|
||||||
|
return items[:5]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _looks_ephemeral(text: str) -> bool:
|
||||||
|
content = str(text or "").strip()
|
||||||
|
if not content:
|
||||||
|
return True
|
||||||
|
ephemeral_markers = (
|
||||||
|
"哈哈",
|
||||||
|
"好的",
|
||||||
|
"收到",
|
||||||
|
"嗯嗯",
|
||||||
|
"晚安",
|
||||||
|
"早安",
|
||||||
|
"拜拜",
|
||||||
|
"谢谢",
|
||||||
|
"在吗",
|
||||||
|
"?",
|
||||||
|
)
|
||||||
|
if len(content) <= 8 and any(marker in content for marker in ephemeral_markers):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAutomationService:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.session_manager = LongTermMemorySessionManager()
|
||||||
|
self.fact_writeback = PersonFactWritebackService()
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._started:
|
||||||
|
return
|
||||||
|
await self.fact_writeback.start()
|
||||||
|
self._started = True
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
if not self._started:
|
||||||
|
return
|
||||||
|
await self.session_manager.shutdown()
|
||||||
|
await self.fact_writeback.shutdown()
|
||||||
|
self._started = False
|
||||||
|
|
||||||
|
async def on_incoming_message(self, message: Any) -> None:
|
||||||
|
if not self._started:
|
||||||
|
await self.start()
|
||||||
|
await self.session_manager.on_message(message)
|
||||||
|
|
||||||
|
async def on_message_sent(self, message: Any) -> None:
|
||||||
|
if not self._started:
|
||||||
|
await self.start()
|
||||||
|
await self.fact_writeback.enqueue(message)
|
||||||
|
|
||||||
|
|
||||||
|
memory_automation_service = MemoryAutomationService()
|
||||||
428
src/services/memory_service.py
Normal file
428
src/services/memory_service.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("memory_service")
|
||||||
|
|
||||||
|
PLUGIN_ID = "A_Memorix"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryHit:
|
||||||
|
content: str
|
||||||
|
score: float = 0.0
|
||||||
|
hit_type: str = ""
|
||||||
|
source: str = ""
|
||||||
|
hash_value: str = ""
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
episode_id: str = ""
|
||||||
|
title: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"score": self.score,
|
||||||
|
"type": self.hit_type,
|
||||||
|
"source": self.source,
|
||||||
|
"hash": self.hash_value,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"episode_id": self.episode_id,
|
||||||
|
"title": self.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemorySearchResult:
|
||||||
|
summary: str = ""
|
||||||
|
hits: List[MemoryHit] = field(default_factory=list)
|
||||||
|
filtered: bool = False
|
||||||
|
|
||||||
|
def to_text(self, limit: int = 5) -> str:
|
||||||
|
if not self.hits:
|
||||||
|
return ""
|
||||||
|
lines = []
|
||||||
|
for index, item in enumerate(self.hits[: max(1, int(limit))], start=1):
|
||||||
|
content = item.content.strip().replace("\n", " ")
|
||||||
|
if len(content) > 160:
|
||||||
|
content = content[:160] + "..."
|
||||||
|
lines.append(f"{index}. {content}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"summary": self.summary,
|
||||||
|
"hits": [item.to_dict() for item in self.hits],
|
||||||
|
"filtered": self.filtered,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryWriteResult:
|
||||||
|
success: bool
|
||||||
|
stored_ids: List[str] = field(default_factory=list)
|
||||||
|
skipped_ids: List[str] = field(default_factory=list)
|
||||||
|
detail: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"success": self.success,
|
||||||
|
"stored_ids": self.stored_ids,
|
||||||
|
"skipped_ids": self.skipped_ids,
|
||||||
|
"detail": self.detail,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PersonProfileResult:
|
||||||
|
summary: str = ""
|
||||||
|
traits: List[str] = field(default_factory=list)
|
||||||
|
evidence: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {"summary": self.summary, "traits": self.traits, "evidence": self.evidence}
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryService:
|
||||||
|
async def _invoke(self, component_name: str, args: Optional[Dict[str, Any]] = None, *, timeout_ms: int = 30000) -> Any:
|
||||||
|
runtime = get_plugin_runtime_manager()
|
||||||
|
if not runtime.is_running:
|
||||||
|
raise RuntimeError("plugin_runtime 未启动")
|
||||||
|
return await runtime.invoke_plugin(
|
||||||
|
method="plugin.invoke_tool",
|
||||||
|
plugin_id=PLUGIN_ID,
|
||||||
|
component_name=component_name,
|
||||||
|
args=args or {},
|
||||||
|
timeout_ms=max(1000, int(timeout_ms or 30000)),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _invoke_admin(
|
||||||
|
self,
|
||||||
|
component_name: str,
|
||||||
|
*,
|
||||||
|
action: str,
|
||||||
|
timeout_ms: int = 30000,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
payload = await self._invoke(component_name, {"action": action, **kwargs}, timeout_ms=timeout_ms)
|
||||||
|
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_write_result(payload: Any) -> MemoryWriteResult:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return MemoryWriteResult(success=False, detail="invalid_payload")
|
||||||
|
stored_ids = [str(item) for item in (payload.get("stored_ids") or []) if str(item).strip()]
|
||||||
|
skipped_ids = [str(item) for item in (payload.get("skipped_ids") or []) if str(item).strip()]
|
||||||
|
detail = str(payload.get("detail") or payload.get("reason") or "")
|
||||||
|
if stored_ids or skipped_ids:
|
||||||
|
success = True
|
||||||
|
elif "success" in payload:
|
||||||
|
success = bool(payload.get("success"))
|
||||||
|
else:
|
||||||
|
success = not bool(detail)
|
||||||
|
return MemoryWriteResult(
|
||||||
|
success=success,
|
||||||
|
stored_ids=stored_ids,
|
||||||
|
skipped_ids=skipped_ids,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_search_result(payload: Any) -> MemorySearchResult:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return MemorySearchResult()
|
||||||
|
hits: List[MemoryHit] = []
|
||||||
|
for item in payload.get("hits", []) or []:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
metadata = item.get("metadata", {}) or {}
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
metadata = {}
|
||||||
|
if "source_branches" in item and "source_branches" not in metadata:
|
||||||
|
metadata["source_branches"] = item.get("source_branches") or []
|
||||||
|
if "rank" in item and "rank" not in metadata:
|
||||||
|
metadata["rank"] = item.get("rank")
|
||||||
|
hits.append(
|
||||||
|
MemoryHit(
|
||||||
|
content=str(item.get("content", "") or ""),
|
||||||
|
score=float(item.get("score", 0.0) or 0.0),
|
||||||
|
hit_type=str(item.get("type", "") or ""),
|
||||||
|
source=str(item.get("source", "") or ""),
|
||||||
|
hash_value=str(item.get("hash", "") or ""),
|
||||||
|
metadata=metadata,
|
||||||
|
episode_id=str(item.get("episode_id", "") or ""),
|
||||||
|
title=str(item.get("title", "") or ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return MemorySearchResult(
|
||||||
|
summary=str(payload.get("summary", "") or ""),
|
||||||
|
hits=hits,
|
||||||
|
filtered=bool(payload.get("filtered", False)),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_profile_result(payload: Any) -> PersonProfileResult:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return PersonProfileResult()
|
||||||
|
return PersonProfileResult(
|
||||||
|
summary=str(payload.get("summary", "") or ""),
|
||||||
|
traits=[str(item) for item in (payload.get("traits") or []) if str(item).strip()],
|
||||||
|
evidence=[item for item in (payload.get("evidence") or []) if isinstance(item, dict)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
limit: int = 5,
|
||||||
|
mode: str = "hybrid",
|
||||||
|
chat_id: str = "",
|
||||||
|
person_id: str = "",
|
||||||
|
time_start: str | float | None = None,
|
||||||
|
time_end: str | float | None = None,
|
||||||
|
respect_filter: bool = True,
|
||||||
|
user_id: str = "",
|
||||||
|
group_id: str = "",
|
||||||
|
) -> MemorySearchResult:
|
||||||
|
clean_query = str(query or "").strip()
|
||||||
|
normalized_time_start = None if time_start in {None, ""} else time_start
|
||||||
|
normalized_time_end = None if time_end in {None, ""} else time_end
|
||||||
|
if not clean_query and normalized_time_start is None and normalized_time_end is None:
|
||||||
|
return MemorySearchResult()
|
||||||
|
try:
|
||||||
|
payload = await self._invoke(
|
||||||
|
"search_memory",
|
||||||
|
{
|
||||||
|
"query": clean_query,
|
||||||
|
"limit": max(1, int(limit)),
|
||||||
|
"mode": mode,
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"person_id": person_id,
|
||||||
|
"time_start": normalized_time_start,
|
||||||
|
"time_end": normalized_time_end,
|
||||||
|
"respect_filter": bool(respect_filter),
|
||||||
|
"user_id": str(user_id or "").strip(),
|
||||||
|
"group_id": str(group_id or "").strip(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return self._coerce_search_result(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("长期记忆搜索失败: %s", exc)
|
||||||
|
return MemorySearchResult()
|
||||||
|
|
||||||
|
async def ingest_summary(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
external_id: str,
|
||||||
|
chat_id: str,
|
||||||
|
text: str,
|
||||||
|
participants: Optional[List[str]] = None,
|
||||||
|
time_start: float | None = None,
|
||||||
|
time_end: float | None = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
respect_filter: bool = True,
|
||||||
|
user_id: str = "",
|
||||||
|
group_id: str = "",
|
||||||
|
) -> MemoryWriteResult:
|
||||||
|
try:
|
||||||
|
payload = await self._invoke(
|
||||||
|
"ingest_summary",
|
||||||
|
{
|
||||||
|
"external_id": external_id,
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"text": text,
|
||||||
|
"participants": participants or [],
|
||||||
|
"time_start": time_start,
|
||||||
|
"time_end": time_end,
|
||||||
|
"tags": tags or [],
|
||||||
|
"metadata": metadata or {},
|
||||||
|
"respect_filter": bool(respect_filter),
|
||||||
|
"user_id": str(user_id or "").strip(),
|
||||||
|
"group_id": str(group_id or "").strip(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return self._coerce_write_result(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("长期记忆写入摘要失败: %s", exc)
|
||||||
|
return MemoryWriteResult(success=False, detail=str(exc))
|
||||||
|
|
||||||
|
async def ingest_text(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
external_id: str,
|
||||||
|
source_type: str,
|
||||||
|
text: str,
|
||||||
|
chat_id: str = "",
|
||||||
|
person_ids: Optional[List[str]] = None,
|
||||||
|
participants: Optional[List[str]] = None,
|
||||||
|
timestamp: float | None = None,
|
||||||
|
time_start: float | None = None,
|
||||||
|
time_end: float | None = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
entities: Optional[List[str]] = None,
|
||||||
|
relations: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
respect_filter: bool = True,
|
||||||
|
user_id: str = "",
|
||||||
|
group_id: str = "",
|
||||||
|
) -> MemoryWriteResult:
|
||||||
|
try:
|
||||||
|
payload = await self._invoke(
|
||||||
|
"ingest_text",
|
||||||
|
{
|
||||||
|
"external_id": external_id,
|
||||||
|
"source_type": source_type,
|
||||||
|
"text": text,
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"person_ids": person_ids or [],
|
||||||
|
"participants": participants or [],
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"time_start": time_start,
|
||||||
|
"time_end": time_end,
|
||||||
|
"tags": tags or [],
|
||||||
|
"metadata": metadata or {},
|
||||||
|
"entities": entities or [],
|
||||||
|
"relations": relations or [],
|
||||||
|
"respect_filter": bool(respect_filter),
|
||||||
|
"user_id": str(user_id or "").strip(),
|
||||||
|
"group_id": str(group_id or "").strip(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return self._coerce_write_result(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("长期记忆写入文本失败: %s", exc)
|
||||||
|
return MemoryWriteResult(success=False, detail=str(exc))
|
||||||
|
|
||||||
|
async def get_person_profile(self, person_id: str, *, chat_id: str = "", limit: int = 10) -> PersonProfileResult:
|
||||||
|
clean_person_id = str(person_id or "").strip()
|
||||||
|
if not clean_person_id:
|
||||||
|
return PersonProfileResult()
|
||||||
|
try:
|
||||||
|
payload = await self._invoke(
|
||||||
|
"get_person_profile",
|
||||||
|
{"person_id": clean_person_id, "chat_id": chat_id, "limit": max(1, int(limit))},
|
||||||
|
)
|
||||||
|
return self._coerce_profile_result(payload)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("获取人物画像失败: %s", exc)
|
||||||
|
return PersonProfileResult()
|
||||||
|
|
||||||
|
async def maintain_memory(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
action: str,
|
||||||
|
target: str = "",
|
||||||
|
hours: float | None = None,
|
||||||
|
reason: str = "",
|
||||||
|
limit: int = 50,
|
||||||
|
) -> MemoryWriteResult:
|
||||||
|
try:
|
||||||
|
payload = await self._invoke(
|
||||||
|
"maintain_memory",
|
||||||
|
{"action": action, "target": target, "hours": hours, "reason": reason, "limit": limit},
|
||||||
|
)
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return MemoryWriteResult(success=False, detail="invalid_payload")
|
||||||
|
return MemoryWriteResult(success=bool(payload.get("success")), detail=str(payload.get("detail", "") or ""))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("记忆维护失败: %s", exc)
|
||||||
|
return MemoryWriteResult(success=False, detail=str(exc))
|
||||||
|
|
||||||
|
async def memory_stats(self) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
payload = await self._invoke("memory_stats", {})
|
||||||
|
return payload if isinstance(payload, dict) else {}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("获取记忆统计失败: %s", exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_graph_admin", action=action, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("图谱管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_source_admin", action=action, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("来源管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_episode_admin", action=action, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Episode 管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_profile_admin", action=action, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("画像管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("运行时管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def import_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_import_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("导入管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def tuning_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_tuning_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("调优管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def v5_admin(self, *, action: str, timeout_ms: int = 30000, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_v5_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("V5 记忆管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def delete_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
return await self._invoke_admin("memory_delete_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("删除管理调用失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def get_recycle_bin(self, *, limit: int = 50) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
payload = await self._invoke("maintain_memory", {"action": "recycle_bin", "limit": max(1, int(limit or 50))})
|
||||||
|
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("获取回收站失败: %s", exc)
|
||||||
|
return {"success": False, "error": str(exc)}
|
||||||
|
|
||||||
|
async def restore_memory(self, *, target: str) -> MemoryWriteResult:
|
||||||
|
return await self.maintain_memory(action="restore", target=target)
|
||||||
|
|
||||||
|
async def reinforce_memory(self, *, target: str) -> MemoryWriteResult:
|
||||||
|
return await self.maintain_memory(action="reinforce", target=target)
|
||||||
|
|
||||||
|
async def freeze_memory(self, *, target: str) -> MemoryWriteResult:
|
||||||
|
return await self.maintain_memory(action="freeze", target=target)
|
||||||
|
|
||||||
|
async def protect_memory(self, *, target: str, hours: float | None = None) -> MemoryWriteResult:
|
||||||
|
return await self.maintain_memory(action="protect", target=target, hours=hours)
|
||||||
|
|
||||||
|
|
||||||
|
memory_service = MemoryService()
|
||||||
@@ -17,14 +17,14 @@ def get_all_routers() -> List[APIRouter]:
|
|||||||
from src.webui.api.planner import router as planner_router
|
from src.webui.api.planner import router as planner_router
|
||||||
from src.webui.api.replier import router as replier_router
|
from src.webui.api.replier import router as replier_router
|
||||||
from src.webui.routers.chat import router as chat_router
|
from src.webui.routers.chat import router as chat_router
|
||||||
from src.webui.routers.knowledge import router as knowledge_router
|
from src.webui.routers.memory import compat_router as memory_compat_router
|
||||||
from src.webui.routers.websocket.logs import router as logs_router
|
from src.webui.routers.websocket.logs import router as logs_router
|
||||||
from src.webui.routes import router as main_router
|
from src.webui.routes import router as main_router
|
||||||
|
|
||||||
return [
|
return [
|
||||||
main_router,
|
main_router,
|
||||||
|
memory_compat_router,
|
||||||
logs_router,
|
logs_router,
|
||||||
knowledge_router,
|
|
||||||
chat_router,
|
chat_router,
|
||||||
planner_router,
|
planner_router,
|
||||||
replier_router,
|
replier_router,
|
||||||
|
|||||||
1395
src/webui/routers/memory.py
Normal file
1395
src/webui/routers/memory.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,7 @@ from src.webui.routers.config import router as config_router
|
|||||||
from src.webui.routers.emoji import router as emoji_router
|
from src.webui.routers.emoji import router as emoji_router
|
||||||
from src.webui.routers.expression import router as expression_router
|
from src.webui.routers.expression import router as expression_router
|
||||||
from src.webui.routers.jargon import router as jargon_router
|
from src.webui.routers.jargon import router as jargon_router
|
||||||
|
from src.webui.routers.memory import router as memory_router
|
||||||
from src.webui.routers.model import router as model_router
|
from src.webui.routers.model import router as model_router
|
||||||
from src.webui.routers.person import router as person_router
|
from src.webui.routers.person import router as person_router
|
||||||
from src.webui.routers.plugin import get_progress_router
|
from src.webui.routers.plugin import get_progress_router
|
||||||
@@ -49,6 +50,8 @@ router.include_router(get_progress_router())
|
|||||||
router.include_router(system_router)
|
router.include_router(system_router)
|
||||||
# 注册模型列表获取路由
|
# 注册模型列表获取路由
|
||||||
router.include_router(model_router)
|
router.include_router(model_router)
|
||||||
|
# 注册长期记忆管理路由
|
||||||
|
router.include_router(memory_router)
|
||||||
# 注册 WebSocket 认证路由
|
# 注册 WebSocket 认证路由
|
||||||
router.include_router(ws_auth_router)
|
router.include_router(ws_auth_router)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user