chore: import deployable mai-bot source tree

This commit is contained in:
2026-05-11 00:51:12 +00:00
parent 4813699b3e
commit 7a54015f94
1009 changed files with 312999 additions and 16 deletions

View File

@@ -0,0 +1,398 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List
import asyncio
import inspect
import json
import pickle
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
import numpy as np
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from src.A_memorix.core.utils import summary_importer as summary_importer_module
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.database import database as database_module
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.message_repository import count_messages
from src.config.model_configs import TaskConfig
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services import send_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
summary_importer_module = None # type: ignore[assignment]
BotChatSession = None # type: ignore[assignment]
SessionMessage = None # type: ignore[assignment]
MessageInfo = None # type: ignore[assignment]
UserInfo = None # type: ignore[assignment]
MessageSequence = None # type: ignore[assignment]
TextComponent = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
count_messages = None # type: ignore[assignment]
TaskConfig = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
send_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(
self,
component_name: str,
args: Dict[str, Any] | None,
*,
timeout_ms: int = 30000,
) -> Any:
del timeout_ms
payload = args or {}
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
class _FakePlatformIOManager:
def __init__(self) -> None:
self.ensure_calls = 0
async def ensure_send_pipeline_ready(self) -> None:
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
del message, metadata
return SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=route_key,
)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_incoming_message(
*,
session_id: str,
user_id: str,
text: str,
timestamp: datetime | None = None,
) -> SessionMessage:
message = SessionMessage(
message_id="incoming-message-id",
timestamp=timestamp or datetime.now(),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id=user_id,
user_nickname="测试用户",
user_cardname="测试用户",
),
additional_config={},
)
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
message.session_id = session_id
message.reply_to = None
message.is_mentioned = False
message.is_at = False
message.is_emoji = False
message.is_picture = False
message.is_command = False
message.is_notify = False
message.processed_plain_text = text
message.initialized = True
return message
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
@pytest.mark.asyncio
async def test_text_to_stream_triggers_real_chat_summary_writeback(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
_install_temp_main_database(monkeypatch, tmp_path)
fake_embedding_manager = _FakeEmbeddingManager()
captured_prompts: List[str] = []
fixed_send_timestamp = 1_777_000_000.0
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
del kwargs
return {
"ok": True,
"message": "ok",
"configured_dimension": fake_embedding_manager.default_dimension,
"requested_dimension": fake_embedding_manager.default_dimension,
"vector_store_dimension": fake_embedding_manager.default_dimension,
"detected_dimension": fake_embedding_manager.default_dimension,
"encoded_dimension": fake_embedding_manager.default_dimension,
"elapsed_ms": 0.0,
"sample_text": "test",
"checked_at": datetime.now().timestamp(),
}
async def _fake_generate(request: Any) -> Any:
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
return SimpleNamespace(
success=True,
completion=SimpleNamespace(
response=json.dumps(
{
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
"entities": ["绿色围巾"],
"relations": [],
},
ensure_ascii=False,
)
),
)
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"get_available_models",
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"resolve_task_name_from_model_config",
lambda model_config: "utils",
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"generate",
_fake_generate,
)
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
"summarization": {"model_name": ["utils"]},
},
)
service = memory_flow_service_module.MemoryAutomationService()
fake_platform_io_manager = _FakePlatformIOManager()
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
return {
"rebuilt": 0,
"items": [],
"failures": [],
"sources": list(sources),
}
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
monkeypatch.setattr(
memory_service_module,
"a_memorix_host_service",
_KernelBackedRuntimeManager(kernel),
)
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: (
BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
if stream_id == "test-session"
else None
),
)
integration_config = memory_flow_service_module.global_config.a_memorix.integration
monkeypatch.setattr(integration_config, "chat_summary_writeback_enabled", True, raising=False)
monkeypatch.setattr(integration_config, "chat_summary_writeback_message_threshold", 2, raising=False)
monkeypatch.setattr(integration_config, "chat_summary_writeback_context_length", 10, raising=False)
monkeypatch.setattr(integration_config, "person_fact_writeback_enabled", False, raising=False)
await kernel.initialize()
try:
incoming_message = _build_incoming_message(
session_id="test-session",
user_id="target-user",
text="我最近买了一条绿色围巾。",
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
)
with database_module.get_db_session() as session:
session.add(incoming_message.to_db_instance())
sent_message = await send_service.text_to_stream_with_message(
text="好的,我会记住你最近买了绿色围巾。",
stream_id="test-session",
storage_message=True,
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_platform_io_manager.ensure_calls == 1
assert count_messages(session_id="test-session") == 2
paragraphs = await _wait_until(
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
description="等待聊天摘要写回到 A_memorix",
)
assert captured_prompts
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
assert any(
int(
(
pickle.loads(item.get("metadata"))
if isinstance(item.get("metadata"), (bytes, bytearray))
else item.get("metadata")
or {}
).get("trigger_message_count", 0)
or 0
)
== 2
for item in paragraphs
)
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
finally:
await service.shutdown()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass

View File

@@ -0,0 +1,191 @@
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import pytest
from A_memorix.core.embedding import api_adapter as api_adapter_module
from A_memorix.core.embedding.api_adapter import EmbeddingAPIAdapter
from A_memorix.core.utils.runtime_self_check import run_embedding_runtime_self_check
class _FakeEmbeddingClient:
def __init__(self, *, natural_dimension: int = 12) -> None:
self.natural_dimension = int(natural_dimension)
self.requests = []
async def get_embedding(self, request):
self.requests.append(request)
requested_dimension = request.extra_params.get("dimensions")
if requested_dimension is None:
requested_dimension = request.extra_params.get("output_dimensionality")
dimension = int(requested_dimension or self.natural_dimension)
return SimpleNamespace(embedding=[1.0] * dimension)
def _build_adapter(
monkeypatch: pytest.MonkeyPatch,
*,
client_type: str,
configured_dimension: int = 1024,
effective_dimension: int | None = None,
model_extra_params: dict | None = None,
):
adapter = EmbeddingAPIAdapter(default_dimension=configured_dimension)
if effective_dimension is not None:
adapter._dimension = int(effective_dimension)
adapter._dimension_detected = True
fake_client = _FakeEmbeddingClient()
model_info = SimpleNamespace(
name="embedding-model",
api_provider="provider-1",
model_identifier="embedding-model-id",
extra_params=dict(model_extra_params or {}),
)
provider = SimpleNamespace(name="provider-1", client_type=client_type)
monkeypatch.setattr(adapter, "_resolve_candidate_model_names", lambda: ["embedding-model"])
monkeypatch.setattr(adapter, "_find_model_info", lambda model_name: model_info)
monkeypatch.setattr(adapter, "_find_provider", lambda provider_name: provider)
monkeypatch.setattr(
api_adapter_module.client_registry,
"get_client_class_instance",
lambda api_provider, force_new=True: fake_client,
)
return adapter, fake_client
@pytest.mark.asyncio
async def test_encode_uses_canonical_dimension_for_openai_provider(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="openai",
configured_dimension=1024,
effective_dimension=1024,
model_extra_params={"task_type": "SEMANTIC_SIMILARITY"},
)
embedding = await adapter.encode("北塔木梯")
request = fake_client.requests[-1]
assert request.extra_params["dimensions"] == 1024
assert "output_dimensionality" not in request.extra_params
assert request.extra_params["task_type"] == "SEMANTIC_SIMILARITY"
assert embedding.shape == (1024,)
@pytest.mark.asyncio
async def test_encode_explicit_dimension_override_wins(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="openai",
configured_dimension=1024,
effective_dimension=1024,
)
embedding = await adapter.encode("海潮图", dimensions=256)
request = fake_client.requests[-1]
assert request.extra_params["dimensions"] == 256
assert "output_dimensionality" not in request.extra_params
assert embedding.shape == (256,)
@pytest.mark.asyncio
async def test_encode_maps_dimension_to_gemini_output_dimensionality(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="gemini",
configured_dimension=1024,
effective_dimension=768,
)
embedding = await adapter.encode("广播站")
request = fake_client.requests[-1]
assert request.extra_params["output_dimensionality"] == 768
assert "dimensions" not in request.extra_params
assert embedding.shape == (768,)
@pytest.mark.asyncio
async def test_encode_does_not_force_dimension_for_unsupported_provider(monkeypatch):
adapter, fake_client = _build_adapter(
monkeypatch,
client_type="custom",
configured_dimension=1024,
effective_dimension=640,
model_extra_params={
"dimensions": 999,
"output_dimensionality": 888,
"custom_flag": "keep-me",
},
)
embedding = await adapter.encode("蓝漆铁盒")
request = fake_client.requests[-1]
assert "dimensions" not in request.extra_params
assert "output_dimensionality" not in request.extra_params
assert request.extra_params["custom_flag"] == "keep-me"
assert embedding.shape == (fake_client.natural_dimension,)
@pytest.mark.asyncio
async def test_runtime_self_check_reports_requested_dimension_without_explicit_override():
class _FakeEmbeddingManager:
def __init__(self) -> None:
self.detected_dimension = 384
self.encode_calls = []
async def _detect_dimension(self) -> int:
return self.detected_dimension
def get_requested_dimension(self) -> int:
return self.detected_dimension
async def encode(self, text):
self.encode_calls.append(text)
return np.ones(self.detected_dimension, dtype=np.float32)
manager = _FakeEmbeddingManager()
report = await run_embedding_runtime_self_check(
config={"embedding": {"dimension": 1024}},
vector_store=SimpleNamespace(dimension=384),
embedding_manager=manager,
)
assert report["ok"] is True
assert report["configured_dimension"] == 1024
assert report["requested_dimension"] == 384
assert report["detected_dimension"] == 384
assert report["encoded_dimension"] == 384
assert manager.encode_calls == ["A_Memorix runtime self check"]
@pytest.mark.asyncio
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
adapter._dimension = 4
adapter._dimension_detected = True
async def fake_detect_dimension() -> int:
return 4
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
del dimensions
base = float(ord(str(text)[0]))
return [base, base + 1.0, base + 2.0, base + 3.0]
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
assert embeddings.shape == (4, 4)
assert np.array_equal(embeddings[0], embeddings[2])
assert embeddings[1][0] == float(ord("B"))
assert embeddings[3][0] == float(ord("C"))

View File

@@ -0,0 +1,780 @@
from __future__ import annotations
import asyncio
import inspect
import json
import time
import uuid
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict
import numpy as np
import pytest
import pytest_asyncio
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine, select
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.heart_flow.heartflow_manager import heartflow_manager
from src.chat.message_receive import bot as bot_module
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.message_receive.bot import chat_bot
from src.common.database import database as database_module
from src.common.database.database_model import PersonInfo, ToolRecord
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.utils.utils_session import SessionUtils
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka import reasoning_engine as reasoning_engine_module
from src.maisaka import runtime as runtime_module
from src.maisaka import chat_loop_service as chat_loop_service_module
from src.maisaka.chat_loop_service import ChatResponse
from src.maisaka.context_messages import AssistantMessage
from src.plugin_runtime import component_query as component_query_module
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
KernelSearchRequest = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
heartflow_manager = None # type: ignore[assignment]
bot_module = None # type: ignore[assignment]
chat_manager = None # type: ignore[assignment]
chat_bot = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
ToolRecord = None # type: ignore[assignment]
PersonInfo = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
SessionUtils = None # type: ignore[assignment]
ToolCall = None # type: ignore[assignment]
reasoning_engine_module = None # type: ignore[assignment]
runtime_module = None # type: ignore[assignment]
chat_loop_service_module = None # type: ignore[assignment]
ChatResponse = None # type: ignore[assignment]
AssistantMessage = None # type: ignore[assignment]
component_query_module = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
memory_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
return ChatResponse(
content=content,
tool_calls=tool_calls,
request_messages=[],
raw_message=AssistantMessage(
content=content,
timestamp=datetime.now(),
tool_calls=tool_calls,
),
selected_history_count=0,
tool_count=len(tool_calls),
prompt_tokens=0,
built_message_count=0,
completion_tokens=0,
total_tokens=0,
prompt_section=None,
)
def _build_message_data(
*,
content: str,
platform: str,
user_id: str,
user_name: str,
group_id: str,
group_name: str,
) -> Dict[str, Any]:
message_id = str(uuid.uuid4())
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": user_id,
"user_nickname": user_name,
"user_cardname": user_name,
"platform": platform,
},
"additional_config": {
"at_bot": True,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
).fetchall()
tasks: list[Dict[str, Any]] = []
for row in rows:
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
if task is not None:
tasks.append(task)
return tasks
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
).fetchall()
return [str(row["action_type"] or "") for row in rows]
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
with database_module.get_db_session() as session:
statement = (
select(ToolRecord)
.where(ToolRecord.session_id == session_id)
.where(ToolRecord.tool_name == "query_memory")
.order_by(ToolRecord.timestamp)
)
rows = list(session.exec(statement).all())
return [
{
"tool_id": str(row.tool_id or ""),
"session_id": str(row.session_id or ""),
"tool_name": str(row.tool_name or ""),
"tool_data": str(row.tool_data or ""),
"timestamp": row.timestamp,
}
for row in rows
]
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
with database_module.get_db_session() as session:
session.add(
PersonInfo(
is_known=True,
person_id=person_id,
person_name=person_name,
platform=str(session_info["platform"]),
user_id=str(session_info["user_id"]),
user_nickname=str(session_info["user_name"]),
group_cardname=json.dumps(
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
ensure_ascii=False,
),
know_counts=1,
first_known_time=datetime.now(),
last_known_time=datetime.now(),
)
)
session.commit()
@pytest_asyncio.fixture
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
_install_temp_main_database(monkeypatch, tmp_path)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
heartflow_manager.heartflow_chat_list.clear()
noop_runtime_manager = _NoopRuntimeManager()
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
monkeypatch.setattr(
component_query_module.component_query_service,
"find_command_by_text",
lambda text: None,
)
monkeypatch.setattr(
component_query_module.component_query_service,
"get_llm_available_tool_specs",
lambda **kwargs: {},
)
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
monkeypatch.setattr(
runtime_module.MaisakaHeartFlowChatting,
"_get_message_trigger_threshold",
lambda self: 1,
)
async def _noop_on_incoming_message(message: Any) -> None:
del message
monkeypatch.setattr(
memory_flow_service_module.memory_automation_service,
"on_incoming_message",
_noop_on_incoming_message,
)
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
async def _fake_runtime_self_check(
*,
config: Any,
sample_text: str,
vector_store: Any,
embedding_manager: Any,
) -> Dict[str, Any]:
del config, sample_text, vector_store, embedding_manager
return {
"ok": True,
"message": "ok",
"checked_at": time.time(),
"encoded_dimension": fake_embedding_manager.default_dimension,
}
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
async def _fake_classify_feedback(
*,
query_tool_id: str,
query_text: str,
hit_briefs: list[Dict[str, Any]],
feedback_messages: list[str],
) -> Dict[str, Any]:
del query_tool_id, query_text, feedback_messages
target_hash = ""
for item in hit_briefs:
if str(item.get("type", "") or "").strip() == "relation":
target_hash = str(item.get("hash", "") or "").strip()
break
if not target_hash and hit_briefs:
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
return {
"decision": "correct",
"confidence": 0.97,
"target_hashes": [target_hash] if target_hash else [],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
}
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
await kernel.initialize()
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
raise RuntimeError("force_fallback_for_test")
monkeypatch.setattr(
kernel.episode_service.segmentation_service,
"segment",
_force_episode_fallback,
)
monkeypatch.setattr(
kernel,
"process_episode_pending_batch",
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
)
host_manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
planner_calls: list[str] = []
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), [], []
async def _fake_planner(
self,
*,
injected_user_messages: list[str] | None = None,
tool_definitions: list[dict[str, Any]] | None = None,
) -> ChatResponse:
del injected_user_messages, tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
if handled_message_ids is None:
handled_message_ids = set()
self._runtime._test_query_message_ids = handled_message_ids
if latest_message.message_id not in handled_message_ids and (
"回忆" in latest_text or "再查" in latest_text
):
handled_message_ids.add(latest_message.message_id)
tool_call = ToolCall(
call_id=f"query-{uuid.uuid4().hex}",
func_name="query_memory",
args={
"query": RELATION_QUERY,
"mode": "search",
"limit": 5,
"respect_filter": False,
},
)
return _build_chat_response("先查询长期记忆。", [tool_call])
stop_call = ToolCall(
call_id=f"stop-{uuid.uuid4().hex}",
func_name="no_reply",
args={},
)
return _build_chat_response("当前轮次结束。", [stop_call])
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_timing_gate",
_fake_timing_gate,
)
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_interruptible_planner",
_fake_planner,
)
monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False)
monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False)
session_info = {
"platform": "unit_test_chat",
"user_id": "user_feedback_flow",
"user_name": "反馈测试用户",
"group_id": "group_feedback_flow",
"group_name": "反馈纠错测试群",
}
person_id = "person_feedback_flow"
session_id = SessionUtils.calculate_session_id(
session_info["platform"],
user_id=session_info["user_id"],
group_id=session_info["group_id"],
)
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
try:
yield {
"kernel": kernel,
"session_id": session_id,
"session_info": session_info,
"person_id": person_id,
"planner_calls": planner_calls,
}
finally:
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
try:
await chat.stop()
except Exception:
pass
heartflow_manager.heartflow_chat_list.pop(key, None)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass
@pytest.mark.asyncio
async def test_feedback_correction_real_chat_flow(
chat_feedback_env,
monkeypatch: pytest.MonkeyPatch,
) -> None:
kernel = chat_feedback_env["kernel"]
session_id = chat_feedback_env["session_id"]
session_info = chat_feedback_env["session_info"]
person_id = chat_feedback_env["person_id"]
write_result = await memory_service.ingest_text(
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
source_type="chat_summary",
text="测试用户 最喜欢的颜色是 蓝色",
chat_id=session_id,
relations=[
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"confidence": 1.0,
}
],
metadata={"test_case": "feedback_correction_chat_flow"},
respect_filter=False,
)
assert write_result.success is True
pre_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert pre_search.hits
assert any("蓝色" in hit.content for hit in pre_search.hits)
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
assert "蓝色" in pre_profile_text
seed_source = f"chat_summary:{session_id}"
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
assert rebuild_result["rebuilt"] >= 1
pre_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert pre_episode.hits
assert any("蓝色" in hit.content for hit in pre_episode.hits)
await chat_bot.message_process(
_build_message_data(
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
await _wait_until(
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
description="planner 收到首条聊天消息",
)
first_query_records = await _wait_until(
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
description="首条 query_memory 工具记录生成",
)
assert first_query_records
first_task = await _wait_until(
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
description="首个反馈任务入队",
)
assert first_task["status"] == "pending"
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
assert first_hits
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
await chat_bot.message_process(
_build_message_data(
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
**session_info,
)
)
finalized_task = await _wait_until(
lambda: (
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
in {"applied", "skipped", "error"}
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="反馈任务进入终态",
)
assert finalized_task["status"] == "applied", finalized_task
assert finalized_task["decision_payload"]["decision"] == "correct"
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
corrected_hashes = list(
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
)
assert corrected_hashes
corrected_hash = str(corrected_hashes[0] or "")
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
assert bool(relation_status.get("is_inactive")) is True
action_types = _load_feedback_action_types(kernel)
assert "classification" in action_types
assert "forget_relation" in action_types
assert "ingest_correction" in action_types
assert "mark_stale_paragraph" in action_types
assert "enqueue_episode_rebuild" in action_types
assert "enqueue_profile_refresh" in action_types
original_search = memory_service.search
original_get_person_profile = memory_service.get_person_profile
corrected_search_result = memory_service_module.MemorySearchResult(
summary="测试用户最喜欢的颜色是绿色。",
hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)],
)
stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[])
corrected_profile_result = memory_service_module.PersonProfileResult(
summary="测试用户最喜欢的颜色是绿色。",
traits=["最喜欢的颜色是绿色"],
evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}],
)
async def _mock_post_correction_search(query: str, **kwargs: Any):
mode = str(kwargs.get("mode", "search") or "search")
if mode == "episode" and "蓝色" in str(query):
return stale_search_result
return corrected_search_result
async def _mock_post_correction_profile(person_id: str, **kwargs: Any):
del person_id, kwargs
return corrected_profile_result
monkeypatch.setattr(memory_service, "search", _mock_post_correction_search)
monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile)
direct_post_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert direct_post_search.hits
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
assert "绿色" in post_contents
assert "蓝色" not in post_contents
profile_refresh_request = await _wait_until(
lambda: (
kernel.metadata_store.get_person_profile_refresh_request(person_id)
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="人物画像刷新完成",
)
assert profile_refresh_request["status"] == "done"
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
assert "绿色" in post_profile_text
assert "蓝色" not in post_profile_text
async def _latest_episode_result():
result = await memory_service.search(
"绿色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
if not result.hits:
return None
contents = "\n".join(hit.content for hit in result.hits)
if "绿色" in contents and "蓝色" not in contents:
return result
return None
post_episode = await _wait_until(
_latest_episode_result,
timeout_seconds=12.0,
interval_seconds=0.2,
description="episode 重建后返回修正结果",
)
assert post_episode is not None
stale_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert not stale_episode.hits
await chat_bot.message_process(
_build_message_data(
content="再查一次,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
tool_records = await _wait_until(
lambda: (
_load_query_memory_tool_records(session_id)
if len(_load_query_memory_tool_records(session_id)) >= 2
else None
),
timeout_seconds=10.0,
interval_seconds=0.1,
description="第二次 query_memory 工具记录生成",
)
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
assert latest_hits
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
assert "绿色" in latest_contents
assert "蓝色" not in latest_contents
monkeypatch.setattr(memory_service, "search", original_search)
monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile)

View File

@@ -0,0 +1,396 @@
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
SparseBM25Config = None # type: ignore[assignment]
SparseBM25Index = None # type: ignore[assignment]
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_enqueue_feedback_task(**kwargs):
captured.update(kwargs)
return {
"id": 1,
"query_tool_id": kwargs["query_tool_id"],
"session_id": kwargs["session_id"],
"query_timestamp": kwargs["query_timestamp"],
"due_at": kwargs["due_at"],
"query_snapshot": kwargs["query_snapshot"],
}
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
feedback_correction_enabled=True,
feedback_correction_window_hours=12.0,
)
),
)
query_time = datetime(2026, 4, 9, 10, 30, 0)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-1",
session_id="session-1",
query_timestamp=query_time,
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is True
assert payload["queued"] is True
assert captured["query_tool_id"] == "tool-query-1"
assert captured["session_id"] == "session-1"
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-2",
session_id="session-1",
query_timestamp=datetime.now(),
structured_content={"hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is False
assert payload["reason"] == "feedback_correction_disabled"
@pytest.mark.asyncio
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
action_logs: list[Dict[str, Any]] = []
forgotten_hashes: list[str] = []
ingested_payloads: list[Dict[str, Any]] = []
stale_marks: list[Dict[str, Any]] = []
episode_sources: list[str] = []
profile_refresh_ids: list[str] = []
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(
get_paragraph_relations=lambda paragraph_hash: [
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
if paragraph_hash == "paragraph-1"
else [],
get_relation_status_batch=lambda hashes: {
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
for hash_value in hashes
},
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
"relation-1": ["paragraph-1"]
}
if "relation-1" in hashes
else {},
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
if paragraph_hash == "paragraph-1"
else None,
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
)
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
forgotten_hashes.extend([str(item) for item in hashes]),
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
)[1]
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
async def _fake_ingest_feedback_relations(**kwargs):
ingested_payloads.append(kwargs)
return {"success": True, "stored_ids": ["relation-2"]}
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
payload = await kernel._apply_feedback_decision(
task_id=1,
query_tool_id="tool-query-1",
session_id="session-1",
decision={
"decision": "correct",
"confidence": 0.97,
"target_hashes": ["paragraph-1"],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
},
hit_map={
"paragraph-1": {
"hash": "paragraph-1",
"type": "paragraph",
"content": "测试用户 最喜欢的颜色是 蓝色",
"linked_relation_hashes": ["relation-1"],
}
},
)
assert payload["applied"] is True
assert payload["relation_hashes"] == ["relation-1"]
assert forgotten_hashes == ["relation-1"]
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
assert payload["profile_refresh_person_ids"] == ["person-1"]
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
assert {item["action_type"] for item in action_logs} == {
"forget_relation",
"ingest_correction",
"mark_stale_paragraph",
"enqueue_episode_rebuild",
"enqueue_profile_refresh",
}
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
kernel.metadata_store = SimpleNamespace(
get_relation_status_batch=lambda hashes: {
"r-active": {"is_inactive": False},
"r-inactive": {"is_inactive": True},
"r-para-inactive": {"is_inactive": True},
"r-stale-active": {"is_inactive": False},
"r-stale-inactive": {"is_inactive": True},
},
get_paragraph_relations=lambda paragraph_hash: (
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
),
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
"p-stale": [{"relation_hash": "r-stale-inactive"}],
"p-restored": [{"relation_hash": "r-stale-active"}],
},
)
hits = [
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
]
filtered = kernel._filter_active_relation_hits(hits)
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
def test_sparse_relation_search_requests_active_only() -> None:
captured: Dict[str, Any] = {}
class FakeMetadataStore:
def fts_search_relations_bm25(self, **kwargs):
captured.update(kwargs)
return []
index = SparseBM25Index(
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
config=SparseBM25Config(enabled=True, lazy_load=False),
)
index._loaded = True
index._conn = object() # type: ignore[assignment]
result = index.search_relations("测试纠错", k=5)
assert result == []
assert captured["include_inactive"] is False
@pytest.mark.asyncio
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
action_logs: list[Dict[str, Any]] = []
queued_sources: list[str] = []
queued_profiles: list[str] = []
deleted_marks: list[tuple[str, str]] = []
deleted_paragraphs: list[str] = []
relation_statuses: Dict[str, Dict[str, Any]] = {
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
}
current_task: Dict[str, Any] = {
"id": 1,
"query_tool_id": "tool-query-rollback",
"session_id": "session-1",
"status": "applied",
"rollback_status": "none",
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
"decision_payload": {"decision": "correct", "confidence": 0.97},
"rollback_plan": {
"forgotten_relations": [
{
"hash": "rel-old",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"before_status": {
"is_inactive": False,
"weight": 0.8,
"is_pinned": False,
"protected_until": 0.0,
"last_reinforced": None,
"inactive_since": None,
},
}
],
"corrected_write": {
"paragraph_hashes": ["paragraph-new"],
"corrected_relations": [
{
"hash": "rel-new",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"existed_before": False,
"before_status": {},
}
],
},
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
"episode_sources": ["chat_summary:session-1"],
"profile_person_ids": ["person-1"],
},
}
class _Conn:
def cursor(self):
return self
def execute(self, *_args, **_kwargs):
return self
def commit(self):
return None
metadata_store = SimpleNamespace(
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
{
"rollback_status": kwargs["rollback_status"],
"rollback_result": kwargs.get("rollback_result") or {},
"rollback_error": kwargs.get("rollback_error", ""),
}
)
or current_task,
get_relation_status_batch=lambda hashes: {
hash_value: dict(relation_statuses[hash_value])
for hash_value in hashes
if hash_value in relation_statuses
},
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
{hash_value: dict(snapshot)}
)
or dict(snapshot),
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
get_connection=lambda: _Conn(),
delete_external_memory_refs_by_paragraphs=lambda hashes: [
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
for hash_value in hashes
],
update_relations_protection=lambda hashes, **kwargs: None,
mark_relations_inactive=lambda hashes, inactive_since=None: [
relation_statuses.__setitem__(
hash_value,
{
**relation_statuses.get(hash_value, {}),
"is_inactive": True,
"inactive_since": inactive_since,
},
)
for hash_value in hashes
],
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = metadata_store
async def _noop_initialize() -> None:
return None
kernel.initialize = _noop_initialize # type: ignore[method-assign]
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
kernel._persist = lambda: None # type: ignore[method-assign]
payload = await kernel._rollback_feedback_task(
task_id=1,
requested_by="pytest",
reason="manual rollback",
)
assert payload["success"] is True
assert relation_statuses["rel-old"]["is_inactive"] is False
assert relation_statuses["rel-new"]["is_inactive"] is True
assert deleted_paragraphs == ["paragraph-new"]
assert deleted_marks == [("paragraph-old", "rel-old")]
assert queued_sources == ["chat_summary:session-1"]
assert queued_profiles == ["person-1"]
assert current_task["rollback_status"] == "rolled_back"
assert {item["action_type"] for item in action_logs} >= {
"rollback_restore_relation",
"rollback_revert_corrected_relation",
"rollback_delete_correction_paragraph",
"rollback_clear_stale_mark",
"rollback_enqueue_episode_rebuild",
"rollback_enqueue_profile_refresh",
}

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import pickle
from pathlib import Path
import pytest
try:
from src.A_memorix.core.storage.graph_store import GraphStore
except SystemExit as exc:
GraphStore = None # type: ignore[assignment]
IMPORT_ERROR = f"config initialization exited during import: {exc}"
else:
IMPORT_ERROR = None
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
def _build_empty_graph_metadata() -> dict:
return {
"nodes": [],
"node_to_idx": {},
"node_attrs": {},
"matrix_format": "csr",
"total_nodes_added": 0,
"total_edges_added": 0,
"total_nodes_deleted": 0,
"total_edges_deleted": 0,
"edge_hash_map": {},
}
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
matrix_path = data_dir / "graph_adjacency.npz"
assert matrix_path.exists()
store.clear()
store.save()
assert not matrix_path.exists()
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
with metadata_path.open("wb") as handle:
pickle.dump(_build_empty_graph_metadata(), handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.num_nodes == 0
assert reloaded.num_edges == 0
assert reloaded.get_nodes() == []
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
empty_metadata = _build_empty_graph_metadata()
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
with metadata_path.open("wb") as handle:
pickle.dump(empty_metadata, handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.has_edge_hash_map() is False

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import json
from pathlib import Path
DATA_DIR = Path(__file__).parent / "data" / "benchmarks"
def _fixture_files() -> list[Path]:
return sorted(DATA_DIR.glob("group_chat_stream_memory_benchmark*.json"))
def _load_fixture(path: Path) -> dict:
return json.loads(path.read_text(encoding="utf-8"))
def _assert_fixture_matches_current_design_constraints(dataset: dict) -> None:
assert dataset["meta"]["scenario_id"]
assert dataset["session"]["group_id"]
assert dataset["session"]["platform"] == "qq"
simulated_batches = dataset["simulated_stream_batches"]
assert len(simulated_batches) >= 5
positive_batches = [item for item in simulated_batches if item["bot_participated"]]
negative_batches = [item for item in simulated_batches if not item["bot_participated"]]
assert len(positive_batches) >= 4
assert len(negative_batches) >= 1
assert any(item["expected_behavior"] == "ignored_by_summarizer_without_bot_message" for item in negative_batches)
for batch in positive_batches:
assert "Mai" in batch["participants"]
assert batch["message_count"] >= 10
assert len(batch["combined_text"]) >= 300
assert batch["start_time"] < batch["end_time"]
assert len(batch["expected_memory_targets"]) >= 4
runtime_streams = dataset["runtime_trigger_streams"]
assert len(runtime_streams) >= 2
runtime_positive = [item for item in runtime_streams if item["bot_participated"]]
runtime_negative = [item for item in runtime_streams if not item["bot_participated"]]
assert len(runtime_positive) >= 1
assert len(runtime_negative) >= 1
for stream in runtime_streams:
stream_text = "\n".join(stream["messages"])
assert stream["trigger_mode"] == "time_threshold"
assert stream["elapsed_since_last_check_hours"] >= 8.0
assert stream["message_count"] >= 20
assert len(stream["messages"]) == stream["message_count"]
assert len(stream_text) >= 1000
assert stream["start_time"] < stream["end_time"]
assert any(item["expected_check_outcome"] == "should_trigger_topic_check_and_pass_bot_gate" for item in runtime_positive)
assert any(
item["expected_check_outcome"] == "should_trigger_topic_check_but_be_discarded_without_bot_message"
for item in runtime_negative
)
records = dataset["chat_history_records"]
assert len(records) >= 4
for record in records:
assert "Mai" in record["participants"]
assert len(record["summary"]) >= 40
assert len(record["original_text"]) >= 200
assert record["start_time"] < record["end_time"]
assert len(dataset["person_writebacks"]) >= 3
assert len(dataset["search_cases"]) >= 4
assert len(dataset["time_cases"]) >= 3
assert len(dataset["episode_cases"]) >= 4
assert len(dataset["knowledge_fetcher_cases"]) >= 3
assert len(dataset["profile_cases"]) >= 3
assert len(dataset["negative_control_cases"]) >= 1
def test_group_chat_stream_fixture_matches_current_design_constraints():
files = _fixture_files()
assert files, "未找到 group_chat_stream_memory_benchmark*.json fixture"
for path in files:
_assert_fixture_matches_current_design_constraints(_load_fixture(path))

View File

@@ -0,0 +1,124 @@
from types import SimpleNamespace
import pytest
from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module
from src.services.memory_service import MemoryHit, MemorySearchResult
def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch):
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
assert fetcher._resolve_private_memory_context() == {
"chat_id": "stream-1",
"person_id": "Alice:qq:42",
"user_id": "42",
"group_id": "",
}
@pytest.mark.asyncio
async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch):
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}",
)
calls = []
async def fake_search(query: str, **kwargs):
calls.append((query, kwargs))
return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False)
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
result = await fetcher._memory_get_knowledge("她喜欢什么")
assert "1. 她喜欢猫" in result
assert calls == [
(
"她喜欢什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "Alice:qq:42",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
)
]
@pytest.mark.asyncio
async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch):
monkeypatch.setattr(
knowledge_module,
"_chat_manager",
SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")),
)
monkeypatch.setattr(
knowledge_module,
"resolve_person_id_for_memory",
lambda *, person_name, platform, user_id: "person-1",
)
calls = []
async def fake_search(query: str, **kwargs):
calls.append((query, kwargs))
if kwargs.get("person_id"):
return MemorySearchResult(summary="", hits=[], filtered=False)
return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False)
monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search)
fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1")
result = await fetcher._memory_get_knowledge("Alice 最近在忙什么")
assert "杭州音乐节" in result
assert calls == [
(
"Alice 最近在忙什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
),
(
"Alice 最近在忙什么",
{
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "",
"user_id": "42",
"group_id": "",
"respect_filter": True,
},
),
]

