fix:解决 upstream 合并冲突

在临时整合分支中合并 upstream/r-dev
保留人物事实写回与反馈纠错配置
移除已下线的聊天总结配置并同步测试
This commit is contained in:
A-Dawn
2026-04-16 10:27:25 +08:00
31 changed files with 349 additions and 5128 deletions

View File

@@ -1,148 +0,0 @@
from types import SimpleNamespace
import pytest
from src.memory_system import chat_history_summarizer as summarizer_module
def _build_summarizer() -> summarizer_module.ChatHistorySummarizer:
summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer)
summarizer.session_id = "session-1"
summarizer.log_prefix = "[session-1]"
return summarizer
@pytest.mark.asyncio
async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch):
calls = []
summarizer = _build_summarizer()
async def fake_ingest_summary(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
)
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8)))
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._import_to_long_term_memory(
record_id=1,
theme="旅行计划",
summary="我们讨论了春游安排",
participants=["Alice", "Bob"],
start_time=1.0,
end_time=2.0,
original_text="long text",
)
assert len(calls) == 1
payload = calls[0]
assert payload["external_id"] == "chat_history:1"
assert payload["chat_id"] == "session-1"
assert payload["participants"] == ["Alice", "Bob"]
assert payload["respect_filter"] is True
assert payload["user_id"] == "user-1"
assert payload["group_id"] == ""
assert "主题:旅行计划" in payload["text"]
assert "概括:我们讨论了春游安排" in payload["text"]
@pytest.mark.asyncio
async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch):
summarizer = _build_summarizer()
fallback_calls = []
async def fake_fallback(**kwargs):
fallback_calls.append(kwargs)
summarizer._fallback_import_to_long_term_memory = fake_fallback
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")),
)
await summarizer._import_to_long_term_memory(
record_id=2,
theme="",
summary="",
participants=[],
start_time=10.0,
end_time=20.0,
original_text="raw chat",
)
assert len(fallback_calls) == 1
assert fallback_calls[0]["record_id"] == 2
assert fallback_calls[0]["original_text"] == "raw chat"
@pytest.mark.asyncio
async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch):
summarizer = _build_summarizer()
fallback_calls = []
async def fake_ingest_summary(**kwargs):
return SimpleNamespace(success=False, detail="boom", stored_ids=[])
async def fake_fallback(**kwargs):
fallback_calls.append(kwargs)
summarizer._fallback_import_to_long_term_memory = fake_fallback
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")),
)
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._import_to_long_term_memory(
record_id=3,
theme="电影",
summary="聊了电影推荐",
participants=["Alice"],
start_time=3.0,
end_time=4.0,
original_text="raw",
)
assert len(fallback_calls) == 1
assert fallback_calls[0]["theme"] == "电影"
@pytest.mark.asyncio
async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch):
calls = []
summarizer = _build_summarizer()
async def fake_ingest_summary(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[])
monkeypatch.setattr(
summarizer_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")),
)
monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12)))
monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary)
await summarizer._fallback_import_to_long_term_memory(
record_id=4,
theme="工作",
participants=["Alice"],
start_time=5.0,
end_time=6.0,
original_text="a" * 128,
)
assert len(calls) == 1
metadata = calls[0]["metadata"]
assert metadata["generate_from_chat"] is True
assert metadata["context_length"] == 12
assert calls[0]["respect_filter"] is True

View File

