From 459927e7c0544b2888cb95b5747a713429c1e87d Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Thu, 16 Apr 2026 19:04:08 +0800 Subject: [PATCH] =?UTF-8?q?fix:=E5=AE=8C=E5=96=84Maisaka=E8=AE=B0=E5=BF=86?= =?UTF-8?q?=E5=86=99=E5=9B=9E=E9=93=BE=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。 --- .gitignore | 1 + ...test_chat_summary_writeback_integration.py | 404 ++++++++++++++++++ .../test_feedback_correction_chat_flow.py | 11 +- .../test_graph_store_persistence.py | 64 +++ .../test_legacy_config_migration.py | 31 -- .../test_memory_flow_service.py | 125 ++++++ pytests/test_send_service.py | 70 +++ .../core/runtime/sdk_memory_kernel.py | 3 + src/A_memorix/core/storage/graph_store.py | 19 +- src/A_memorix/core/utils/summary_importer.py | 22 +- src/config/legacy_migration.py | 13 +- src/config/official_configs.py | 46 +- src/services/memory_flow_service.py | 174 +++++++- 13 files changed, 918 insertions(+), 65 deletions(-) create mode 100644 pytests/A_memorix_test/test_chat_summary_writeback_integration.py create mode 100644 pytests/A_memorix_test/test_graph_store_persistence.py diff --git a/.gitignore b/.gitignore index c5a687ca..960c78ed 100644 --- a/.gitignore +++ b/.gitignore @@ -371,3 +371,4 @@ packages/ .claude/ .omc/ /.venv312 +/src/A_memorix/algorithm_redesign diff --git a/pytests/A_memorix_test/test_chat_summary_writeback_integration.py b/pytests/A_memorix_test/test_chat_summary_writeback_integration.py new file mode 100644 index 00000000..0feab922 --- /dev/null +++ b/pytests/A_memorix_test/test_chat_summary_writeback_integration.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Callable, Dict, List + +import asyncio +import inspect +import json + +from sqlalchemy.orm import sessionmaker +from sqlmodel import Session, create_engine +import numpy as np +import pytest + +IMPORT_ERROR: str | None = None + +try: + from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module + from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel + from src.A_memorix.core.utils import summary_importer as summary_importer_module + from src.chat.message_receive.chat_manager import BotChatSession + from src.chat.message_receive.message import SessionMessage + from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo + from src.common.data_models.message_component_data_model import MessageSequence, TextComponent + from src.common.database import database as database_module + from src.common.database.migrations import create_database_migration_bootstrapper + from src.common.message_repository import count_messages + from src.config.model_configs import TaskConfig + from src.services import memory_flow_service as memory_flow_service_module + from src.services import memory_service as memory_service_module + from src.services import send_service +except SystemExit as exc: + IMPORT_ERROR = f"config initialization exited during import: {exc}" + kernel_module = None # type: ignore[assignment] + SDKMemoryKernel = None # type: ignore[assignment] + summary_importer_module = None # type: ignore[assignment] + BotChatSession = None # type: ignore[assignment] + SessionMessage = None # type: ignore[assignment] + MessageInfo = None # type: ignore[assignment] + UserInfo = None # type: ignore[assignment] + MessageSequence = None # type: ignore[assignment] + TextComponent = None # type: ignore[assignment] + database_module = None # type: ignore[assignment] + create_database_migration_bootstrapper = None # type: ignore[assignment] + count_messages = None # type: ignore[assignment] + TaskConfig = None # type: ignore[assignment] + memory_flow_service_module = None # type: ignore[assignment] + memory_service_module = None # type: ignore[assignment] + send_service = None # type: ignore[assignment] + + +pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "") + + +class _FakeEmbeddingManager: + def __init__(self, dimension: int = 8) -> None: + self.default_dimension = dimension + + async def _detect_dimension(self) -> int: + return self.default_dimension + + async def encode(self, text: Any) -> np.ndarray: + def _encode_one(raw: Any) -> np.ndarray: + content = str(raw or "") + vector = np.zeros(self.default_dimension, dtype=np.float32) + for index, byte in enumerate(content.encode("utf-8")): + vector[index % self.default_dimension] += float((byte % 17) + 1) + norm = float(np.linalg.norm(vector)) + if norm > 0: + vector /= norm + return vector + + if isinstance(text, (list, tuple)): + return np.stack([_encode_one(item) for item in text]).astype(np.float32) + return _encode_one(text).astype(np.float32) + + +class _KernelBackedRuntimeManager: + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke( + self, + component_name: str, + args: Dict[str, Any] | None, + *, + timeout_ms: int = 30000, + ) -> Any: + del timeout_ms + payload = args or {} + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +class _NoopRuntimeManager: + async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any: + del hook_name + return SimpleNamespace(aborted=False, kwargs=kwargs) + + +class _FakePlatformIOManager: + def __init__(self) -> None: + self.ensure_calls = 0 + + async def ensure_send_pipeline_ready(self) -> None: + self.ensure_calls += 1 + + def build_route_key_from_message(self, message: Any) -> Any: + del message + return SimpleNamespace(platform="qq") + + async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any: + del message, metadata + return SimpleNamespace( + has_success=True, + sent_receipts=[ + SimpleNamespace( + driver_id="plugin.qq.sender", + external_message_id="real-message-id", + metadata={}, + ) + ], + failed_receipts=[], + route_key=route_key, + ) + + +def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + db_dir = (tmp_path / "main_db").resolve() + db_dir.mkdir(parents=True, exist_ok=True) + db_file = db_dir / "MaiBot.db" + database_url = f"sqlite:///{db_file}" + + try: + database_module.engine.dispose() + except Exception: + pass + + engine = create_engine( + database_url, + echo=False, + connect_args={"check_same_thread": False}, + pool_pre_ping=True, + ) + session_local = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, + class_=Session, + ) + bootstrapper = create_database_migration_bootstrapper(engine) + + monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False) + monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False) + monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False) + monkeypatch.setattr(database_module, "engine", engine, raising=False) + monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False) + monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False) + monkeypatch.setattr(database_module, "_db_initialized", False, raising=False) + + +def _build_incoming_message( + *, + session_id: str, + user_id: str, + text: str, + timestamp: datetime | None = None, +) -> SessionMessage: + message = SessionMessage( + message_id="incoming-message-id", + timestamp=timestamp or datetime.now(), + platform="qq", + ) + message.message_info = MessageInfo( + user_info=UserInfo( + user_id=user_id, + user_nickname="测试用户", + user_cardname="测试用户", + ), + additional_config={}, + ) + message.raw_message = MessageSequence(components=[TextComponent(text=text)]) + message.session_id = session_id + message.reply_to = None + message.is_mentioned = False + message.is_at = False + message.is_emoji = False + message.is_picture = False + message.is_command = False + message.is_notify = False + message.processed_plain_text = text + message.display_message = text + message.initialized = True + return message + + +async def _wait_until( + predicate: Callable[[], Any], + *, + timeout_seconds: float = 10.0, + interval_seconds: float = 0.05, + description: str, +) -> Any: + deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds)) + while asyncio.get_running_loop().time() < deadline: + value = predicate() + if inspect.isawaitable(value): + value = await value + if value: + return value + await asyncio.sleep(interval_seconds) + raise AssertionError(f"等待超时: {description}") + + +@pytest.mark.asyncio +async def test_text_to_stream_triggers_real_chat_summary_writeback( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + _install_temp_main_database(monkeypatch, tmp_path) + + fake_embedding_manager = _FakeEmbeddingManager() + captured_prompts: List[str] = [] + fixed_send_timestamp = 1_777_000_000.0 + + async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]: + del kwargs + return { + "ok": True, + "message": "ok", + "configured_dimension": fake_embedding_manager.default_dimension, + "requested_dimension": fake_embedding_manager.default_dimension, + "vector_store_dimension": fake_embedding_manager.default_dimension, + "detected_dimension": fake_embedding_manager.default_dimension, + "encoded_dimension": fake_embedding_manager.default_dimension, + "elapsed_ms": 0.0, + "sample_text": "test", + "checked_at": datetime.now().timestamp(), + } + + async def _fake_generate(request: Any) -> Any: + captured_prompts.append(str(getattr(request, "prompt", "") or "")) + return SimpleNamespace( + success=True, + completion=SimpleNamespace( + response=json.dumps( + { + "summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。", + "entities": ["绿色围巾"], + "relations": [], + }, + ensure_ascii=False, + ) + ), + ) + + monkeypatch.setattr( + kernel_module, + "create_embedding_api_adapter", + lambda **kwargs: fake_embedding_manager, + ) + monkeypatch.setattr( + kernel_module, + "run_embedding_runtime_self_check", + _fake_runtime_self_check, + ) + monkeypatch.setattr( + summary_importer_module, + "run_embedding_runtime_self_check", + _fake_runtime_self_check, + ) + monkeypatch.setattr( + summary_importer_module.llm_api, + "get_available_models", + lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])}, + ) + monkeypatch.setattr( + summary_importer_module.llm_api, + "resolve_task_name_from_model_config", + lambda model_config: "utils", + ) + monkeypatch.setattr( + summary_importer_module.llm_api, + "generate", + _fake_generate, + ) + monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp) + monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp) + + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())}, + "advanced": {"enable_auto_save": False}, + "embedding": {"dimension": fake_embedding_manager.default_dimension}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + "summarization": {"model_name": ["utils"]}, + }, + ) + + service = memory_flow_service_module.MemoryAutomationService() + fake_platform_io_manager = _FakePlatformIOManager() + + async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]: + return { + "rebuilt": 0, + "items": [], + "failures": [], + "sources": list(sources), + } + + monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources) + monkeypatch.setattr( + memory_service_module, + "a_memorix_host_service", + _KernelBackedRuntimeManager(kernel), + ) + monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service) + monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager()) + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: ( + BotChatSession( + session_id="test-session", + platform="qq", + user_id="target-user", + group_id=None, + ) + if stream_id == "test-session" + else None + ), + ) + monkeypatch.setattr( + memory_flow_service_module.global_config.memory, + "chat_summary_writeback_enabled", + True, + raising=False, + ) + monkeypatch.setattr( + memory_flow_service_module.global_config.memory, + "chat_summary_writeback_message_threshold", + 2, + raising=False, + ) + monkeypatch.setattr( + memory_flow_service_module.global_config.memory, + "chat_summary_writeback_context_length", + 10, + raising=False, + ) + monkeypatch.setattr( + memory_flow_service_module.global_config.memory, + "person_fact_writeback_enabled", + False, + raising=False, + ) + + await kernel.initialize() + + try: + incoming_message = _build_incoming_message( + session_id="test-session", + user_id="target-user", + text="我最近买了一条绿色围巾。", + timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1), + ) + with database_module.get_db_session() as session: + session.add(incoming_message.to_db_instance()) + + sent_message = await send_service.text_to_stream_with_message( + text="好的,我会记住你最近买了绿色围巾。", + stream_id="test-session", + storage_message=True, + ) + + assert sent_message is not None + assert sent_message.message_id == "real-message-id" + assert fake_platform_io_manager.ensure_calls == 1 + assert count_messages(session_id="test-session") == 2 + + paragraphs = await _wait_until( + lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"), + description="等待聊天摘要写回到 A_memorix", + ) + + assert captured_prompts + assert "我最近买了一条绿色围巾。" in captured_prompts[-1] + assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1] + assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs) + assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2 + finally: + await service.shutdown() + await kernel.shutdown() + try: + database_module.engine.dispose() + except Exception: + pass diff --git a/pytests/A_memorix_test/test_feedback_correction_chat_flow.py b/pytests/A_memorix_test/test_feedback_correction_chat_flow.py index 26844eba..3d53c618 100644 --- a/pytests/A_memorix_test/test_feedback_correction_chat_flow.py +++ b/pytests/A_memorix_test/test_feedback_correction_chat_flow.py @@ -455,10 +455,15 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): async def _fake_timing_gate(self, anchor_message: Any): del self, anchor_message - return "continue", _build_chat_response("直接进入 planner。", []), [] + return "continue", _build_chat_response("直接进入 planner。", []), [], [] - async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse: - del tool_definitions + async def _fake_planner( + self, + *, + injected_user_messages: list[str] | None = None, + tool_definitions: list[dict[str, Any]] | None = None, + ) -> ChatResponse: + del injected_user_messages, tool_definitions latest_message = self._runtime.message_cache[-1] latest_text = str(latest_message.processed_plain_text or "") planner_calls.append(latest_text) diff --git a/pytests/A_memorix_test/test_graph_store_persistence.py b/pytests/A_memorix_test/test_graph_store_persistence.py new file mode 100644 index 00000000..1f2bc4a7 --- /dev/null +++ b/pytests/A_memorix_test/test_graph_store_persistence.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import pickle +from pathlib import Path + +import pytest + +try: + from src.A_memorix.core.storage.graph_store import GraphStore +except SystemExit as exc: + GraphStore = None # type: ignore[assignment] + IMPORT_ERROR = f"config initialization exited during import: {exc}" +else: + IMPORT_ERROR = None + + +pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "") + + +def _build_empty_graph_metadata() -> dict: + return { + "nodes": [], + "node_to_idx": {}, + "node_attrs": {}, + "matrix_format": "csr", + "total_nodes_added": 0, + "total_edges_added": 0, + "total_nodes_deleted": 0, + "total_edges_deleted": 0, + "edge_hash_map": {}, + } + + +def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None: + data_dir = tmp_path / "graph_data" + store = GraphStore(data_dir=data_dir) + store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"]) + store.save() + + matrix_path = data_dir / "graph_adjacency.npz" + assert matrix_path.exists() + + store.clear() + store.save() + + assert not matrix_path.exists() + + +def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None: + data_dir = tmp_path / "graph_data" + store = GraphStore(data_dir=data_dir) + store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"]) + store.save() + + metadata_path = data_dir / "graph_metadata.pkl" + with metadata_path.open("wb") as handle: + pickle.dump(_build_empty_graph_metadata(), handle) + + reloaded = GraphStore(data_dir=data_dir) + reloaded.load() + + assert reloaded.num_nodes == 0 + assert reloaded.num_edges == 0 + assert reloaded.get_nodes() == [] diff --git a/pytests/A_memorix_test/test_legacy_config_migration.py b/pytests/A_memorix_test/test_legacy_config_migration.py index 15599241..c382e4f3 100644 --- a/pytests/A_memorix_test/test_legacy_config_migration.py +++ b/pytests/A_memorix_test/test_legacy_config_migration.py @@ -33,34 +33,3 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated(): "enable_jargon_learning": False, }, ] - - -def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None: - payload = { - "visual": { - "multimodal_replyer": True, - } - } - - result = try_migrate_legacy_bot_config_dict(payload) - - assert result.migrated is True - assert "visual.multimodal_replyer_moved_to_visual.replyer_mode" in result.reason - assert result.data["visual"]["replyer_mode"] == "multimodal" - assert "multimodal_replyer" not in result.data["visual"] - - -def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None: - payload = { - "chat": { - "replyer_generator_type": "legacy", - }, - "visual": {}, - } - - result = try_migrate_legacy_bot_config_dict(payload) - - assert result.migrated is True - assert "chat.replyer_generator_type_moved_to_visual.replyer_mode" in result.reason - assert result.data["visual"]["replyer_mode"] == "text" - assert "replyer_generator_type" not in result.data["chat"] diff --git a/pytests/A_memorix_test/test_memory_flow_service.py b/pytests/A_memorix_test/test_memory_flow_service.py index a11dd621..dbe8809a 100644 --- a/pytests/A_memorix_test/test_memory_flow_service.py +++ b/pytests/A_memorix_test/test_memory_flow_service.py @@ -38,6 +38,78 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch): assert person.person_id == "qq:123" +@pytest.mark.asyncio +async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch): + events: list[tuple[str, object]] = [] + + monkeypatch.setattr( + memory_flow_module, + "global_config", + SimpleNamespace( + memory=SimpleNamespace( + chat_summary_writeback_enabled=True, + chat_summary_writeback_message_threshold=3, + chat_summary_writeback_context_length=7, + ) + ), + ) + monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5) + + async def fake_ingest_summary(**kwargs): + events.append(("ingest_summary", kwargs)) + return SimpleNamespace(success=True, detail="ok") + + monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) + + service = memory_flow_module.ChatSummaryWritebackService() + message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) + + await service._handle_message(message) + + assert len(events) == 1 + _, payload = events[0] + assert payload["external_id"] == "chat_auto_summary:session-1:5" + assert payload["chat_id"] == "session-1" + assert payload["text"] == "" + assert payload["metadata"]["generate_from_chat"] is True + assert payload["metadata"]["context_length"] == 7 + assert payload["metadata"]["trigger"] == "message_threshold" + assert payload["user_id"] == "user-1" + assert payload["group_id"] == "group-1" + + +@pytest.mark.asyncio +async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch): + called = False + + monkeypatch.setattr( + memory_flow_module, + "global_config", + SimpleNamespace( + memory=SimpleNamespace( + chat_summary_writeback_enabled=True, + chat_summary_writeback_message_threshold=6, + chat_summary_writeback_context_length=9, + ) + ), + ) + monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5) + + async def fake_ingest_summary(**kwargs): + nonlocal called + called = True + return SimpleNamespace(success=True, detail="ok") + + monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) + + service = memory_flow_module.ChatSummaryWritebackService() + message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) + + await service._handle_message(message) + + assert called is False + + @pytest.mark.asyncio async def test_memory_automation_service_auto_starts_and_delegates(): events: list[tuple[str, str]] = [] @@ -52,14 +124,67 @@ async def test_memory_automation_service_auto_starts_and_delegates(): async def shutdown(self): events.append(("shutdown", "fact")) + class FakeChatSummaryWriteback: + async def start(self): + events.append(("start", "summary")) + + async def enqueue(self, message): + events.append(("summary", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "summary")) + service = memory_flow_module.MemoryAutomationService() service.fact_writeback = FakeFactWriteback() + service.chat_summary_writeback = FakeChatSummaryWriteback() await service.on_message_sent(SimpleNamespace(session_id="session-1")) await service.shutdown() assert events == [ ("start", "fact"), + ("start", "summary"), ("sent", "session-1"), + ("summary", "session-1"), + ("shutdown", "summary"), + ("shutdown", "fact"), + ] + + +@pytest.mark.asyncio +async def test_memory_automation_service_on_incoming_message_auto_starts_only(): + events: list[tuple[str, str]] = [] + + class FakeFactWriteback: + async def start(self): + events.append(("start", "fact")) + + async def enqueue(self, message): + events.append(("sent", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "fact")) + + class FakeChatSummaryWriteback: + async def start(self): + events.append(("start", "summary")) + + async def enqueue(self, message): + events.append(("summary", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "summary")) + + service = memory_flow_module.MemoryAutomationService() + service.fact_writeback = FakeFactWriteback() + service.chat_summary_writeback = FakeChatSummaryWriteback() + + await service.on_incoming_message(SimpleNamespace(session_id="session-1")) + await service.shutdown() + + assert events == [ + ("start", "fact"), + ("start", "summary"), + ("shutdown", "summary"), ("shutdown", "fact"), ] diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py index 5007f364..5af18d1f 100644 --- a/pytests/test_send_service.py +++ b/pytests/test_send_service.py @@ -1,5 +1,6 @@ """发送服务回归测试。""" +import sys from types import SimpleNamespace from typing import Any, Dict, List @@ -182,6 +183,75 @@ async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pyt assert stored_messages[0].message_id == "real-message-id" +@pytest.mark.asyncio +async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake_manager = _FakePlatformIOManager( + delivery_batch=SimpleNamespace( + has_success=True, + sent_receipts=[ + SimpleNamespace( + driver_id="plugin.qq.sender", + external_message_id="real-message-id", + metadata={}, + ) + ], + failed_receipts=[], + route_key=SimpleNamespace(platform="qq"), + ) + ) + stored_messages: List[Any] = [] + memory_events: List[str] = [] + history_events: List[tuple[str, str]] = [] + + class FakeMemoryAutomationService: + async def on_message_sent(self, message: Any) -> None: + memory_events.append(str(message.message_id)) + + class FakeRuntime: + def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None: + history_events.append((str(message.message_id), source_kind)) + + monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager) + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq") + monkeypatch.setattr( + send_service._chat_manager, + "get_session_by_session_id", + lambda stream_id: _build_private_stream() if stream_id == "test-session" else None, + ) + monkeypatch.setattr( + send_service.MessageUtils, + "store_message_to_db", + lambda message: stored_messages.append(message), + ) + monkeypatch.setitem( + sys.modules, + "src.services.memory_flow_service", + SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()), + ) + monkeypatch.setitem( + sys.modules, + "src.chat.heart_flow.heartflow_manager", + SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})), + ) + + sent_message = await send_service.text_to_stream_with_message( + text="你好", + stream_id="test-session", + sync_to_maisaka_history=True, + maisaka_source_kind="guided_reply", + ) + + assert sent_message is not None + assert sent_message.message_id == "real-message-id" + assert fake_manager.ensure_calls == 1 + assert len(stored_messages) == 1 + assert stored_messages[0].message_id == "real-message-id" + assert memory_events == ["real-message-id"] + assert history_events == [("real-message-id", "guided_reply")] + + @pytest.mark.asyncio async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None: fake_manager = _FakePlatformIOManager( diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index c006838a..19c04df0 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -802,6 +802,7 @@ class SDKMemoryKernel: chat_id: str, context_length: Optional[int] = None, include_personality: Optional[bool] = None, + time_end: Optional[float] = None, ) -> Dict[str, Any]: await self.initialize() assert self.summary_importer @@ -809,6 +810,7 @@ class SDKMemoryKernel: stream_id=str(chat_id or "").strip(), context_length=context_length, include_personality=include_personality, + time_end=time_end, ) if success: await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])]) @@ -851,6 +853,7 @@ class SDKMemoryKernel: chat_id=chat_id, context_length=self._optional_int(summary_meta.get("context_length")), include_personality=summary_meta.get("include_personality"), + time_end=time_end, ) result.setdefault("external_id", external_id) result.setdefault("chat_id", chat_id) diff --git a/src/A_memorix/core/storage/graph_store.py b/src/A_memorix/core/storage/graph_store.py index 88574a12..85a6cddb 100644 --- a/src/A_memorix/core/storage/graph_store.py +++ b/src/A_memorix/core/storage/graph_store.py @@ -1190,11 +1190,14 @@ class GraphStore: data_dir.mkdir(parents=True, exist_ok=True) # 保存邻接矩阵 + matrix_path = data_dir / "graph_adjacency.npz" if self._adjacency is not None: - matrix_path = data_dir / "graph_adjacency.npz" with atomic_write(matrix_path, "wb") as f: save_npz(f, self._adjacency) logger.debug(f"保存邻接矩阵: {matrix_path}") + elif matrix_path.exists(): + matrix_path.unlink() + logger.debug(f"删除陈旧邻接矩阵: {matrix_path}") # 保存元数据 metadata = { @@ -1288,9 +1291,20 @@ class GraphStore: if self._adjacency is not None: adj_n = self._adjacency.shape[0] current_n = len(self._nodes) - if current_n > adj_n: + if current_n == 0: + logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。") + self._adjacency = None + elif current_n > adj_n: logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") self._expand_adjacency_matrix(current_n - adj_n) + elif current_n < adj_n: + logger.warning( + f"检测到过期邻接矩阵: 节点数={current_n}, 矩阵大小={adj_n}. 正在重置邻接矩阵..." + ) + if self.matrix_format == "csc": + self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32) + else: + self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32) self._adjacency_dirty = True logger.info( @@ -1445,4 +1459,3 @@ class GraphStore: self._adjacency_dirty = True logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边") return count - diff --git a/src/A_memorix/core/utils/summary_importer.py b/src/A_memorix/core/utils/summary_importer.py index 68a0052c..94e5ebb2 100644 --- a/src/A_memorix/core/utils/summary_importer.py +++ b/src/A_memorix/core/utils/summary_importer.py @@ -5,12 +5,13 @@ 导入到 A_memorix 的存储组件中。 """ -import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + import json import re +import time import traceback -from typing import List, Dict, Any, Tuple, Optional -from pathlib import Path from src.common.logger import get_logger from src.services import llm_service as llm_api @@ -222,7 +223,8 @@ class SummaryImporter: self, stream_id: str, context_length: Optional[int] = None, - include_personality: Optional[bool] = None + include_personality: Optional[bool] = None, + time_end: Optional[float] = None, ) -> Tuple[bool, str]: """ 从指定的聊天流中提取记录并执行总结导入 @@ -231,6 +233,7 @@ class SummaryImporter: stream_id: 聊天流 ID context_length: 总结的历史消息条数 include_personality: 是否包含人设 + time_end: 用于截取聊天记录的时间上界(闭区间) Returns: Tuple[bool, str]: (是否成功, 结果消息) @@ -248,12 +251,13 @@ class SummaryImporter: include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True) # 2. 获取历史消息 - # 获取当前时间之前的消息 - now = time.time() - messages = message_api.get_messages_before_time_in_chat( + query_time_end = time.time() if time_end is None else float(time_end) + messages = message_api.get_messages_by_time_in_chat( chat_id=stream_id, - timestamp=now, - limit=context_length + start_time=0.0, + end_time=query_time_end, + limit=context_length, + limit_mode="latest", ) if not messages: diff --git a/src/config/legacy_migration.py b/src/config/legacy_migration.py index 1fb5c4ac..f62e7eca 100644 --- a/src/config/legacy_migration.py +++ b/src/config/legacy_migration.py @@ -282,6 +282,10 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult: personality = _as_dict(data.get("personality")) visual = _as_dict(data.get("visual")) + if visual is None and personality is not None and "visual_style" in personality: + visual = {} + data["visual"] = visual + if visual is not None and personality is not None and "visual_style" in personality: if "visual_style" not in visual: visual["visual_style"] = personality["visual_style"] @@ -289,15 +293,6 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult: migrated_any = True reasons.append("personality.visual_style_moved_to_visual.visual_style") - if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual: - multimodal_planner = visual.pop("multimodal_planner") - if isinstance(multimodal_planner, bool): - visual["planner_mode"] = "multimodal" if multimodal_planner else "text" - migrated_any = True - reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode") - else: - visual["multimodal_planner"] = multimodal_planner - memory = _as_dict(data.get("memory")) if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"): migrated_any = True diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 4ecd8de4..d014b62d 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -145,23 +145,23 @@ class VisualConfig(ConfigBase): __ui_label__ = "视觉" __ui_icon__ = "image" - multimodal_planner: bool = Field( - default=True, + planner_mode: Literal["text", "multimodal", "auto"] = Field( + default="text", json_schema_extra={ - "x-widget": "switch", + "x-widget": "select", "x-icon": "image", }, ) - """是否直接输入图片""" + """Planner 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择""" - multimodal_replyer: bool = Field( - default=False, + replyer_mode: Literal["text", "multimodal", "auto"] = Field( + default="text", json_schema_extra={ - "x-widget": "switch", + "x-widget": "select", "x-icon": "git-branch", }, ) - """是否启用 Maisaka 多模态 replyer 生成器""" + """Replyer 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择""" visual_style: str = Field( default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本", @@ -424,6 +424,36 @@ class MemoryConfig(ConfigBase): ) """是否在发送回复后自动提取并写回人物事实到长期记忆""" + chat_summary_writeback_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "scroll-text", + }, + ) + """是否在 Maisaka 聊天过程中按消息窗口自动写回聊天摘要到长期记忆""" + + chat_summary_writeback_message_threshold: int = Field( + default=12, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "messages-square", + }, + ) + """自动写回聊天摘要的消息窗口阈值""" + + chat_summary_writeback_context_length: int = Field( + default=50, + ge=1, + le=500, + json_schema_extra={ + "x-widget": "input", + "x-icon": "rows-3", + }, + ) + """自动写回聊天摘要时,从聊天流中回看的消息条数""" + feedback_correction_enabled: bool = Field( default=False, json_schema_extra={ diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py index 969b8a23..237f925c 100644 --- a/src/services/memory_flow_service.py +++ b/src/services/memory_flow_service.py @@ -1,16 +1,21 @@ from __future__ import annotations +from dataclasses import dataclass +from datetime import datetime +from typing import Any, List, Optional + import asyncio import json -from typing import Any, List, Optional +import time from json_repair import repair_json from src.chat.utils.utils import is_bot_self from src.common.logger import get_logger -from src.common.message_repository import find_messages +from src.common.message_repository import count_messages, find_messages from src.config.config import global_config from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer +from src.services.memory_service import memory_service from src.services.llm_service import LLMServiceClient logger = get_logger("memory_flow_service") @@ -210,27 +215,192 @@ class PersonFactWritebackService: return False +@dataclass +class ChatSummaryWritebackState: + last_trigger_message_count: int = 0 + last_trigger_time: float = 0.0 + + +class ChatSummaryWritebackService: + def __init__(self) -> None: + self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256) + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + self._states: dict[str, ChatSummaryWritebackState] = {} + + async def start(self) -> None: + if self._worker_task is not None and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback") + + async def shutdown(self) -> None: + self._stopping = True + worker = self._worker_task + self._worker_task = None + if worker is None: + return + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning("关闭聊天摘要写回 worker 失败: %s", exc) + + async def enqueue(self, message: Any) -> None: + if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)): + return + if self._stopping: + return + try: + self._queue.put_nowait(message) + except asyncio.QueueFull: + logger.warning("聊天摘要写回队列已满,跳过本次触发") + + async def _worker_loop(self) -> None: + try: + while not self._stopping: + message = await self._queue.get() + try: + await self._handle_message(message) + except Exception as exc: + logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True) + finally: + self._queue.task_done() + except asyncio.CancelledError: + raise + + async def _handle_message(self, message: Any) -> None: + session_id = self._resolve_session_id(message) + if not session_id: + return + + total_message_count = count_messages(session_id=session_id) + if total_message_count <= 0: + return + + threshold = self._message_threshold() + state = self._states.setdefault(session_id, ChatSummaryWritebackState()) + pending_message_count = max(0, total_message_count - state.last_trigger_message_count) + if pending_message_count < threshold: + return + + context_length = self._context_length() + message_time = self._extract_message_timestamp(message) + result = await memory_service.ingest_summary( + external_id=f"chat_auto_summary:{session_id}:{total_message_count}", + chat_id=session_id, + text="", + participants=[], + time_end=message_time, + metadata={ + "generate_from_chat": True, + "context_length": context_length, + "writeback_source": "memory_flow_service", + "trigger": "message_threshold", + "trigger_message_count": total_message_count, + }, + respect_filter=True, + user_id=self._extract_session_user_id(message), + group_id=self._extract_session_group_id(message), + ) + if not getattr(result, "success", False): + logger.warning( + "聊天摘要自动写回失败: session_id=%s detail=%s", + session_id, + getattr(result, "detail", ""), + ) + return + + state.last_trigger_message_count = total_message_count + state.last_trigger_time = time.time() + logger.info( + "聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s", + session_id, + "message_threshold", + total_message_count, + context_length, + getattr(result, "detail", ""), + ) + + @staticmethod + def _resolve_session_id(message: Any) -> str: + return str( + getattr(message, "session_id", "") + or getattr(getattr(message, "session", None), "session_id", "") + or "" + ).strip() + + @staticmethod + def _extract_session_user_id(message: Any) -> str: + return str( + getattr(getattr(message, "session", None), "user_id", "") + or getattr(message, "user_id", "") + or "" + ).strip() + + @staticmethod + def _extract_session_group_id(message: Any) -> str: + return str( + getattr(getattr(message, "session", None), "group_id", "") + or getattr(message, "group_id", "") + or "" + ).strip() + + @staticmethod + def _extract_message_timestamp(message: Any) -> float | None: + raw_timestamp = getattr(message, "timestamp", None) + if isinstance(raw_timestamp, datetime): + return raw_timestamp.timestamp() + if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp): + try: + return float(raw_timestamp.timestamp()) + except Exception: + return None + if isinstance(raw_timestamp, (int, float)): + return float(raw_timestamp) + return None + + @staticmethod + def _message_threshold() -> int: + return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12)) + + @staticmethod + def _context_length() -> int: + return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50)) + + class MemoryAutomationService: def __init__(self) -> None: self.fact_writeback = PersonFactWritebackService() + self.chat_summary_writeback = ChatSummaryWritebackService() self._started = False async def start(self) -> None: if self._started: return await self.fact_writeback.start() + await self.chat_summary_writeback.start() self._started = True async def shutdown(self) -> None: if not self._started: return + await self.chat_summary_writeback.shutdown() await self.fact_writeback.shutdown() self._started = False + async def on_incoming_message(self, message: Any) -> None: + del message + if not self._started: + await self.start() + async def on_message_sent(self, message: Any) -> None: if not self._started: await self.start() await self.fact_writeback.enqueue(message) + await self.chat_summary_writeback.enqueue(message) memory_automation_service = MemoryAutomationService()