534 lines
19 KiB
Python
534 lines
19 KiB
Python
from __future__ import annotations
|
||
|
||
from pathlib import Path
|
||
from time import monotonic, sleep
|
||
from typing import Any, Dict, Generator
|
||
from uuid import uuid4
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
from fastapi import FastAPI
|
||
from fastapi.testclient import TestClient
|
||
import pytest
|
||
import tomlkit
|
||
|
||
from src.A_memorix import host_service as host_service_module
|
||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
|
||
from src.webui.dependencies import require_auth
|
||
from src.webui.routers import memory as memory_router_module
|
||
|
||
|
||
REQUEST_TIMEOUT_SECONDS = 30
|
||
IMPORT_TIMEOUT_SECONDS = 120
|
||
TUNING_TIMEOUT_SECONDS = 420
|
||
|
||
IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"}
|
||
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
|
||
|
||
|
||
class _FakeEmbeddingManager:
|
||
def __init__(self, dimension: int = 64) -> None:
|
||
self.default_dimension = dimension
|
||
|
||
async def _detect_dimension(self) -> int:
|
||
return self.default_dimension
|
||
|
||
async def encode(self, text: Any, **kwargs: Any) -> Any:
|
||
del kwargs
|
||
import numpy as np
|
||
|
||
def _encode_one(raw: Any) -> Any:
|
||
content = str(raw or "")
|
||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||
for index, byte in enumerate(content.encode("utf-8")):
|
||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||
norm = float(np.linalg.norm(vector))
|
||
if norm > 0:
|
||
vector /= norm
|
||
return vector
|
||
|
||
if isinstance(text, (list, tuple)):
|
||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||
return _encode_one(text).astype(np.float32)
|
||
|
||
async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
|
||
return await self.encode(texts, **kwargs)
|
||
|
||
|
||
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
|
||
return {
|
||
"storage": {
|
||
"data_dir": str(data_dir),
|
||
},
|
||
"advanced": {
|
||
"enable_auto_save": False,
|
||
},
|
||
"embedding": {
|
||
"dimension": 64,
|
||
"batch_size": 4,
|
||
"max_concurrent": 1,
|
||
"retry": {
|
||
"max_attempts": 1,
|
||
"min_wait_seconds": 0.1,
|
||
"max_wait_seconds": 0.2,
|
||
"backoff_multiplier": 1.0,
|
||
},
|
||
"fallback": {
|
||
"enabled": True,
|
||
"allow_metadata_only_write": True,
|
||
"probe_interval_seconds": 30,
|
||
},
|
||
"paragraph_vector_backfill": {
|
||
"enabled": False,
|
||
"interval_seconds": 60,
|
||
"batch_size": 32,
|
||
"max_retry": 2,
|
||
},
|
||
},
|
||
"retrieval": {
|
||
"enable_parallel": False,
|
||
"enable_ppr": False,
|
||
"top_k_paragraphs": 20,
|
||
"top_k_relations": 10,
|
||
"top_k_final": 10,
|
||
"alpha": 0.5,
|
||
"search": {
|
||
"smart_fallback": {
|
||
"enabled": True,
|
||
},
|
||
},
|
||
"sparse": {
|
||
"enabled": True,
|
||
"mode": "auto",
|
||
"candidate_k": 80,
|
||
"relation_candidate_k": 60,
|
||
},
|
||
"fusion": {
|
||
"method": "weighted_rrf",
|
||
"rrf_k": 60,
|
||
"vector_weight": 0.7,
|
||
"bm25_weight": 0.3,
|
||
},
|
||
},
|
||
"threshold": {
|
||
"percentile": 70.0,
|
||
"min_results": 1,
|
||
},
|
||
"web": {
|
||
"tuning": {
|
||
"enabled": True,
|
||
"poll_interval_ms": 300,
|
||
"max_queue_size": 4,
|
||
"default_objective": "balanced",
|
||
"default_intensity": "quick",
|
||
"default_sample_size": 4,
|
||
"default_top_k_eval": 5,
|
||
"eval_query_timeout_seconds": 1.0,
|
||
"llm_retry": {
|
||
"max_attempts": 1,
|
||
"min_wait_seconds": 0.1,
|
||
"max_wait_seconds": 0.2,
|
||
"backoff_multiplier": 1.0,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
|
||
|
||
def _assert_response_ok(response: Any) -> Dict[str, Any]:
|
||
assert response.status_code == 200, response.text
|
||
payload = response.json()
|
||
assert payload.get("success", True) is True, payload
|
||
return payload
|
||
|
||
|
||
def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
||
deadline = monotonic() + timeout_seconds
|
||
last_payload: Dict[str, Any] = {}
|
||
while monotonic() < deadline:
|
||
response = client.get(
|
||
f"/api/webui/memory/import/tasks/{task_id}",
|
||
params={"include_chunks": True},
|
||
)
|
||
payload = _assert_response_ok(response)
|
||
last_payload = payload
|
||
task = payload.get("task") or {}
|
||
status = str(task.get("status", "") or "")
|
||
if status in IMPORT_TERMINAL_STATUSES:
|
||
return task
|
||
sleep(0.2)
|
||
raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}")
|
||
|
||
|
||
def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]:
|
||
deadline = monotonic() + timeout_seconds
|
||
last_payload: Dict[str, Any] = {}
|
||
while monotonic() < deadline:
|
||
response = client.get(
|
||
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}",
|
||
params={"include_rounds": False},
|
||
)
|
||
payload = _assert_response_ok(response)
|
||
last_payload = payload
|
||
task = payload.get("task") or {}
|
||
status = str(task.get("status", "") or "")
|
||
if status in TUNING_TERMINAL_STATUSES:
|
||
return task
|
||
sleep(0.3)
|
||
raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}")
|
||
|
||
|
||
def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]:
|
||
deadline = monotonic() + timeout_seconds
|
||
last_payload: Dict[str, Any] = {}
|
||
while monotonic() < deadline:
|
||
payload = _assert_response_ok(
|
||
client.get(
|
||
"/api/webui/memory/query/aggregate",
|
||
params={"query": query, "limit": 20},
|
||
)
|
||
)
|
||
last_payload = payload
|
||
hits = payload.get("hits") or []
|
||
if isinstance(hits, list) and len(hits) > 0:
|
||
return payload
|
||
sleep(0.2)
|
||
raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}")
|
||
|
||
|
||
def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None:
|
||
payload = _assert_response_ok(client.get("/api/webui/memory/sources"))
|
||
items = payload.get("items") or []
|
||
for item in items:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if str(item.get("source", "") or "") == source_name:
|
||
return item
|
||
return None
|
||
|
||
|
||
def _source_paragraph_count(item: Dict[str, Any] | None) -> int:
|
||
payload = item or {}
|
||
if "paragraph_count" in payload:
|
||
return int(payload.get("paragraph_count", 0) or 0)
|
||
return int(payload.get("count", 0) or 0)
|
||
|
||
|
||
def _wait_for_source_paragraph_count(
|
||
client: TestClient,
|
||
source_name: str,
|
||
*,
|
||
min_count: int,
|
||
timeout_seconds: float = 30.0,
|
||
) -> Dict[str, Any]:
|
||
deadline = monotonic() + timeout_seconds
|
||
last_item: Dict[str, Any] = {}
|
||
while monotonic() < deadline:
|
||
item = _get_source_item(client, source_name)
|
||
count = _source_paragraph_count(item)
|
||
if count >= int(min_count):
|
||
return item or {}
|
||
if item:
|
||
last_item = dict(item)
|
||
sleep(0.2)
|
||
raise AssertionError(
|
||
f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}"
|
||
)
|
||
|
||
|
||
def _create_multitype_upload_task(client: TestClient) -> str:
|
||
structured_json = {
|
||
"paragraphs": [
|
||
{
|
||
"content": "Alice 携带地图前往火星港。",
|
||
"source": "integration-upload-json",
|
||
"entities": ["Alice", "地图", "火星港"],
|
||
"relations": [
|
||
{"subject": "Alice", "predicate": "携带", "object": "地图"},
|
||
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
|
||
],
|
||
}
|
||
]
|
||
}
|
||
extra_json = {
|
||
"paragraphs": [
|
||
{
|
||
"content": "Carol 记录了一条补充说明。",
|
||
"source": "integration-upload-json-extra",
|
||
"entities": ["Carol"],
|
||
"relations": [],
|
||
}
|
||
]
|
||
}
|
||
payload_json = json.dumps(
|
||
{
|
||
"input_mode": "text",
|
||
"llm_enabled": False,
|
||
"file_concurrency": 2,
|
||
"chunk_concurrency": 2,
|
||
"dedupe_policy": "none",
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
files = [
|
||
("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")),
|
||
("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")),
|
||
("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")),
|
||
("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")),
|
||
]
|
||
|
||
response = client.post(
|
||
"/api/webui/memory/import/upload",
|
||
data={"payload_json": payload_json},
|
||
files=files,
|
||
)
|
||
payload = _assert_response_ok(response)
|
||
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
|
||
assert task_id, payload
|
||
return task_id
|
||
|
||
|
||
def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str:
|
||
seed_payload = {
|
||
"paragraphs": [
|
||
{
|
||
"content": f"Alice 在火星港携带地图并记录了口令 {unique_token}。",
|
||
"source": source,
|
||
"entities": ["Alice", "火星港", "地图"],
|
||
"relations": [
|
||
{"subject": "Alice", "predicate": "前往", "object": "火星港"},
|
||
{"subject": "Alice", "predicate": "携带", "object": "地图"},
|
||
],
|
||
},
|
||
{
|
||
"content": f"Bob 在火星港遇见 Alice,并重复口令 {unique_token}。",
|
||
"source": source,
|
||
"entities": ["Bob", "Alice", "火星港"],
|
||
"relations": [
|
||
{"subject": "Bob", "predicate": "遇见", "object": "Alice"},
|
||
{"subject": "Bob", "predicate": "位于", "object": "火星港"},
|
||
],
|
||
},
|
||
]
|
||
}
|
||
response = client.post(
|
||
"/api/webui/memory/import/paste",
|
||
json={
|
||
"name": "integration-seed.json",
|
||
"input_mode": "json",
|
||
"llm_enabled": False,
|
||
"content": json.dumps(seed_payload, ensure_ascii=False),
|
||
"dedupe_policy": "none",
|
||
},
|
||
)
|
||
payload = _assert_response_ok(response)
|
||
task_id = str((payload.get("task") or {}).get("task_id") or "").strip()
|
||
assert task_id, payload
|
||
return task_id
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]:
|
||
tmp_root = tmp_path_factory.mktemp("memory_routes_integration")
|
||
data_dir = (tmp_root / "data").resolve()
|
||
staging_dir = (tmp_root / "upload_staging").resolve()
|
||
artifacts_dir = (tmp_root / "artifacts").resolve()
|
||
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
|
||
runtime_config = _build_test_config(data_dir)
|
||
|
||
patches = pytest.MonkeyPatch()
|
||
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
|
||
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
|
||
patches.setattr(
|
||
kernel_module,
|
||
"create_embedding_api_adapter",
|
||
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
|
||
)
|
||
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
|
||
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
|
||
|
||
asyncio.run(host_service_module.a_memorix_host_service.stop())
|
||
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
|
||
|
||
app = FastAPI()
|
||
app.dependency_overrides[require_auth] = lambda: "ok"
|
||
app.include_router(memory_router_module.router, prefix="/api/webui")
|
||
app.include_router(memory_router_module.compat_router)
|
||
|
||
unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}"
|
||
source_name = f"integration-source-{uuid4().hex[:8]}"
|
||
|
||
with TestClient(app) as client:
|
||
upload_task_id = _create_multitype_upload_task(client)
|
||
upload_task = _wait_for_import_task_terminal(client, upload_task_id)
|
||
|
||
seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token)
|
||
seed_task = _wait_for_import_task_terminal(client, seed_task_id)
|
||
assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task
|
||
|
||
_wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
|
||
|
||
yield {
|
||
"client": client,
|
||
"upload_task": upload_task,
|
||
"seed_task": seed_task,
|
||
"source_name": source_name,
|
||
"unique_token": unique_token,
|
||
}
|
||
|
||
asyncio.run(host_service_module.a_memorix_host_service.stop())
|
||
host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined]
|
||
patches.undo()
|
||
|
||
|
||
def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None:
|
||
upload_task = integration_state["upload_task"]
|
||
|
||
assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task
|
||
files = upload_task.get("files") or []
|
||
assert isinstance(files, list)
|
||
assert len(files) >= 4
|
||
|
||
file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)}
|
||
assert "integration-notes.txt" in file_names
|
||
assert "integration-diary.md" in file_names
|
||
assert "integration-structured.json" in file_names
|
||
assert "integration-extra.json" in file_names
|
||
|
||
|
||
def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None:
|
||
client = integration_state["client"]
|
||
unique_token = integration_state["unique_token"]
|
||
|
||
aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0)
|
||
hits = aggregate_payload.get("hits") or []
|
||
joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict))
|
||
assert unique_token in joined_content
|
||
|
||
graph_payload = _assert_response_ok(
|
||
client.get(
|
||
"/api/webui/memory/graph/search",
|
||
params={"query": "Alice", "limit": 20},
|
||
)
|
||
)
|
||
graph_items = graph_payload.get("items") or []
|
||
assert isinstance(graph_items, list)
|
||
assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items
|
||
|
||
|
||
def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None:
|
||
client = integration_state["client"]
|
||
|
||
create_payload = _assert_response_ok(
|
||
client.post(
|
||
"/api/webui/memory/retrieval_tuning/tasks",
|
||
json={
|
||
"objective": "balanced",
|
||
"intensity": "quick",
|
||
"rounds": 2,
|
||
"sample_size": 4,
|
||
"top_k_eval": 5,
|
||
"llm_enabled": False,
|
||
"eval_query_timeout_seconds": 1.0,
|
||
"seed": 20260403,
|
||
},
|
||
)
|
||
)
|
||
task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip()
|
||
assert task_id, create_payload
|
||
|
||
task = _wait_for_tuning_task_terminal(client, task_id)
|
||
assert str(task.get("status", "") or "") == "completed", task
|
||
|
||
apply_payload = _assert_response_ok(
|
||
client.post(
|
||
f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best",
|
||
)
|
||
)
|
||
assert "applied" in apply_payload
|
||
|
||
|
||
def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None:
|
||
client = integration_state["client"]
|
||
unique_token = integration_state["unique_token"]
|
||
source_name = integration_state["source_name"]
|
||
|
||
before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
|
||
assert _source_paragraph_count(before_source_item) >= 1
|
||
|
||
preview_payload = _assert_response_ok(
|
||
client.post(
|
||
"/api/webui/memory/delete/preview",
|
||
json={
|
||
"mode": "source",
|
||
"selector": {"sources": [source_name]},
|
||
"reason": "integration_delete_preview",
|
||
"requested_by": "pytest_integration",
|
||
},
|
||
)
|
||
)
|
||
preview_counts = preview_payload.get("counts") or {}
|
||
assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload
|
||
|
||
execute_payload = _assert_response_ok(
|
||
client.post(
|
||
"/api/webui/memory/delete/execute",
|
||
json={
|
||
"mode": "source",
|
||
"selector": {"sources": [source_name]},
|
||
"reason": "integration_delete_execute",
|
||
"requested_by": "pytest_integration",
|
||
},
|
||
)
|
||
)
|
||
operation_id = str(execute_payload.get("operation_id", "") or "").strip()
|
||
assert operation_id, execute_payload
|
||
|
||
after_delete_payload = _assert_response_ok(
|
||
client.get(
|
||
"/api/webui/memory/query/aggregate",
|
||
params={"query": unique_token, "limit": 20},
|
||
)
|
||
)
|
||
after_delete_hits = after_delete_payload.get("hits") or []
|
||
after_delete_text = "\n".join(
|
||
str(item.get("content", "") or "")
|
||
for item in after_delete_hits
|
||
if isinstance(item, dict)
|
||
)
|
||
assert unique_token not in after_delete_text
|
||
assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload
|
||
|
||
_assert_response_ok(
|
||
client.post(
|
||
"/api/webui/memory/delete/restore",
|
||
json={
|
||
"operation_id": operation_id,
|
||
"requested_by": "pytest_integration",
|
||
},
|
||
)
|
||
)
|
||
|
||
restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0)
|
||
assert _source_paragraph_count(restored_source_item) >= 1
|
||
|
||
operations_payload = _assert_response_ok(
|
||
client.get(
|
||
"/api/webui/memory/delete/operations",
|
||
params={"limit": 20, "mode": "source"},
|
||
)
|
||
)
|
||
operation_items = operations_payload.get("items") or []
|
||
operation_ids = {
|
||
str(item.get("operation_id", "") or "")
|
||
for item in operation_items
|
||
if isinstance(item, dict)
|
||
}
|
||
assert operation_id in operation_ids
|
||
|
||
operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}"))
|
||
detail_operation = operation_detail_payload.get("operation") or {}
|
||
assert str(detail_operation.get("status", "") or "") == "restored"
|