@@ -1,687 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import numpy as np
import pytest
import pytest_asyncio
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import MemorySearchResult, memory_service
DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json"
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json"
def _load_benchmark_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _FakeEmbeddingAdapter:
def __init__(self, dimension: int = 32) -> None:
self.dimension = dimension
async def _detect_dimension(self) -> int:
return self.dimension
async def encode(self, texts, dimensions=None):
dim = int(dimensions or self.dimension)
if isinstance(texts, str):
sequence = [texts]
single = True
else:
sequence = list(texts)
single = False
rows = []
for text in sequence:
vec = np.zeros(dim, dtype=np.float32)
for ch in str(text or ""):
code = ord(ch)
vec[code % dim] += 1.0
vec[(code * 7) % dim] += 0.5
if not vec.any():
vec[0] = 1.0
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
rows.append(vec)
payload = np.vstack(rows)
return payload[0] if single else payload
class _KnownPerson:
def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None:
self.person_id = person_id
self.is_known = person_id in reverse_registry
self.person_name = reverse_registry.get(person_id, "")
self._registry = registry
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]:
for _ in range(max_rounds):
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(max(0.01, float(sleep_seconds)))
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result: MemorySearchResult) -> str:
return "\n".join(hit.content for hit in search_result.hits)
def _keyword_hits(text: str, keywords: List[str]) -> int:
haystack = str(text or "")
return sum(1 for keyword in keywords if keyword in haystack)
def _keyword_recall(text: str, keywords: List[str]) -> float:
if not keywords:
return 1.0
return _keyword_hits(text, keywords) / float(len(keywords))
def _hit_blob(hit) -> str:
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
return "\n".join(
[
str(hit.content or ""),
str(hit.title or ""),
str(hit.source or ""),
json.dumps(meta, ensure_ascii=False),
]
)
def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int:
for index, hit in enumerate(search_result.hits[:5], start=1):
if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))):
return index
return 0
def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str:
return "\n".join(
(
f"{item.get('title', '')}\n"
f"{item.get('summary', '')}\n"
f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n"
f"{json.dumps(item.get('participants', []), ensure_ascii=False)}"
)
for item in items
)
def _episode_blob_from_hits(search_result: MemorySearchResult) -> str:
chunks = []
for hit in search_result.hits:
meta = hit.metadata if isinstance(hit.metadata, dict) else {}
chunks.append(
"\n".join(
[
str(hit.title or ""),
str(hit.content or ""),
json.dumps(meta.get("keywords", []) or [], ensure_ascii=False),
json.dumps(meta.get("participants", []) or [], ensure_ascii=False),
]
)
)
return "\n".join(chunks)
async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
episode_source = f"chat_summary:{session_id}"
payload = await memory_service.episode_admin(
action="query",
source=episode_source,
limit=20,
)
items = payload.get("items") or []
blob = _episode_blob_from_items(items)
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in episode_cases:
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(items),
"top_episode": items[0] if items else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"episode_count": len(items),
"reports": reports,
}
async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
episode_source = f"chat_summary:{session_id}"
for case in episode_cases:
payload = await memory_service.episode_admin(
action="query",
source=episode_source,
query=case["query"],
limit=5,
)
items = payload.get("items") or []
blob = "\n".join(
f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}"
for item in items
)
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(items),
"top_episode": items[0] if items else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]:
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in episode_cases:
result = await memory_service.search(
case["query"],
mode="episode",
chat_id=session_id,
respect_filter=False,
limit=5,
)
blob = _episode_blob_from_hits(result)
recall = _keyword_recall(blob, case["expected_keywords"])
success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0))
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"episode_count": len(result.hits),
"top_episode": result.hits[0].to_dict() if result.hits else None,
}
)
total = max(1, len(episode_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]:
search_case = dataset["search_cases"][0]
episode_case = dataset["episode_cases"][0]
aggregate_case = dataset["knowledge_fetcher_cases"][0]
first_record = (dataset.get("chat_history_records") or [{}])[0]
reference_ts = first_record.get("end_time") or first_record.get("start_time") or 0
if reference_ts:
time_expression = datetime.fromtimestamp(float(reference_ts)).strftime("%Y/%m/%d")
else:
time_expression = "最近7天"
tool_cases = [
{
"name": "search",
"kwargs": {
"query": "蓝漆铁盒 北塔木梯",
"mode": "search",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"],
"minimum_keyword_recall": 0.67,
},
{
"name": "time",
"kwargs": {
"query": "蓝漆铁盒 北塔",
"mode": "time",
"chat_id": session_id,
"limit": 5,
"time_expression": time_expression,
},
"expected_keywords": ["蓝漆铁盒", "北塔木梯"],
"minimum_keyword_recall": 0.67,
},
{
"name": "episode",
"kwargs": {
"query": episode_case["query"],
"mode": "episode",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": episode_case["expected_keywords"],
"minimum_keyword_recall": 0.67,
},
{
"name": "aggregate",
"kwargs": {
"query": aggregate_case["query"],
"mode": "aggregate",
"chat_id": session_id,
"limit": 5,
},
"expected_keywords": aggregate_case["expected_keywords"],
"minimum_keyword_recall": 0.67,
},
]
reports: List[Dict[str, Any]] = []
success_rate = 0.0
keyword_recall = 0.0
for case in tool_cases:
text = await query_long_term_memory(**case["kwargs"])
recall = _keyword_recall(text, case["expected_keywords"])
success = (
"失败" not in text
and "无法解析" not in text
and "未找到" not in text
and recall >= float(case["minimum_keyword_recall"])
)
success_rate += 1.0 if success else 0.0
keyword_recall += recall
reports.append(
{
"name": case["name"],
"success": success,
"keyword_recall": recall,
"preview": text[:320],
}
)
total = max(1, len(tool_cases))
return {
"success_rate": round(success_rate / total, 4),
"keyword_recall": round(keyword_recall / total, 4),
"reports": reports,
}
@pytest_asyncio.fixture
async def benchmark_env(monkeypatch, tmp_path):
dataset = _load_benchmark_fixture()
session_cfg = dataset["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
reverse_registry = {value: key for key, value in registry.items()}
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
async def fake_self_check(**kwargs):
return {"ok": True, "message": "ok", "encoded_dimension": 32}
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
monkeypatch.setattr(
person_info_module,
"Person",
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
)
data_dir = (tmp_path / "a_memorix_benchmark_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"dataset": dataset,
"kernel": kernel,
"session": session,
"person_registry": registry,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_long_novel_memory_benchmark(benchmark_env):
dataset = benchmark_env["dataset"]
session_id = benchmark_env["session"].session_id
created = await memory_service.import_admin(
action="create_paste",
name="long_novel_memory_benchmark.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
import_detail = await _wait_for_import_task(created["task"]["task_id"])
assert import_detail["task"]["status"] == "completed"
for record in dataset["chat_history_records"]:
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
for payload in dataset["person_writebacks"]:
await person_info_module.store_person_memory_from_answer(
payload["person_name"],
payload["memory_content"],
session_id,
)
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
search_case_reports: List[Dict[str, Any]] = []
search_accuracy_at_1 = 0.0
search_recall_at_5 = 0.0
search_precision_at_5 = 0.0
search_mrr = 0.0
search_keyword_recall = 0.0
for case in dataset["search_cases"]:
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
joined = _join_hit_content(result)
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
relevant_hits = sum(
1
for hit in result.hits[:5]
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
)
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
search_recall_at_5 += 1.0 if rank > 0 else 0.0
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
search_keyword_recall += keyword_recall
search_case_reports.append(
{
"query": case["query"],
"rank_of_first_relevant": rank,
"relevant_hits_top5": relevant_hits,
"keyword_recall_top5": keyword_recall,
"top_hit": result.hits[0].to_dict() if result.hits else None,
}
)
search_total = max(1, len(dataset["search_cases"]))
writeback_reports: List[Dict[str, Any]] = []
writeback_success_rate = 0.0
writeback_keyword_recall = 0.0
for payload in dataset["person_writebacks"]:
query = " ".join(payload["expected_keywords"])
result = await memory_service.search(
query,
mode="search",
chat_id=session_id,
person_id=payload["person_id"],
respect_filter=False,
limit=5,
)
joined = _join_hit_content(result)
recall = _keyword_recall(joined, payload["expected_keywords"])
success = bool(result.hits) and recall >= 0.67
writeback_success_rate += 1.0 if success else 0.0
writeback_keyword_recall += recall
writeback_reports.append(
{
"person_id": payload["person_id"],
"success": success,
"keyword_recall": recall,
"hit_count": len(result.hits),
}
)
writeback_total = max(1, len(dataset["person_writebacks"]))
knowledge_reports: List[Dict[str, Any]] = []
knowledge_success_rate = 0.0
knowledge_keyword_recall = 0.0
fetcher = knowledge_module.KnowledgeFetcher(
private_name=dataset["session"]["display_name"],
stream_id=session_id,
)
for case in dataset["knowledge_fetcher_cases"]:
knowledge_text, _ = await fetcher.fetch(case["query"], [])
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
knowledge_success_rate += 1.0 if success else 0.0
knowledge_keyword_recall += recall
knowledge_reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"preview": knowledge_text[:300],
}
)
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
profile_reports: List[Dict[str, Any]] = []
profile_success_rate = 0.0
profile_keyword_recall = 0.0
profile_evidence_rate = 0.0
for case in dataset["profile_cases"]:
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
recall = _keyword_recall(profile.summary, case["expected_keywords"])
has_evidence = bool(profile.evidence)
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
profile_success_rate += 1.0 if success else 0.0
profile_keyword_recall += recall
profile_evidence_rate += 1.0 if has_evidence else 0.0
profile_reports.append(
{
"person_id": case["person_id"],
"success": success,
"keyword_recall": recall,
"evidence_count": len(profile.evidence),
"summary_preview": profile.summary[:240],
}
)
profile_total = max(1, len(dataset["profile_cases"]))
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_rebuild = await memory_service.episode_admin(
action="rebuild",
source=f"chat_summary:{session_id}",
)
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
report = {
"dataset": dataset["meta"],
"import": {
"task_id": created["task"]["task_id"],
"status": import_detail["task"]["status"],
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
},
"metrics": {
"search": {
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
"recall_at_5": round(search_recall_at_5 / search_total, 4),
"precision_at_5": round(search_precision_at_5 / search_total, 4),
"mrr": round(search_mrr / search_total, 4),
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
},
"writeback": {
"success_rate": round(writeback_success_rate / writeback_total, 4),
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
},
"knowledge_fetcher": {
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
},
"profile": {
"success_rate": round(profile_success_rate / profile_total, 4),
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
},
"tool_modes": {
"success_rate": tool_modes["success_rate"],
"keyword_recall": tool_modes["keyword_recall"],
},
"episode_generation_auto": {
"success_rate": episode_generation_auto["success_rate"],
"keyword_recall": episode_generation_auto["keyword_recall"],
"episode_count": episode_generation_auto["episode_count"],
},
"episode_generation_after_rebuild": {
"success_rate": episode_generation_after_rebuild["success_rate"],
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
"episode_count": episode_generation_after_rebuild["episode_count"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_admin_query_auto": {
"success_rate": episode_admin_query_auto["success_rate"],
"keyword_recall": episode_admin_query_auto["keyword_recall"],
},
"episode_admin_query_after_rebuild": {
"success_rate": episode_admin_query_after_rebuild["success_rate"],
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_search_mode_auto": {
"success_rate": episode_search_mode_auto["success_rate"],
"keyword_recall": episode_search_mode_auto["keyword_recall"],
},
"episode_search_mode_after_rebuild": {
"success_rate": episode_search_mode_after_rebuild["success_rate"],
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
},
"cases": {
"search": search_case_reports,
"writeback": writeback_reports,
"knowledge_fetcher": knowledge_reports,
"profile": profile_reports,
"tool_modes": tool_modes["reports"],
"episode_generation_auto": episode_generation_auto["reports"],
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
"episode_admin_query_auto": episode_admin_query_auto["reports"],
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
"episode_search_mode_auto": episode_search_mode_auto["reports"],
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
},
}
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
assert report["import"]["status"] == "completed"
assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35
assert report["metrics"]["search"]["recall_at_5"] >= 0.6
assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8
assert report["metrics"]["writeback"]["success_rate"] >= 0.66
assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66
assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75
assert report["metrics"]["profile"]["success_rate"] >= 0.66
assert report["metrics"]["profile"]["evidence_rate"] >= 1.0
assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75
assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True
assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"]

View File

@@ -1,342 +0,0 @@
from __future__ import annotations
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
import pytest_asyncio
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from pytests.A_memorix_test.test_long_novel_memory_benchmark import (
_evaluate_episode_admin_query,
_evaluate_episode_generation,
_evaluate_episode_search_mode,
_evaluate_tool_modes,
_KernelBackedRuntimeManager,
_KnownPerson,
_first_relevant_rank,
_hit_blob,
_join_hit_content,
_keyword_hits,
_keyword_recall,
_load_benchmark_fixture,
_wait_for_import_task,
)
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
pytestmark = pytest.mark.skipif(
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
reason="需要显式开启真实 external embedding benchmark",
)
REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json"
@pytest_asyncio.fixture
async def benchmark_live_env(monkeypatch, tmp_path):
dataset = _load_benchmark_fixture()
session_cfg = dataset["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]}
reverse_registry = {value: key for key, value in registry.items()}
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), ""))
monkeypatch.setattr(
person_info_module,
"Person",
lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry),
)
data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"dataset": dataset,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_long_novel_memory_benchmark_live(benchmark_live_env):
dataset = benchmark_live_env["dataset"]
session_id = benchmark_live_env["session"].session_id
self_check = await memory_service.runtime_admin(action="refresh_self_check")
assert self_check["success"] is True
assert self_check["report"]["ok"] is True
created = await memory_service.import_admin(
action="create_paste",
name="long_novel_memory_benchmark.live.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(dataset["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
import_detail = await _wait_for_import_task(
created["task"]["task_id"],
max_rounds=2400,
sleep_seconds=0.25,
)
assert import_detail["task"]["status"] == "completed"
for record in dataset["chat_history_records"]:
summarizer = summarizer_module.ChatHistorySummarizer(session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
for payload in dataset["person_writebacks"]:
await person_info_module.store_person_memory_from_answer(
payload["person_name"],
payload["memory_content"],
session_id,
)
await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2)
search_case_reports: List[Dict[str, Any]] = []
search_accuracy_at_1 = 0.0
search_recall_at_5 = 0.0
search_precision_at_5 = 0.0
search_mrr = 0.0
search_keyword_recall = 0.0
for case in dataset["search_cases"]:
result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5)
joined = _join_hit_content(result)
rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"])))
relevant_hits = sum(
1
for hit in result.hits[:5]
if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"]))))
)
keyword_recall = _keyword_recall(joined, case["expected_keywords"])
search_accuracy_at_1 += 1.0 if rank == 1 else 0.0
search_recall_at_5 += 1.0 if rank > 0 else 0.0
search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits))))
search_mrr += 1.0 / float(rank) if rank > 0 else 0.0
search_keyword_recall += keyword_recall
search_case_reports.append(
{
"query": case["query"],
"rank_of_first_relevant": rank,
"relevant_hits_top5": relevant_hits,
"keyword_recall_top5": keyword_recall,
"top_hit": result.hits[0].to_dict() if result.hits else None,
}
)
search_total = max(1, len(dataset["search_cases"]))
writeback_reports: List[Dict[str, Any]] = []
writeback_success_rate = 0.0
writeback_keyword_recall = 0.0
for payload in dataset["person_writebacks"]:
query = " ".join(payload["expected_keywords"])
result = await memory_service.search(
query,
mode="search",
chat_id=session_id,
person_id=payload["person_id"],
respect_filter=False,
limit=5,
)
joined = _join_hit_content(result)
recall = _keyword_recall(joined, payload["expected_keywords"])
success = bool(result.hits) and recall >= 0.67
writeback_success_rate += 1.0 if success else 0.0
writeback_keyword_recall += recall
writeback_reports.append(
{
"person_id": payload["person_id"],
"success": success,
"keyword_recall": recall,
"hit_count": len(result.hits),
}
)
writeback_total = max(1, len(dataset["person_writebacks"]))
knowledge_reports: List[Dict[str, Any]] = []
knowledge_success_rate = 0.0
knowledge_keyword_recall = 0.0
fetcher = knowledge_module.KnowledgeFetcher(
private_name=dataset["session"]["display_name"],
stream_id=session_id,
)
for case in dataset["knowledge_fetcher_cases"]:
knowledge_text, _ = await fetcher.fetch(case["query"], [])
recall = _keyword_recall(knowledge_text, case["expected_keywords"])
success = recall >= float(case.get("minimum_keyword_recall", 1.0))
knowledge_success_rate += 1.0 if success else 0.0
knowledge_keyword_recall += recall
knowledge_reports.append(
{
"query": case["query"],
"success": success,
"keyword_recall": recall,
"preview": knowledge_text[:300],
}
)
knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"]))
profile_reports: List[Dict[str, Any]] = []
profile_success_rate = 0.0
profile_keyword_recall = 0.0
profile_evidence_rate = 0.0
for case in dataset["profile_cases"]:
profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id)
recall = _keyword_recall(profile.summary, case["expected_keywords"])
has_evidence = bool(profile.evidence)
success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence
profile_success_rate += 1.0 if success else 0.0
profile_keyword_recall += recall
profile_evidence_rate += 1.0 if has_evidence else 0.0
profile_reports.append(
{
"person_id": case["person_id"],
"success": success,
"keyword_recall": recall,
"evidence_count": len(profile.evidence),
"summary_preview": profile.summary[:240],
}
)
profile_total = max(1, len(dataset["profile_cases"]))
episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_rebuild = await memory_service.episode_admin(
action="rebuild",
source=f"chat_summary:{session_id}",
)
episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"])
episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"])
tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset)
report = {
"dataset": dataset["meta"],
"runtime_self_check": self_check["report"],
"import": {
"task_id": created["task"]["task_id"],
"status": import_detail["task"]["status"],
"paragraph_count": len(dataset["import_payload"]["paragraphs"]),
},
"metrics": {
"search": {
"accuracy_at_1": round(search_accuracy_at_1 / search_total, 4),
"recall_at_5": round(search_recall_at_5 / search_total, 4),
"precision_at_5": round(search_precision_at_5 / search_total, 4),
"mrr": round(search_mrr / search_total, 4),
"keyword_recall_at_5": round(search_keyword_recall / search_total, 4),
},
"writeback": {
"success_rate": round(writeback_success_rate / writeback_total, 4),
"keyword_recall": round(writeback_keyword_recall / writeback_total, 4),
},
"knowledge_fetcher": {
"success_rate": round(knowledge_success_rate / knowledge_total, 4),
"keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4),
},
"profile": {
"success_rate": round(profile_success_rate / profile_total, 4),
"keyword_recall": round(profile_keyword_recall / profile_total, 4),
"evidence_rate": round(profile_evidence_rate / profile_total, 4),
},
"tool_modes": {
"success_rate": tool_modes["success_rate"],
"keyword_recall": tool_modes["keyword_recall"],
},
"episode_generation_auto": {
"success_rate": episode_generation_auto["success_rate"],
"keyword_recall": episode_generation_auto["keyword_recall"],
"episode_count": episode_generation_auto["episode_count"],
},
"episode_generation_after_rebuild": {
"success_rate": episode_generation_after_rebuild["success_rate"],
"keyword_recall": episode_generation_after_rebuild["keyword_recall"],
"episode_count": episode_generation_after_rebuild["episode_count"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_admin_query_auto": {
"success_rate": episode_admin_query_auto["success_rate"],
"keyword_recall": episode_admin_query_auto["keyword_recall"],
},
"episode_admin_query_after_rebuild": {
"success_rate": episode_admin_query_after_rebuild["success_rate"],
"keyword_recall": episode_admin_query_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
"episode_search_mode_auto": {
"success_rate": episode_search_mode_auto["success_rate"],
"keyword_recall": episode_search_mode_auto["keyword_recall"],
},
"episode_search_mode_after_rebuild": {
"success_rate": episode_search_mode_after_rebuild["success_rate"],
"keyword_recall": episode_search_mode_after_rebuild["keyword_recall"],
"rebuild_success": bool(episode_rebuild.get("success", False)),
},
},
"cases": {
"search": search_case_reports,
"writeback": writeback_reports,
"knowledge_fetcher": knowledge_reports,
"profile": profile_reports,
"tool_modes": tool_modes["reports"],
"episode_generation_auto": episode_generation_auto["reports"],
"episode_generation_after_rebuild": episode_generation_after_rebuild["reports"],
"episode_admin_query_auto": episode_admin_query_auto["reports"],
"episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"],
"episode_search_mode_auto": episode_search_mode_auto["reports"],
"episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"],
},
}
REPORT_FILE.parent.mkdir(parents=True, exist_ok=True)
REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
assert report["import"]["status"] == "completed"
assert report["runtime_self_check"]["ok"] is True