View File

@@ -0,0 +1,355 @@
from types import SimpleNamespace
import pytest
from src.services import memory_flow_service as memory_flow_module
def _fake_global_config(**integration_values):
return SimpleNamespace(
a_memorix=SimpleNamespace(
integration=SimpleNamespace(**integration_values),
)
)
def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items():
raw = '["他喜欢猫", "他喜欢猫", "", "", "他会弹吉他"]'
result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw)
assert result == ["他喜欢猫", "他会弹吉他"]
def test_person_fact_looks_ephemeral_detects_short_chitchat():
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈")
assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?")
assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴")
def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.is_known = True
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False)
monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}")
monkeypatch.setattr(memory_flow_module, "Person", FakePerson)
message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id=""))
person = service._resolve_target_person(message)
assert person is not None
assert person.person_id == "qq:123"
@pytest.mark.asyncio
async def test_person_fact_writeback_skips_bot_only_fact_without_user_evidence(monkeypatch):
stored_facts: list[tuple[str, str, str]] = []
class FakePerson:
person_id = "person-1"
person_name = "测试用户"
nickname = "测试用户"
is_known = True
service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService)
service._resolve_target_person = lambda message: FakePerson()
async def fake_extract_facts(person, reply_text, user_evidence_text):
del person, reply_text, user_evidence_text
return ["测试用户喜欢辣椒"]
async def fake_store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str, **kwargs):
del kwargs
stored_facts.append((person_name, memory_content, chat_id))
service._extract_facts = fake_extract_facts
monkeypatch.setattr(memory_flow_module, "store_person_memory_from_answer", fake_store_person_memory_from_answer)
monkeypatch.setattr(memory_flow_module, "find_messages", lambda **kwargs: [])
message = SimpleNamespace(
processed_plain_text="我记得你喜欢辣椒。",
session_id="session-1",
reply_to="",
session=SimpleNamespace(platform="qq", user_id="bot-1", group_id=""),
)
await service._handle_message(message)
assert stored_facts == []
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:5"
assert payload["chat_id"] == "session-1"
assert payload["text"] == ""
assert payload["metadata"]["generate_from_chat"] is True
assert payload["metadata"]["context_length"] == 7
assert payload["metadata"]["trigger"] == "message_threshold"
assert payload["user_id"] == "user-1"
assert payload["group_id"] == "group-1"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=6,
chat_summary_writeback_context_length=9,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:8"
assert service._states["session-1"].last_trigger_message_count == 8
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
_fake_global_config(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
assert service._states["session-1"].last_trigger_message_count == 5
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
class FakeMetadataStore:
@staticmethod
def get_paragraphs_by_source(source: str):
assert source == "chat_summary:session-1"
return [
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
]
class FakeRuntimeManager:
@staticmethod
async def _ensure_kernel():
return SimpleNamespace(metadata_store=FakeMetadataStore())
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
service = memory_flow_module.ChatSummaryWritebackService()
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
assert restored == 6
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("sent", "session-1"),
("summary", "session-1"),
("shutdown", "summary"),
("shutdown", "fact"),
]
@pytest.mark.asyncio
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("shutdown", "summary"),
("shutdown", "fact"),
]

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
class _DummyMetadataStore:
def __init__(self, *, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> None:
self._entities = entities
self._relations = relations
def query(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
sql_token = " ".join(str(sql or "").lower().split())
keyword = str(params[0] or "").strip("%").lower() if params else ""
if "from entities" in sql_token:
rows = [dict(item) for item in self._entities if not bool(item.get("is_deleted", 0))]
if not keyword:
return rows
return [
row
for row in rows
if keyword in str(row.get("name", "") or "").lower()
or keyword in str(row.get("hash", "") or "").lower()
]
if "from relations" in sql_token:
rows = [dict(item) for item in self._relations if not bool(item.get("is_inactive", 0))]
if not keyword:
return rows
return [
row
for row in rows
if keyword in str(row.get("subject", "") or "").lower()
or keyword in str(row.get("object", "") or "").lower()
or keyword in str(row.get("predicate", "") or "").lower()
or keyword in str(row.get("hash", "") or "").lower()
]
raise AssertionError(f"unexpected query: {sql_token}")
def _build_kernel(*, entities: list[dict[str, Any]], relations: list[dict[str, Any]]) -> SDKMemoryKernel:
kernel = SDKMemoryKernel(plugin_root=Path.cwd(), config={})
async def _fake_initialize() -> None:
return None
kernel.initialize = _fake_initialize # type: ignore[method-assign]
kernel.metadata_store = _DummyMetadataStore(entities=entities, relations=relations)
kernel.graph_store = object() # type: ignore[assignment]
return kernel
@pytest.mark.asyncio
async def test_memory_graph_admin_search_orders_and_dedupes_results() -> None:
kernel = _build_kernel(
entities=[
{"hash": "e1", "name": "Alice", "appearance_count": 5, "is_deleted": 0},
{"hash": "e1", "name": "Alice Duplicate", "appearance_count": 99, "is_deleted": 0},
{"hash": "e2", "name": "Alice Cooper", "appearance_count": 7, "is_deleted": 0},
{"hash": "e3", "name": "my alice note", "appearance_count": 11, "is_deleted": 0},
{"hash": "e4", "name": "alice deleted", "appearance_count": 100, "is_deleted": 1},
],
relations=[
{"hash": "r1", "subject": "Alice", "predicate": "knows", "object": "Bob", "confidence": 0.6, "created_at": 100, "is_inactive": 0},
{"hash": "r3", "subject": "Alice", "predicate": "supports", "object": "Carol", "confidence": 0.9, "created_at": 90, "is_inactive": 0},
{"hash": "r1", "subject": "Alice", "predicate": "knows duplicate", "object": "Bob", "confidence": 0.99, "created_at": 200, "is_inactive": 0},
{"hash": "r2", "subject": "Alice Cooper", "predicate": "likes", "object": "Tea", "confidence": 0.2, "created_at": 50, "is_inactive": 0},
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.8, "created_at": 70, "is_inactive": 0},
{"hash": "", "subject": "Carol", "predicate": "mentions alice", "object": "Topic", "confidence": 0.3, "created_at": 10, "is_inactive": 0},
{"hash": "r4", "subject": "alice inactive", "predicate": "old", "object": "Data", "confidence": 1.0, "created_at": 300, "is_inactive": 1},
],
)
payload = await kernel.memory_graph_admin(action="search", query="alice", limit=20)
assert payload["success"] is True
assert payload["count"] == len(payload["items"])
entity_items = [item for item in payload["items"] if item["type"] == "entity"]
relation_items = [item for item in payload["items"] if item["type"] == "relation"]
assert [item["entity_hash"] for item in entity_items] == ["e1", "e2", "e3"]
assert [item["relation_hash"] for item in relation_items] == ["r3", "r1", "r2", ""]
assert relation_items[0]["confidence"] == pytest.approx(0.9)
assert relation_items[1]["confidence"] == pytest.approx(0.6)
@pytest.mark.asyncio
async def test_memory_graph_admin_search_filters_deleted_and_inactive_records() -> None:
kernel = _build_kernel(
entities=[
{"hash": "e-deleted", "name": "Ghost Alice", "appearance_count": 10, "is_deleted": 1},
],
relations=[
{
"hash": "r-inactive",
"subject": "Ghost Alice",
"predicate": "linked",
"object": "Ghost Bob",
"confidence": 0.9,
"created_at": 10,
"is_inactive": 1,
},
],
)
payload = await kernel.memory_graph_admin(action="search", query="ghost", limit=50)
assert payload["success"] is True
assert payload["items"] == []
assert payload["count"] == 0

View File

@@ -0,0 +1,281 @@
import pytest
from src.services.memory_service import MemorySearchResult, MemoryService
def test_coerce_write_result_treats_skipped_payload_as_success():
result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"})
assert result.success is True
assert result.stored_ids == []
assert result.skipped_ids == ["p1"]
assert result.detail == "chat_filtered"
@pytest.mark.asyncio
async def test_graph_admin_invokes_plugin(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "nodes": [], "edges": []}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.graph_admin(action="get_graph", limit=12)
assert result["success"] is True
assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})]
@pytest.mark.asyncio
async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"success": True, "items": [{"hash": "abc"}], "count": 1}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.get_recycle_bin(limit=5)
assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1}
assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})]
@pytest.mark.asyncio
async def test_search_respects_filter_by_default(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"summary": "ok", "hits": [], "filtered": True}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.search(
"mai",
chat_id="stream-1",
person_id="person-1",
user_id="user-1",
group_id="",
)
assert isinstance(result, MemorySearchResult)
assert result.filtered is True
assert calls == [
(
"search_memory",
{
"query": "mai",
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"time_start": None,
"time_end": None,
"respect_filter": True,
"user_id": "user-1",
"group_id": "",
},
)
]
@pytest.mark.asyncio
async def test_ingest_summary_can_bypass_filter(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"success": True, "stored_ids": ["p1"], "detail": ""}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.ingest_summary(
external_id="chat_history:1",
chat_id="stream-1",
text="summary",
respect_filter=False,
user_id="user-1",
)
assert result.success is True
assert calls == [
(
"ingest_summary",
{
"external_id": "chat_history:1",
"chat_id": "stream-1",
"text": "summary",
"participants": [],
"time_start": None,
"time_end": None,
"tags": [],
"metadata": {},
"respect_filter": False,
"user_id": "user-1",
"group_id": "",
},
)
]
@pytest.mark.asyncio
async def test_v5_admin_invokes_plugin(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "count": 1}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.v5_admin(action="status", target="mai", limit=5)
assert result["success"] is True
assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})]
@pytest.mark.asyncio
async def test_delete_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "operation_id": "del-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"})
assert result["success"] is True
assert calls == [
(
"memory_delete_admin",
{"action": "execute", "mode": "relation", "selector": {"query": "mai"}},
{"timeout_ms": 120000},
)
]
@pytest.mark.asyncio
async def test_search_returns_empty_when_query_and_time_missing_async():
service = MemoryService()
result = await service.search("", time_start=None, time_end=None)
assert isinstance(result, MemorySearchResult)
assert result.summary == ""
assert result.hits == []
assert result.filtered is False
@pytest.mark.asyncio
async def test_search_accepts_string_time_bounds(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args))
return {"summary": "ok", "hits": [], "filtered": False}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.search(
"广播站",
mode="time",
time_start="2026/03/18",
time_end="2026/03/18 09:30",
)
assert isinstance(result, MemorySearchResult)
assert calls == [
(
"search_memory",
{
"query": "广播站",
"limit": 5,
"mode": "time",
"chat_id": "",
"person_id": "",
"time_start": "2026/03/18",
"time_end": "2026/03/18 09:30",
"respect_filter": True,
"user_id": "",
"group_id": "",
},
)
]
def test_coerce_search_result_preserves_aggregate_source_branches():
result = MemoryService._coerce_search_result(
{
"hits": [
{
"content": "广播站值夜班",
"type": "paragraph",
"metadata": {"event_time_start": 1.0},
"source_branches": ["search", "time"],
"rank": 1,
}
]
}
)
assert result.hits[0].metadata["source_branches"] == ["search", "time"]
assert result.hits[0].metadata["rank"] == 1
@pytest.mark.asyncio
async def test_import_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "task_id": "import-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.import_admin(action="create_lpmm_openie", alias="lpmm")
assert result["success"] is True
assert calls == [
(
"memory_import_admin",
{"action": "create_lpmm_openie", "alias": "lpmm"},
{"timeout_ms": 120000},
)
]
@pytest.mark.asyncio
async def test_tuning_admin_uses_long_timeout(monkeypatch):
service = MemoryService()
calls = []
async def fake_invoke(component_name, args=None, **kwargs):
calls.append((component_name, args, kwargs))
return {"success": True, "task_id": "tuning-1"}
monkeypatch.setattr(service, "_invoke", fake_invoke)
result = await service.tuning_admin(action="create_task", payload={"query": "mai"})
assert result["success"] is True
assert calls == [
(
"memory_tuning_admin",
{"action": "create_task", "payload": {"query": "mai"}},
{"timeout_ms": 120000},
)
]

