fix:解决 upstream 合并冲突
在临时整合分支中合并 upstream/r-dev 保留人物事实写回与反馈纠错配置 移除已下线的聊天总结配置并同步测试
This commit is contained in:
@@ -1,148 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
|
||||
|
||||
def _build_summarizer() -> summarizer_module.ChatHistorySummarizer:
|
||||
summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer)
|
||||
summarizer.session_id = "session-1"
|
||||
summarizer.log_prefix = "[session-1]"
|
||||
return summarizer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch):
|
||||
calls = []
|
||||
summarizer = _build_summarizer()
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
|
||||
)
|
||||
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8)))
|
||||
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=1,
|
||||
theme="旅行计划",
|
||||
summary="我们讨论了春游安排",
|
||||
participants=["Alice", "Bob"],
|
||||
start_time=1.0,
|
||||
end_time=2.0,
|
||||
original_text="long text",
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
payload = calls[0]
|
||||
assert payload["external_id"] == "chat_history:1"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["participants"] == ["Alice", "Bob"]
|
||||
assert payload["respect_filter"] is True
|
||||
assert payload["user_id"] == "user-1"
|
||||
assert payload["group_id"] == ""
|
||||
assert "主题:旅行计划" in payload["text"]
|
||||
assert "概括:我们讨论了春游安排" in payload["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch):
|
||||
summarizer = _build_summarizer()
|
||||
fallback_calls = []
|
||||
|
||||
async def fake_fallback(**kwargs):
|
||||
fallback_calls.append(kwargs)
|
||||
|
||||
summarizer._fallback_import_to_long_term_memory = fake_fallback
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
|
||||
)
|
||||
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=2,
|
||||
theme="",
|
||||
summary="",
|
||||
participants=[],
|
||||
start_time=10.0,
|
||||
end_time=20.0,
|
||||
original_text="raw chat",
|
||||
)
|
||||
|
||||
assert len(fallback_calls) == 1
|
||||
assert fallback_calls[0]["record_id"] == 2
|
||||
assert fallback_calls[0]["original_text"] == "raw chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch):
|
||||
summarizer = _build_summarizer()
|
||||
fallback_calls = []
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
return SimpleNamespace(success=False, detail="boom", stored_ids=[])
|
||||
|
||||
async def fake_fallback(**kwargs):
|
||||
fallback_calls.append(kwargs)
|
||||
|
||||
summarizer._fallback_import_to_long_term_memory = fake_fallback
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")),
|
||||
)
|
||||
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=3,
|
||||
theme="电影",
|
||||
summary="聊了电影推荐",
|
||||
participants=["Alice"],
|
||||
start_time=3.0,
|
||||
end_time=4.0,
|
||||
original_text="raw",
|
||||
)
|
||||
|
||||
assert len(fallback_calls) == 1
|
||||
assert fallback_calls[0]["theme"] == "电影"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch):
|
||||
calls = []
|
||||
summarizer = _build_summarizer()
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[])
|
||||
|
||||
monkeypatch.setattr(
|
||||
summarizer_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")),
|
||||
)
|
||||
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12)))
|
||||
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
|
||||
|
||||
await summarizer._fallback_import_to_long_term_memory(
|
||||
record_id=4,
|
||||
theme="工作",
|
||||
participants=["Alice"],
|
||||
start_time=5.0,
|
||||
end_time=6.0,
|
||||
original_text="a" * 128,
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
metadata = calls[0]["metadata"]
|
||||
assert metadata["generate_from_chat"] is True
|
||||
assert metadata["context_length"] == 12
|
||||
assert calls[0]["respect_filter"] is True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,687 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import MemorySearchResult, memory_service
|
||||
|
||||
|
||||
DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json"
|
||||
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json"
|
||||
|
||||
|
||||
def _load_benchmark_fixture() -> Dict[str, Any]:
|
||||
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class _FakeEmbeddingAdapter:
|
||||
def __init__(self, dimension: int = 32) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.dimension
|
||||
|
||||
async def encode(self, texts, dimensions=None):
|
||||
dim = int(dimensions or self.dimension)
|
||||
if isinstance(texts, str):
|
||||
sequence = [texts]
|
||||
single = True
|
||||
else:
|
||||
sequence = list(texts)
|
||||
single = False
|
||||
|
||||
rows = []
|
||||
for text in sequence:
|
||||
vec = np.zeros(dim, dtype=np.float32)
|
||||
for ch in str(text or ""):
|
||||
code = ord(ch)
|
||||
vec[code % dim] += 1.0
|
||||
vec[(code * 7) % dim] += 0.5
|
||||
if not vec.any():
|
||||
vec[0] = 1.0
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
rows.append(vec)
|
||||
payload = np.vstack(rows)
|
||||
return payload[0] if single else payload
|
||||
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = person_id in reverse_registry
|
||||
self.person_name = reverse_registry.get(person_id, "")
|
||||
self._registry = registry
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]:
|
||||
for _ in range(max_rounds):
|
||||
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||
task = detail.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||
return detail
|
||||
await asyncio.sleep(max(0.01, float(sleep_seconds)))
|
||||
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||
|
||||
|
||||
def _join_hit_content(search_result: MemorySearchResult) -> str:
|
||||
return "\n".join(hit.content for hit in search_result.hits)
|
||||
|
||||
|
||||
def _keyword_hits(text: str, keywords: List[str]) -> int:
|
||||
haystack = str(text or "")
|
||||
return sum(1 for keyword in keywords if keyword in haystack)
|
||||
|
||||
|
||||
def _keyword_recall(text: str, keywords: List[str]) -> float:
|
||||
if not keywords:
|
||||
return 1.0
|
||||
return _keyword_hits(text, keywords) / float(len(keywords))
|
||||
|
||||
|
||||
def _hit_blob(hit) -> str:
|
||||
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
|
||||
return "\n".join(
|
||||
[
|
||||
str(hit.content or ""),
|
||||
str(hit.title or ""),
|
||||
str(hit.source or ""),
|
||||
json.dumps(meta, ensure_ascii=False),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int:
|
||||
for index, hit in enumerate(search_result.hits[:5], start=1):
|
||||
if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))):
|
||||
return index
|
||||
return 0
|
||||
|
||||
|
||||
def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str:
|
||||
return "\n".join(
|
||||
(
|
||||
f"{item.get('title', '')}\n"
|
||||
f"{item.get('summary', '')}\n"
|
||||
f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n"
|
||||
f"{json.dumps(item.get('participants', []), ensure_ascii=False)}"
|
||||
)
|
||||
for item in items
|
||||
)
|
||||
|
||||
|
||||
def _episode_blob_from_hits(search_result: MemorySearchResult) -> str:
|
||||
chunks = []
|
||||
for hit in search_result.hits:
|
||||
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
|
||||
chunks.append(
|
||||
"\n".join(
|
||||
[
|
||||
str(hit.title or ""),
|
||||
str(hit.content or ""),
|
||||
json.dumps(meta.get("keywords", []) or [], ensure_ascii=False),
|
||||
json.dumps(meta.get("participants", []) or [], ensure_ascii=False),
|
||||
]
|
||||
)
|
||||
)
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
episode_source = f"chat_summary:{session_id}"
|
||||
payload = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=episode_source,
|
||||
limit=20,
|
||||
)
|
||||
items = payload.get("items") or []
|
||||
blob = _episode_blob_from_items(items)
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
|
||||
for case in episode_cases:
|
||||
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"episode_count": len(items),
|
||||
"top_episode": items[0] if items else None,
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(episode_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"episode_count": len(items),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
episode_source = f"chat_summary:{session_id}"
|
||||
|
||||
for case in episode_cases:
|
||||
payload = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=episode_source,
|
||||
query=case["query"],
|
||||
limit=5,
|
||||
)
|
||||
items = payload.get("items") or []
|
||||
blob = "\n".join(
|
||||
f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}"
|
||||
for item in items
|
||||
)
|
||||
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"episode_count": len(items),
|
||||
"top_episode": items[0] if items else None,
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(episode_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
|
||||
for case in episode_cases:
|
||||
result = await memory_service.search(
|
||||
case["query"],
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
limit=5,
|
||||
)
|
||||
blob = _episode_blob_from_hits(result)
|
||||
recall = _keyword_recall(blob, case["expected_keywords"])
|
||||
success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"episode_count": len(result.hits),
|
||||
"top_episode": result.hits[0].to_dict() if result.hits else None,
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(episode_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]:
|
||||
search_case = dataset["search_cases"][0]
|
||||
episode_case = dataset["episode_cases"][0]
|
||||
aggregate_case = dataset["knowledge_fetcher_cases"][0]
|
||||
first_record = (dataset.get("chat_history_records") or [{}])[0]
|
||||
reference_ts = first_record.get("end_time") or first_record.get("start_time") or 0
|
||||
if reference_ts:
|
||||
time_expression = datetime.fromtimestamp(float(reference_ts)).strftime("%Y/%m/%d")
|
||||
else:
|
||||
time_expression = "最近7天"
|
||||
tool_cases = [
|
||||
{
|
||||
"name": "search",
|
||||
"kwargs": {
|
||||
"query": "蓝漆铁盒 北塔木梯",
|
||||
"mode": "search",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
},
|
||||
"expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
{
|
||||
"name": "time",
|
||||
"kwargs": {
|
||||
"query": "蓝漆铁盒 北塔",
|
||||
"mode": "time",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
"time_expression": time_expression,
|
||||
},
|
||||
"expected_keywords": ["蓝漆铁盒", "北塔木梯"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
{
|
||||
"name": "episode",
|
||||
"kwargs": {
|
||||
"query": episode_case["query"],
|
||||
"mode": "episode",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
},
|
||||
"expected_keywords": episode_case["expected_keywords"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
{
|
||||
"name": "aggregate",
|
||||
"kwargs": {
|
||||
"query": aggregate_case["query"],
|
||||
"mode": "aggregate",
|
||||
"chat_id": session_id,
|
||||
"limit": 5,
|
||||
},
|
||||
"expected_keywords": aggregate_case["expected_keywords"],
|
||||
"minimum_keyword_recall": 0.67,
|
||||
},
|
||||
]
|
||||
reports: List[Dict[str, Any]] = []
|
||||
success_rate = 0.0
|
||||
keyword_recall = 0.0
|
||||
|
||||
for case in tool_cases:
|
||||
text = await query_long_term_memory(**case["kwargs"])
|
||||
recall = _keyword_recall(text, case["expected_keywords"])
|
||||
success = (
|
||||
"失败" not in text
|
||||
and "无法解析" not in text
|
||||
and "未找到" not in text
|
||||
and recall >= float(case["minimum_keyword_recall"])
|
||||
)
|
||||
success_rate += 1.0 if success else 0.0
|
||||
keyword_recall += recall
|
||||
reports.append(
|
||||
{
|
||||
"name": case["name"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"preview": text[:320],
|
||||
}
|
||||
)
|
||||
|
||||
total = max(1, len(tool_cases))
|
||||
return {
|
||||
"success_rate": round(success_rate / total, 4),
|
||||
"keyword_recall": round(keyword_recall / total, 4),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def benchmark_env(monkeypatch, tmp_path):
|
||||
dataset = _load_benchmark_fixture()
|
||||
session_cfg = dataset["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
|
||||
reverse_registry = {value: key for key, value in registry.items()}
|
||||
|
||||
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
|
||||
|
||||
async def fake_self_check(**kwargs):
|
||||
return {"ok": True, "message": "ok", "encoded_dimension": 32}
|
||||
|
||||
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"Person",
|
||||
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
|
||||
)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_benchmark_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"dataset": dataset,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
"person_registry": registry,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_novel_memory_benchmark(benchmark_env):
|
||||
dataset = benchmark_env["dataset"]
|
||||
session_id = benchmark_env["session"].session_id
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="long_novel_memory_benchmark.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
assert created["success"] is True
|
||||
|
||||
import_detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||
assert import_detail["task"]["status"] == "completed"
|
||||
|
||||
for record in dataset["chat_history_records"]:
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
for payload in dataset["person_writebacks"]:
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
payload["person_name"],
|
||||
payload["memory_content"],
|
||||
session_id,
|
||||
)
|
||||
|
||||
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
|
||||
|
||||
search_case_reports: List[Dict[str, Any]] = []
|
||||
search_accuracy_at_1 = 0.0
|
||||
search_recall_at_5 = 0.0
|
||||
search_precision_at_5 = 0.0
|
||||
search_mrr = 0.0
|
||||
search_keyword_recall = 0.0
|
||||
|
||||
for case in dataset["search_cases"]:
|
||||
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
|
||||
joined = _join_hit_content(result)
|
||||
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
|
||||
relevant_hits = sum(
|
||||
1
|
||||
for hit in result.hits[:5]
|
||||
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
|
||||
)
|
||||
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
|
||||
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
|
||||
search_recall_at_5 += 1.0 if rank > 0 else 0.0
|
||||
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
|
||||
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
|
||||
search_keyword_recall += keyword_recall
|
||||
search_case_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"rank_of_first_relevant": rank,
|
||||
"relevant_hits_top5": relevant_hits,
|
||||
"keyword_recall_top5": keyword_recall,
|
||||
"top_hit": result.hits[0].to_dict() if result.hits else None,
|
||||
}
|
||||
)
|
||||
|
||||
search_total = max(1, len(dataset["search_cases"]))
|
||||
|
||||
writeback_reports: List[Dict[str, Any]] = []
|
||||
writeback_success_rate = 0.0
|
||||
writeback_keyword_recall = 0.0
|
||||
for payload in dataset["person_writebacks"]:
|
||||
query = " ".join(payload["expected_keywords"])
|
||||
result = await memory_service.search(
|
||||
query,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
person_id=payload["person_id"],
|
||||
respect_filter=False,
|
||||
limit=5,
|
||||
)
|
||||
joined = _join_hit_content(result)
|
||||
recall = _keyword_recall(joined, payload["expected_keywords"])
|
||||
success = bool(result.hits) and recall >= 0.67
|
||||
writeback_success_rate += 1.0 if success else 0.0
|
||||
writeback_keyword_recall += recall
|
||||
writeback_reports.append(
|
||||
{
|
||||
"person_id": payload["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"hit_count": len(result.hits),
|
||||
}
|
||||
)
|
||||
writeback_total = max(1, len(dataset["person_writebacks"]))
|
||||
|
||||
knowledge_reports: List[Dict[str, Any]] = []
|
||||
knowledge_success_rate = 0.0
|
||||
knowledge_keyword_recall = 0.0
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=dataset["session"]["display_name"],
|
||||
stream_id=session_id,
|
||||
)
|
||||
for case in dataset["knowledge_fetcher_cases"]:
|
||||
knowledge_text, _ = await fetcher.fetch(case["query"], [])
|
||||
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
knowledge_success_rate += 1.0 if success else 0.0
|
||||
knowledge_keyword_recall += recall
|
||||
knowledge_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"preview": knowledge_text[:300],
|
||||
}
|
||||
)
|
||||
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
|
||||
|
||||
profile_reports: List[Dict[str, Any]] = []
|
||||
profile_success_rate = 0.0
|
||||
profile_keyword_recall = 0.0
|
||||
profile_evidence_rate = 0.0
|
||||
for case in dataset["profile_cases"]:
|
||||
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
|
||||
recall = _keyword_recall(profile.summary, case["expected_keywords"])
|
||||
has_evidence = bool(profile.evidence)
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
|
||||
profile_success_rate += 1.0 if success else 0.0
|
||||
profile_keyword_recall += recall
|
||||
profile_evidence_rate += 1.0 if has_evidence else 0.0
|
||||
profile_reports.append(
|
||||
{
|
||||
"person_id": case["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"evidence_count": len(profile.evidence),
|
||||
"summary_preview": profile.summary[:240],
|
||||
}
|
||||
)
|
||||
profile_total = max(1, len(dataset["profile_cases"]))
|
||||
|
||||
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_rebuild = await memory_service.episode_admin(
|
||||
action="rebuild",
|
||||
source=f"chat_summary:{session_id}",
|
||||
)
|
||||
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
|
||||
|
||||
report = {
|
||||
"dataset": dataset["meta"],
|
||||
"import": {
|
||||
"task_id": created["task"]["task_id"],
|
||||
"status": import_detail["task"]["status"],
|
||||
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
|
||||
},
|
||||
"metrics": {
|
||||
"search": {
|
||||
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
|
||||
"recall_at_5": round(search_recall_at_5 / search_total, 4),
|
||||
"precision_at_5": round(search_precision_at_5 / search_total, 4),
|
||||
"mrr": round(search_mrr / search_total, 4),
|
||||
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
|
||||
},
|
||||
"writeback": {
|
||||
"success_rate": round(writeback_success_rate / writeback_total, 4),
|
||||
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
|
||||
},
|
||||
"knowledge_fetcher": {
|
||||
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
|
||||
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
|
||||
},
|
||||
"profile": {
|
||||
"success_rate": round(profile_success_rate / profile_total, 4),
|
||||
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
|
||||
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
|
||||
},
|
||||
"tool_modes": {
|
||||
"success_rate": tool_modes["success_rate"],
|
||||
"keyword_recall": tool_modes["keyword_recall"],
|
||||
},
|
||||
"episode_generation_auto": {
|
||||
"success_rate": episode_generation_auto["success_rate"],
|
||||
"keyword_recall": episode_generation_auto["keyword_recall"],
|
||||
"episode_count": episode_generation_auto["episode_count"],
|
||||
},
|
||||
"episode_generation_after_rebuild": {
|
||||
"success_rate": episode_generation_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
|
||||
"episode_count": episode_generation_after_rebuild["episode_count"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_admin_query_auto": {
|
||||
"success_rate": episode_admin_query_auto["success_rate"],
|
||||
"keyword_recall": episode_admin_query_auto["keyword_recall"],
|
||||
},
|
||||
"episode_admin_query_after_rebuild": {
|
||||
"success_rate": episode_admin_query_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_search_mode_auto": {
|
||||
"success_rate": episode_search_mode_auto["success_rate"],
|
||||
"keyword_recall": episode_search_mode_auto["keyword_recall"],
|
||||
},
|
||||
"episode_search_mode_after_rebuild": {
|
||||
"success_rate": episode_search_mode_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
},
|
||||
"cases": {
|
||||
"search": search_case_reports,
|
||||
"writeback": writeback_reports,
|
||||
"knowledge_fetcher": knowledge_reports,
|
||||
"profile": profile_reports,
|
||||
"tool_modes": tool_modes["reports"],
|
||||
"episode_generation_auto": episode_generation_auto["reports"],
|
||||
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
|
||||
"episode_admin_query_auto": episode_admin_query_auto["reports"],
|
||||
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
|
||||
"episode_search_mode_auto": episode_search_mode_auto["reports"],
|
||||
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
|
||||
},
|
||||
}
|
||||
|
||||
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
|
||||
|
||||
assert report["import"]["status"] == "completed"
|
||||
assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35
|
||||
assert report["metrics"]["search"]["recall_at_5"] >= 0.6
|
||||
assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8
|
||||
assert report["metrics"]["writeback"]["success_rate"] >= 0.66
|
||||
assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66
|
||||
assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75
|
||||
assert report["metrics"]["profile"]["success_rate"] >= 0.66
|
||||
assert report["metrics"]["profile"]["evidence_rate"] >= 1.0
|
||||
assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75
|
||||
assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True
|
||||
assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"]
|
||||
@@ -1,342 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
from pytests.A_memorix_test.test_long_novel_memory_benchmark import (
|
||||
_evaluate_episode_admin_query,
|
||||
_evaluate_episode_generation,
|
||||
_evaluate_episode_search_mode,
|
||||
_evaluate_tool_modes,
|
||||
_KernelBackedRuntimeManager,
|
||||
_KnownPerson,
|
||||
_first_relevant_rank,
|
||||
_hit_blob,
|
||||
_join_hit_content,
|
||||
_keyword_hits,
|
||||
_keyword_recall,
|
||||
_load_benchmark_fixture,
|
||||
_wait_for_import_task,
|
||||
)
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
|
||||
reason="需要显式开启真实 external embedding benchmark",
|
||||
)
|
||||
|
||||
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def benchmark_live_env(monkeypatch, tmp_path):
|
||||
dataset = _load_benchmark_fixture()
|
||||
session_cfg = dataset["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
|
||||
reverse_registry = {value: key for key, value in registry.items()}
|
||||
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"Person",
|
||||
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
|
||||
)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"dataset": dataset,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_novel_memory_benchmark_live(benchmark_live_env):
|
||||
dataset = benchmark_live_env["dataset"]
|
||||
session_id = benchmark_live_env["session"].session_id
|
||||
|
||||
self_check = await memory_service.runtime_admin(action="refresh_self_check")
|
||||
assert self_check["success"] is True
|
||||
assert self_check["report"]["ok"] is True
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="long_novel_memory_benchmark.live.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
assert created["success"] is True
|
||||
|
||||
import_detail = await _wait_for_import_task(
|
||||
created["task"]["task_id"],
|
||||
max_rounds=2400,
|
||||
sleep_seconds=0.25,
|
||||
)
|
||||
assert import_detail["task"]["status"] == "completed"
|
||||
|
||||
for record in dataset["chat_history_records"]:
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
for payload in dataset["person_writebacks"]:
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
payload["person_name"],
|
||||
payload["memory_content"],
|
||||
session_id,
|
||||
)
|
||||
|
||||
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
|
||||
|
||||
search_case_reports: List[Dict[str, Any]] = []
|
||||
search_accuracy_at_1 = 0.0
|
||||
search_recall_at_5 = 0.0
|
||||
search_precision_at_5 = 0.0
|
||||
search_mrr = 0.0
|
||||
search_keyword_recall = 0.0
|
||||
for case in dataset["search_cases"]:
|
||||
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
|
||||
joined = _join_hit_content(result)
|
||||
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
|
||||
relevant_hits = sum(
|
||||
1
|
||||
for hit in result.hits[:5]
|
||||
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
|
||||
)
|
||||
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
|
||||
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
|
||||
search_recall_at_5 += 1.0 if rank > 0 else 0.0
|
||||
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
|
||||
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
|
||||
search_keyword_recall += keyword_recall
|
||||
search_case_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"rank_of_first_relevant": rank,
|
||||
"relevant_hits_top5": relevant_hits,
|
||||
"keyword_recall_top5": keyword_recall,
|
||||
"top_hit": result.hits[0].to_dict() if result.hits else None,
|
||||
}
|
||||
)
|
||||
search_total = max(1, len(dataset["search_cases"]))
|
||||
|
||||
writeback_reports: List[Dict[str, Any]] = []
|
||||
writeback_success_rate = 0.0
|
||||
writeback_keyword_recall = 0.0
|
||||
for payload in dataset["person_writebacks"]:
|
||||
query = " ".join(payload["expected_keywords"])
|
||||
result = await memory_service.search(
|
||||
query,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
person_id=payload["person_id"],
|
||||
respect_filter=False,
|
||||
limit=5,
|
||||
)
|
||||
joined = _join_hit_content(result)
|
||||
recall = _keyword_recall(joined, payload["expected_keywords"])
|
||||
success = bool(result.hits) and recall >= 0.67
|
||||
writeback_success_rate += 1.0 if success else 0.0
|
||||
writeback_keyword_recall += recall
|
||||
writeback_reports.append(
|
||||
{
|
||||
"person_id": payload["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"hit_count": len(result.hits),
|
||||
}
|
||||
)
|
||||
writeback_total = max(1, len(dataset["person_writebacks"]))
|
||||
|
||||
knowledge_reports: List[Dict[str, Any]] = []
|
||||
knowledge_success_rate = 0.0
|
||||
knowledge_keyword_recall = 0.0
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=dataset["session"]["display_name"],
|
||||
stream_id=session_id,
|
||||
)
|
||||
for case in dataset["knowledge_fetcher_cases"]:
|
||||
knowledge_text, _ = await fetcher.fetch(case["query"], [])
|
||||
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
|
||||
knowledge_success_rate += 1.0 if success else 0.0
|
||||
knowledge_keyword_recall += recall
|
||||
knowledge_reports.append(
|
||||
{
|
||||
"query": case["query"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"preview": knowledge_text[:300],
|
||||
}
|
||||
)
|
||||
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
|
||||
|
||||
profile_reports: List[Dict[str, Any]] = []
|
||||
profile_success_rate = 0.0
|
||||
profile_keyword_recall = 0.0
|
||||
profile_evidence_rate = 0.0
|
||||
for case in dataset["profile_cases"]:
|
||||
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
|
||||
recall = _keyword_recall(profile.summary, case["expected_keywords"])
|
||||
has_evidence = bool(profile.evidence)
|
||||
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
|
||||
profile_success_rate += 1.0 if success else 0.0
|
||||
profile_keyword_recall += recall
|
||||
profile_evidence_rate += 1.0 if has_evidence else 0.0
|
||||
profile_reports.append(
|
||||
{
|
||||
"person_id": case["person_id"],
|
||||
"success": success,
|
||||
"keyword_recall": recall,
|
||||
"evidence_count": len(profile.evidence),
|
||||
"summary_preview": profile.summary[:240],
|
||||
}
|
||||
)
|
||||
profile_total = max(1, len(dataset["profile_cases"]))
|
||||
|
||||
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_rebuild = await memory_service.episode_admin(
|
||||
action="rebuild",
|
||||
source=f"chat_summary:{session_id}",
|
||||
)
|
||||
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
|
||||
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
|
||||
|
||||
report = {
|
||||
"dataset": dataset["meta"],
|
||||
"runtime_self_check": self_check["report"],
|
||||
"import": {
|
||||
"task_id": created["task"]["task_id"],
|
||||
"status": import_detail["task"]["status"],
|
||||
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
|
||||
},
|
||||
"metrics": {
|
||||
"search": {
|
||||
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
|
||||
"recall_at_5": round(search_recall_at_5 / search_total, 4),
|
||||
"precision_at_5": round(search_precision_at_5 / search_total, 4),
|
||||
"mrr": round(search_mrr / search_total, 4),
|
||||
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
|
||||
},
|
||||
"writeback": {
|
||||
"success_rate": round(writeback_success_rate / writeback_total, 4),
|
||||
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
|
||||
},
|
||||
"knowledge_fetcher": {
|
||||
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
|
||||
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
|
||||
},
|
||||
"profile": {
|
||||
"success_rate": round(profile_success_rate / profile_total, 4),
|
||||
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
|
||||
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
|
||||
},
|
||||
"tool_modes": {
|
||||
"success_rate": tool_modes["success_rate"],
|
||||
"keyword_recall": tool_modes["keyword_recall"],
|
||||
},
|
||||
"episode_generation_auto": {
|
||||
"success_rate": episode_generation_auto["success_rate"],
|
||||
"keyword_recall": episode_generation_auto["keyword_recall"],
|
||||
"episode_count": episode_generation_auto["episode_count"],
|
||||
},
|
||||
"episode_generation_after_rebuild": {
|
||||
"success_rate": episode_generation_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
|
||||
"episode_count": episode_generation_after_rebuild["episode_count"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_admin_query_auto": {
|
||||
"success_rate": episode_admin_query_auto["success_rate"],
|
||||
"keyword_recall": episode_admin_query_auto["keyword_recall"],
|
||||
},
|
||||
"episode_admin_query_after_rebuild": {
|
||||
"success_rate": episode_admin_query_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
"episode_search_mode_auto": {
|
||||
"success_rate": episode_search_mode_auto["success_rate"],
|
||||
"keyword_recall": episode_search_mode_auto["keyword_recall"],
|
||||
},
|
||||
"episode_search_mode_after_rebuild": {
|
||||
"success_rate": episode_search_mode_after_rebuild["success_rate"],
|
||||
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
|
||||
"rebuild_success": bool(episode_rebuild.get("success", False)),
|
||||
},
|
||||
},
|
||||
"cases": {
|
||||
"search": search_case_reports,
|
||||
"writeback": writeback_reports,
|
||||
"knowledge_fetcher": knowledge_reports,
|
||||
"profile": profile_reports,
|
||||
"tool_modes": tool_modes["reports"],
|
||||
"episode_generation_auto": episode_generation_auto["reports"],
|
||||
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
|
||||
"episode_admin_query_auto": episode_admin_query_auto["reports"],
|
||||
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
|
||||
"episode_search_mode_auto": episode_search_mode_auto["reports"],
|
||||
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
|
||||
},
|
||||
}
|
||||
|
||||
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
|
||||
|
||||
assert report["import"]["status"] == "completed"
|
||||
assert report["runtime_self_check"]["ok"] is True
|
||||
@@ -5,68 +5,6 @@ import pytest
|
||||
from src.services import memory_flow_service as memory_flow_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch):
|
||||
starts: list[str] = []
|
||||
summarizers: list[object] = []
|
||||
|
||||
class FakeSummarizer:
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
summarizers.append(self)
|
||||
|
||||
async def start(self):
|
||||
starts.append(self.session_id)
|
||||
|
||||
async def stop(self):
|
||||
starts.append(f"stop:{self.session_id}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
|
||||
|
||||
manager = memory_flow_module.LongTermMemorySessionManager()
|
||||
message = SimpleNamespace(session_id="session-1")
|
||||
|
||||
await manager.on_message(message)
|
||||
await manager.on_message(message)
|
||||
|
||||
assert len(summarizers) == 1
|
||||
assert starts == ["session-1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch):
|
||||
stopped: list[str] = []
|
||||
|
||||
class FakeSummarizer:
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
|
||||
async def start(self):
|
||||
return None
|
||||
|
||||
async def stop(self):
|
||||
stopped.append(self.session_id)
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
|
||||
|
||||
manager = memory_flow_module.LongTermMemorySessionManager()
|
||||
await manager.on_message(SimpleNamespace(session_id="session-a"))
|
||||
await manager.on_message(SimpleNamespace(session_id="session-b"))
|
||||
await manager.shutdown()
|
||||
|
||||
assert stopped == ["session-a", "session-b"]
|
||||
|
||||
|
||||
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
|
||||
raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]'
|
||||
|
||||
@@ -101,16 +39,9 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
|
||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
events: list[tuple[str, str]] = []
|
||||
|
||||
class FakeSessionManager:
|
||||
async def on_message(self, message):
|
||||
events.append(("incoming", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "session"))
|
||||
|
||||
class FakeFactWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "fact"))
|
||||
@@ -122,17 +53,13 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.session_manager = FakeSessionManager()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
|
||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("incoming", "session-1"),
|
||||
("sent", "session-1"),
|
||||
("shutdown", "session"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
|
||||
|
||||
|
||||
def _load_dialogue_fixture() -> Dict[str, Any]:
|
||||
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class _FakeEmbeddingAdapter:
|
||||
def __init__(self, dimension: int = 16) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.dimension
|
||||
|
||||
async def encode(self, texts, dimensions=None):
|
||||
dim = int(dimensions or self.dimension)
|
||||
if isinstance(texts, str):
|
||||
sequence = [texts]
|
||||
single = True
|
||||
else:
|
||||
sequence = list(texts)
|
||||
single = False
|
||||
|
||||
rows = []
|
||||
for text in sequence:
|
||||
vec = np.zeros(dim, dtype=np.float32)
|
||||
for ch in str(text or ""):
|
||||
vec[ord(ch) % dim] += 1.0
|
||||
if not vec.any():
|
||||
vec[0] = 1.0
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm > 0:
|
||||
vec = vec / norm
|
||||
rows.append(vec)
|
||||
payload = np.vstack(rows)
|
||||
return payload[0] if single else payload
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]:
|
||||
for _ in range(max_rounds):
|
||||
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||
task = detail.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||
return detail
|
||||
await asyncio.sleep(0.05)
|
||||
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||
|
||||
|
||||
def _join_hit_content(search_result) -> str:
|
||||
return "\n".join(hit.content for hit in search_result.hits)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def real_dialogue_env(monkeypatch, tmp_path):
|
||||
scenario = _load_dialogue_fixture()
|
||||
session_cfg = scenario["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
|
||||
|
||||
async def fake_self_check(**kwargs):
|
||||
return {"ok": True, "message": "ok"}
|
||||
|
||||
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"scenario": scenario,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="private_alice.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
|
||||
assert created["success"] is True
|
||||
detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||
assert detail["task"]["status"] == "completed"
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
person_id=scenario["person"]["person_id"],
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
await memory_service.ingest_text(
|
||||
external_id="fixture:knowledge_fetcher",
|
||||
source_type="dialogue_note",
|
||||
text=scenario["person_fact"]["memory_content"],
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
person_ids=[scenario["person"]["person_id"]],
|
||||
participants=[scenario["person"]["person_name"]],
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=scenario["session"]["display_name"],
|
||||
stream_id=real_dialogue_env["session"].session_id,
|
||||
)
|
||||
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
|
||||
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in knowledge_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
profile = await memory_service.get_person_profile(
|
||||
scenario["person"]["person_id"],
|
||||
chat_id=real_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert profile.evidence
|
||||
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env):
|
||||
scenario = real_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
episodes = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=scenario["expectations"]["episode_source"],
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert episodes["success"] is True
|
||||
assert int(episodes["count"]) >= 1
|
||||
@@ -1,301 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.memory_system import chat_history_summarizer as summarizer_module
|
||||
from src.person_info import person_info as person_info_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
|
||||
reason="需要显式开启真实 embedding / self-check 集成测试",
|
||||
)
|
||||
|
||||
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
|
||||
|
||||
|
||||
def _load_dialogue_fixture() -> Dict[str, Any]:
|
||||
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]:
|
||||
deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
|
||||
task = detail.get("task") or {}
|
||||
status = str(task.get("status", "") or "")
|
||||
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
|
||||
return detail
|
||||
await asyncio.sleep(0.2)
|
||||
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
|
||||
|
||||
|
||||
def _join_hit_content(search_result) -> str:
|
||||
return "\n".join(hit.content for hit in search_result.hits)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def live_dialogue_env(monkeypatch, tmp_path):
|
||||
scenario = _load_dialogue_fixture()
|
||||
session_cfg = scenario["session"]
|
||||
session = SimpleNamespace(
|
||||
session_id=session_cfg["session_id"],
|
||||
platform=session_cfg["platform"],
|
||||
user_id=session_cfg["user_id"],
|
||||
group_id=session_cfg["group_id"],
|
||||
)
|
||||
fake_chat_manager = SimpleNamespace(
|
||||
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
|
||||
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
|
||||
|
||||
data_dir = (tmp_path / "a_memorix_data").resolve()
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str(data_dir)},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
|
||||
|
||||
await kernel.initialize()
|
||||
try:
|
||||
yield {
|
||||
"scenario": scenario,
|
||||
"kernel": kernel,
|
||||
"session": session,
|
||||
}
|
||||
finally:
|
||||
await kernel.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_runtime_self_check_passes(live_dialogue_env):
|
||||
report = await memory_service.runtime_admin(action="refresh_self_check")
|
||||
|
||||
assert report["success"] is True
|
||||
assert report["report"]["ok"] is True
|
||||
assert report["report"]["encoded_dimension"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
created = await memory_service.import_admin(
|
||||
action="create_paste",
|
||||
name="private_alice.json",
|
||||
input_mode="json",
|
||||
llm_enabled=False,
|
||||
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
|
||||
)
|
||||
|
||||
assert created["success"] is True
|
||||
detail = await _wait_for_import_task(created["task"]["task_id"])
|
||||
assert detail["task"]["status"] == "completed"
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
search = await memory_service.search(
|
||||
scenario["search_queries"]["direct"],
|
||||
mode="search",
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
person_id=scenario["person"]["person_id"],
|
||||
)
|
||||
|
||||
assert search.hits
|
||||
joined = _join_hit_content(search)
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in joined
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
await memory_service.ingest_text(
|
||||
external_id="fixture:knowledge_fetcher",
|
||||
source_type="dialogue_note",
|
||||
text=scenario["person_fact"]["memory_content"],
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
person_ids=[scenario["person"]["person_id"]],
|
||||
participants=[scenario["person"]["person_name"]],
|
||||
respect_filter=False,
|
||||
)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(
|
||||
private_name=scenario["session"]["display_name"],
|
||||
stream_id=live_dialogue_env["session"].session_id,
|
||||
)
|
||||
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
|
||||
|
||||
for keyword in scenario["expectations"]["search_keywords"]:
|
||||
assert keyword in knowledge_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
|
||||
class _KnownPerson:
|
||||
def __init__(self, person_id: str) -> None:
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
self.person_name = scenario["person"]["person_name"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
person_info_module,
|
||||
"get_person_id_by_person_name",
|
||||
lambda person_name: scenario["person"]["person_id"],
|
||||
)
|
||||
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer(
|
||||
scenario["person"]["person_name"],
|
||||
scenario["person_fact"]["memory_content"],
|
||||
live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
profile = await memory_service.get_person_profile(
|
||||
scenario["person"]["person_id"],
|
||||
chat_id=live_dialogue_env["session"].session_id,
|
||||
)
|
||||
|
||||
assert profile.evidence
|
||||
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env):
|
||||
scenario = live_dialogue_env["scenario"]
|
||||
record = scenario["chat_history_record"]
|
||||
|
||||
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
|
||||
await summarizer._import_to_long_term_memory(
|
||||
record_id=record["record_id"],
|
||||
theme=record["theme"],
|
||||
summary=record["summary"],
|
||||
participants=record["participants"],
|
||||
start_time=record["start_time"],
|
||||
end_time=record["end_time"],
|
||||
original_text=record["original_text"],
|
||||
)
|
||||
|
||||
episodes = await memory_service.episode_admin(
|
||||
action="query",
|
||||
source=scenario["expectations"]["episode_source"],
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert episodes["success"] is True
|
||||
assert int(episodes["count"]) >= 1
|
||||
@@ -1,23 +0,0 @@
|
||||
from src.maisaka.chat_loop_service import MaisakaChatLoopService
|
||||
|
||||
|
||||
def test_build_tool_names_log_text_supports_openai_function_schema() -> None:
|
||||
tool_definitions = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mute_user",
|
||||
"description": "禁言指定用户",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "reply",
|
||||
"description": "发送回复",
|
||||
},
|
||||
]
|
||||
|
||||
assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply"
|
||||
@@ -1,339 +0,0 @@
|
||||
"""MutePlugin SDK 回归测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
|
||||
from maibot_sdk.context import PluginContext
|
||||
from maibot_sdk.plugin import MaiBotPlugin
|
||||
|
||||
from plugins.MutePlugin.plugin import create_plugin
|
||||
from src.core.tooling import ToolExecutionContext, ToolInvocation
|
||||
from src.plugin_runtime.component_query import ComponentQueryService
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
|
||||
def _build_plugin() -> MaiBotPlugin:
|
||||
"""构造已注入默认配置的插件实例。"""
|
||||
|
||||
plugin = create_plugin()
|
||||
plugin.set_plugin_config(plugin.get_default_config())
|
||||
return plugin
|
||||
|
||||
|
||||
def test_mute_plugin_manifest_is_valid_v2() -> None:
|
||||
"""MutePlugin 的 manifest 应符合当前运行时要求。"""
|
||||
|
||||
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0")
|
||||
manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin"))
|
||||
|
||||
assert manifest is not None
|
||||
assert manifest.id == "sengokucola.mute-plugin"
|
||||
assert manifest.manifest_version == 2
|
||||
|
||||
|
||||
def test_create_plugin_returns_sdk_plugin() -> None:
|
||||
"""插件入口应返回 SDK 插件实例。"""
|
||||
|
||||
plugin = create_plugin()
|
||||
|
||||
assert isinstance(plugin, MaiBotPlugin)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_command_calls_napcat_group_ban_api() -> None:
|
||||
"""手动禁言命令应通过 NapCat Adapter 新 API 执行。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
plugin.set_plugin_config(
|
||||
{
|
||||
**plugin.get_default_config(),
|
||||
"components": {
|
||||
"enable_smart_mute": True,
|
||||
"enable_mute_command": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call":
|
||||
return {"success": True, "result": {"status": "ok", "retcode": 0}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message, intercept = await plugin.handle_mute_command(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
user_id="42",
|
||||
matched_groups={
|
||||
"target": "张三",
|
||||
"duration": "120",
|
||||
"reason": "刷屏",
|
||||
},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
assert intercept is True
|
||||
|
||||
api_call = next(
|
||||
call
|
||||
for call in capability_calls
|
||||
if call["capability"] == "api.call"
|
||||
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
)
|
||||
assert api_call["args"]["version"] == "1"
|
||||
assert api_call["args"]["args"] == {
|
||||
"group_id": "10001",
|
||||
"user_id": "123456",
|
||||
"duration": 120,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_requires_target_person_name() -> None:
|
||||
"""禁言工具在缺少目标时应直接失败并提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
return {"success": True}
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="",
|
||||
duration="60",
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "禁言目标不能为空"
|
||||
assert capability_calls[-1]["capability"] == "send.text"
|
||||
assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None:
|
||||
"""禁言工具应能兼容解包多层 capability 返回结果。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "result": {"success": True, "person_id": "person-1"}}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "result": {"success": True, "value": "123456"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call":
|
||||
return {"success": True, "result": {"status": "ok"}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
|
||||
api_call = next(
|
||||
call
|
||||
for call in capability_calls
|
||||
if call["capability"] == "api.call"
|
||||
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
)
|
||||
assert api_call["args"]["args"]["user_id"] == "123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_rejects_owner_before_group_ban_call() -> None:
|
||||
"""禁言工具应在检测到群主时提前返回明确提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "owner"}}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "张三 是群主,不能被禁言"
|
||||
assert not any(
|
||||
call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
|
||||
for call in capability_calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None:
|
||||
"""NapCat 返回 cannot ban owner 时应转成明确中文提示。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
capability_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
capability_calls.append(payload)
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
|
||||
return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
assert message == "张三 是群主,不能被禁言"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mute_tool_accepts_nested_ok_api_result() -> None:
|
||||
"""嵌套的 success/result/status=ok 返回值也应判定为成功。"""
|
||||
|
||||
plugin = _build_plugin()
|
||||
|
||||
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
|
||||
assert method == "cap.call"
|
||||
assert payload is not None
|
||||
|
||||
capability = payload["capability"]
|
||||
if capability == "person.get_id_by_name":
|
||||
return {"success": True, "person_id": "person-1"}
|
||||
if capability == "person.get_value":
|
||||
return {"success": True, "value": "123456"}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
|
||||
return {"success": True, "result": {"role": "member"}}
|
||||
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
|
||||
return {
|
||||
"success": True,
|
||||
"result": {
|
||||
"status": "ok",
|
||||
"retcode": 0,
|
||||
"data": None,
|
||||
"message": "",
|
||||
"wording": "",
|
||||
},
|
||||
}
|
||||
if capability == "send.text":
|
||||
return {"success": True}
|
||||
raise AssertionError(f"unexpected capability: {capability}")
|
||||
|
||||
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
|
||||
|
||||
success, message = await plugin.handle_mute_tool(
|
||||
stream_id="group-10001",
|
||||
group_id="10001",
|
||||
target="张三",
|
||||
duration=60,
|
||||
reason="测试",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert message == "成功禁言 张三"
|
||||
|
||||
|
||||
def test_tool_invocation_payload_injects_group_and_user_context() -> None:
|
||||
"""插件工具执行时应自动补齐群聊上下文字段。"""
|
||||
|
||||
entry = SimpleNamespace(invoke_method="plugin.invoke_tool")
|
||||
anchor_message = SimpleNamespace(
|
||||
message_info=SimpleNamespace(
|
||||
group_info=SimpleNamespace(group_id="10001"),
|
||||
user_info=SimpleNamespace(user_id="20002"),
|
||||
)
|
||||
)
|
||||
invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1")
|
||||
context = ToolExecutionContext(
|
||||
session_id="session-1",
|
||||
stream_id="session-1",
|
||||
reasoning="test",
|
||||
metadata={"anchor_message": anchor_message},
|
||||
)
|
||||
|
||||
payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context)
|
||||
|
||||
assert payload["target"] == "张三"
|
||||
assert payload["stream_id"] == "session-1"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["group_id"] == "10001"
|
||||
assert payload["user_id"] == "20002"
|
||||
@@ -1,227 +0,0 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.warning_messages: list[str] = []
|
||||
|
||||
def debug(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
def info(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
def warning(self, msg: str) -> None:
|
||||
self.warning_messages.append(msg)
|
||||
|
||||
def error(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def load_utils_module(monkeypatch, qq_account=123456, platforms=None):
|
||||
logger = DummyLogger()
|
||||
configured_platforms = platforms or []
|
||||
|
||||
def _stub_module(name: str) -> ModuleType:
|
||||
module = ModuleType(name)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
for package_name in [
|
||||
"src",
|
||||
"src.chat",
|
||||
"src.chat.message_receive",
|
||||
"src.chat.utils",
|
||||
"src.common",
|
||||
"src.config",
|
||||
"src.llm_models",
|
||||
"src.person_info",
|
||||
]:
|
||||
if package_name not in sys.modules:
|
||||
package_module = ModuleType(package_name)
|
||||
package_module.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, package_name, package_module)
|
||||
|
||||
jieba_module = ModuleType("jieba")
|
||||
jieba_module.cut = lambda text: list(text)
|
||||
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
|
||||
|
||||
logger_module = _stub_module("src.common.logger")
|
||||
logger_module.get_logger = lambda _name: logger
|
||||
|
||||
config_module = _stub_module("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(
|
||||
bot=SimpleNamespace(
|
||||
qq_account=qq_account,
|
||||
platforms=configured_platforms,
|
||||
nickname="MaiBot",
|
||||
alias_names=[],
|
||||
),
|
||||
chat=SimpleNamespace(
|
||||
at_bot_inevitable_reply=1,
|
||||
mentioned_bot_reply=1,
|
||||
),
|
||||
)
|
||||
config_module.model_config = SimpleNamespace()
|
||||
|
||||
message_module = _stub_module("src.chat.message_receive.message")
|
||||
|
||||
class SessionMessage:
|
||||
pass
|
||||
|
||||
message_module.SessionMessage = SessionMessage
|
||||
|
||||
chat_manager_module = _stub_module("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace(get_session_by_session_id=lambda _chat_id: None)
|
||||
|
||||
llm_module = _stub_module("src.llm_models.utils_model")
|
||||
|
||||
class LLMRequest:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
del args, kwargs
|
||||
|
||||
llm_module.LLMRequest = LLMRequest
|
||||
|
||||
person_module = _stub_module("src.person_info.person_info")
|
||||
|
||||
class Person:
|
||||
pass
|
||||
|
||||
person_module.Person = Person
|
||||
|
||||
typo_generator_module = _stub_module("src.chat.utils.typo_generator")
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
del args, kwargs
|
||||
|
||||
def create_typo_sentence(self, sentence: str):
|
||||
return sentence, ""
|
||||
|
||||
typo_generator_module.ChineseTypoGenerator = ChineseTypoGenerator
|
||||
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "utils" / "utils.py"
|
||||
spec = importlib.util.spec_from_file_location("src.chat.utils.utils", file_path)
|
||||
utils_module = importlib.util.module_from_spec(spec)
|
||||
utils_module.__package__ = "src.chat.utils"
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(utils_module)
|
||||
return utils_module, logger
|
||||
|
||||
|
||||
def test_platform_specific_bot_accounts(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=[" TG : tg_bot ", "discord: disc_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_bot_account("qq") == "123456"
|
||||
assert utils_module.get_bot_account("webui") == "123456"
|
||||
assert utils_module.get_bot_account("telegram") == "tg_bot"
|
||||
assert utils_module.get_bot_account("tg") == "tg_bot"
|
||||
assert utils_module.get_bot_account("discord") == "disc_bot"
|
||||
|
||||
assert utils_module.is_bot_self("qq", "123456")
|
||||
assert utils_module.is_bot_self("webui", "123456")
|
||||
assert utils_module.is_bot_self("telegram", "tg_bot")
|
||||
assert utils_module.is_bot_self(" TG ", "tg_bot")
|
||||
|
||||
|
||||
def test_get_all_bot_accounts_includes_runtime_aliases(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=["TG:tg_bot", "discord:disc_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_all_bot_accounts() == {
|
||||
"qq": "123456",
|
||||
"webui": "123456",
|
||||
"telegram": "tg_bot",
|
||||
"tg": "tg_bot",
|
||||
"discord": "disc_bot",
|
||||
}
|
||||
|
||||
|
||||
def test_get_all_bot_accounts_keeps_canonical_qq_identity(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=["qq:999999", "webui:888888", "TG:tg_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_all_bot_accounts()["qq"] == "123456"
|
||||
assert utils_module.get_all_bot_accounts()["webui"] == "123456"
|
||||
|
||||
|
||||
def test_unknown_platform_no_longer_falls_back_to_qq(monkeypatch):
|
||||
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
assert utils_module.is_bot_self("unknown_platform", "123456") is False
|
||||
assert logger.warning_messages
|
||||
assert "unknown_platform" in logger.warning_messages[-1]
|
||||
|
||||
|
||||
def test_unknown_platform_warns_only_once(monkeypatch):
|
||||
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
assert utils_module.is_bot_self("unknown_platform", "first") is False
|
||||
assert utils_module.is_bot_self(" unknown_platform ", "second") is False
|
||||
assert len(logger.warning_messages) == 1
|
||||
|
||||
|
||||
def test_unconfigured_qq_account_disables_qq_and_webui_identity(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=0, platforms=["telegram:tg_bot"])
|
||||
|
||||
assert utils_module.get_bot_account("qq") == ""
|
||||
assert utils_module.get_bot_account("webui") == ""
|
||||
assert utils_module.is_bot_self("qq", "0") is False
|
||||
assert utils_module.is_bot_self("webui", "0") is False
|
||||
|
||||
|
||||
def test_is_mentioned_bot_in_message_uses_platform_account(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=["TG:tg_bot"])
|
||||
|
||||
message = SimpleNamespace(
|
||||
processed_plain_text="@tg_bot 你好",
|
||||
platform="telegram",
|
||||
is_mentioned=False,
|
||||
message_segment=None,
|
||||
message_info=SimpleNamespace(
|
||||
additional_config={},
|
||||
user_info=SimpleNamespace(user_id="user_1"),
|
||||
),
|
||||
)
|
||||
|
||||
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
|
||||
|
||||
assert is_mentioned is True
|
||||
assert is_at is True
|
||||
assert reply_probability == 1.0
|
||||
|
||||
|
||||
def test_is_mentioned_bot_in_message_normalizes_qq_platform(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
message = SimpleNamespace(
|
||||
processed_plain_text="@<MaiBot:123456> 你好",
|
||||
platform=" QQ ",
|
||||
is_mentioned=False,
|
||||
message_segment=None,
|
||||
message_info=SimpleNamespace(
|
||||
additional_config={},
|
||||
user_info=SimpleNamespace(user_id="user_1"),
|
||||
),
|
||||
)
|
||||
|
||||
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
|
||||
|
||||
assert is_mentioned is True
|
||||
assert is_at is True
|
||||
assert reply_probability == 1.0
|
||||
Reference in New Issue
Block a user