View File

@@ -5,68 +5,6 @@ import pytest
from src.services import memory_flow_service as memory_flow_module
@pytest.mark.asyncio
async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch):
starts: list[str] = []
summarizers: list[object] = []
class FakeSummarizer:
def __init__(self, session_id: str):
self.session_id = session_id
summarizers.append(self)
async def start(self):
starts.append(self.session_id)
async def stop(self):
starts.append(f"stop:{self.session_id}")
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
)
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
manager = memory_flow_module.LongTermMemorySessionManager()
message = SimpleNamespace(session_id="session-1")
await manager.on_message(message)
await manager.on_message(message)
assert len(summarizers) == 1
assert starts == ["session-1"]
@pytest.mark.asyncio
async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch):
stopped: list[str] = []
class FakeSummarizer:
def __init__(self, session_id: str):
self.session_id = session_id
async def start(self):
return None
async def stop(self):
stopped.append(self.session_id)
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)),
)
monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer)
manager = memory_flow_module.LongTermMemorySessionManager()
await manager.on_message(SimpleNamespace(session_id="session-a"))
await manager.on_message(SimpleNamespace(session_id="session-b"))
await manager.shutdown()
assert stopped == ["session-a", "session-b"]
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
raw = '["他喜欢猫", "他喜欢猫", "", "", "他会弹吉他"]'
@@ -101,16 +39,9 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = []
class FakeSessionManager:
async def on_message(self, message):
events.append(("incoming", message.session_id))
async def shutdown(self):
events.append(("shutdown", "session"))
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
@@ -122,17 +53,13 @@ async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch):
events.append(("shutdown", "fact"))
service = memory_flow_module.MemoryAutomationService()
service.session_manager = FakeSessionManager()
service.fact_writeback = FakeFactWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("incoming", "session-1"),
("sent", "session-1"),
("shutdown", "session"),
("shutdown", "fact"),
]

