fix:完善Maisaka记忆写回链路

补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。
This commit is contained in:
A-Dawn
2026-04-16 19:04:08 +08:00
parent 7ed5630583
commit 459927e7c0
13 changed files with 918 additions and 65 deletions

View File

@@ -0,0 +1,404 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List
import asyncio
import inspect
import json
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
import numpy as np
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from src.A_memorix.core.utils import summary_importer as summary_importer_module
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.database import database as database_module
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.message_repository import count_messages
from src.config.model_configs import TaskConfig
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services import send_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
summary_importer_module = None # type: ignore[assignment]
BotChatSession = None # type: ignore[assignment]
SessionMessage = None # type: ignore[assignment]
MessageInfo = None # type: ignore[assignment]
UserInfo = None # type: ignore[assignment]
MessageSequence = None # type: ignore[assignment]
TextComponent = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
count_messages = None # type: ignore[assignment]
TaskConfig = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
send_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(
self,
component_name: str,
args: Dict[str, Any] | None,
*,
timeout_ms: int = 30000,
) -> Any:
del timeout_ms
payload = args or {}
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
class _FakePlatformIOManager:
def __init__(self) -> None:
self.ensure_calls = 0
async def ensure_send_pipeline_ready(self) -> None:
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
del message, metadata
return SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=route_key,
)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_incoming_message(
*,
session_id: str,
user_id: str,
text: str,
timestamp: datetime | None = None,
) -> SessionMessage:
message = SessionMessage(
message_id="incoming-message-id",
timestamp=timestamp or datetime.now(),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id=user_id,
user_nickname="测试用户",
user_cardname="测试用户",
),
additional_config={},
)
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
message.session_id = session_id
message.reply_to = None
message.is_mentioned = False
message.is_at = False
message.is_emoji = False
message.is_picture = False
message.is_command = False
message.is_notify = False
message.processed_plain_text = text
message.display_message = text
message.initialized = True
return message
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
@pytest.mark.asyncio
async def test_text_to_stream_triggers_real_chat_summary_writeback(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
_install_temp_main_database(monkeypatch, tmp_path)
fake_embedding_manager = _FakeEmbeddingManager()
captured_prompts: List[str] = []
fixed_send_timestamp = 1_777_000_000.0
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
del kwargs
return {
"ok": True,
"message": "ok",
"configured_dimension": fake_embedding_manager.default_dimension,
"requested_dimension": fake_embedding_manager.default_dimension,
"vector_store_dimension": fake_embedding_manager.default_dimension,
"detected_dimension": fake_embedding_manager.default_dimension,
"encoded_dimension": fake_embedding_manager.default_dimension,
"elapsed_ms": 0.0,
"sample_text": "test",
"checked_at": datetime.now().timestamp(),
}
async def _fake_generate(request: Any) -> Any:
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
return SimpleNamespace(
success=True,
completion=SimpleNamespace(
response=json.dumps(
{
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
"entities": ["绿色围巾"],
"relations": [],
},
ensure_ascii=False,
)
),
)
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"get_available_models",
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"resolve_task_name_from_model_config",
lambda model_config: "utils",
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"generate",
_fake_generate,
)
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
"summarization": {"model_name": ["utils"]},
},
)
service = memory_flow_service_module.MemoryAutomationService()
fake_platform_io_manager = _FakePlatformIOManager()
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
return {
"rebuilt": 0,
"items": [],
"failures": [],
"sources": list(sources),
}
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
monkeypatch.setattr(
memory_service_module,
"a_memorix_host_service",
_KernelBackedRuntimeManager(kernel),
)
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: (
BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
if stream_id == "test-session"
else None
),
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_enabled",
True,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_message_threshold",
2,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_context_length",
10,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"person_fact_writeback_enabled",
False,
raising=False,
)
await kernel.initialize()
try:
incoming_message = _build_incoming_message(
session_id="test-session",
user_id="target-user",
text="我最近买了一条绿色围巾。",
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
)
with database_module.get_db_session() as session:
session.add(incoming_message.to_db_instance())
sent_message = await send_service.text_to_stream_with_message(
text="好的,我会记住你最近买了绿色围巾。",
stream_id="test-session",
storage_message=True,
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_platform_io_manager.ensure_calls == 1
assert count_messages(session_id="test-session") == 2
paragraphs = await _wait_until(
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
description="等待聊天摘要写回到 A_memorix",
)
assert captured_prompts
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
finally:
await service.shutdown()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass

View File

@@ -455,10 +455,15 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), []
return "continue", _build_chat_response("直接进入 planner。", []), [], []
async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse:
del tool_definitions
async def _fake_planner(
self,
*,
injected_user_messages: list[str] | None = None,
tool_definitions: list[dict[str, Any]] | None = None,
) -> ChatResponse:
del injected_user_messages, tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
import pickle
from pathlib import Path
import pytest
try:
from src.A_memorix.core.storage.graph_store import GraphStore
except SystemExit as exc:
GraphStore = None # type: ignore[assignment]
IMPORT_ERROR = f"config initialization exited during import: {exc}"
else:
IMPORT_ERROR = None
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
def _build_empty_graph_metadata() -> dict:
return {
"nodes": [],
"node_to_idx": {},
"node_attrs": {},
"matrix_format": "csr",
"total_nodes_added": 0,
"total_edges_added": 0,
"total_nodes_deleted": 0,
"total_edges_deleted": 0,
"edge_hash_map": {},
}
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
matrix_path = data_dir / "graph_adjacency.npz"
assert matrix_path.exists()
store.clear()
store.save()
assert not matrix_path.exists()
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
with metadata_path.open("wb") as handle:
pickle.dump(_build_empty_graph_metadata(), handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.num_nodes == 0
assert reloaded.num_edges == 0
assert reloaded.get_nodes() == []

View File

@@ -33,34 +33,3 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
"enable_jargon_learning": False,
},
]
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
payload = {
"visual": {
"multimodal_replyer": True,
}
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "visual.multimodal_replyer_moved_to_visual.replyer_mode" in result.reason
assert result.data["visual"]["replyer_mode"] == "multimodal"
assert "multimodal_replyer" not in result.data["visual"]
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None:
payload = {
"chat": {
"replyer_generator_type": "legacy",
},
"visual": {},
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "chat.replyer_generator_type_moved_to_visual.replyer_mode" in result.reason
assert result.data["visual"]["replyer_mode"] == "text"
assert "replyer_generator_type" not in result.data["chat"]

View File

@@ -38,6 +38,78 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
assert person.person_id == "qq:123"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:5"
assert payload["chat_id"] == "session-1"
assert payload["text"] == ""
assert payload["metadata"]["generate_from_chat"] is True
assert payload["metadata"]["context_length"] == 7
assert payload["metadata"]["trigger"] == "message_threshold"
assert payload["user_id"] == "user-1"
assert payload["group_id"] == "group-1"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=6,
chat_summary_writeback_context_length=9,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = []
@@ -52,14 +124,67 @@ async def test_memory_automation_service_auto_starts_and_delegates():
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("sent", "session-1"),
("summary", "session-1"),
("shutdown", "summary"),
("shutdown", "fact"),
]
@pytest.mark.asyncio
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("shutdown", "summary"),
("shutdown", "fact"),
]

View File

@@ -1,5 +1,6 @@
"""发送服务回归测试。"""
import sys
from types import SimpleNamespace
from typing import Any, Dict, List
@@ -182,6 +183,75 @@ async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pyt
assert stored_messages[0].message_id == "real-message-id"
@pytest.mark.asyncio
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
memory_events: List[str] = []
history_events: List[tuple[str, str]] = []
class FakeMemoryAutomationService:
async def on_message_sent(self, message: Any) -> None:
memory_events.append(str(message.message_id))
class FakeRuntime:
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
history_events.append((str(message.message_id), source_kind))
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
monkeypatch.setitem(
sys.modules,
"src.services.memory_flow_service",
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
)
monkeypatch.setitem(
sys.modules,
"src.chat.heart_flow.heartflow_manager",
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
)
sent_message = await send_service.text_to_stream_with_message(
text="你好",
stream_id="test-session",
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_manager.ensure_calls == 1
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
assert memory_events == ["real-message-id"]
assert history_events == [("real-message-id", "guided_reply")]
@pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager(