chore: import deployable mai-bot source tree
This commit is contained in:
@@ -0,0 +1,398 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import pickle
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
from src.A_memorix.core.utils import summary_importer as summary_importer_module
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.message_repository import count_messages
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.services import memory_flow_service as memory_flow_service_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services import send_service
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
summary_importer_module = None # type: ignore[assignment]
|
||||
BotChatSession = None # type: ignore[assignment]
|
||||
SessionMessage = None # type: ignore[assignment]
|
||||
MessageInfo = None # type: ignore[assignment]
|
||||
UserInfo = None # type: ignore[assignment]
|
||||
MessageSequence = None # type: ignore[assignment]
|
||||
TextComponent = None # type: ignore[assignment]
|
||||
database_module = None # type: ignore[assignment]
|
||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
||||
count_messages = None # type: ignore[assignment]
|
||||
TaskConfig = None # type: ignore[assignment]
|
||||
memory_flow_service_module = None # type: ignore[assignment]
|
||||
memory_service_module = None # type: ignore[assignment]
|
||||
send_service = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 8) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any) -> np.ndarray:
|
||||
def _encode_one(raw: Any) -> np.ndarray:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None,
|
||||
*,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
class _NoopRuntimeManager:
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
||||
del hook_name
|
||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
||||
|
||||
|
||||
class _FakePlatformIOManager:
|
||||
def __init__(self) -> None:
|
||||
self.ensure_calls = 0
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
del message, metadata
|
||||
return SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
external_message_id="real-message-id",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
failed_receipts=[],
|
||||
route_key=route_key,
|
||||
)
|
||||
|
||||
|
||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
db_dir = (tmp_path / "main_db").resolve()
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_file = db_dir / "MaiBot.db"
|
||||
database_url = f"sqlite:///{db_file}"
|
||||
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
session_local = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
||||
|
||||
|
||||
def _build_incoming_message(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
text: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> SessionMessage:
|
||||
message = SessionMessage(
|
||||
message_id="incoming-message-id",
|
||||
timestamp=timestamp or datetime.now(),
|
||||
platform="qq",
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id=user_id,
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试用户",
|
||||
),
|
||||
additional_config={},
|
||||
)
|
||||
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
|
||||
message.session_id = session_id
|
||||
message.reply_to = None
|
||||
message.is_mentioned = False
|
||||
message.is_at = False
|
||||
message.is_emoji = False
|
||||
message.is_picture = False
|
||||
message.is_command = False
|
||||
message.is_notify = False
|
||||
message.processed_plain_text = text
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], Any],
|
||||
*,
|
||||
timeout_seconds: float = 10.0,
|
||||
interval_seconds: float = 0.05,
|
||||
description: str,
|
||||
) -> Any:
|
||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
value = predicate()
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
if value:
|
||||
return value
|
||||
await asyncio.sleep(interval_seconds)
|
||||
raise AssertionError(f"等待超时: {description}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_triggers_real_chat_summary_writeback(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
_install_temp_main_database(monkeypatch, tmp_path)
|
||||
|
||||
fake_embedding_manager = _FakeEmbeddingManager()
|
||||
captured_prompts: List[str] = []
|
||||
fixed_send_timestamp = 1_777_000_000.0
|
||||
|
||||
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
|
||||
del kwargs
|
||||
return {
|
||||
"ok": True,
|
||||
"message": "ok",
|
||||
"configured_dimension": fake_embedding_manager.default_dimension,
|
||||
"requested_dimension": fake_embedding_manager.default_dimension,
|
||||
"vector_store_dimension": fake_embedding_manager.default_dimension,
|
||||
"detected_dimension": fake_embedding_manager.default_dimension,
|
||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
||||
"elapsed_ms": 0.0,
|
||||
"sample_text": "test",
|
||||
"checked_at": datetime.now().timestamp(),
|
||||
}
|
||||
|
||||
async def _fake_generate(request: Any) -> Any:
|
||||
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
|
||||
return SimpleNamespace(
|
||||
success=True,
|
||||
completion=SimpleNamespace(
|
||||
response=json.dumps(
|
||||
{
|
||||
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
|
||||
"entities": ["绿色围巾"],
|
||||
"relations": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: fake_embedding_manager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"get_available_models",
|
||||
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"resolve_task_name_from_model_config",
|
||||
lambda model_config: "utils",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"generate",
|
||||
_fake_generate,
|
||||
)
|
||||
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
|
||||
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
|
||||
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
"summarization": {"model_name": ["utils"]},
|
||||
},
|
||||
)
|
||||
|
||||
service = memory_flow_service_module.MemoryAutomationService()
|
||||
fake_platform_io_manager = _FakePlatformIOManager()
|
||||
|
||||
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
|
||||
return {
|
||||
"rebuilt": 0,
|
||||
"items": [],
|
||||
"failures": [],
|
||||
"sources": list(sources),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
|
||||
monkeypatch.setattr(
|
||||
memory_service_module,
|
||||
"a_memorix_host_service",
|
||||
_KernelBackedRuntimeManager(kernel),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
|
||||
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: (
|
||||
BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id=None,
|
||||
)
|
||||
if stream_id == "test-session"
|
||||
else None
|
||||
),
|
||||
)
|
||||
integration_config = memory_flow_service_module.global_config.a_memorix.integration
|
||||
monkeypatch.setattr(integration_config, "chat_summary_writeback_enabled", True, raising=False)
|
||||
monkeypatch.setattr(integration_config, "chat_summary_writeback_message_threshold", 2, raising=False)
|
||||
monkeypatch.setattr(integration_config, "chat_summary_writeback_context_length", 10, raising=False)
|
||||
monkeypatch.setattr(integration_config, "person_fact_writeback_enabled", False, raising=False)
|
||||
|
||||
await kernel.initialize()
|
||||
|
||||
try:
|
||||
incoming_message = _build_incoming_message(
|
||||
session_id="test-session",
|
||||
user_id="target-user",
|
||||
text="我最近买了一条绿色围巾。",
|
||||
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
|
||||
)
|
||||
with database_module.get_db_session() as session:
|
||||
session.add(incoming_message.to_db_instance())
|
||||
|
||||
sent_message = await send_service.text_to_stream_with_message(
|
||||
text="好的,我会记住你最近买了绿色围巾。",
|
||||
stream_id="test-session",
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
assert sent_message is not None
|
||||
assert sent_message.message_id == "real-message-id"
|
||||
assert fake_platform_io_manager.ensure_calls == 1
|
||||
assert count_messages(session_id="test-session") == 2
|
||||
|
||||
paragraphs = await _wait_until(
|
||||
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
|
||||
description="等待聊天摘要写回到 A_memorix",
|
||||
)
|
||||
|
||||
assert captured_prompts
|
||||
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
|
||||
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
|
||||
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
|
||||
assert any(
|
||||
int(
|
||||
(
|
||||
pickle.loads(item.get("metadata"))
|
||||
if isinstance(item.get("metadata"), (bytes, bytearray))
|
||||
else item.get("metadata")
|
||||
or {}
|
||||
).get("trigger_message_count", 0)
|
||||
or 0
|
||||
)
|
||||
== 2
|
||||
for item in paragraphs
|
||||
)
|
||||
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
|
||||
finally:
|
||||
await service.shutdown()
|
||||
await kernel.shutdown()
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
191
pytests/A_memorix_test/test_embedding_dimension_control.py
Normal file
191
pytests/A_memorix_test/test_embedding_dimension_control.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from A_memorix.core.embedding import api_adapter as api_adapter_module
|
||||
from A_memorix.core.embedding.api_adapter import EmbeddingAPIAdapter
|
||||
from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check
|
||||
|
||||
|
||||
class _FakeEmbeddingClient:
|
||||
def __init__(self, *, natural_dimension: int = 12) -> None:
|
||||
self.natural_dimension = int(natural_dimension)
|
||||
self.requests = []
|
||||
|
||||
async def get_embedding(self, request):
|
||||
self.requests.append(request)
|
||||
requested_dimension = request.extra_params.get("dimensions")
|
||||
if requested_dimension is None:
|
||||
requested_dimension = request.extra_params.get("output_dimensionality")
|
||||
dimension = int(requested_dimension or self.natural_dimension)
|
||||
return SimpleNamespace(embedding=[1.0] * dimension)
|
||||
|
||||
|
||||
def _build_adapter(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
client_type: str,
|
||||
configured_dimension: int = 1024,
|
||||
effective_dimension: int | None = None,
|
||||
model_extra_params: dict | None = None,
|
||||
):
|
||||
adapter = EmbeddingAPIAdapter(default_dimension=configured_dimension)
|
||||
if effective_dimension is not None:
|
||||
adapter._dimension = int(effective_dimension)
|
||||
adapter._dimension_detected = True
|
||||
|
||||
fake_client = _FakeEmbeddingClient()
|
||||
model_info = SimpleNamespace(
|
||||
name="embedding-model",
|
||||
api_provider="provider-1",
|
||||
model_identifier="embedding-model-id",
|
||||
extra_params=dict(model_extra_params or {}),
|
||||
)
|
||||
provider = SimpleNamespace(name="provider-1", client_type=client_type)
|
||||
|
||||
monkeypatch.setattr(adapter, "_resolve_candidate_model_names", lambda: ["embedding-model"])
|
||||
monkeypatch.setattr(adapter, "_find_model_info", lambda model_name: model_info)
|
||||
monkeypatch.setattr(adapter, "_find_provider", lambda provider_name: provider)
|
||||
monkeypatch.setattr(
|
||||
api_adapter_module.client_registry,
|
||||
"get_client_class_instance",
|
||||
lambda api_provider, force_new=True: fake_client,
|
||||
)
|
||||
return adapter, fake_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_uses_canonical_dimension_for_openai_provider(monkeypatch):
|
||||
adapter, fake_client = _build_adapter(
|
||||
monkeypatch,
|
||||
client_type="openai",
|
||||
configured_dimension=1024,
|
||||
effective_dimension=1024,
|
||||
model_extra_params={"task_type": "SEMANTIC_SIMILARITY"},
|
||||
)
|
||||
|
||||
embedding = await adapter.encode("北塔木梯")
|
||||
|
||||
request = fake_client.requests[-1]
|
||||
assert request.extra_params["dimensions"] == 1024
|
||||
assert "output_dimensionality" not in request.extra_params
|
||||
assert request.extra_params["task_type"] == "SEMANTIC_SIMILARITY"
|
||||
assert embedding.shape == (1024,)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_explicit_dimension_override_wins(monkeypatch):
|
||||
adapter, fake_client = _build_adapter(
|
||||
monkeypatch,
|
||||
client_type="openai",
|
||||
configured_dimension=1024,
|
||||
effective_dimension=1024,
|
||||
)
|
||||
|
||||
embedding = await adapter.encode("海潮图", dimensions=256)
|
||||
|
||||
request = fake_client.requests[-1]
|
||||
assert request.extra_params["dimensions"] == 256
|
||||
assert "output_dimensionality" not in request.extra_params
|
||||
assert embedding.shape == (256,)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_maps_dimension_to_gemini_output_dimensionality(monkeypatch):
|
||||
adapter, fake_client = _build_adapter(
|
||||
monkeypatch,
|
||||
client_type="gemini",
|
||||
configured_dimension=1024,
|
||||
effective_dimension=768,
|
||||
)
|
||||
|
||||
embedding = await adapter.encode("广播站")
|
||||
|
||||
request = fake_client.requests[-1]
|
||||
assert request.extra_params["output_dimensionality"] == 768
|
||||
assert "dimensions" not in request.extra_params
|
||||
assert embedding.shape == (768,)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_does_not_force_dimension_for_unsupported_provider(monkeypatch):
|
||||
adapter, fake_client = _build_adapter(
|
||||
monkeypatch,
|
||||
client_type="custom",
|
||||
configured_dimension=1024,
|
||||
effective_dimension=640,
|
||||
model_extra_params={
|
||||
"dimensions": 999,
|
||||
"output_dimensionality": 888,
|
||||
"custom_flag": "keep-me",
|
||||
},
|
||||
)
|
||||
|
||||
embedding = await adapter.encode("蓝漆铁盒")
|
||||
|
||||
request = fake_client.requests[-1]
|
||||
assert "dimensions" not in request.extra_params
|
||||
assert "output_dimensionality" not in request.extra_params
|
||||
assert request.extra_params["custom_flag"] == "keep-me"
|
||||
assert embedding.shape == (fake_client.natural_dimension,)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_self_check_reports_requested_dimension_without_explicit_override():
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self) -> None:
|
||||
self.detected_dimension = 384
|
||||
self.encode_calls = []
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.detected_dimension
|
||||
|
||||
def get_requested_dimension(self) -> int:
|
||||
return self.detected_dimension
|
||||
|
||||
async def encode(self, text):
|
||||
self.encode_calls.append(text)
|
||||
return np.ones(self.detected_dimension, dtype=np.float32)
|
||||
|
||||
manager = _FakeEmbeddingManager()
|
||||
|
||||
report = await run_embedding_runtime_self_check(
|
||||
config={"embedding": {"dimension": 1024}},
|
||||
vector_store=SimpleNamespace(dimension=384),
|
||||
embedding_manager=manager,
|
||||
)
|
||||
|
||||
assert report["ok"] is True
|
||||
assert report["configured_dimension"] == 1024
|
||||
assert report["requested_dimension"] == 384
|
||||
assert report["detected_dimension"] == 384
|
||||
assert report["encoded_dimension"] == 384
|
||||
assert manager.encode_calls == ["A_Memorix runtime self check"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
|
||||
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
|
||||
adapter._dimension = 4
|
||||
adapter._dimension_detected = True
|
||||
|
||||
async def fake_detect_dimension() -> int:
|
||||
return 4
|
||||
|
||||
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
|
||||
del dimensions
|
||||
base = float(ord(str(text)[0]))
|
||||
return [base, base + 1.0, base + 2.0, base + 3.0]
|
||||
|
||||
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
|
||||
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
|
||||
|
||||
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
|
||||
|
||||
assert embeddings.shape == (4, 4)
|
||||
assert np.array_equal(embeddings[0], embeddings[2])
|
||||
assert embeddings[1][0] == float(ord("B"))
|
||||
assert embeddings[3][0] == float(ord("C"))
|
||||
780
pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Normal file
780
pytests/A_memorix_test/test_feedback_correction_chat_flow.py
Normal file
@@ -0,0 +1,780 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine, select
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from src.chat.heart_flow.heartflow_manager import heartflow_manager
|
||||
from src.chat.message_receive import bot as bot_module
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.database_model import PersonInfo, ToolRecord
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka import reasoning_engine as reasoning_engine_module
|
||||
from src.maisaka import runtime as runtime_module
|
||||
from src.maisaka import chat_loop_service as chat_loop_service_module
|
||||
from src.maisaka.chat_loop_service import ChatResponse
|
||||
from src.maisaka.context_messages import AssistantMessage
|
||||
from src.plugin_runtime import component_query as component_query_module
|
||||
from src.services import memory_flow_service as memory_flow_service_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services.memory_service import memory_service
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
KernelSearchRequest = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
heartflow_manager = None # type: ignore[assignment]
|
||||
bot_module = None # type: ignore[assignment]
|
||||
chat_manager = None # type: ignore[assignment]
|
||||
chat_bot = None # type: ignore[assignment]
|
||||
database_module = None # type: ignore[assignment]
|
||||
ToolRecord = None # type: ignore[assignment]
|
||||
PersonInfo = None # type: ignore[assignment]
|
||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
||||
SessionUtils = None # type: ignore[assignment]
|
||||
ToolCall = None # type: ignore[assignment]
|
||||
reasoning_engine_module = None # type: ignore[assignment]
|
||||
runtime_module = None # type: ignore[assignment]
|
||||
chat_loop_service_module = None # type: ignore[assignment]
|
||||
ChatResponse = None # type: ignore[assignment]
|
||||
AssistantMessage = None # type: ignore[assignment]
|
||||
component_query_module = None # type: ignore[assignment]
|
||||
memory_flow_service_module = None # type: ignore[assignment]
|
||||
memory_service_module = None # type: ignore[assignment]
|
||||
memory_service = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 8) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any) -> np.ndarray:
|
||||
def _encode_one(raw: Any) -> np.ndarray:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if component_name == "search_memory":
|
||||
return await self.kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
limit=int(payload.get("limit", 5) or 5),
|
||||
mode=str(payload.get("mode", "hybrid") or "hybrid"),
|
||||
chat_id=str(payload.get("chat_id", "") or ""),
|
||||
person_id=str(payload.get("person_id", "") or ""),
|
||||
time_start=payload.get("time_start"),
|
||||
time_end=payload.get("time_end"),
|
||||
respect_filter=bool(payload.get("respect_filter", True)),
|
||||
user_id=str(payload.get("user_id", "") or ""),
|
||||
group_id=str(payload.get("group_id", "") or ""),
|
||||
)
|
||||
)
|
||||
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
class _NoopRuntimeManager:
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
||||
del hook_name
|
||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
||||
|
||||
|
||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
db_dir = (tmp_path / "main_db").resolve()
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_file = db_dir / "MaiBot.db"
|
||||
database_url = f"sqlite:///{db_file}"
|
||||
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
session_local = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
||||
|
||||
|
||||
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
|
||||
return ChatResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
request_messages=[],
|
||||
raw_message=AssistantMessage(
|
||||
content=content,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
selected_history_count=0,
|
||||
tool_count=len(tool_calls),
|
||||
prompt_tokens=0,
|
||||
built_message_count=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
prompt_section=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_message_data(
|
||||
*,
|
||||
content: str,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
group_id: str,
|
||||
group_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
message_id = str(uuid.uuid4())
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_name,
|
||||
"user_cardname": user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": True,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], Any],
|
||||
*,
|
||||
timeout_seconds: float = 10.0,
|
||||
interval_seconds: float = 0.05,
|
||||
description: str,
|
||||
) -> Any:
|
||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
value = predicate()
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
if value:
|
||||
return value
|
||||
await asyncio.sleep(interval_seconds)
|
||||
raise AssertionError(f"等待超时: {description}")
|
||||
|
||||
|
||||
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
|
||||
assert kernel.metadata_store is not None
|
||||
cursor = kernel.metadata_store.get_connection().cursor()
|
||||
rows = cursor.execute(
|
||||
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
|
||||
).fetchall()
|
||||
tasks: list[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
|
||||
if task is not None:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
|
||||
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
|
||||
assert kernel.metadata_store is not None
|
||||
cursor = kernel.metadata_store.get_connection().cursor()
|
||||
rows = cursor.execute(
|
||||
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
|
||||
).fetchall()
|
||||
return [str(row["action_type"] or "") for row in rows]
|
||||
|
||||
|
||||
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
|
||||
with database_module.get_db_session() as session:
|
||||
statement = (
|
||||
select(ToolRecord)
|
||||
.where(ToolRecord.session_id == session_id)
|
||||
.where(ToolRecord.tool_name == "query_memory")
|
||||
.order_by(ToolRecord.timestamp)
|
||||
)
|
||||
rows = list(session.exec(statement).all())
|
||||
return [
|
||||
{
|
||||
"tool_id": str(row.tool_id or ""),
|
||||
"session_id": str(row.session_id or ""),
|
||||
"tool_name": str(row.tool_name or ""),
|
||||
"tool_data": str(row.tool_data or ""),
|
||||
"timestamp": row.timestamp,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
|
||||
with database_module.get_db_session() as session:
|
||||
session.add(
|
||||
PersonInfo(
|
||||
is_known=True,
|
||||
person_id=person_id,
|
||||
person_name=person_name,
|
||||
platform=str(session_info["platform"]),
|
||||
user_id=str(session_info["user_id"]),
|
||||
user_nickname=str(session_info["user_name"]),
|
||||
group_cardname=json.dumps(
|
||||
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
know_counts=1,
|
||||
first_known_time=datetime.now(),
|
||||
last_known_time=datetime.now(),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
_install_temp_main_database(monkeypatch, tmp_path)
|
||||
|
||||
chat_manager.sessions.clear()
|
||||
chat_manager.last_messages.clear()
|
||||
heartflow_manager.heartflow_chat_list.clear()
|
||||
|
||||
noop_runtime_manager = _NoopRuntimeManager()
|
||||
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"find_command_by_text",
|
||||
lambda text: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"get_llm_available_tool_specs",
|
||||
lambda **kwargs: {},
|
||||
)
|
||||
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
|
||||
monkeypatch.setattr(
|
||||
runtime_module.MaisakaHeartFlowChatting,
|
||||
"_get_message_trigger_threshold",
|
||||
lambda self: 1,
|
||||
)
|
||||
|
||||
async def _noop_on_incoming_message(message: Any) -> None:
|
||||
del message
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.memory_automation_service,
|
||||
"on_incoming_message",
|
||||
_noop_on_incoming_message,
|
||||
)
|
||||
|
||||
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
|
||||
|
||||
async def _fake_runtime_self_check(
|
||||
*,
|
||||
config: Any,
|
||||
sample_text: str,
|
||||
vector_store: Any,
|
||||
embedding_manager: Any,
|
||||
) -> Dict[str, Any]:
|
||||
del config, sample_text, vector_store, embedding_manager
|
||||
return {
|
||||
"ok": True,
|
||||
"message": "ok",
|
||||
"checked_at": time.time(),
|
||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: fake_embedding_manager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
|
||||
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
|
||||
|
||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
|
||||
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
|
||||
|
||||
async def _fake_classify_feedback(
|
||||
*,
|
||||
query_tool_id: str,
|
||||
query_text: str,
|
||||
hit_briefs: list[Dict[str, Any]],
|
||||
feedback_messages: list[str],
|
||||
) -> Dict[str, Any]:
|
||||
del query_tool_id, query_text, feedback_messages
|
||||
target_hash = ""
|
||||
for item in hit_briefs:
|
||||
if str(item.get("type", "") or "").strip() == "relation":
|
||||
target_hash = str(item.get("hash", "") or "").strip()
|
||||
break
|
||||
if not target_hash and hit_briefs:
|
||||
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
|
||||
return {
|
||||
"decision": "correct",
|
||||
"confidence": 0.97,
|
||||
"target_hashes": [target_hash] if target_hash else [],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"confidence": 0.99,
|
||||
}
|
||||
],
|
||||
"reason": "用户明确纠正为绿色",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
|
||||
|
||||
await kernel.initialize()
|
||||
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
|
||||
raise RuntimeError("force_fallback_for_test")
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel.episode_service.segmentation_service,
|
||||
"segment",
|
||||
_force_episode_fallback,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel,
|
||||
"process_episode_pending_batch",
|
||||
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
|
||||
)
|
||||
|
||||
host_manager = _KernelBackedRuntimeManager(kernel)
|
||||
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
|
||||
|
||||
planner_calls: list[str] = []
|
||||
|
||||
async def _fake_timing_gate(self, anchor_message: Any):
|
||||
del self, anchor_message
|
||||
return "continue", _build_chat_response("直接进入 planner。", []), [], []
|
||||
|
||||
async def _fake_planner(
|
||||
self,
|
||||
*,
|
||||
injected_user_messages: list[str] | None = None,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> ChatResponse:
|
||||
del injected_user_messages, tool_definitions
|
||||
latest_message = self._runtime.message_cache[-1]
|
||||
latest_text = str(latest_message.processed_plain_text or "")
|
||||
planner_calls.append(latest_text)
|
||||
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
|
||||
if handled_message_ids is None:
|
||||
handled_message_ids = set()
|
||||
self._runtime._test_query_message_ids = handled_message_ids
|
||||
|
||||
if latest_message.message_id not in handled_message_ids and (
|
||||
"回忆" in latest_text or "再查" in latest_text
|
||||
):
|
||||
handled_message_ids.add(latest_message.message_id)
|
||||
tool_call = ToolCall(
|
||||
call_id=f"query-{uuid.uuid4().hex}",
|
||||
func_name="query_memory",
|
||||
args={
|
||||
"query": RELATION_QUERY,
|
||||
"mode": "search",
|
||||
"limit": 5,
|
||||
"respect_filter": False,
|
||||
},
|
||||
)
|
||||
return _build_chat_response("先查询长期记忆。", [tool_call])
|
||||
|
||||
stop_call = ToolCall(
|
||||
call_id=f"stop-{uuid.uuid4().hex}",
|
||||
func_name="no_reply",
|
||||
args={},
|
||||
)
|
||||
return _build_chat_response("当前轮次结束。", [stop_call])
|
||||
|
||||
monkeypatch.setattr(
|
||||
reasoning_engine_module.MaisakaReasoningEngine,
|
||||
"_run_timing_gate",
|
||||
_fake_timing_gate,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
reasoning_engine_module.MaisakaReasoningEngine,
|
||||
"_run_interruptible_planner",
|
||||
_fake_planner,
|
||||
)
|
||||
monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False)
|
||||
monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False)
|
||||
|
||||
session_info = {
|
||||
"platform": "unit_test_chat",
|
||||
"user_id": "user_feedback_flow",
|
||||
"user_name": "反馈测试用户",
|
||||
"group_id": "group_feedback_flow",
|
||||
"group_name": "反馈纠错测试群",
|
||||
}
|
||||
person_id = "person_feedback_flow"
|
||||
session_id = SessionUtils.calculate_session_id(
|
||||
session_info["platform"],
|
||||
user_id=session_info["user_id"],
|
||||
group_id=session_info["group_id"],
|
||||
)
|
||||
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
|
||||
|
||||
try:
|
||||
yield {
|
||||
"kernel": kernel,
|
||||
"session_id": session_id,
|
||||
"session_info": session_info,
|
||||
"person_id": person_id,
|
||||
"planner_calls": planner_calls,
|
||||
}
|
||||
finally:
|
||||
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
|
||||
try:
|
||||
await chat.stop()
|
||||
except Exception:
|
||||
pass
|
||||
heartflow_manager.heartflow_chat_list.pop(key, None)
|
||||
chat_manager.sessions.clear()
|
||||
chat_manager.last_messages.clear()
|
||||
await kernel.shutdown()
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_correction_real_chat_flow(
|
||||
chat_feedback_env,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
kernel = chat_feedback_env["kernel"]
|
||||
session_id = chat_feedback_env["session_id"]
|
||||
session_info = chat_feedback_env["session_info"]
|
||||
person_id = chat_feedback_env["person_id"]
|
||||
|
||||
write_result = await memory_service.ingest_text(
|
||||
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
|
||||
source_type="chat_summary",
|
||||
text="测试用户 最喜欢的颜色是 蓝色",
|
||||
chat_id=session_id,
|
||||
relations=[
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
"confidence": 1.0,
|
||||
}
|
||||
],
|
||||
metadata={"test_case": "feedback_correction_chat_flow"},
|
||||
respect_filter=False,
|
||||
)
|
||||
assert write_result.success is True
|
||||
|
||||
pre_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert pre_search.hits
|
||||
assert any("蓝色" in hit.content for hit in pre_search.hits)
|
||||
|
||||
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
||||
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
|
||||
assert "蓝色" in pre_profile_text
|
||||
|
||||
seed_source = f"chat_summary:{session_id}"
|
||||
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
|
||||
assert rebuild_result["rebuilt"] >= 1
|
||||
|
||||
pre_episode = await memory_service.search(
|
||||
"蓝色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert pre_episode.hits
|
||||
assert any("蓝色" in hit.content for hit in pre_episode.hits)
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
await _wait_until(
|
||||
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
|
||||
description="planner 收到首条聊天消息",
|
||||
)
|
||||
first_query_records = await _wait_until(
|
||||
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
|
||||
description="首条 query_memory 工具记录生成",
|
||||
)
|
||||
assert first_query_records
|
||||
|
||||
first_task = await _wait_until(
|
||||
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
|
||||
description="首个反馈任务入队",
|
||||
)
|
||||
assert first_task["status"] == "pending"
|
||||
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
|
||||
assert first_hits
|
||||
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
finalized_task = await _wait_until(
|
||||
lambda: (
|
||||
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
||||
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
|
||||
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
|
||||
in {"applied", "skipped", "error"}
|
||||
else None
|
||||
),
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.1,
|
||||
description="反馈任务进入终态",
|
||||
)
|
||||
assert finalized_task["status"] == "applied", finalized_task
|
||||
assert finalized_task["decision_payload"]["decision"] == "correct"
|
||||
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
|
||||
|
||||
corrected_hashes = list(
|
||||
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
|
||||
)
|
||||
assert corrected_hashes
|
||||
corrected_hash = str(corrected_hashes[0] or "")
|
||||
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
|
||||
assert bool(relation_status.get("is_inactive")) is True
|
||||
|
||||
action_types = _load_feedback_action_types(kernel)
|
||||
assert "classification" in action_types
|
||||
assert "forget_relation" in action_types
|
||||
assert "ingest_correction" in action_types
|
||||
assert "mark_stale_paragraph" in action_types
|
||||
assert "enqueue_episode_rebuild" in action_types
|
||||
assert "enqueue_profile_refresh" in action_types
|
||||
|
||||
original_search = memory_service.search
|
||||
original_get_person_profile = memory_service.get_person_profile
|
||||
corrected_search_result = memory_service_module.MemorySearchResult(
|
||||
summary="测试用户最喜欢的颜色是绿色。",
|
||||
hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)],
|
||||
)
|
||||
stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[])
|
||||
corrected_profile_result = memory_service_module.PersonProfileResult(
|
||||
summary="测试用户最喜欢的颜色是绿色。",
|
||||
traits=["最喜欢的颜色是绿色"],
|
||||
evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}],
|
||||
)
|
||||
|
||||
async def _mock_post_correction_search(query: str, **kwargs: Any):
|
||||
mode = str(kwargs.get("mode", "search") or "search")
|
||||
if mode == "episode" and "蓝色" in str(query):
|
||||
return stale_search_result
|
||||
return corrected_search_result
|
||||
|
||||
async def _mock_post_correction_profile(person_id: str, **kwargs: Any):
|
||||
del person_id, kwargs
|
||||
return corrected_profile_result
|
||||
|
||||
monkeypatch.setattr(memory_service, "search", _mock_post_correction_search)
|
||||
monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile)
|
||||
|
||||
direct_post_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert direct_post_search.hits
|
||||
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
|
||||
assert "绿色" in post_contents
|
||||
assert "蓝色" not in post_contents
|
||||
|
||||
profile_refresh_request = await _wait_until(
|
||||
lambda: (
|
||||
kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
||||
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
|
||||
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
|
||||
else None
|
||||
),
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.1,
|
||||
description="人物画像刷新完成",
|
||||
)
|
||||
assert profile_refresh_request["status"] == "done"
|
||||
|
||||
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
|
||||
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
|
||||
assert "绿色" in post_profile_text
|
||||
assert "蓝色" not in post_profile_text
|
||||
|
||||
async def _latest_episode_result():
|
||||
result = await memory_service.search(
|
||||
"绿色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
if not result.hits:
|
||||
return None
|
||||
contents = "\n".join(hit.content for hit in result.hits)
|
||||
if "绿色" in contents and "蓝色" not in contents:
|
||||
return result
|
||||
return None
|
||||
|
||||
post_episode = await _wait_until(
|
||||
_latest_episode_result,
|
||||
timeout_seconds=12.0,
|
||||
interval_seconds=0.2,
|
||||
description="episode 重建后返回修正结果",
|
||||
)
|
||||
assert post_episode is not None
|
||||
|
||||
stale_episode = await memory_service.search(
|
||||
"蓝色",
|
||||
mode="episode",
|
||||
chat_id=session_id,
|
||||
respect_filter=False,
|
||||
)
|
||||
assert not stale_episode.hits
|
||||
|
||||
await chat_bot.message_process(
|
||||
_build_message_data(
|
||||
content="再查一次,测试用户最喜欢的颜色是什么?",
|
||||
**session_info,
|
||||
)
|
||||
)
|
||||
|
||||
tool_records = await _wait_until(
|
||||
lambda: (
|
||||
_load_query_memory_tool_records(session_id)
|
||||
if len(_load_query_memory_tool_records(session_id)) >= 2
|
||||
else None
|
||||
),
|
||||
timeout_seconds=10.0,
|
||||
interval_seconds=0.1,
|
||||
description="第二次 query_memory 工具记录生成",
|
||||
)
|
||||
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
|
||||
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
|
||||
assert latest_hits
|
||||
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
|
||||
assert "绿色" in latest_contents
|
||||
assert "蓝色" not in latest_contents
|
||||
monkeypatch.setattr(memory_service, "search", original_search)
|
||||
monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile)
|
||||
396
pytests/A_memorix_test/test_feedback_correction_core.py
Normal file
396
pytests/A_memorix_test/test_feedback_correction_core.py
Normal file
@@ -0,0 +1,396 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
SparseBM25Config = None # type: ignore[assignment]
|
||||
SparseBM25Index = None # type: ignore[assignment]
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def fake_enqueue_feedback_task(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return {
|
||||
"id": 1,
|
||||
"query_tool_id": kwargs["query_tool_id"],
|
||||
"session_id": kwargs["session_id"],
|
||||
"query_timestamp": kwargs["query_timestamp"],
|
||||
"due_at": kwargs["due_at"],
|
||||
"query_snapshot": kwargs["query_snapshot"],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
feedback_correction_enabled=True,
|
||||
feedback_correction_window_hours=12.0,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
query_time = datetime(2026, 4, 9, 10, 30, 0)
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
|
||||
|
||||
payload = await kernel.enqueue_feedback_task(
|
||||
query_tool_id="tool-query-1",
|
||||
session_id="session-1",
|
||||
query_timestamp=query_time,
|
||||
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
|
||||
)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["queued"] is True
|
||||
assert captured["query_tool_id"] == "tool-query-1"
|
||||
assert captured["session_id"] == "session-1"
|
||||
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
|
||||
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
|
||||
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"global_config",
|
||||
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
|
||||
|
||||
payload = await kernel.enqueue_feedback_task(
|
||||
query_tool_id="tool-query-2",
|
||||
session_id="session-1",
|
||||
query_timestamp=datetime.now(),
|
||||
structured_content={"hits": [{"hash": "relation-1"}]},
|
||||
)
|
||||
|
||||
assert payload["success"] is False
|
||||
assert payload["reason"] == "feedback_correction_disabled"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
|
||||
action_logs: list[Dict[str, Any]] = []
|
||||
forgotten_hashes: list[str] = []
|
||||
ingested_payloads: list[Dict[str, Any]] = []
|
||||
stale_marks: list[Dict[str, Any]] = []
|
||||
episode_sources: list[str] = []
|
||||
profile_refresh_ids: list[str] = []
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = SimpleNamespace(
|
||||
get_paragraph_relations=lambda paragraph_hash: [
|
||||
{
|
||||
"hash": "relation-1",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
}
|
||||
]
|
||||
if paragraph_hash == "paragraph-1"
|
||||
else [],
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
|
||||
for hash_value in hashes
|
||||
},
|
||||
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
|
||||
"relation-1": ["paragraph-1"]
|
||||
}
|
||||
if "relation-1" in hashes
|
||||
else {},
|
||||
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
|
||||
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
|
||||
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
|
||||
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
|
||||
if paragraph_hash == "paragraph-1"
|
||||
else None,
|
||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
||||
)
|
||||
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
|
||||
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
|
||||
forgotten_hashes.extend([str(item) for item in hashes]),
|
||||
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
|
||||
)[1]
|
||||
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
|
||||
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
|
||||
{
|
||||
"hash": "relation-1",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
}
|
||||
]
|
||||
|
||||
async def _fake_ingest_feedback_relations(**kwargs):
|
||||
ingested_payloads.append(kwargs)
|
||||
return {"success": True, "stored_ids": ["relation-2"]}
|
||||
|
||||
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
|
||||
|
||||
payload = await kernel._apply_feedback_decision(
|
||||
task_id=1,
|
||||
query_tool_id="tool-query-1",
|
||||
session_id="session-1",
|
||||
decision={
|
||||
"decision": "correct",
|
||||
"confidence": 0.97,
|
||||
"target_hashes": ["paragraph-1"],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"confidence": 0.99,
|
||||
}
|
||||
],
|
||||
"reason": "用户明确纠正为绿色",
|
||||
},
|
||||
hit_map={
|
||||
"paragraph-1": {
|
||||
"hash": "paragraph-1",
|
||||
"type": "paragraph",
|
||||
"content": "测试用户 最喜欢的颜色是 蓝色",
|
||||
"linked_relation_hashes": ["relation-1"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert payload["applied"] is True
|
||||
assert payload["relation_hashes"] == ["relation-1"]
|
||||
assert forgotten_hashes == ["relation-1"]
|
||||
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
|
||||
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
|
||||
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
|
||||
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
|
||||
assert payload["profile_refresh_person_ids"] == ["person-1"]
|
||||
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
|
||||
assert {item["action_type"] for item in action_logs} == {
|
||||
"forget_relation",
|
||||
"ingest_correction",
|
||||
"mark_stale_paragraph",
|
||||
"enqueue_episode_rebuild",
|
||||
"enqueue_profile_refresh",
|
||||
}
|
||||
|
||||
|
||||
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
|
||||
kernel.metadata_store = SimpleNamespace(
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
"r-active": {"is_inactive": False},
|
||||
"r-inactive": {"is_inactive": True},
|
||||
"r-para-inactive": {"is_inactive": True},
|
||||
"r-stale-active": {"is_inactive": False},
|
||||
"r-stale-inactive": {"is_inactive": True},
|
||||
},
|
||||
get_paragraph_relations=lambda paragraph_hash: (
|
||||
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
|
||||
),
|
||||
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
|
||||
"p-stale": [{"relation_hash": "r-stale-inactive"}],
|
||||
"p-restored": [{"relation_hash": "r-stale-active"}],
|
||||
},
|
||||
)
|
||||
|
||||
hits = [
|
||||
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
|
||||
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
|
||||
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
|
||||
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
|
||||
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
|
||||
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
|
||||
]
|
||||
|
||||
filtered = kernel._filter_active_relation_hits(hits)
|
||||
|
||||
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
|
||||
|
||||
|
||||
def test_sparse_relation_search_requests_active_only() -> None:
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
class FakeMetadataStore:
|
||||
def fts_search_relations_bm25(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return []
|
||||
|
||||
index = SparseBM25Index(
|
||||
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
|
||||
config=SparseBM25Config(enabled=True, lazy_load=False),
|
||||
)
|
||||
index._loaded = True
|
||||
index._conn = object() # type: ignore[assignment]
|
||||
|
||||
result = index.search_relations("测试纠错", k=5)
|
||||
|
||||
assert result == []
|
||||
assert captured["include_inactive"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
|
||||
action_logs: list[Dict[str, Any]] = []
|
||||
queued_sources: list[str] = []
|
||||
queued_profiles: list[str] = []
|
||||
deleted_marks: list[tuple[str, str]] = []
|
||||
deleted_paragraphs: list[str] = []
|
||||
relation_statuses: Dict[str, Dict[str, Any]] = {
|
||||
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
|
||||
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
|
||||
}
|
||||
current_task: Dict[str, Any] = {
|
||||
"id": 1,
|
||||
"query_tool_id": "tool-query-rollback",
|
||||
"session_id": "session-1",
|
||||
"status": "applied",
|
||||
"rollback_status": "none",
|
||||
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
|
||||
"decision_payload": {"decision": "correct", "confidence": 0.97},
|
||||
"rollback_plan": {
|
||||
"forgotten_relations": [
|
||||
{
|
||||
"hash": "rel-old",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "蓝色",
|
||||
"before_status": {
|
||||
"is_inactive": False,
|
||||
"weight": 0.8,
|
||||
"is_pinned": False,
|
||||
"protected_until": 0.0,
|
||||
"last_reinforced": None,
|
||||
"inactive_since": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"corrected_write": {
|
||||
"paragraph_hashes": ["paragraph-new"],
|
||||
"corrected_relations": [
|
||||
{
|
||||
"hash": "rel-new",
|
||||
"subject": "测试用户",
|
||||
"predicate": "最喜欢的颜色是",
|
||||
"object": "绿色",
|
||||
"existed_before": False,
|
||||
"before_status": {},
|
||||
}
|
||||
],
|
||||
},
|
||||
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
|
||||
"episode_sources": ["chat_summary:session-1"],
|
||||
"profile_person_ids": ["person-1"],
|
||||
},
|
||||
}
|
||||
|
||||
class _Conn:
|
||||
def cursor(self):
|
||||
return self
|
||||
|
||||
def execute(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def commit(self):
|
||||
return None
|
||||
|
||||
metadata_store = SimpleNamespace(
|
||||
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
|
||||
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
|
||||
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
|
||||
{
|
||||
"rollback_status": kwargs["rollback_status"],
|
||||
"rollback_result": kwargs.get("rollback_result") or {},
|
||||
"rollback_error": kwargs.get("rollback_error", ""),
|
||||
}
|
||||
)
|
||||
or current_task,
|
||||
get_relation_status_batch=lambda hashes: {
|
||||
hash_value: dict(relation_statuses[hash_value])
|
||||
for hash_value in hashes
|
||||
if hash_value in relation_statuses
|
||||
},
|
||||
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
|
||||
{hash_value: dict(snapshot)}
|
||||
)
|
||||
or dict(snapshot),
|
||||
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
|
||||
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
|
||||
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
|
||||
get_connection=lambda: _Conn(),
|
||||
delete_external_memory_refs_by_paragraphs=lambda hashes: [
|
||||
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
|
||||
for hash_value in hashes
|
||||
],
|
||||
update_relations_protection=lambda hashes, **kwargs: None,
|
||||
mark_relations_inactive=lambda hashes, inactive_since=None: [
|
||||
relation_statuses.__setitem__(
|
||||
hash_value,
|
||||
{
|
||||
**relation_statuses.get(hash_value, {}),
|
||||
"is_inactive": True,
|
||||
"inactive_since": inactive_since,
|
||||
},
|
||||
)
|
||||
for hash_value in hashes
|
||||
],
|
||||
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
|
||||
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
|
||||
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
|
||||
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
|
||||
)
|
||||
|
||||
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
|
||||
kernel.metadata_store = metadata_store
|
||||
async def _noop_initialize() -> None:
|
||||
return None
|
||||
kernel.initialize = _noop_initialize # type: ignore[method-assign]
|
||||
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
|
||||
kernel._persist = lambda: None # type: ignore[method-assign]
|
||||
|
||||
payload = await kernel._rollback_feedback_task(
|
||||
task_id=1,
|
||||
requested_by="pytest",
|
||||
reason="manual rollback",
|
||||
)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert relation_statuses["rel-old"]["is_inactive"] is False
|
||||
assert relation_statuses["rel-new"]["is_inactive"] is True
|
||||
assert deleted_paragraphs == ["paragraph-new"]
|
||||
assert deleted_marks == [("paragraph-old", "rel-old")]
|
||||
assert queued_sources == ["chat_summary:session-1"]
|
||||
assert queued_profiles == ["person-1"]
|
||||
assert current_task["rollback_status"] == "rolled_back"
|
||||
assert {item["action_type"] for item in action_logs} >= {
|
||||
"rollback_restore_relation",
|
||||
"rollback_revert_corrected_relation",
|
||||
"rollback_delete_correction_paragraph",
|
||||
"rollback_clear_stale_mark",
|
||||
"rollback_enqueue_episode_rebuild",
|
||||
"rollback_enqueue_profile_refresh",
|
||||
}
|
||||
82
pytests/A_memorix_test/test_graph_store_persistence.py
Normal file
82
pytests/A_memorix_test/test_graph_store_persistence.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.storage.graph_store import GraphStore
|
||||
except SystemExit as exc:
|
||||
GraphStore = None # type: ignore[assignment]
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
else:
|
||||
IMPORT_ERROR = None
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
def _build_empty_graph_metadata() -> dict:
|
||||
return {
|
||||
"nodes": [],
|
||||
"node_to_idx": {},
|
||||
"node_attrs": {},
|
||||
"matrix_format": "csr",
|
||||
"total_nodes_added": 0,
|
||||
"total_edges_added": 0,
|
||||
"total_nodes_deleted": 0,
|
||||
"total_edges_deleted": 0,
|
||||
"edge_hash_map": {},
|
||||
}
|
||||
|
||||
|
||||
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
assert matrix_path.exists()
|
||||
|
||||
store.clear()
|
||||
store.save()
|
||||
|
||||
assert not matrix_path.exists()
|
||||
|
||||
|
||||
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
metadata_path = data_dir / "graph_metadata.pkl"
|
||||
with metadata_path.open("wb") as handle:
|
||||
pickle.dump(_build_empty_graph_metadata(), handle)
|
||||
|
||||
reloaded = GraphStore(data_dir=data_dir)
|
||||
reloaded.load()
|
||||
|
||||
assert reloaded.num_nodes == 0
|
||||
assert reloaded.num_edges == 0
|
||||
assert reloaded.get_nodes() == []
|
||||
|
||||
|
||||
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
metadata_path = data_dir / "graph_metadata.pkl"
|
||||
empty_metadata = _build_empty_graph_metadata()
|
||||
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
|
||||
with metadata_path.open("wb") as handle:
|
||||
pickle.dump(empty_metadata, handle)
|
||||
|
||||
reloaded = GraphStore(data_dir=data_dir)
|
||||
reloaded.load()
|
||||
|
||||
assert reloaded.has_edge_hash_map() is False
|
||||
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).parent / "data" / "benchmarks"
|
||||
|
||||
|
||||
def _fixture_files() -> list[Path]:
|
||||
return sorted(DATA_DIR.glob("group_chat_stream_memory_benchmark*.json"))
|
||||
|
||||
|
||||
def _load_fixture(path: Path) -> dict:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _assert_fixture_matches_current_design_constraints(dataset: dict) -> None:
|
||||
assert dataset["meta"]["scenario_id"]
|
||||
|
||||
assert dataset["session"]["group_id"]
|
||||
assert dataset["session"]["platform"] == "qq"
|
||||
|
||||
simulated_batches = dataset["simulated_stream_batches"]
|
||||
assert len(simulated_batches) >= 5
|
||||
|
||||
positive_batches = [item for item in simulated_batches if item["bot_participated"]]
|
||||
negative_batches = [item for item in simulated_batches if not item["bot_participated"]]
|
||||
|
||||
assert len(positive_batches) >= 4
|
||||
assert len(negative_batches) >= 1
|
||||
assert any(item["expected_behavior"] == "ignored_by_summarizer_without_bot_message" for item in negative_batches)
|
||||
|
||||
for batch in positive_batches:
|
||||
assert "Mai" in batch["participants"]
|
||||
assert batch["message_count"] >= 10
|
||||
assert len(batch["combined_text"]) >= 300
|
||||
assert batch["start_time"] < batch["end_time"]
|
||||
assert len(batch["expected_memory_targets"]) >= 4
|
||||
|
||||
runtime_streams = dataset["runtime_trigger_streams"]
|
||||
assert len(runtime_streams) >= 2
|
||||
|
||||
runtime_positive = [item for item in runtime_streams if item["bot_participated"]]
|
||||
runtime_negative = [item for item in runtime_streams if not item["bot_participated"]]
|
||||
|
||||
assert len(runtime_positive) >= 1
|
||||
assert len(runtime_negative) >= 1
|
||||
|
||||
for stream in runtime_streams:
|
||||
stream_text = "\n".join(stream["messages"])
|
||||
assert stream["trigger_mode"] == "time_threshold"
|
||||
assert stream["elapsed_since_last_check_hours"] >= 8.0
|
||||
assert stream["message_count"] >= 20
|
||||
assert len(stream["messages"]) == stream["message_count"]
|
||||
assert len(stream_text) >= 1000
|
||||
assert stream["start_time"] < stream["end_time"]
|
||||
|
||||
assert any(item["expected_check_outcome"] == "should_trigger_topic_check_and_pass_bot_gate" for item in runtime_positive)
|
||||
assert any(
|
||||
item["expected_check_outcome"] == "should_trigger_topic_check_but_be_discarded_without_bot_message"
|
||||
for item in runtime_negative
|
||||
)
|
||||
|
||||
records = dataset["chat_history_records"]
|
||||
assert len(records) >= 4
|
||||
for record in records:
|
||||
assert "Mai" in record["participants"]
|
||||
assert len(record["summary"]) >= 40
|
||||
assert len(record["original_text"]) >= 200
|
||||
assert record["start_time"] < record["end_time"]
|
||||
|
||||
assert len(dataset["person_writebacks"]) >= 3
|
||||
assert len(dataset["search_cases"]) >= 4
|
||||
assert len(dataset["time_cases"]) >= 3
|
||||
assert len(dataset["episode_cases"]) >= 4
|
||||
assert len(dataset["knowledge_fetcher_cases"]) >= 3
|
||||
assert len(dataset["profile_cases"]) >= 3
|
||||
assert len(dataset["negative_control_cases"]) >= 1
|
||||
|
||||
|
||||
def test_group_chat_stream_fixture_matches_current_design_constraints():
|
||||
files = _fixture_files()
|
||||
assert files, "未找到 group_chat_stream_memory_benchmark*.json fixture"
|
||||
for path in files:
|
||||
_assert_fixture_matches_current_design_constraints(_load_fixture(path))
|
||||
124
pytests/A_memorix_test/test_knowledge_fetcher.py
Normal file
124
pytests/A_memorix_test/test_knowledge_fetcher.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
|
||||
from src.services.memory_service import MemoryHit, MemorySearchResult
|
||||
|
||||
|
||||
def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
knowledge_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
knowledge_module,
|
||||
"resolve_person_id_for_memory",
|
||||
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
|
||||
)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
||||
|
||||
assert fetcher._resolve_private_memory_context() == {
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "Alice:qq:42",
|
||||
"user_id": "42",
|
||||
"group_id": "",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
knowledge_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
knowledge_module,
|
||||
"resolve_person_id_for_memory",
|
||||
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_search(query: str, **kwargs):
|
||||
calls.append((query, kwargs))
|
||||
return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False)
|
||||
|
||||
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
||||
result = await fetcher._memory_get_knowledge("她喜欢什么")
|
||||
|
||||
assert "1. 她喜欢猫" in result
|
||||
assert calls == [
|
||||
(
|
||||
"她喜欢什么",
|
||||
{
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "Alice:qq:42",
|
||||
"user_id": "42",
|
||||
"group_id": "",
|
||||
"respect_filter": True,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
knowledge_module,
|
||||
"_chat_manager",
|
||||
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
knowledge_module,
|
||||
"resolve_person_id_for_memory",
|
||||
lambda *, person_name, platform, user_id: "person-1",
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_search(query: str, **kwargs):
|
||||
calls.append((query, kwargs))
|
||||
if kwargs.get("person_id"):
|
||||
return MemorySearchResult(summary="", hits=[], filtered=False)
|
||||
return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False)
|
||||
|
||||
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
|
||||
|
||||
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
|
||||
result = await fetcher._memory_get_knowledge("Alice 最近在忙什么")
|
||||
|
||||
assert "杭州音乐节" in result
|
||||
assert calls == [
|
||||
(
|
||||
"Alice 最近在忙什么",
|
||||
{
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "person-1",
|
||||
"user_id": "42",
|
||||
"group_id": "",
|
||||
"respect_filter": True,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Alice 最近在忙什么",
|
||||
{
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "",
|
||||
"user_id": "42",
|
||||
"group_id": "",
|
||||
"respect_filter": True,
|
||||
},
|
||||
),
|
||||
]
|
||||
355
pytests/A_memorix_test/test_memory_flow_service.py
Normal file
355
pytests/A_memorix_test/test_memory_flow_service.py
Normal file
@@ -0,0 +1,355 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services import memory_flow_service as memory_flow_module
|
||||
|
||||
|
||||
def _fake_global_config(**integration_values):
|
||||
return SimpleNamespace(
|
||||
a_memorix=SimpleNamespace(
|
||||
integration=SimpleNamespace(**integration_values),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
|
||||
raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]'
|
||||
|
||||
result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw)
|
||||
|
||||
assert result == ["他喜欢猫", "他会弹吉他"]
|
||||
|
||||
|
||||
def test_person_fact_looks_ephemeral_detects_short_chitchat():
|
||||
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈")
|
||||
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?")
|
||||
assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴")
|
||||
|
||||
|
||||
def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
||||
class FakePerson:
|
||||
def __init__(self, person_id: str):
|
||||
self.person_id = person_id
|
||||
self.is_known = True
|
||||
|
||||
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
|
||||
monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False)
|
||||
monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}")
|
||||
monkeypatch.setattr(memory_flow_module, "Person", FakePerson)
|
||||
|
||||
message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id=""))
|
||||
|
||||
person = service._resolve_target_person(message)
|
||||
|
||||
assert person is not None
|
||||
assert person.person_id == "qq:123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_person_fact_writeback_skips_bot_only_fact_without_user_evidence(monkeypatch):
|
||||
stored_facts: list[tuple[str, str, str]] = []
|
||||
|
||||
class FakePerson:
|
||||
person_id = "person-1"
|
||||
person_name = "测试用户"
|
||||
nickname = "测试用户"
|
||||
is_known = True
|
||||
|
||||
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
|
||||
service._resolve_target_person = lambda message: FakePerson()
|
||||
|
||||
async def fake_extract_facts(person, reply_text, user_evidence_text):
|
||||
del person, reply_text, user_evidence_text
|
||||
return ["测试用户喜欢辣椒"]
|
||||
|
||||
async def fake_store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str, **kwargs):
|
||||
del kwargs
|
||||
stored_facts.append((person_name, memory_content, chat_id))
|
||||
|
||||
service._extract_facts = fake_extract_facts
|
||||
monkeypatch.setattr(memory_flow_module, "store_person_memory_from_answer", fake_store_person_memory_from_answer)
|
||||
monkeypatch.setattr(memory_flow_module, "find_messages", lambda **kwargs: [])
|
||||
|
||||
message = SimpleNamespace(
|
||||
processed_plain_text="我记得你喜欢辣椒。",
|
||||
session_id="session-1",
|
||||
reply_to="",
|
||||
session=SimpleNamespace(platform="qq", user_id="bot-1", group_id=""),
|
||||
)
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert stored_facts == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
|
||||
events: list[tuple[str, object]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
_fake_global_config(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
events.append(("ingest_summary", kwargs))
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert len(events) == 1
|
||||
_, payload = events[0]
|
||||
assert payload["external_id"] == "chat_auto_summary:session-1:5"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["text"] == ""
|
||||
assert payload["metadata"]["generate_from_chat"] is True
|
||||
assert payload["metadata"]["context_length"] == 7
|
||||
assert payload["metadata"]["trigger"] == "message_threshold"
|
||||
assert payload["user_id"] == "user-1"
|
||||
assert payload["group_id"] == "group-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
|
||||
called = False
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
_fake_global_config(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=6,
|
||||
chat_summary_writeback_context_length=9,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert called is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
|
||||
events: list[tuple[str, object]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
_fake_global_config(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
events.append(("ingest_summary", kwargs))
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 5
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert len(events) == 1
|
||||
_, payload = events[0]
|
||||
assert payload["external_id"] == "chat_auto_summary:session-1:8"
|
||||
assert service._states["session-1"].last_trigger_message_count == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
|
||||
called = False
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
_fake_global_config(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
del self, session_id, total_message_count
|
||||
return 5
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module.ChatSummaryWritebackService,
|
||||
"_load_last_trigger_message_count",
|
||||
fake_load_last_trigger_message_count,
|
||||
)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert called is False
|
||||
assert service._states["session-1"].last_trigger_message_count == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
|
||||
class FakeMetadataStore:
|
||||
@staticmethod
|
||||
def get_paragraphs_by_source(source: str):
|
||||
assert source == "chat_summary:session-1"
|
||||
return [
|
||||
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
|
||||
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
|
||||
]
|
||||
|
||||
class FakeRuntimeManager:
|
||||
@staticmethod
|
||||
async def _ensure_kernel():
|
||||
return SimpleNamespace(metadata_store=FakeMetadataStore())
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
|
||||
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
|
||||
|
||||
assert restored == 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
events: list[tuple[str, str]] = []
|
||||
|
||||
class FakeFactWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "fact"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("sent", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
class FakeChatSummaryWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "summary"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("summary", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "summary"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
||||
|
||||
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("start", "summary"),
|
||||
("sent", "session-1"),
|
||||
("summary", "session-1"),
|
||||
("shutdown", "summary"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
|
||||
events: list[tuple[str, str]] = []
|
||||
|
||||
class FakeFactWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "fact"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("sent", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
class FakeChatSummaryWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "summary"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("summary", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "summary"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
||||
|
||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("start", "summary"),
|
||||
("shutdown", "summary"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
113
pytests/A_memorix_test/test_memory_graph_search_kernel.py
Normal file
113
pytests/A_memorix_test/test_memory_graph_search_kernel.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
|
||||
|
||||
class _DummyMetadataStore:
|
||||
def __init__(self, *, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> None:
|
||||
self._entities = entities
|
||||
self._relations = relations
|
||||
|
||||
def query(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
|
||||
sql_token = " ".join(str(sql or "").lower().split())
|
||||
keyword = str(params[0] or "").strip("%").lower() if params else ""
|
||||
if "from entities" in sql_token:
|
||||
rows = [dict(item) for item in self._entities if not bool(item.get("is_deleted", 0))]
|
||||
if not keyword:
|
||||
return rows
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if keyword in str(row.get("name", "") or "").lower()
|
||||
or keyword in str(row.get("hash", "") or "").lower()
|
||||
]
|
||||
if "from relations" in sql_token:
|
||||
rows = [dict(item) for item in self._relations if not bool(item.get("is_inactive", 0))]
|
||||
if not keyword:
|
||||
return rows
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if keyword in str(row.get("subject", "") or "").lower()
|
||||
or keyword in str(row.get("object", "") or "").lower()
|
||||
or keyword in str(row.get("predicate", "") or "").lower()
|
||||
or keyword in str(row.get("hash", "") or "").lower()
|
||||
]
|
||||
raise AssertionError(f"unexpected query: {sql_token}")
|
||||
|
||||
|
||||
def _build_kernel(*, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> SDKMemoryKernel:
|
||||
kernel = SDKMemoryKernel(plugin_root=Path.cwd(), config={})
|
||||
|
||||
async def _fake_initialize() -> None:
|
||||
return None
|
||||
|
||||
kernel.initialize = _fake_initialize # type: ignore[method-assign]
|
||||
kernel.metadata_store = _DummyMetadataStore(entities=entities, relations=relations)
|
||||
kernel.graph_store = object() # type: ignore[assignment]
|
||||
return kernel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_graph_admin_search_orders_and_dedupes_results() -> None:
|
||||
kernel = _build_kernel(
|
||||
entities=[
|
||||
{"hash": "e1", "name": "Alice", "appearance_count": 5, "is_deleted": 0},
|
||||
{"hash": "e1", "name": "Alice Duplicate", "appearance_count": 99, "is_deleted": 0},
|
||||
{"hash": "e2", "name": "Alice Cooper", "appearance_count": 7, "is_deleted": 0},
|
||||
{"hash": "e3", "name": "my alice note", "appearance_count": 11, "is_deleted": 0},
|
||||
{"hash": "e4", "name": "alice deleted", "appearance_count": 100, "is_deleted": 1},
|
||||
],
|
||||
relations=[
|
||||
{"hash": "r1", "subject": "Alice", "predicate": "knows", "object": "Bob", "confidence": 0.6, "created_at": 100, "is_inactive": 0},
|
||||
{"hash": "r3", "subject": "Alice", "predicate": "supports", "object": "Carol", "confidence": 0.9, "created_at": 90, "is_inactive": 0},
|
||||
{"hash": "r1", "subject": "Alice", "predicate": "knows duplicate", "object": "Bob", "confidence": 0.99, "created_at": 200, "is_inactive": 0},
|
||||
{"hash": "r2", "subject": "Alice Cooper", "predicate": "likes", "object": "Tea", "confidence": 0.2, "created_at": 50, "is_inactive": 0},
|
||||
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.8, "created_at": 70, "is_inactive": 0},
|
||||
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.3, "created_at": 10, "is_inactive": 0},
|
||||
{"hash": "r4", "subject": "alice inactive", "predicate": "old", "object": "Data", "confidence": 1.0, "created_at": 300, "is_inactive": 1},
|
||||
],
|
||||
)
|
||||
|
||||
payload = await kernel.memory_graph_admin(action="search", query="alice", limit=20)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["count"] == len(payload["items"])
|
||||
entity_items = [item for item in payload["items"] if item["type"] == "entity"]
|
||||
relation_items = [item for item in payload["items"] if item["type"] == "relation"]
|
||||
|
||||
assert [item["entity_hash"] for item in entity_items] == ["e1", "e2", "e3"]
|
||||
assert [item["relation_hash"] for item in relation_items] == ["r3", "r1", "r2", ""]
|
||||
assert relation_items[0]["confidence"] == pytest.approx(0.9)
|
||||
assert relation_items[1]["confidence"] == pytest.approx(0.6)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_graph_admin_search_filters_deleted_and_inactive_records() -> None:
|
||||
kernel = _build_kernel(
|
||||
entities=[
|
||||
{"hash": "e-deleted", "name": "Ghost Alice", "appearance_count": 10, "is_deleted": 1},
|
||||
],
|
||||
relations=[
|
||||
{
|
||||
"hash": "r-inactive",
|
||||
"subject": "Ghost Alice",
|
||||
"predicate": "linked",
|
||||
"object": "Ghost Bob",
|
||||
"confidence": 0.9,
|
||||
"created_at": 10,
|
||||
"is_inactive": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
payload = await kernel.memory_graph_admin(action="search", query="ghost", limit=50)
|
||||
|
||||
assert payload["success"] is True
|
||||
assert payload["items"] == []
|
||||
assert payload["count"] == 0
|
||||
281
pytests/A_memorix_test/test_memory_service.py
Normal file
281
pytests/A_memorix_test/test_memory_service.py
Normal file
@@ -0,0 +1,281 @@
|
||||
import pytest
|
||||
|
||||
from src.services.memory_service import MemorySearchResult, MemoryService
|
||||
|
||||
|
||||
def test_coerce_write_result_treats_skipped_payload_as_success():
|
||||
result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"})
|
||||
|
||||
assert result.success is True
|
||||
assert result.stored_ids == []
|
||||
assert result.skipped_ids == ["p1"]
|
||||
assert result.detail == "chat_filtered"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_admin_invokes_plugin(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args, kwargs))
|
||||
return {"success": True, "nodes": [], "edges": []}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.graph_admin(action="get_graph", limit=12)
|
||||
|
||||
assert result["success"] is True
|
||||
assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args))
|
||||
return {"success": True, "items": [{"hash": "abc"}], "count": 1}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.get_recycle_bin(limit=5)
|
||||
|
||||
assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1}
|
||||
assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_respects_filter_by_default(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args))
|
||||
return {"summary": "ok", "hits": [], "filtered": True}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.search(
|
||||
"mai",
|
||||
chat_id="stream-1",
|
||||
person_id="person-1",
|
||||
user_id="user-1",
|
||||
group_id="",
|
||||
)
|
||||
|
||||
assert isinstance(result, MemorySearchResult)
|
||||
assert result.filtered is True
|
||||
assert calls == [
|
||||
(
|
||||
"search_memory",
|
||||
{
|
||||
"query": "mai",
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "person-1",
|
||||
"time_start": None,
|
||||
"time_end": None,
|
||||
"respect_filter": True,
|
||||
"user_id": "user-1",
|
||||
"group_id": "",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_summary_can_bypass_filter(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args))
|
||||
return {"success": True, "stored_ids": ["p1"], "detail": ""}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.ingest_summary(
|
||||
external_id="chat_history:1",
|
||||
chat_id="stream-1",
|
||||
text="summary",
|
||||
respect_filter=False,
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert calls == [
|
||||
(
|
||||
"ingest_summary",
|
||||
{
|
||||
"external_id": "chat_history:1",
|
||||
"chat_id": "stream-1",
|
||||
"text": "summary",
|
||||
"participants": [],
|
||||
"time_start": None,
|
||||
"time_end": None,
|
||||
"tags": [],
|
||||
"metadata": {},
|
||||
"respect_filter": False,
|
||||
"user_id": "user-1",
|
||||
"group_id": "",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_v5_admin_invokes_plugin(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args, kwargs))
|
||||
return {"success": True, "count": 1}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.v5_admin(action="status", target="mai", limit=5)
|
||||
|
||||
assert result["success"] is True
|
||||
assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_admin_uses_long_timeout(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args, kwargs))
|
||||
return {"success": True, "operation_id": "del-1"}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"})
|
||||
|
||||
assert result["success"] is True
|
||||
assert calls == [
|
||||
(
|
||||
"memory_delete_admin",
|
||||
{"action": "execute", "mode": "relation", "selector": {"query": "mai"}},
|
||||
{"timeout_ms": 120000},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_empty_when_query_and_time_missing_async():
|
||||
service = MemoryService()
|
||||
|
||||
result = await service.search("", time_start=None, time_end=None)
|
||||
|
||||
assert isinstance(result, MemorySearchResult)
|
||||
assert result.summary == ""
|
||||
assert result.hits == []
|
||||
assert result.filtered is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_accepts_string_time_bounds(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args))
|
||||
return {"summary": "ok", "hits": [], "filtered": False}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.search(
|
||||
"广播站",
|
||||
mode="time",
|
||||
time_start="2026/03/18",
|
||||
time_end="2026/03/18 09:30",
|
||||
)
|
||||
|
||||
assert isinstance(result, MemorySearchResult)
|
||||
assert calls == [
|
||||
(
|
||||
"search_memory",
|
||||
{
|
||||
"query": "广播站",
|
||||
"limit": 5,
|
||||
"mode": "time",
|
||||
"chat_id": "",
|
||||
"person_id": "",
|
||||
"time_start": "2026/03/18",
|
||||
"time_end": "2026/03/18 09:30",
|
||||
"respect_filter": True,
|
||||
"user_id": "",
|
||||
"group_id": "",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_coerce_search_result_preserves_aggregate_source_branches():
|
||||
result = MemoryService._coerce_search_result(
|
||||
{
|
||||
"hits": [
|
||||
{
|
||||
"content": "广播站值夜班",
|
||||
"type": "paragraph",
|
||||
"metadata": {"event_time_start": 1.0},
|
||||
"source_branches": ["search", "time"],
|
||||
"rank": 1,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert result.hits[0].metadata["source_branches"] == ["search", "time"]
|
||||
assert result.hits[0].metadata["rank"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_admin_uses_long_timeout(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args, kwargs))
|
||||
return {"success": True, "task_id": "import-1"}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.import_admin(action="create_lpmm_openie", alias="lpmm")
|
||||
|
||||
assert result["success"] is True
|
||||
assert calls == [
|
||||
(
|
||||
"memory_import_admin",
|
||||
{"action": "create_lpmm_openie", "alias": "lpmm"},
|
||||
{"timeout_ms": 120000},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tuning_admin_uses_long_timeout(monkeypatch):
|
||||
service = MemoryService()
|
||||
calls = []
|
||||
|
||||
async def fake_invoke(component_name, args=None, **kwargs):
|
||||
calls.append((component_name, args, kwargs))
|
||||
return {"success": True, "task_id": "tuning-1"}
|
||||
|
||||
monkeypatch.setattr(service, "_invoke", fake_invoke)
|
||||
|
||||
result = await service.tuning_admin(action="create_task", payload={"query": "mai"})
|
||||
|
||||
assert result["success"] is True
|
||||
assert calls == [
|
||||
(
|
||||
"memory_tuning_admin",
|
||||
{"action": "create_task", "payload": {"query": "mai"}},
|
||||
{"timeout_ms": 120000},
|
||||
)
|
||||
]
|
||||
21
pytests/A_memorix_test/test_metadata_store_sources.py
Normal file
21
pytests/A_memorix_test/test_metadata_store_sources.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.A_memorix.core.storage.metadata_store import MetadataStore
|
||||
|
||||
|
||||
def test_get_all_sources_ignores_soft_deleted_paragraphs(tmp_path: Path) -> None:
|
||||
store = MetadataStore(data_dir=tmp_path)
|
||||
store.connect()
|
||||
try:
|
||||
live_hash = store.add_paragraph("Alice 喜欢地图", source="live-source")
|
||||
deleted_hash = store.add_paragraph("Bob 喜欢咖啡", source="deleted-source")
|
||||
|
||||
assert live_hash
|
||||
store.mark_as_deleted([deleted_hash], "paragraph")
|
||||
|
||||
sources = store.get_all_sources()
|
||||
finally:
|
||||
store.close()
|
||||
|
||||
assert [item["source"] for item in sources] == ["live-source"]
|
||||
assert sources[0]["count"] == 1
|
||||
81
pytests/A_memorix_test/test_person_memory_writeback.py
Normal file
81
pytests/A_memorix_test/test_person_memory_writeback.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.person_info import person_info as person_info_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakePerson:
|
||||
def __init__(self, person_id: str):
|
||||
self.person_id = person_id
|
||||
self.person_name = "Alice"
|
||||
self.is_known = True
|
||||
|
||||
async def fake_ingest_text(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||
|
||||
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
|
||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
|
||||
monkeypatch.setattr(person_info_module, "Person", FakePerson)
|
||||
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
|
||||
|
||||
assert len(calls) == 1
|
||||
payload = calls[0]
|
||||
assert payload["external_id"].startswith("person_fact:person-1:")
|
||||
assert payload["source_type"] == "person_fact"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["person_ids"] == ["person-1"]
|
||||
assert payload["participants"] == ["Alice"]
|
||||
assert payload["respect_filter"] is True
|
||||
assert payload["user_id"] == "10001"
|
||||
assert payload["group_id"] == ""
|
||||
assert payload["metadata"]["person_id"] == "person-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch):
|
||||
calls = []
|
||||
|
||||
class FakePerson:
|
||||
def __init__(self, person_id: str):
|
||||
self.person_id = person_id
|
||||
self.person_name = "Unknown"
|
||||
self.is_known = False
|
||||
|
||||
async def fake_ingest_text(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||
|
||||
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
|
||||
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
|
||||
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
|
||||
monkeypatch.setattr(person_info_module, "Person", FakePerson)
|
||||
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
|
||||
|
||||
assert calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch):
|
||||
calls = []
|
||||
|
||||
async def fake_ingest_text(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
|
||||
|
||||
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
|
||||
|
||||
await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1")
|
||||
|
||||
assert calls == []
|
||||
|
||||
115
pytests/A_memorix_test/test_person_profile_service.py
Normal file
115
pytests/A_memorix_test/test_person_profile_service.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.A_memorix.core.utils.person_profile_service import PersonProfileService
|
||||
|
||||
|
||||
class FakeMetadataStore:
|
||||
def __init__(self) -> None:
|
||||
self.snapshots: list[dict] = []
|
||||
|
||||
@staticmethod
|
||||
def get_latest_person_profile_snapshot(person_id: str):
|
||||
del person_id
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_relations(**kwargs):
|
||||
del kwargs
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_paragraphs_by_source(source: str):
|
||||
if source == "person_fact:person-1":
|
||||
return [
|
||||
{
|
||||
"hash": "person-fact-1",
|
||||
"content": "测试用户喜欢猫。",
|
||||
"source": source,
|
||||
"metadata": {"source_type": "person_fact"},
|
||||
"created_at": 2.0,
|
||||
"updated_at": 2.0,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph(hash_value: str):
|
||||
if hash_value == "chat-summary-1":
|
||||
return {
|
||||
"hash": hash_value,
|
||||
"content": "机器人建议测试用户以后叫星灯。",
|
||||
"source": "chat_summary:session-1",
|
||||
"metadata": {"source_type": "chat_summary"},
|
||||
"word_count": 1,
|
||||
}
|
||||
if hash_value == "person-fact-1":
|
||||
return {
|
||||
"hash": hash_value,
|
||||
"content": "测试用户喜欢猫。",
|
||||
"source": "person_fact:person-1",
|
||||
"metadata": {"source_type": "person_fact"},
|
||||
"word_count": 1,
|
||||
}
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph_stale_relation_marks_batch(paragraph_hashes):
|
||||
del paragraph_hashes
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_relation_status_batch(relation_hashes):
|
||||
del relation_hashes
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_person_profile_override(person_id: str):
|
||||
del person_id
|
||||
return None
|
||||
|
||||
def upsert_person_profile_snapshot(self, **kwargs):
|
||||
self.snapshots.append(kwargs)
|
||||
return {
|
||||
"person_id": kwargs["person_id"],
|
||||
"profile_text": kwargs["profile_text"],
|
||||
"aliases": kwargs["aliases"],
|
||||
"relation_edges": kwargs["relation_edges"],
|
||||
"vector_evidence": kwargs["vector_evidence"],
|
||||
"evidence_ids": kwargs["evidence_ids"],
|
||||
"updated_at": 1.0,
|
||||
"expires_at": kwargs["expires_at"],
|
||||
"source_note": kwargs["source_note"],
|
||||
}
|
||||
|
||||
|
||||
class FakeRetriever:
|
||||
async def retrieve(self, query: str, top_k: int):
|
||||
del query, top_k
|
||||
return [
|
||||
SimpleNamespace(
|
||||
hash_value="chat-summary-1",
|
||||
result_type="paragraph",
|
||||
score=0.95,
|
||||
content="机器人建议测试用户以后叫星灯。",
|
||||
metadata={"source_type": "chat_summary"},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_person_profile_keeps_chat_summary_as_recent_interaction_not_stable_profile():
|
||||
metadata_store = FakeMetadataStore()
|
||||
service = PersonProfileService(metadata_store=metadata_store, retriever=FakeRetriever())
|
||||
service.get_person_aliases = lambda person_id: (["测试用户"], "测试用户", [])
|
||||
|
||||
payload = await service.query_person_profile(person_id="person-1", top_k=6, force_refresh=True)
|
||||
|
||||
assert payload["success"] is True
|
||||
profile_text = payload["profile_text"]
|
||||
stable_section = profile_text.split("近期相关互动:", 1)[0]
|
||||
assert "测试用户喜欢猫" in stable_section
|
||||
assert "星灯" not in stable_section
|
||||
assert "近期相关互动:" in profile_text
|
||||
assert "星灯" in profile_text
|
||||
184
pytests/A_memorix_test/test_query_long_term_memory_tool.py
Normal file
184
pytests/A_memorix_test/test_query_long_term_memory_tool.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from src.memory_system.retrieval_tools import query_long_term_memory as tool_module
|
||||
from src.memory_system.retrieval_tools import init_all_tools
|
||||
from src.memory_system.retrieval_tools.query_long_term_memory import (
|
||||
_resolve_time_expression,
|
||||
query_long_term_memory,
|
||||
register_tool,
|
||||
)
|
||||
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
|
||||
from src.services.memory_service import MemoryHit, MemorySearchResult
|
||||
|
||||
|
||||
def test_resolve_time_expression_supports_relative_and_absolute_inputs():
|
||||
now = datetime(2026, 3, 18, 15, 30)
|
||||
|
||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now)
|
||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
|
||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
||||
assert start_text == "2026/03/18 00:00"
|
||||
assert end_text == "2026/03/18 23:59"
|
||||
|
||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now)
|
||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0)
|
||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
||||
assert start_text == "2026/03/12 00:00"
|
||||
assert end_text == "2026/03/18 23:59"
|
||||
|
||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now)
|
||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
|
||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
|
||||
assert start_text == "2026/03/18 00:00"
|
||||
assert end_text == "2026/03/18 23:59"
|
||||
|
||||
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now)
|
||||
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30)
|
||||
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30)
|
||||
assert start_text == "2026/03/18 09:30"
|
||||
assert end_text == "2026/03/18 09:30"
|
||||
|
||||
|
||||
def test_register_tool_exposes_mode_and_time_expression():
|
||||
register_tool()
|
||||
tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||
|
||||
assert tool is not None
|
||||
params = {item["name"]: item for item in tool.parameters}
|
||||
assert "mode" in params
|
||||
assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"]
|
||||
assert "time_expression" in params
|
||||
assert params["query"]["required"] is False
|
||||
|
||||
|
||||
def test_init_all_tools_registers_long_term_memory_tool():
|
||||
init_all_tools()
|
||||
|
||||
tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||
assert tool is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_long_term_memory_search_mode_keeps_search(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
async def fake_search(query, **kwargs):
|
||||
captured["query"] = query
|
||||
captured["kwargs"] = kwargs
|
||||
return MemorySearchResult(
|
||||
hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
||||
|
||||
text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1")
|
||||
|
||||
assert "Alice 喜欢猫" in text
|
||||
assert captured == {
|
||||
"query": "Alice 喜欢什么",
|
||||
"kwargs": {
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "person-1",
|
||||
"time_start": None,
|
||||
"time_end": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
async def fake_search(query, **kwargs):
|
||||
captured["query"] = query
|
||||
captured["kwargs"] = kwargs
|
||||
return MemorySearchResult(
|
||||
hits=[
|
||||
MemoryHit(
|
||||
content="昨天晚上广播站停播了十分钟。",
|
||||
score=0.8,
|
||||
hit_type="paragraph",
|
||||
metadata={"event_time_start": 1773797400.0},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
||||
monkeypatch.setattr(
|
||||
tool_module,
|
||||
"_resolve_time_expression",
|
||||
lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"),
|
||||
)
|
||||
|
||||
text = await query_long_term_memory(
|
||||
query="广播站",
|
||||
mode="time",
|
||||
time_expression="昨天",
|
||||
chat_id="stream-1",
|
||||
)
|
||||
|
||||
assert "指定时间范围" in text
|
||||
assert "广播站停播" in text
|
||||
assert captured == {
|
||||
"query": "广播站",
|
||||
"kwargs": {
|
||||
"limit": 5,
|
||||
"mode": "time",
|
||||
"chat_id": "stream-1",
|
||||
"person_id": "",
|
||||
"time_start": 1773795600.0,
|
||||
"time_end": 1773881940.0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch):
|
||||
responses = {
|
||||
"episode": MemorySearchResult(
|
||||
hits=[
|
||||
MemoryHit(
|
||||
content="苏弦在灯塔拆开了那封冬信。",
|
||||
title="冬信重见天日",
|
||||
hit_type="episode",
|
||||
metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]},
|
||||
)
|
||||
]
|
||||
),
|
||||
"aggregate": MemorySearchResult(
|
||||
hits=[
|
||||
MemoryHit(
|
||||
content="唐未在广播站值夜班时带着黑狗墨点。",
|
||||
hit_type="paragraph",
|
||||
metadata={"source_branches": ["search", "time"]},
|
||||
)
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
async def fake_search(query, **kwargs):
|
||||
return responses[kwargs["mode"]]
|
||||
|
||||
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
|
||||
|
||||
episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode")
|
||||
aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate")
|
||||
|
||||
assert "事件《冬信重见天日》" in episode_text
|
||||
assert "参与者:苏弦" in episode_text
|
||||
assert "[search,time][paragraph]" in aggregate_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message():
|
||||
text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周")
|
||||
|
||||
assert "无法解析" in text
|
||||
assert "最近7天" in text
|
||||
140
pytests/A_memorix_test/test_summary_importer_model_config.py
Normal file
140
pytests/A_memorix_test/test_summary_importer_model_config.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import pytest
|
||||
|
||||
from src.A_memorix.core.utils.summary_importer import (
|
||||
SummaryImporter,
|
||||
_message_timestamp,
|
||||
_normalize_entity_items,
|
||||
_normalize_relation_items,
|
||||
)
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
|
||||
def _fake_available_models() -> dict[str, TaskConfig]:
|
||||
return {
|
||||
"memory": TaskConfig(
|
||||
model_list=["memory-model"],
|
||||
max_tokens=512,
|
||||
temperature=0.4,
|
||||
selection_strategy="random",
|
||||
),
|
||||
"utils": TaskConfig(
|
||||
model_list=["utils-model"],
|
||||
max_tokens=256,
|
||||
temperature=0.5,
|
||||
selection_strategy="random",
|
||||
),
|
||||
"replyer": TaskConfig(
|
||||
model_list=["replyer-model"],
|
||||
max_tokens=128,
|
||||
temperature=0.7,
|
||||
selection_strategy="random",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(monkeypatch):
|
||||
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
|
||||
|
||||
importer = SummaryImporter(
|
||||
vector_store=None,
|
||||
graph_store=None,
|
||||
metadata_store=None,
|
||||
embedding_manager=None,
|
||||
plugin_config={},
|
||||
)
|
||||
|
||||
resolved = importer._resolve_summary_model_config()
|
||||
|
||||
assert resolved is not None
|
||||
assert resolved.model_list == ["memory-model"]
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch):
|
||||
importer = SummaryImporter(
|
||||
vector_store=None,
|
||||
graph_store=None,
|
||||
metadata_store=None,
|
||||
embedding_manager=None,
|
||||
plugin_config={},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
llm_api,
|
||||
"get_available_models",
|
||||
lambda: {
|
||||
"utils": TaskConfig(model_list=["utils-model"]),
|
||||
"planner": TaskConfig(model_list=["planner-model"]),
|
||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
||||
},
|
||||
)
|
||||
resolved = importer._resolve_summary_model_config()
|
||||
assert resolved is not None
|
||||
assert resolved.model_list == ["utils-model"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
llm_api,
|
||||
"get_available_models",
|
||||
lambda: {
|
||||
"planner": TaskConfig(model_list=["planner-model"]),
|
||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
||||
},
|
||||
)
|
||||
resolved = importer._resolve_summary_model_config()
|
||||
assert resolved is not None
|
||||
assert resolved.model_list == ["planner-model"]
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
llm_api,
|
||||
"get_available_models",
|
||||
lambda: {
|
||||
"replyer": TaskConfig(model_list=["replyer-model"]),
|
||||
"embedding": TaskConfig(model_list=["embedding-model"]),
|
||||
},
|
||||
)
|
||||
|
||||
importer = SummaryImporter(
|
||||
vector_store=None,
|
||||
graph_store=None,
|
||||
metadata_store=None,
|
||||
embedding_manager=None,
|
||||
plugin_config={},
|
||||
)
|
||||
|
||||
assert importer._resolve_summary_model_config() is None
|
||||
|
||||
|
||||
def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch):
|
||||
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
|
||||
|
||||
importer = SummaryImporter(
|
||||
vector_store=None,
|
||||
graph_store=None,
|
||||
metadata_store=None,
|
||||
embedding_manager=None,
|
||||
plugin_config={"summarization": {"model_name": "auto"}},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="List\\[str\\]"):
|
||||
importer._resolve_summary_model_config()
|
||||
|
||||
|
||||
def test_summary_importer_normalizes_llm_entities_and_relations():
|
||||
assert _normalize_entity_items(["Alice", {"name": "地图"}, ["bad"], "Alice"]) == ["Alice", "地图"]
|
||||
assert _normalize_entity_items("Alice") == []
|
||||
assert _normalize_relation_items(
|
||||
[
|
||||
{"subject": "Alice", "predicate": "持有", "object": "地图"},
|
||||
{"subject": "Alice", "predicate": "", "object": "地图"},
|
||||
["bad"],
|
||||
]
|
||||
) == [{"subject": "Alice", "predicate": "持有", "object": "地图"}]
|
||||
|
||||
|
||||
def test_summary_importer_message_timestamp_accepts_time_fallback():
|
||||
class Message:
|
||||
time = 123.5
|
||||
|
||||
assert _message_timestamp(Message()) == 123.5
|
||||
182
pytests/A_memorix_test/test_web_import_manager_payloads.py
Normal file
182
pytests/A_memorix_test/test_web_import_manager_payloads.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from src.A_memorix.core.strategies.base import ChunkContext, KnowledgeType, ProcessedChunk, SourceInfo
|
||||
from src.A_memorix.core.utils.web_import_manager import (
|
||||
ImportChunkRecord,
|
||||
ImportFileRecord,
|
||||
ImportTaskManager,
|
||||
ImportTaskRecord,
|
||||
)
|
||||
|
||||
|
||||
class _DummyMetadataStore:
|
||||
def __init__(self) -> None:
|
||||
self.paragraphs: list[dict[str, object]] = []
|
||||
self.entities: list[str] = []
|
||||
self.relations: list[tuple[str, str, str]] = []
|
||||
|
||||
def add_paragraph(self, **kwargs):
|
||||
self.paragraphs.append(dict(kwargs))
|
||||
return f"paragraph-{len(self.paragraphs)}"
|
||||
|
||||
def add_entity(self, *, name: str, source_paragraph: str = "") -> str:
|
||||
del source_paragraph
|
||||
self.entities.append(name)
|
||||
return f"entity-{name}"
|
||||
|
||||
def add_relation(self, *, subject: str, predicate: str, obj: str, **kwargs) -> str:
|
||||
del kwargs
|
||||
self.relations.append((subject, predicate, obj))
|
||||
return f"relation-{len(self.relations)}"
|
||||
|
||||
def set_relation_vector_state(self, rel_hash: str, state: str) -> None:
|
||||
del rel_hash, state
|
||||
|
||||
|
||||
class _DummyGraphStore:
|
||||
def __init__(self) -> None:
|
||||
self.nodes: list[list[str]] = []
|
||||
self.edges: list[list[tuple[str, str]]] = []
|
||||
|
||||
def add_nodes(self, nodes):
|
||||
self.nodes.append(list(nodes))
|
||||
|
||||
def add_edges(self, edges, relation_hashes=None):
|
||||
del relation_hashes
|
||||
self.edges.append(list(edges))
|
||||
|
||||
|
||||
class _DummyVectorStore:
|
||||
def __contains__(self, item: str) -> bool:
|
||||
del item
|
||||
return False
|
||||
|
||||
def add(self, vectors, ids):
|
||||
del vectors, ids
|
||||
|
||||
|
||||
class _DummyEmbeddingManager:
|
||||
async def encode(self, text: str) -> np.ndarray:
|
||||
del text
|
||||
return np.ones(4, dtype=np.float32)
|
||||
|
||||
|
||||
def _build_manager() -> tuple[ImportTaskManager, _DummyMetadataStore]:
|
||||
metadata_store = _DummyMetadataStore()
|
||||
plugin = SimpleNamespace(
|
||||
metadata_store=metadata_store,
|
||||
graph_store=_DummyGraphStore(),
|
||||
vector_store=_DummyVectorStore(),
|
||||
embedding_manager=_DummyEmbeddingManager(),
|
||||
relation_write_service=None,
|
||||
get_config=lambda key, default=None: default,
|
||||
_is_embedding_degraded=lambda: False,
|
||||
_allow_metadata_only_write=lambda: True,
|
||||
write_paragraph_vector_or_enqueue=None,
|
||||
)
|
||||
manager = ImportTaskManager(plugin)
|
||||
return manager, metadata_store
|
||||
|
||||
|
||||
def _build_progress_task(task_id: str, total_chunks: int = 2) -> ImportTaskRecord:
|
||||
file_record = ImportFileRecord(
|
||||
file_id="file-1",
|
||||
name="demo.txt",
|
||||
source_kind="paste",
|
||||
input_mode="text",
|
||||
total_chunks=total_chunks,
|
||||
chunks=[
|
||||
ImportChunkRecord(chunk_id=f"chunk-{index}", index=index, chunk_type="text")
|
||||
for index in range(total_chunks)
|
||||
],
|
||||
)
|
||||
return ImportTaskRecord(task_id=task_id, source="paste", params={}, files=[file_record])
|
||||
|
||||
|
||||
def _build_chunk(data) -> ProcessedChunk:
|
||||
return ProcessedChunk(
|
||||
type=KnowledgeType.FACTUAL,
|
||||
source=SourceInfo(file="demo.txt", offset_start=0, offset_end=4),
|
||||
chunk=ChunkContext(chunk_id="chunk-1", index=0, text="Alice 持有地图"),
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_processed_chunk_rejects_non_object_before_paragraph_write() -> None:
|
||||
manager, metadata_store = _build_manager()
|
||||
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
|
||||
|
||||
with pytest.raises(ValueError, match="分块抽取结果 必须返回 JSON 对象"):
|
||||
await manager._persist_processed_chunk(file_record, _build_chunk(["bad"]))
|
||||
|
||||
assert metadata_store.paragraphs == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_terminal_progress_uses_successful_chunks_only() -> None:
|
||||
manager, _ = _build_manager()
|
||||
|
||||
task = _build_progress_task("task-fail-then-complete")
|
||||
manager._tasks[task.task_id] = task
|
||||
|
||||
await manager._set_chunk_failed(task.task_id, "file-1", "chunk-0", "boom")
|
||||
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-1")
|
||||
|
||||
file_record = task.files[0]
|
||||
assert file_record.done_chunks == 1
|
||||
assert file_record.failed_chunks == 1
|
||||
assert file_record.progress == pytest.approx(0.5)
|
||||
assert task.progress == pytest.approx(0.5)
|
||||
|
||||
reverse_task = _build_progress_task("task-complete-then-fail")
|
||||
manager._tasks[reverse_task.task_id] = reverse_task
|
||||
|
||||
await manager._set_chunk_completed(reverse_task.task_id, "file-1", "chunk-0")
|
||||
await manager._set_chunk_failed(reverse_task.task_id, "file-1", "chunk-1", "boom")
|
||||
|
||||
reverse_file = reverse_task.files[0]
|
||||
assert reverse_file.done_chunks == 1
|
||||
assert reverse_file.failed_chunks == 1
|
||||
assert reverse_file.progress == pytest.approx(0.5)
|
||||
assert reverse_task.progress == pytest.approx(0.5)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_chunks_do_not_increase_file_progress() -> None:
|
||||
manager, _ = _build_manager()
|
||||
task = _build_progress_task("task-cancelled-progress", total_chunks=3)
|
||||
manager._tasks[task.task_id] = task
|
||||
|
||||
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-0")
|
||||
await manager._set_chunk_cancelled(task.task_id, "file-1", "chunk-1", "任务已取消")
|
||||
|
||||
file_record = task.files[0]
|
||||
assert file_record.done_chunks == 1
|
||||
assert file_record.cancelled_chunks == 1
|
||||
assert file_record.progress == pytest.approx(1 / 3)
|
||||
assert task.progress == pytest.approx(1 / 3)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_processed_chunk_skips_invalid_nested_items() -> None:
|
||||
manager, metadata_store = _build_manager()
|
||||
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
|
||||
|
||||
await manager._persist_processed_chunk(
|
||||
file_record,
|
||||
_build_chunk(
|
||||
{
|
||||
"triples": [{"subject": "Alice", "predicate": "持有", "object": "地图"}, ["bad"]],
|
||||
"relations": [{"subject": "Alice", "predicate": "", "object": "地图"}],
|
||||
"entities": ["Alice", {"name": "地图"}, ["bad"]],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
assert len(metadata_store.paragraphs) == 1
|
||||
assert set(metadata_store.entities) >= {"Alice", "地图"}
|
||||
assert metadata_store.relations == [("Alice", "持有", "地图")]
|
||||
Reference in New Issue
Block a user