from __future__ import annotations from pathlib import Path from time import monotonic, sleep from typing import Any, Dict, Generator from uuid import uuid4 import asyncio import json from fastapi import FastAPI from fastapi.testclient import TestClient import pytest import tomlkit from src.A_memorix import host_service as host_service_module from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module from src.webui.dependencies import require_auth from src.webui.routers import memory as memory_router_module REQUEST_TIMEOUT_SECONDS = 30 IMPORT_TIMEOUT_SECONDS = 120 TUNING_TIMEOUT_SECONDS = 420 IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "cancelled"} TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"} class _FakeEmbeddingManager: def __init__(self, dimension: int = 64) -> None: self.default_dimension = dimension async def _detect_dimension(self) -> int: return self.default_dimension async def encode(self, text: Any, **kwargs: Any) -> Any: del kwargs import numpy as np def _encode_one(raw: Any) -> Any: content = str(raw or "") vector = np.zeros(self.default_dimension, dtype=np.float32) for index, byte in enumerate(content.encode("utf-8")): vector[index % self.default_dimension] += float((byte % 17) + 1) norm = float(np.linalg.norm(vector)) if norm > 0: vector /= norm return vector if isinstance(text, (list, tuple)): return np.stack([_encode_one(item) for item in text]).astype(np.float32) return _encode_one(text).astype(np.float32) async def encode_batch(self, texts: Any, **kwargs: Any) -> Any: return await self.encode(texts, **kwargs) def _build_test_config(data_dir: Path) -> Dict[str, Any]: return { "storage": { "data_dir": str(data_dir), }, "advanced": { "enable_auto_save": False, }, "embedding": { "dimension": 64, "batch_size": 4, "max_concurrent": 1, "retry": { "max_attempts": 1, "min_wait_seconds": 0.1, "max_wait_seconds": 0.2, "backoff_multiplier": 1.0, }, "fallback": { "enabled": True, "allow_metadata_only_write": True, "probe_interval_seconds": 30, }, "paragraph_vector_backfill": { "enabled": False, "interval_seconds": 60, "batch_size": 32, "max_retry": 2, }, }, "retrieval": { "enable_parallel": False, "enable_ppr": False, "top_k_paragraphs": 20, "top_k_relations": 10, "top_k_final": 10, "alpha": 0.5, "search": { "smart_fallback": { "enabled": True, }, }, "sparse": { "enabled": True, "mode": "auto", "candidate_k": 80, "relation_candidate_k": 60, }, "fusion": { "method": "weighted_rrf", "rrf_k": 60, "vector_weight": 0.7, "bm25_weight": 0.3, }, }, "threshold": { "percentile": 70.0, "min_results": 1, }, "web": { "tuning": { "enabled": True, "poll_interval_ms": 300, "max_queue_size": 4, "default_objective": "balanced", "default_intensity": "quick", "default_sample_size": 4, "default_top_k_eval": 5, "eval_query_timeout_seconds": 1.0, "llm_retry": { "max_attempts": 1, "min_wait_seconds": 0.1, "max_wait_seconds": 0.2, "backoff_multiplier": 1.0, }, }, }, } def _assert_response_ok(response: Any) -> Dict[str, Any]: assert response.status_code == 200, response.text payload = response.json() assert payload.get("success", True) is True, payload return payload def _wait_for_import_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = IMPORT_TIMEOUT_SECONDS) -> Dict[str, Any]: deadline = monotonic() + timeout_seconds last_payload: Dict[str, Any] = {} while monotonic() < deadline: response = client.get( f"/api/webui/memory/import/tasks/{task_id}", params={"include_chunks": True}, ) payload = _assert_response_ok(response) last_payload = payload task = payload.get("task") or {} status = str(task.get("status", "") or "") if status in IMPORT_TERMINAL_STATUSES: return task sleep(0.2) raise AssertionError(f"导入任务超时: task_id={task_id}, last_payload={last_payload}") def _wait_for_tuning_task_terminal(client: TestClient, task_id: str, *, timeout_seconds: float = TUNING_TIMEOUT_SECONDS) -> Dict[str, Any]: deadline = monotonic() + timeout_seconds last_payload: Dict[str, Any] = {} while monotonic() < deadline: response = client.get( f"/api/webui/memory/retrieval_tuning/tasks/{task_id}", params={"include_rounds": False}, ) payload = _assert_response_ok(response) last_payload = payload task = payload.get("task") or {} status = str(task.get("status", "") or "") if status in TUNING_TERMINAL_STATUSES: return task sleep(0.3) raise AssertionError(f"调优任务超时: task_id={task_id}, last_payload={last_payload}") def _wait_for_query_hit(client: TestClient, query: str, *, timeout_seconds: float = 30.0) -> Dict[str, Any]: deadline = monotonic() + timeout_seconds last_payload: Dict[str, Any] = {} while monotonic() < deadline: payload = _assert_response_ok( client.get( "/api/webui/memory/query/aggregate", params={"query": query, "limit": 20}, ) ) last_payload = payload hits = payload.get("hits") or [] if isinstance(hits, list) and len(hits) > 0: return payload sleep(0.2) raise AssertionError(f"检索命中超时: query={query}, last_payload={last_payload}") def _get_source_item(client: TestClient, source_name: str) -> Dict[str, Any] | None: payload = _assert_response_ok(client.get("/api/webui/memory/sources")) items = payload.get("items") or [] for item in items: if not isinstance(item, dict): continue if str(item.get("source", "") or "") == source_name: return item return None def _source_paragraph_count(item: Dict[str, Any] | None) -> int: payload = item or {} if "paragraph_count" in payload: return int(payload.get("paragraph_count", 0) or 0) return int(payload.get("count", 0) or 0) def _wait_for_source_paragraph_count( client: TestClient, source_name: str, *, min_count: int, timeout_seconds: float = 30.0, ) -> Dict[str, Any]: deadline = monotonic() + timeout_seconds last_item: Dict[str, Any] = {} while monotonic() < deadline: item = _get_source_item(client, source_name) count = _source_paragraph_count(item) if count >= int(min_count): return item or {} if item: last_item = dict(item) sleep(0.2) raise AssertionError( f"等待来源段落计数超时: source={source_name}, min_count={min_count}, last_item={last_item}" ) def _create_multitype_upload_task(client: TestClient) -> str: structured_json = { "paragraphs": [ { "content": "Alice 携带地图前往火星港。", "source": "integration-upload-json", "entities": ["Alice", "地图", "火星港"], "relations": [ {"subject": "Alice", "predicate": "携带", "object": "地图"}, {"subject": "Alice", "predicate": "前往", "object": "火星港"}, ], } ] } extra_json = { "paragraphs": [ { "content": "Carol 记录了一条补充说明。", "source": "integration-upload-json-extra", "entities": ["Carol"], "relations": [], } ] } payload_json = json.dumps( { "input_mode": "text", "llm_enabled": False, "file_concurrency": 2, "chunk_concurrency": 2, "dedupe_policy": "none", }, ensure_ascii=False, ) files = [ ("files", ("integration-notes.txt", "Alice 在测试环境记录了一条长期记忆。".encode("utf-8"), "text/plain")), ("files", ("integration-diary.md", "# 日志\nBob 与 Alice 讨论了导图。".encode("utf-8"), "text/markdown")), ("files", ("integration-structured.json", json.dumps(structured_json, ensure_ascii=False).encode("utf-8"), "application/json")), ("files", ("integration-extra.json", json.dumps(extra_json, ensure_ascii=False).encode("utf-8"), "application/json")), ] response = client.post( "/api/webui/memory/import/upload", data={"payload_json": payload_json}, files=files, ) payload = _assert_response_ok(response) task_id = str((payload.get("task") or {}).get("task_id") or "").strip() assert task_id, payload return task_id def _create_seed_paste_task(client: TestClient, *, source: str, unique_token: str) -> str: seed_payload = { "paragraphs": [ { "content": f"Alice 在火星港携带地图并记录了口令 {unique_token}。", "source": source, "entities": ["Alice", "火星港", "地图"], "relations": [ {"subject": "Alice", "predicate": "前往", "object": "火星港"}, {"subject": "Alice", "predicate": "携带", "object": "地图"}, ], }, { "content": f"Bob 在火星港遇见 Alice,并重复口令 {unique_token}。", "source": source, "entities": ["Bob", "Alice", "火星港"], "relations": [ {"subject": "Bob", "predicate": "遇见", "object": "Alice"}, {"subject": "Bob", "predicate": "位于", "object": "火星港"}, ], }, ] } response = client.post( "/api/webui/memory/import/paste", json={ "name": "integration-seed.json", "input_mode": "json", "llm_enabled": False, "content": json.dumps(seed_payload, ensure_ascii=False), "dedupe_policy": "none", }, ) payload = _assert_response_ok(response) task_id = str((payload.get("task") or {}).get("task_id") or "").strip() assert task_id, payload return task_id @pytest.fixture(scope="module") def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dict[str, Any], None, None]: tmp_root = tmp_path_factory.mktemp("memory_routes_integration") data_dir = (tmp_root / "data").resolve() staging_dir = (tmp_root / "upload_staging").resolve() artifacts_dir = (tmp_root / "artifacts").resolve() config_file = (tmp_root / "config" / "bot_config.toml").resolve() runtime_config = _build_test_config(data_dir) patches = pytest.MonkeyPatch() patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config)) patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file) patches.setattr( kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingManager(dimension=64), ) patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir) patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir) asyncio.run(host_service_module.a_memorix_host_service.stop()) host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined] app = FastAPI() app.dependency_overrides[require_auth] = lambda: "ok" app.include_router(memory_router_module.router, prefix="/api/webui") app.include_router(memory_router_module.compat_router) unique_token = f"INTEG_TOKEN_{uuid4().hex[:12]}" source_name = f"integration-source-{uuid4().hex[:8]}" with TestClient(app) as client: upload_task_id = _create_multitype_upload_task(client) upload_task = _wait_for_import_task_terminal(client, upload_task_id) seed_task_id = _create_seed_paste_task(client, source=source_name, unique_token=unique_token) seed_task = _wait_for_import_task_terminal(client, seed_task_id) assert str(seed_task.get("status", "") or "") in {"completed", "completed_with_errors"}, seed_task _wait_for_query_hit(client, unique_token, timeout_seconds=45.0) yield { "client": client, "upload_task": upload_task, "seed_task": seed_task, "source_name": source_name, "unique_token": unique_token, } asyncio.run(host_service_module.a_memorix_host_service.stop()) host_service_module.a_memorix_host_service._config_cache = None # type: ignore[attr-defined] patches.undo() def test_import_module_end_to_end_supports_multitype_upload(integration_state: Dict[str, Any]) -> None: upload_task = integration_state["upload_task"] assert str(upload_task.get("status", "") or "") in {"completed", "completed_with_errors"}, upload_task files = upload_task.get("files") or [] assert isinstance(files, list) assert len(files) >= 4 file_names = {str(item.get("name", "") or "") for item in files if isinstance(item, dict)} assert "integration-notes.txt" in file_names assert "integration-diary.md" in file_names assert "integration-structured.json" in file_names assert "integration-extra.json" in file_names def test_retrieval_module_end_to_end_queries_seeded_data(integration_state: Dict[str, Any]) -> None: client = integration_state["client"] unique_token = integration_state["unique_token"] aggregate_payload = _wait_for_query_hit(client, unique_token, timeout_seconds=45.0) hits = aggregate_payload.get("hits") or [] joined_content = "\n".join(str(item.get("content", "") or "") for item in hits if isinstance(item, dict)) assert unique_token in joined_content graph_payload = _assert_response_ok( client.get( "/api/webui/memory/graph/search", params={"query": "Alice", "limit": 20}, ) ) graph_items = graph_payload.get("items") or [] assert isinstance(graph_items, list) assert any(str(item.get("type", "") or "") == "entity" for item in graph_items if isinstance(item, dict)), graph_items def test_tuning_module_end_to_end_create_and_apply_best(integration_state: Dict[str, Any]) -> None: client = integration_state["client"] create_payload = _assert_response_ok( client.post( "/api/webui/memory/retrieval_tuning/tasks", json={ "objective": "balanced", "intensity": "quick", "rounds": 2, "sample_size": 4, "top_k_eval": 5, "llm_enabled": False, "eval_query_timeout_seconds": 1.0, "seed": 20260403, }, ) ) task_id = str((create_payload.get("task") or {}).get("task_id") or "").strip() assert task_id, create_payload task = _wait_for_tuning_task_terminal(client, task_id) assert str(task.get("status", "") or "") == "completed", task apply_payload = _assert_response_ok( client.post( f"/api/webui/memory/retrieval_tuning/tasks/{task_id}/apply-best", ) ) assert "applied" in apply_payload def test_delete_module_end_to_end_preview_execute_restore(integration_state: Dict[str, Any]) -> None: client = integration_state["client"] unique_token = integration_state["unique_token"] source_name = integration_state["source_name"] before_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0) assert _source_paragraph_count(before_source_item) >= 1 preview_payload = _assert_response_ok( client.post( "/api/webui/memory/delete/preview", json={ "mode": "source", "selector": {"sources": [source_name]}, "reason": "integration_delete_preview", "requested_by": "pytest_integration", }, ) ) preview_counts = preview_payload.get("counts") or {} assert int(preview_counts.get("paragraphs", 0) or 0) >= 1, preview_payload execute_payload = _assert_response_ok( client.post( "/api/webui/memory/delete/execute", json={ "mode": "source", "selector": {"sources": [source_name]}, "reason": "integration_delete_execute", "requested_by": "pytest_integration", }, ) ) operation_id = str(execute_payload.get("operation_id", "") or "").strip() assert operation_id, execute_payload after_delete_payload = _assert_response_ok( client.get( "/api/webui/memory/query/aggregate", params={"query": unique_token, "limit": 20}, ) ) after_delete_hits = after_delete_payload.get("hits") or [] after_delete_text = "\n".join( str(item.get("content", "") or "") for item in after_delete_hits if isinstance(item, dict) ) assert unique_token not in after_delete_text assert int(execute_payload.get("deleted_paragraph_count", 0) or 0) >= 1, execute_payload _assert_response_ok( client.post( "/api/webui/memory/delete/restore", json={ "operation_id": operation_id, "requested_by": "pytest_integration", }, ) ) restored_source_item = _wait_for_source_paragraph_count(client, source_name, min_count=1, timeout_seconds=45.0) assert _source_paragraph_count(restored_source_item) >= 1 operations_payload = _assert_response_ok( client.get( "/api/webui/memory/delete/operations", params={"limit": 20, "mode": "source"}, ) ) operation_items = operations_payload.get("items") or [] operation_ids = { str(item.get("operation_id", "") or "") for item in operation_items if isinstance(item, dict) } assert operation_id in operation_ids operation_detail_payload = _assert_response_ok(client.get(f"/api/webui/memory/delete/operations/{operation_id}")) detail_operation = operation_detail_payload.get("operation") or {} assert str(detail_operation.get("status", "") or "") == "restored"