View File

@@ -1,324 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import numpy as np
import pytest
import pytest_asyncio
from A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
def _load_dialogue_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _FakeEmbeddingAdapter:
def __init__(self, dimension: int = 16) -> None:
self.dimension = dimension
async def _detect_dimension(self) -> int:
return self.dimension
async def encode(self, texts, dimensions=None):
dim = int(dimensions or self.dimension)
if isinstance(texts, str):
sequence = [texts]
single = True
else:
sequence = list(texts)
single = False
rows = []
for text in sequence:
vec = np.zeros(dim, dtype=np.float32)
for ch in str(text or ""):
vec[ord(ch) % dim] += 1.0
if not vec.any():
vec[0] = 1.0
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
rows.append(vec)
payload = np.vstack(rows)
return payload[0] if single else payload
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]:
for _ in range(max_rounds):
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(0.05)
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result) -> str:
return "\n".join(hit.content for hit in search_result.hits)
@pytest_asyncio.fixture
async def real_dialogue_env(monkeypatch, tmp_path):
scenario = _load_dialogue_fixture()
session_cfg = scenario["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter())
async def fake_self_check(**kwargs):
return {"ok": True, "message": "ok"}
monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
data_dir = (tmp_path / "a_memorix_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"scenario": scenario,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
created = await memory_service.import_admin(
action="create_paste",
name="private_alice.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
detail = await _wait_for_import_task(created["task"]["task_id"])
assert detail["task"]["status"] == "completed"
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
respect_filter=False,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=real_dialogue_env["session"].session_id,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch):
scenario = real_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
real_dialogue_env["session"].session_id,
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=real_dialogue_env["session"].session_id,
person_id=scenario["person"]["person_id"],
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
await memory_service.ingest_text(
external_id="fixture:knowledge_fetcher",
source_type="dialogue_note",
text=scenario["person_fact"]["memory_content"],
chat_id=real_dialogue_env["session"].session_id,
person_ids=[scenario["person"]["person_id"]],
participants=[scenario["person"]["person_name"]],
respect_filter=False,
)
fetcher = knowledge_module.KnowledgeFetcher(
private_name=scenario["session"]["display_name"],
stream_id=real_dialogue_env["session"].session_id,
)
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in knowledge_text
@pytest.mark.asyncio
async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch):
scenario = real_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
real_dialogue_env["session"].session_id,
)
profile = await memory_service.get_person_profile(
scenario["person"]["person_id"],
chat_id=real_dialogue_env["session"].session_id,
)
assert profile.evidence
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
@pytest.mark.asyncio
async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env):
scenario = real_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
episodes = await memory_service.episode_admin(
action="query",
source=scenario["expectations"]["episode_source"],
limit=5,
)
assert episodes["success"] is True
assert int(episodes["count"]) >= 1

View File

@@ -1,301 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
import json
import os
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
import pytest_asyncio
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.memory_system import chat_history_summarizer as summarizer_module
from src.person_info import person_info as person_info_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
pytestmark = pytest.mark.skipif(
os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1",
reason="需要显式开启真实 embedding / self-check 集成测试",
)
DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json"
def _load_dialogue_fixture() -> Dict[str, Any]:
return json.loads(DATA_FILE.read_text(encoding="utf-8"))
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]:
deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True)
task = detail.get("task") or {}
status = str(task.get("status", "") or "")
if status in {"completed", "completed_with_errors", "failed", "cancelled"}:
return detail
await asyncio.sleep(0.2)
raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}")
def _join_hit_content(search_result) -> str:
return "\n".join(hit.content for hit in search_result.hits)
@pytest_asyncio.fixture
async def live_dialogue_env(monkeypatch, tmp_path):
scenario = _load_dialogue_fixture()
session_cfg = scenario["session"]
session = SimpleNamespace(
session_id=session_cfg["session_id"],
platform=session_cfg["platform"],
user_id=session_cfg["user_id"],
group_id=session_cfg["group_id"],
)
fake_chat_manager = SimpleNamespace(
get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None,
get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id,
)
monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager)
monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager)
data_dir = (tmp_path / "a_memorix_data").resolve()
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str(data_dir)},
"advanced": {"enable_auto_save": False},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", manager)
await kernel.initialize()
try:
yield {
"scenario": scenario,
"kernel": kernel,
"session": session,
}
finally:
await kernel.shutdown()
@pytest.mark.asyncio
async def test_live_runtime_self_check_passes(live_dialogue_env):
report = await memory_service.runtime_admin(action="refresh_self_check")
assert report["success"] is True
assert report["report"]["ok"] is True
assert report["report"]["encoded_dimension"] > 0
@pytest.mark.asyncio
async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
created = await memory_service.import_admin(
action="create_paste",
name="private_alice.json",
input_mode="json",
llm_enabled=False,
content=json.dumps(scenario["import_payload"], ensure_ascii=False),
)
assert created["success"] is True
detail = await _wait_for_import_task(created["task"]["task_id"])
assert detail["task"]["status"] == "completed"
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
respect_filter=False,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=live_dialogue_env["session"].session_id,
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch):
scenario = live_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
live_dialogue_env["session"].session_id,
)
search = await memory_service.search(
scenario["search_queries"]["direct"],
mode="search",
chat_id=live_dialogue_env["session"].session_id,
person_id=scenario["person"]["person_id"],
)
assert search.hits
joined = _join_hit_content(search)
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in joined
@pytest.mark.asyncio
async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
await memory_service.ingest_text(
external_id="fixture:knowledge_fetcher",
source_type="dialogue_note",
text=scenario["person_fact"]["memory_content"],
chat_id=live_dialogue_env["session"].session_id,
person_ids=[scenario["person"]["person_id"]],
participants=[scenario["person"]["person_name"]],
respect_filter=False,
)
fetcher = knowledge_module.KnowledgeFetcher(
private_name=scenario["session"]["display_name"],
stream_id=live_dialogue_env["session"].session_id,
)
knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], [])
for keyword in scenario["expectations"]["search_keywords"]:
assert keyword in knowledge_text
@pytest.mark.asyncio
async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch):
scenario = live_dialogue_env["scenario"]
class _KnownPerson:
def __init__(self, person_id: str) -> None:
self.person_id = person_id
self.is_known = True
self.person_name = scenario["person"]["person_name"]
monkeypatch.setattr(
person_info_module,
"get_person_id_by_person_name",
lambda person_name: scenario["person"]["person_id"],
)
monkeypatch.setattr(person_info_module, "Person", _KnownPerson)
await person_info_module.store_person_memory_from_answer(
scenario["person"]["person_name"],
scenario["person_fact"]["memory_content"],
live_dialogue_env["session"].session_id,
)
profile = await memory_service.get_person_profile(
scenario["person"]["person_id"],
chat_id=live_dialogue_env["session"].session_id,
)
assert profile.evidence
assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"])
@pytest.mark.asyncio
async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env):
scenario = live_dialogue_env["scenario"]
record = scenario["chat_history_record"]
summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id)
await summarizer._import_to_long_term_memory(
record_id=record["record_id"],
theme=record["theme"],
summary=record["summary"],
participants=record["participants"],
start_time=record["start_time"],
end_time=record["end_time"],
original_text=record["original_text"],
)
episodes = await memory_service.episode_admin(
action="query",
source=scenario["expectations"]["episode_source"],
limit=5,
)
assert episodes["success"] is True
assert int(episodes["count"]) >= 1