View File

@@ -0,0 +1,21 @@
from pathlib import Path
from src.A_memorix.core.storage.metadata_store import MetadataStore
def test_get_all_sources_ignores_soft_deleted_paragraphs(tmp_path: Path) -> None:
store = MetadataStore(data_dir=tmp_path)
store.connect()
try:
live_hash = store.add_paragraph("Alice 喜欢地图", source="live-source")
deleted_hash = store.add_paragraph("Bob 喜欢咖啡", source="deleted-source")
assert live_hash
store.mark_as_deleted([deleted_hash], "paragraph")
sources = store.get_all_sources()
finally:
store.close()
assert [item["source"] for item in sources] == ["live-source"]
assert sources[0]["count"] == 1

View File

@@ -0,0 +1,81 @@
from types import SimpleNamespace
import pytest
from src.person_info import person_info as person_info_module
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch):
calls = []
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.person_name = "Alice"
self.is_known = True
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
monkeypatch.setattr(person_info_module, "Person", FakePerson)
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
assert len(calls) == 1
payload = calls[0]
assert payload["external_id"].startswith("person_fact:person-1:")
assert payload["source_type"] == "person_fact"
assert payload["chat_id"] == "session-1"
assert payload["person_ids"] == ["person-1"]
assert payload["participants"] == ["Alice"]
assert payload["respect_filter"] is True
assert payload["user_id"] == "10001"
assert payload["group_id"] == ""
assert payload["metadata"]["person_id"] == "person-1"
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch):
calls = []
class FakePerson:
def __init__(self, person_id: str):
self.person_id = person_id
self.person_name = "Unknown"
self.is_known = False
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1")
monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session))
monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1")
monkeypatch.setattr(person_info_module, "Person", FakePerson)
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1")
assert calls == []
@pytest.mark.asyncio
async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch):
calls = []
async def fake_ingest_text(**kwargs):
calls.append(kwargs)
return SimpleNamespace(success=True, detail="", stored_ids=["p1"])
monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text)
await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1")
assert calls == []

