Files
mai-bot/pytests/webui/test_memory_routes_integration.py
DawnARC da95b06f96 feat:完善长期记忆控制台导入链路与联调测试
summary:\n- 扩展长期记忆控制台导入、调优与删除相关 UI/接口,补充中文化展示与任务细粒度状态管理\n- 强化 memory API 与后端路由能力,补齐导入任务、图谱检索、配置与运行态相关字段\n- 新增与增强前后端测试,覆盖导入多文件类型、检索、调优、删除及图谱查询关键路径

description:\n- dashboard: 重构 knowledge-base 页面与 memory-api,统一任务队列、分块分页、来源删除恢复、调优闭环交互\n- backend: 扩展 webui memory 路由与 A_Memorix 内核检索逻辑,完善服务侧能力与配置 schema\n- tests: 增加 webui 集成测试和 kernel 单测,提升导入/检索/调优/删除全流程回归保障
2026-04-03 19:50:08 +08:00

500 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from pathlib import Path
from time import monotonic, sleep
from typing import Any, Dict, Generator
from uuid import uuid4
import asyncio
import json
from fastapi import FastAPI
from fastapi.testclient import TestClient
import pytest
import tomlkit
from src.A_memorix import host_service as host_service_module
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
from src.webui.dependencies import require_auth
from src.webui.routers import memory as memory_router_module
REQUEST_TIMEOUT_SECONDS = 30
IMPORT_TIMEOUT_SECONDS = 120
TUNING_TIMEOUT_SECONDS = 420
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
return {
"storage": {
"data_dir": str(data_dir),
},
"advanced": {
"enable_auto_save": False,
},
"embedding": {
"dimension": 64,
"batch_size": 4,
"max_concurrent": 1,
"retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
"fallback": {
"enabled": True,
"allow_metadata_only_write": True,
"probe_interval_seconds": 30,
},
"paragraph_vector_backfill": {
"enabled": False,
"interval_seconds": 60,
"batch_size": 32,
"max_retry": 2,
},
},
"retrieval": {
"enable_parallel": False,
"enable_ppr": False,
"top_k_paragraphs": 20,
"top_k_relations": 10,
"top_k_final": 10,
"alpha": 0.5,
"search": {
"smart_fallback": {
"enabled": True,
},
},
"sparse": {
"enabled": True,
"mode": "auto",
"candidate_k": 80,
"relation_candidate_k": 60,
},
"fusion": {
"method": "weighted_rrf",
"rrf_k": 60,
"vector_weight": 0.7,
"bm25_weight": 0.3,
},
},
"threshold": {
"percentile": 70.0,
"min_results": 1,
},
"web": {
"tuning": {
"enabled": True,
"poll_interval_ms": 300,
"max_queue_size": 4,
"default_objective": "balanced",
"default_intensity": "quick",
"default_sample_size": 4,
"default_top_k_eval": 5,
"eval_query_timeout_seconds": 1.0,
"llm_retry": {
"max_attempts": 1,
"min_wait_seconds": 0.1,
"max_wait_seconds": 0.2,
"backoff_multiplier": 1.0,
},
},
},
}
def _assert_response_ok(response: Any) -> Dict[str, Any]:
assert response.status_code == 200, response.text
payload = response.json()
assert payload.get("success", True) is True, payload
return payload
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/import/tasks/{task_id}",
params={"include_chunks": True},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in IMPORT_TERMINAL_STATUSES:
return task
sleep(0.2)
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
response = client.get(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
params={"include_rounds": False},
)
payload = _assert_response_ok(response)
last_payload = payload
task = payload.get("task") or {}
status = str(task.get("status", "") or "")
if status in TUNING_TERMINAL_STATUSES:
return task
sleep(0.3)
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_payload: Dict[str, Any] = {}
while monotonic() < deadline:
payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": query, "limit": 20},
)
)
last_payload = payload
hits = payload.get("hits") or []
if isinstance(hits, list) and len(hits) > 0:
return payload
sleep(0.2)
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
items = payload.get("items") or []
for item in items:
if not isinstance(item, dict):
continue
if str(item.get("source", "") or "") == source_name:
return item
return None
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
payload = item or {}
if "paragraph_count" in payload:
return int(payload.get("paragraph_count", 0) or 0)
return int(payload.get("count", 0) or 0)
def _wait_for_source_paragraph_count(
client: TestClient,
source_name: str,
*,
min_count: int,
timeout_seconds: float = 30.0,
) -> Dict[str, Any]:
deadline = monotonic() + timeout_seconds
last_item: Dict[str, Any] = {}
while monotonic() < deadline:
item = _get_source_item(client, source_name)
count = _source_paragraph_count(item)
if count >= int(min_count):
return item or {}
if item:
last_item = dict(item)
sleep(0.2)
raise AssertionError(
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
)
def _create_multitype_upload_task(client: TestClient) -> str:
structured_json = {
"paragraphs": [
{
"content": "Alice 携带地图前往火星港。",
"source": "integration-upload-json",
"entities": ["Alice", "地图", "火星港"],
"relations": [
{"subject": "Alice", "predicate": "携带", "object": "地图"},
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
],
}
]
}
extra_json = {
"paragraphs": [
{
"content": "Carol 记录了一条补充说明。",
"source": "integration-upload-json-extra",
"entities": ["Carol"],
"relations": [],
}
]
}
payload_json = json.dumps(
{
"input_mode": "text",
"llm_enabled": False,
"file_concurrency": 2,
"chunk_concurrency": 2,
"dedupe_policy": "none",
},
ensure_ascii=False,
)
files = [
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
]
response = client.post(
"/api/webui/memory/import/upload",
data={"payload_json": payload_json},
files=files,
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
seed_payload = {
"paragraphs": [
{
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}",
"source": source,
"entities": ["Alice", "火星港", "地图"],
"relations": [
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
{"subject": "Alice", "predicate": "携带", "object": "地图"},
],
},
{
"content": f"Bob 在火星港遇见 Alice并重复口令 {unique_token}",
"source": source,
"entities": ["Bob", "Alice", "火星港"],
"relations": [
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
],
},
]
}
response = client.post(
"/api/webui/memory/import/paste",
json={
"name": "integration-seed.json",
"input_mode": "json",
"llm_enabled": False,
"content": json.dumps(seed_payload, ensure_ascii=False),
"dedupe_policy": "none",
},
)
payload = _assert_response_ok(response)
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, payload
return task_id
@pytest.fixture(scope="module")
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
data_dir = (tmp_root / "data").resolve()
staging_dir = (tmp_root / "upload_staging").resolve()
artifacts_dir = (tmp_root / "artifacts").resolve()
config_file = (tmp_root / "config" / "a_memorix.toml").resolve()
config_file.parent.mkdir(parents=True, exist_ok=True)
config_file.write_text(tomlkit.dumps(_build_test_config(data_dir)), encoding="utf-8")
patches = pytest.MonkeyPatch()
patches.setattr(host_service_module, "config_path", lambda: config_file)
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
app = FastAPI()
app.dependency_overrides[require_auth] = lambda: "ok"
app.include_router(memory_router_module.router, prefix="/api/webui")
app.include_router(memory_router_module.compat_router)
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
source_name = f"integration-source-{uuid4().hex[:8]}"
with TestClient(app) as client:
upload_task_id = _create_multitype_upload_task(client)
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
yield {
"client": client,
"upload_task": upload_task,
"seed_task": seed_task,
"source_name": source_name,
"unique_token": unique_token,
}
asyncio.run(host_service_module.a_memorix_host_service.stop())
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
patches.undo()
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
upload_task = integration_state["upload_task"]
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
files = upload_task.get("files") or []
assert isinstance(files, list)
assert len(files) >= 4
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
assert "integration-notes.txt" in file_names
assert "integration-diary.md" in file_names
assert "integration-structured.json" in file_names
assert "integration-extra.json" in file_names
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
hits = aggregate_payload.get("hits") or []
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
assert unique_token in joined_content
graph_payload = _assert_response_ok(
client.get(
"/api/webui/memory/graph/search",
params={"query": "Alice", "limit": 20},
)
)
graph_items = graph_payload.get("items") or []
assert isinstance(graph_items, list)
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
create_payload = _assert_response_ok(
client.post(
"/api/webui/memory/retrieval_tuning/tasks",
json={
"objective": "balanced",
"intensity": "quick",
"rounds": 2,
"sample_size": 4,
"top_k_eval": 5,
"llm_enabled": False,
"eval_query_timeout_seconds": 1.0,
"seed": 20260403,
},
)
)
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
assert task_id, create_payload
task = _wait_for_tuning_task_terminal(client, task_id)
assert str(task.get("status", "") or "") == "completed", task
apply_payload = _assert_response_ok(
client.post(
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
)
)
assert "applied" in apply_payload
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
client = integration_state["client"]
unique_token = integration_state["unique_token"]
source_name = integration_state["source_name"]
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(before_source_item) >= 1
preview_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/preview",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_preview",
"requested_by": "pytest_integration",
},
)
)
preview_counts = preview_payload.get("counts") or {}
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
execute_payload = _assert_response_ok(
client.post(
"/api/webui/memory/delete/execute",
json={
"mode": "source",
"selector": {"sources": [source_name]},
"reason": "integration_delete_execute",
"requested_by": "pytest_integration",
},
)
)
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
assert operation_id, execute_payload
after_delete_payload = _assert_response_ok(
client.get(
"/api/webui/memory/query/aggregate",
params={"query": unique_token, "limit": 20},
)
)
after_delete_hits = after_delete_payload.get("hits") or []
after_delete_text = "\n".join(
str(item.get("content", "") or "")
for item in after_delete_hits
if isinstance(item, dict)
)
assert unique_token not in after_delete_text
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
_assert_response_ok(
client.post(
"/api/webui/memory/delete/restore",
json={
"operation_id": operation_id,
"requested_by": "pytest_integration",
},
)
)
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
assert _source_paragraph_count(restored_source_item) >= 1
operations_payload = _assert_response_ok(
client.get(
"/api/webui/memory/delete/operations",
params={"limit": 20, "mode": "source"},
)
)
operation_items = operations_payload.get("items") or []
operation_ids = {
str(item.get("operation_id", "") or "")
for item in operation_items
if isinstance(item, dict)
}
assert operation_id in operation_ids
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
detail_operation = operation_detail_payload.get("operation") or {}
assert str(detail_operation.get("status", "") or "") == "restored"