View File

@@ -1,23 +0,0 @@
from src.maisaka.chat_loop_service import MaisakaChatLoopService
def test_build_tool_names_log_text_supports_openai_function_schema() -> None:
tool_definitions = [
{
"type": "function",
"function": {
"name": "mute_user",
"description": "禁言指定用户",
"parameters": {
"type": "object",
"properties": {},
},
},
},
{
"name": "reply",
"description": "发送回复",
},
]
assert MaisakaChatLoopService._build_tool_names_log_text(tool_definitions) == "mute_user、reply"

View File

@@ -1,339 +0,0 @@
"""MutePlugin SDK 回归测试。"""
from __future__ import annotations
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
import pytest
from maibot_sdk.context import PluginContext
from maibot_sdk.plugin import MaiBotPlugin
from plugins.MutePlugin.plugin import create_plugin
from src.core.tooling import ToolExecutionContext, ToolInvocation
from src.plugin_runtime.component_query import ComponentQueryService
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
def _build_plugin() -> MaiBotPlugin:
"""构造已注入默认配置的插件实例。"""
plugin = create_plugin()
plugin.set_plugin_config(plugin.get_default_config())
return plugin
def test_mute_plugin_manifest_is_valid_v2() -> None:
"""MutePlugin 的 manifest 应符合当前运行时要求。"""
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.3.0")
manifest = validator.load_from_plugin_path(Path("plugins/MutePlugin"))
assert manifest is not None
assert manifest.id == "sengokucola.mute-plugin"
assert manifest.manifest_version == 2
def test_create_plugin_returns_sdk_plugin() -> None:
"""插件入口应返回 SDK 插件实例。"""
plugin = create_plugin()
assert isinstance(plugin, MaiBotPlugin)
@pytest.mark.asyncio
async def test_mute_command_calls_napcat_group_ban_api() -> None:
"""手动禁言命令应通过 NapCat Adapter 新 API 执行。"""
plugin = _build_plugin()
plugin.set_plugin_config(
{
**plugin.get_default_config(),
"components": {
"enable_smart_mute": True,
"enable_mute_command": True,
},
}
)
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call":
return {"success": True, "result": {"status": "ok", "retcode": 0}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message, intercept = await plugin.handle_mute_command(
stream_id="group-10001",
group_id="10001",
user_id="42",
matched_groups={
"target": "张三",
"duration": "120",
"reason": "刷屏",
},
)
assert success is True
assert message == "成功禁言 张三"
assert intercept is True
api_call = next(
call
for call in capability_calls
if call["capability"] == "api.call"
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
)
assert api_call["args"]["version"] == "1"
assert api_call["args"]["args"] == {
"group_id": "10001",
"user_id": "123456",
"duration": 120,
}
@pytest.mark.asyncio
async def test_mute_tool_requires_target_person_name() -> None:
"""禁言工具在缺少目标时应直接失败并提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
return {"success": True}
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="",
duration="60",
reason="测试",
)
assert success is False
assert message == "禁言目标不能为空"
assert capability_calls[-1]["capability"] == "send.text"
assert capability_calls[-1]["args"]["text"] == "没有指定禁言对象哦"
@pytest.mark.asyncio
async def test_mute_tool_can_unwrap_nested_person_user_id_response() -> None:
"""禁言工具应能兼容解包多层 capability 返回结果。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "result": {"success": True, "person_id": "person-1"}}
if capability == "person.get_value":
return {"success": True, "result": {"success": True, "value": "123456"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call":
return {"success": True, "result": {"status": "ok"}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is True
assert message == "成功禁言 张三"
api_call = next(
call
for call in capability_calls
if call["capability"] == "api.call"
and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
)
assert api_call["args"]["args"]["user_id"] == "123456"
@pytest.mark.asyncio
async def test_mute_tool_rejects_owner_before_group_ban_call() -> None:
"""禁言工具应在检测到群主时提前返回明确提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "owner"}}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is False
assert message == "张三 是群主,不能被禁言"
assert not any(
call["capability"] == "api.call" and call["args"]["api_name"] == "adapter.napcat.group.set_group_ban"
for call in capability_calls
)
@pytest.mark.asyncio
async def test_mute_tool_maps_cannot_ban_owner_error_message() -> None:
"""NapCat 返回 cannot ban owner 时应转成明确中文提示。"""
plugin = _build_plugin()
capability_calls: List[Dict[str, Any]] = []
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability_calls.append(payload)
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
return {"success": False, "error": "NapCat 动作返回失败: action=set_group_ban message=cannot ban owner"}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is False
assert message == "张三 是群主,不能被禁言"
@pytest.mark.asyncio
async def test_mute_tool_accepts_nested_ok_api_result() -> None:
"""嵌套的 success/result/status=ok 返回值也应判定为成功。"""
plugin = _build_plugin()
async def fake_rpc_call(method: str, plugin_id: str = "", payload: Dict[str, Any] | None = None) -> Dict[str, Any]:
assert method == "cap.call"
assert payload is not None
capability = payload["capability"]
if capability == "person.get_id_by_name":
return {"success": True, "person_id": "person-1"}
if capability == "person.get_value":
return {"success": True, "value": "123456"}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.get_group_member_info":
return {"success": True, "result": {"role": "member"}}
if capability == "api.call" and payload["args"]["api_name"] == "adapter.napcat.group.set_group_ban":
return {
"success": True,
"result": {
"status": "ok",
"retcode": 0,
"data": None,
"message": "",
"wording": "",
},
}
if capability == "send.text":
return {"success": True}
raise AssertionError(f"unexpected capability: {capability}")
plugin._set_context(PluginContext(plugin_id="mute", rpc_call=fake_rpc_call))
success, message = await plugin.handle_mute_tool(
stream_id="group-10001",
group_id="10001",
target="张三",
duration=60,
reason="测试",
)
assert success is True
assert message == "成功禁言 张三"
def test_tool_invocation_payload_injects_group_and_user_context() -> None:
"""插件工具执行时应自动补齐群聊上下文字段。"""
entry = SimpleNamespace(invoke_method="plugin.invoke_tool")
anchor_message = SimpleNamespace(
message_info=SimpleNamespace(
group_info=SimpleNamespace(group_id="10001"),
user_info=SimpleNamespace(user_id="20002"),
)
)
invocation = ToolInvocation(tool_name="mute", arguments={"target": "张三"}, stream_id="session-1")
context = ToolExecutionContext(
session_id="session-1",
stream_id="session-1",
reasoning="test",
metadata={"anchor_message": anchor_message},
)
payload = ComponentQueryService._build_tool_invocation_payload(entry, invocation, context)
assert payload["target"] == "张三"
assert payload["stream_id"] == "session-1"
assert payload["chat_id"] == "session-1"
assert payload["group_id"] == "10001"
assert payload["user_id"] == "20002"

View File

@@ -1,227 +0,0 @@
from pathlib import Path
from types import ModuleType, SimpleNamespace
import importlib.util
import sys
class DummyLogger:
def __init__(self) -> None:
self.warning_messages: list[str] = []
def debug(self, _msg: str) -> None:
return
def info(self, _msg: str) -> None:
return
def warning(self, msg: str) -> None:
self.warning_messages.append(msg)
def error(self, _msg: str) -> None:
return
def load_utils_module(monkeypatch, qq_account=123456, platforms=None):
logger = DummyLogger()
configured_platforms = platforms or []
def _stub_module(name: str) -> ModuleType:
module = ModuleType(name)
monkeypatch.setitem(sys.modules, name, module)
return module
for package_name in [
"src",
"src.chat",
"src.chat.message_receive",
"src.chat.utils",
"src.common",
"src.config",
"src.llm_models",
"src.person_info",
]:
if package_name not in sys.modules:
package_module = ModuleType(package_name)
package_module.__path__ = []
monkeypatch.setitem(sys.modules, package_name, package_module)
jieba_module = ModuleType("jieba")
jieba_module.cut = lambda text: list(text)
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
logger_module = _stub_module("src.common.logger")
logger_module.get_logger = lambda _name: logger
config_module = _stub_module("src.config.config")
config_module.global_config = SimpleNamespace(
bot=SimpleNamespace(
qq_account=qq_account,
platforms=configured_platforms,
nickname="MaiBot",
alias_names=[],
),
chat=SimpleNamespace(
at_bot_inevitable_reply=1,
mentioned_bot_reply=1,
),
)
config_module.model_config = SimpleNamespace()
message_module = _stub_module("src.chat.message_receive.message")
class SessionMessage:
pass
message_module.SessionMessage = SessionMessage
chat_manager_module = _stub_module("src.chat.message_receive.chat_manager")
chat_manager_module.chat_manager = SimpleNamespace(get_session_by_session_id=lambda _chat_id: None)
llm_module = _stub_module("src.llm_models.utils_model")
class LLMRequest:
def __init__(self, *args, **kwargs) -> None:
del args, kwargs
llm_module.LLMRequest = LLMRequest
person_module = _stub_module("src.person_info.person_info")
class Person:
pass
person_module.Person = Person
typo_generator_module = _stub_module("src.chat.utils.typo_generator")
class ChineseTypoGenerator:
def __init__(self, *args, **kwargs) -> None:
del args, kwargs
def create_typo_sentence(self, sentence: str):
return sentence, ""
typo_generator_module.ChineseTypoGenerator = ChineseTypoGenerator
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "utils" / "utils.py"
spec = importlib.util.spec_from_file_location("src.chat.utils.utils", file_path)
utils_module = importlib.util.module_from_spec(spec)
utils_module.__package__ = "src.chat.utils"
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
assert spec.loader is not None
spec.loader.exec_module(utils_module)
return utils_module, logger
def test_platform_specific_bot_accounts(monkeypatch):
utils_module, _logger = load_utils_module(
monkeypatch,
qq_account=123456,
platforms=[" TG : tg_bot ", "discord: disc_bot"],
)
assert utils_module.get_bot_account("qq") == "123456"
assert utils_module.get_bot_account("webui") == "123456"
assert utils_module.get_bot_account("telegram") == "tg_bot"
assert utils_module.get_bot_account("tg") == "tg_bot"
assert utils_module.get_bot_account("discord") == "disc_bot"
assert utils_module.is_bot_self("qq", "123456")
assert utils_module.is_bot_self("webui", "123456")
assert utils_module.is_bot_self("telegram", "tg_bot")
assert utils_module.is_bot_self(" TG ", "tg_bot")
def test_get_all_bot_accounts_includes_runtime_aliases(monkeypatch):
utils_module, _logger = load_utils_module(
monkeypatch,
qq_account=123456,
platforms=["TG:tg_bot", "discord:disc_bot"],
)
assert utils_module.get_all_bot_accounts() == {
"qq": "123456",
"webui": "123456",
"telegram": "tg_bot",
"tg": "tg_bot",
"discord": "disc_bot",
}
def test_get_all_bot_accounts_keeps_canonical_qq_identity(monkeypatch):
utils_module, _logger = load_utils_module(
monkeypatch,
qq_account=123456,
platforms=["qq:999999", "webui:888888", "TG:tg_bot"],
)
assert utils_module.get_all_bot_accounts()["qq"] == "123456"
assert utils_module.get_all_bot_accounts()["webui"] == "123456"
def test_unknown_platform_no_longer_falls_back_to_qq(monkeypatch):
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
assert utils_module.is_bot_self("unknown_platform", "123456") is False
assert logger.warning_messages
assert "unknown_platform" in logger.warning_messages[-1]
def test_unknown_platform_warns_only_once(monkeypatch):
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
assert utils_module.is_bot_self("unknown_platform", "first") is False
assert utils_module.is_bot_self(" unknown_platform ", "second") is False
assert len(logger.warning_messages) == 1
def test_unconfigured_qq_account_disables_qq_and_webui_identity(monkeypatch):
utils_module, _logger = load_utils_module(monkeypatch, qq_account=0, platforms=["telegram:tg_bot"])
assert utils_module.get_bot_account("qq") == ""
assert utils_module.get_bot_account("webui") == ""
assert utils_module.is_bot_self("qq", "0") is False
assert utils_module.is_bot_self("webui", "0") is False
def test_is_mentioned_bot_in_message_uses_platform_account(monkeypatch):
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=["TG:tg_bot"])
message = SimpleNamespace(
processed_plain_text="@tg_bot 你好",
platform="telegram",
is_mentioned=False,
message_segment=None,
message_info=SimpleNamespace(
additional_config={},
user_info=SimpleNamespace(user_id="user_1"),
),
)
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
assert is_mentioned is True
assert is_at is True
assert reply_probability == 1.0
def test_is_mentioned_bot_in_message_normalizes_qq_platform(monkeypatch):
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
message = SimpleNamespace(
processed_plain_text="@<MaiBot:123456> 你好",
platform=" QQ ",
is_mentioned=False,
message_segment=None,
message_info=SimpleNamespace(
additional_config={},
user_info=SimpleNamespace(user_id="user_1"),
),
)
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
assert is_mentioned is True
assert is_at is True
assert reply_probability == 1.0