View File

@@ -0,0 +1,115 @@
from types import SimpleNamespace
import pytest
from src.A_memorix.core.utils.person_profile_service import PersonProfileService
class FakeMetadataStore:
def __init__(self) -> None:
self.snapshots: list[dict] = []
@staticmethod
def get_latest_person_profile_snapshot(person_id: str):
del person_id
return None
@staticmethod
def get_relations(**kwargs):
del kwargs
return []
@staticmethod
def get_paragraphs_by_source(source: str):
if source == "person_fact:person-1":
return [
{
"hash": "person-fact-1",
"content": "测试用户喜欢猫。",
"source": source,
"metadata": {"source_type": "person_fact"},
"created_at": 2.0,
"updated_at": 2.0,
}
]
return []
@staticmethod
def get_paragraph(hash_value: str):
if hash_value == "chat-summary-1":
return {
"hash": hash_value,
"content": "机器人建议测试用户以后叫星灯。",
"source": "chat_summary:session-1",
"metadata": {"source_type": "chat_summary"},
"word_count": 1,
}
if hash_value == "person-fact-1":
return {
"hash": hash_value,
"content": "测试用户喜欢猫。",
"source": "person_fact:person-1",
"metadata": {"source_type": "person_fact"},
"word_count": 1,
}
return None
@staticmethod
def get_paragraph_stale_relation_marks_batch(paragraph_hashes):
del paragraph_hashes
return {}
@staticmethod
def get_relation_status_batch(relation_hashes):
del relation_hashes
return {}
@staticmethod
def get_person_profile_override(person_id: str):
del person_id
return None
def upsert_person_profile_snapshot(self, **kwargs):
self.snapshots.append(kwargs)
return {
"person_id": kwargs["person_id"],
"profile_text": kwargs["profile_text"],
"aliases": kwargs["aliases"],
"relation_edges": kwargs["relation_edges"],
"vector_evidence": kwargs["vector_evidence"],
"evidence_ids": kwargs["evidence_ids"],
"updated_at": 1.0,
"expires_at": kwargs["expires_at"],
"source_note": kwargs["source_note"],
}
class FakeRetriever:
async def retrieve(self, query: str, top_k: int):
del query, top_k
return [
SimpleNamespace(
hash_value="chat-summary-1",
result_type="paragraph",
score=0.95,
content="机器人建议测试用户以后叫星灯。",
metadata={"source_type": "chat_summary"},
)
]
@pytest.mark.asyncio
async def test_person_profile_keeps_chat_summary_as_recent_interaction_not_stable_profile():
metadata_store = FakeMetadataStore()
service = PersonProfileService(metadata_store=metadata_store, retriever=FakeRetriever())
service.get_person_aliases = lambda person_id: (["测试用户"], "测试用户", [])
payload = await service.query_person_profile(person_id="person-1", top_k=6, force_refresh=True)
assert payload["success"] is True
profile_text = payload["profile_text"]
stable_section = profile_text.split("近期相关互动:", 1)[0]
assert "测试用户喜欢猫" in stable_section
assert "星灯" not in stable_section
assert "近期相关互动:" in profile_text
assert "星灯" in profile_text

View File

@@ -0,0 +1,184 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
import pytest
from src.memory_system.retrieval_tools import query_long_term_memory as tool_module
from src.memory_system.retrieval_tools import init_all_tools
from src.memory_system.retrieval_tools.query_long_term_memory import (
_resolve_time_expression,
query_long_term_memory,
register_tool,
)
from src.memory_system.retrieval_tools.tool_registry import get_tool_registry
from src.services.memory_service import MemoryHit, MemorySearchResult
def test_resolve_time_expression_supports_relative_and_absolute_inputs():
now = datetime(2026, 3, 18, 15, 30)
start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/18 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/12 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59)
assert start_text == "2026/03/18 00:00"
assert end_text == "2026/03/18 23:59"
start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now)
assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30)
assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30)
assert start_text == "2026/03/18 09:30"
assert end_text == "2026/03/18 09:30"
def test_register_tool_exposes_mode_and_time_expression():
register_tool()
tool = get_tool_registry().get_tool("search_long_term_memory")
assert tool is not None
params = {item["name"]: item for item in tool.parameters}
assert "mode" in params
assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"]
assert "time_expression" in params
assert params["query"]["required"] is False
def test_init_all_tools_registers_long_term_memory_tool():
init_all_tools()
tool = get_tool_registry().get_tool("search_long_term_memory")
assert tool is not None
@pytest.mark.asyncio
async def test_query_long_term_memory_search_mode_keeps_search(monkeypatch):
captured = {}
async def fake_search(query, **kwargs):
captured["query"] = query
captured["kwargs"] = kwargs
return MemorySearchResult(
hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")],
)
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1")
assert "Alice 喜欢猫" in text
assert captured == {
"query": "Alice 喜欢什么",
"kwargs": {
"limit": 5,
"mode": "search",
"chat_id": "stream-1",
"person_id": "person-1",
"time_start": None,
"time_end": None,
},
}
@pytest.mark.asyncio
async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch):
captured = {}
async def fake_search(query, **kwargs):
captured["query"] = query
captured["kwargs"] = kwargs
return MemorySearchResult(
hits=[
MemoryHit(
content="昨天晚上广播站停播了十分钟。",
score=0.8,
hit_type="paragraph",
metadata={"event_time_start": 1773797400.0},
)
]
)
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
monkeypatch.setattr(
tool_module,
"_resolve_time_expression",
lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"),
)
text = await query_long_term_memory(
query="广播站",
mode="time",
time_expression="昨天",
chat_id="stream-1",
)
assert "指定时间范围" in text
assert "广播站停播" in text
assert captured == {
"query": "广播站",
"kwargs": {
"limit": 5,
"mode": "time",
"chat_id": "stream-1",
"person_id": "",
"time_start": 1773795600.0,
"time_end": 1773881940.0,
},
}
@pytest.mark.asyncio
async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch):
responses = {
"episode": MemorySearchResult(
hits=[
MemoryHit(
content="苏弦在灯塔拆开了那封冬信。",
title="冬信重见天日",
hit_type="episode",
metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]},
)
]
),
"aggregate": MemorySearchResult(
hits=[
MemoryHit(
content="唐未在广播站值夜班时带着黑狗墨点。",
hit_type="paragraph",
metadata={"source_branches": ["search", "time"]},
)
]
),
}
async def fake_search(query, **kwargs):
return responses[kwargs["mode"]]
monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search))
episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode")
aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate")
assert "事件《冬信重见天日》" in episode_text
assert "参与者:苏弦" in episode_text
assert "[search,time][paragraph]" in aggregate_text
@pytest.mark.asyncio
async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message():
text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周")
assert "无法解析" in text
assert "最近7天" in text

View File

@@ -0,0 +1,140 @@
import pytest
from src.A_memorix.core.utils.summary_importer import (
SummaryImporter,
_message_timestamp,
_normalize_entity_items,
_normalize_relation_items,
)
from src.config.model_configs import TaskConfig
from src.services import llm_service as llm_api
def _fake_available_models() -> dict[str, TaskConfig]:
return {
"memory": TaskConfig(
model_list=["memory-model"],
max_tokens=512,
temperature=0.4,
selection_strategy="random",
),
"utils": TaskConfig(
model_list=["utils-model"],
max_tokens=256,
temperature=0.5,
selection_strategy="random",
),
"replyer": TaskConfig(
model_list=["replyer-model"],
max_tokens=128,
temperature=0.7,
selection_strategy="random",
),
}
def test_resolve_summary_model_config_uses_auto_list_when_summarization_missing(monkeypatch):
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["memory-model"]
def test_resolve_summary_model_config_auto_falls_back_to_utils_then_planner(monkeypatch):
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"utils": TaskConfig(model_list=["utils-model"]),
"planner": TaskConfig(model_list=["planner-model"]),
"replyer": TaskConfig(model_list=["replyer-model"]),
},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["utils-model"]
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"planner": TaskConfig(model_list=["planner-model"]),
"replyer": TaskConfig(model_list=["replyer-model"]),
},
)
resolved = importer._resolve_summary_model_config()
assert resolved is not None
assert resolved.model_list == ["planner-model"]
def test_resolve_summary_model_config_auto_does_not_fallback_to_replyer(monkeypatch):
monkeypatch.setattr(
llm_api,
"get_available_models",
lambda: {
"replyer": TaskConfig(model_list=["replyer-model"]),
"embedding": TaskConfig(model_list=["embedding-model"]),
},
)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={},
)
assert importer._resolve_summary_model_config() is None
def test_resolve_summary_model_config_rejects_legacy_string_selector(monkeypatch):
monkeypatch.setattr(llm_api, "get_available_models", _fake_available_models)
importer = SummaryImporter(
vector_store=None,
graph_store=None,
metadata_store=None,
embedding_manager=None,
plugin_config={"summarization": {"model_name": "auto"}},
)
with pytest.raises(ValueError, match="List\\[str\\]"):
importer._resolve_summary_model_config()
def test_summary_importer_normalizes_llm_entities_and_relations():
assert _normalize_entity_items(["Alice", {"name": "地图"}, ["bad"], "Alice"]) == ["Alice", "地图"]
assert _normalize_entity_items("Alice") == []
assert _normalize_relation_items(
[
{"subject": "Alice", "predicate": "持有", "object": "地图"},
{"subject": "Alice", "predicate": "", "object": "地图"},
["bad"],
]
) == [{"subject": "Alice", "predicate": "持有", "object": "地图"}]
def test_summary_importer_message_timestamp_accepts_time_fallback():
class Message:
time = 123.5
assert _message_timestamp(Message()) == 123.5

View File

@@ -0,0 +1,182 @@
from types import SimpleNamespace
import numpy as np
import pytest
from src.A_memorix.core.strategies.base import ChunkContext, KnowledgeType, ProcessedChunk, SourceInfo
from src.A_memorix.core.utils.web_import_manager import (
ImportChunkRecord,
ImportFileRecord,
ImportTaskManager,
ImportTaskRecord,
)
class _DummyMetadataStore:
def __init__(self) -> None:
self.paragraphs: list[dict[str, object]] = []
self.entities: list[str] = []
self.relations: list[tuple[str, str, str]] = []
def add_paragraph(self, **kwargs):
self.paragraphs.append(dict(kwargs))
return f"paragraph-{len(self.paragraphs)}"
def add_entity(self, *, name: str, source_paragraph: str = "") -> str:
del source_paragraph
self.entities.append(name)
return f"entity-{name}"
def add_relation(self, *, subject: str, predicate: str, obj: str, **kwargs) -> str:
del kwargs
self.relations.append((subject, predicate, obj))
return f"relation-{len(self.relations)}"
def set_relation_vector_state(self, rel_hash: str, state: str) -> None:
del rel_hash, state
class _DummyGraphStore:
def __init__(self) -> None:
self.nodes: list[list[str]] = []
self.edges: list[list[tuple[str, str]]] = []
def add_nodes(self, nodes):
self.nodes.append(list(nodes))
def add_edges(self, edges, relation_hashes=None):
del relation_hashes
self.edges.append(list(edges))
class _DummyVectorStore:
def __contains__(self, item: str) -> bool:
del item
return False
def add(self, vectors, ids):
del vectors, ids
class _DummyEmbeddingManager:
async def encode(self, text: str) -> np.ndarray:
del text
return np.ones(4, dtype=np.float32)
def _build_manager() -> tuple[ImportTaskManager, _DummyMetadataStore]:
metadata_store = _DummyMetadataStore()
plugin = SimpleNamespace(
metadata_store=metadata_store,
graph_store=_DummyGraphStore(),
vector_store=_DummyVectorStore(),
embedding_manager=_DummyEmbeddingManager(),
relation_write_service=None,
get_config=lambda key, default=None: default,
_is_embedding_degraded=lambda: False,
_allow_metadata_only_write=lambda: True,
write_paragraph_vector_or_enqueue=None,
)
manager = ImportTaskManager(plugin)
return manager, metadata_store
def _build_progress_task(task_id: str, total_chunks: int = 2) -> ImportTaskRecord:
file_record = ImportFileRecord(
file_id="file-1",
name="demo.txt",
source_kind="paste",
input_mode="text",
total_chunks=total_chunks,
chunks=[
ImportChunkRecord(chunk_id=f"chunk-{index}", index=index, chunk_type="text")
for index in range(total_chunks)
],
)
return ImportTaskRecord(task_id=task_id, source="paste", params={}, files=[file_record])
def _build_chunk(data) -> ProcessedChunk:
return ProcessedChunk(
type=KnowledgeType.FACTUAL,
source=SourceInfo(file="demo.txt", offset_start=0, offset_end=4),
chunk=ChunkContext(chunk_id="chunk-1", index=0, text="Alice 持有地图"),
data=data,
)
@pytest.mark.asyncio
async def test_persist_processed_chunk_rejects_non_object_before_paragraph_write() -> None:
manager, metadata_store = _build_manager()
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
with pytest.raises(ValueError, match="分块抽取结果 必须返回 JSON 对象"):
await manager._persist_processed_chunk(file_record, _build_chunk(["bad"]))
assert metadata_store.paragraphs == []
@pytest.mark.asyncio
async def test_chunk_terminal_progress_uses_successful_chunks_only() -> None:
manager, _ = _build_manager()
task = _build_progress_task("task-fail-then-complete")
manager._tasks[task.task_id] = task
await manager._set_chunk_failed(task.task_id, "file-1", "chunk-0", "boom")
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-1")
file_record = task.files[0]
assert file_record.done_chunks == 1
assert file_record.failed_chunks == 1
assert file_record.progress == pytest.approx(0.5)
assert task.progress == pytest.approx(0.5)
reverse_task = _build_progress_task("task-complete-then-fail")
manager._tasks[reverse_task.task_id] = reverse_task
await manager._set_chunk_completed(reverse_task.task_id, "file-1", "chunk-0")
await manager._set_chunk_failed(reverse_task.task_id, "file-1", "chunk-1", "boom")
reverse_file = reverse_task.files[0]
assert reverse_file.done_chunks == 1
assert reverse_file.failed_chunks == 1
assert reverse_file.progress == pytest.approx(0.5)
assert reverse_task.progress == pytest.approx(0.5)
@pytest.mark.asyncio
async def test_cancelled_chunks_do_not_increase_file_progress() -> None:
manager, _ = _build_manager()
task = _build_progress_task("task-cancelled-progress", total_chunks=3)
manager._tasks[task.task_id] = task
await manager._set_chunk_completed(task.task_id, "file-1", "chunk-0")
await manager._set_chunk_cancelled(task.task_id, "file-1", "chunk-1", "任务已取消")
file_record = task.files[0]
assert file_record.done_chunks == 1
assert file_record.cancelled_chunks == 1
assert file_record.progress == pytest.approx(1 / 3)
assert task.progress == pytest.approx(1 / 3)
@pytest.mark.asyncio
async def test_persist_processed_chunk_skips_invalid_nested_items() -> None:
manager, metadata_store = _build_manager()
file_record = SimpleNamespace(source_path="", source_kind="paste", name="demo.txt")
await manager._persist_processed_chunk(
file_record,
_build_chunk(
{
"triples": [{"subject": "Alice", "predicate": "持有", "object": "地图"}, ["bad"]],
"relations": [{"subject": "Alice", "predicate": "", "object": "地图"}],
"entities": ["Alice", {"name": "地图"}, ["bad"]],
}
),
)
assert len(metadata_store.paragraphs) == 1
assert set(metadata_store.entities) >= {"Alice", "地图"}
assert metadata_store.relations == [("Alice", "持有", "地图")]