fix:完善Maisaka记忆写回链路

补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。
This commit is contained in:
A-Dawn
2026-04-16 19:04:08 +08:00
parent 7ed5630583
commit 459927e7c0
13 changed files with 918 additions and 65 deletions

1
.gitignore vendored
View File

@@ -371,3 +371,4 @@ packages/
.claude/
.omc/
/.venv312
/src/A_memorix/algorithm_redesign

View File

@@ -0,0 +1,404 @@
from __future__ import annotations
from datetime import datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict, List
import asyncio
import inspect
import json
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine
import numpy as np
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
from src.A_memorix.core.utils import summary_importer as summary_importer_module
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
from src.common.database import database as database_module
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.message_repository import count_messages
from src.config.model_configs import TaskConfig
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services import send_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
summary_importer_module = None # type: ignore[assignment]
BotChatSession = None # type: ignore[assignment]
SessionMessage = None # type: ignore[assignment]
MessageInfo = None # type: ignore[assignment]
UserInfo = None # type: ignore[assignment]
MessageSequence = None # type: ignore[assignment]
TextComponent = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
count_messages = None # type: ignore[assignment]
TaskConfig = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
send_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(
self,
component_name: str,
args: Dict[str, Any] | None,
*,
timeout_ms: int = 30000,
) -> Any:
del timeout_ms
payload = args or {}
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
class _FakePlatformIOManager:
def __init__(self) -> None:
self.ensure_calls = 0
async def ensure_send_pipeline_ready(self) -> None:
self.ensure_calls += 1
def build_route_key_from_message(self, message: Any) -> Any:
del message
return SimpleNamespace(platform="qq")
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
del message, metadata
return SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=route_key,
)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_incoming_message(
*,
session_id: str,
user_id: str,
text: str,
timestamp: datetime | None = None,
) -> SessionMessage:
message = SessionMessage(
message_id="incoming-message-id",
timestamp=timestamp or datetime.now(),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id=user_id,
user_nickname="测试用户",
user_cardname="测试用户",
),
additional_config={},
)
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
message.session_id = session_id
message.reply_to = None
message.is_mentioned = False
message.is_at = False
message.is_emoji = False
message.is_picture = False
message.is_command = False
message.is_notify = False
message.processed_plain_text = text
message.display_message = text
message.initialized = True
return message
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
@pytest.mark.asyncio
async def test_text_to_stream_triggers_real_chat_summary_writeback(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
_install_temp_main_database(monkeypatch, tmp_path)
fake_embedding_manager = _FakeEmbeddingManager()
captured_prompts: List[str] = []
fixed_send_timestamp = 1_777_000_000.0
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
del kwargs
return {
"ok": True,
"message": "ok",
"configured_dimension": fake_embedding_manager.default_dimension,
"requested_dimension": fake_embedding_manager.default_dimension,
"vector_store_dimension": fake_embedding_manager.default_dimension,
"detected_dimension": fake_embedding_manager.default_dimension,
"encoded_dimension": fake_embedding_manager.default_dimension,
"elapsed_ms": 0.0,
"sample_text": "test",
"checked_at": datetime.now().timestamp(),
}
async def _fake_generate(request: Any) -> Any:
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
return SimpleNamespace(
success=True,
completion=SimpleNamespace(
response=json.dumps(
{
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
"entities": ["绿色围巾"],
"relations": [],
},
ensure_ascii=False,
)
),
)
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"get_available_models",
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"resolve_task_name_from_model_config",
lambda model_config: "utils",
)
monkeypatch.setattr(
summary_importer_module.llm_api,
"generate",
_fake_generate,
)
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
"summarization": {"model_name": ["utils"]},
},
)
service = memory_flow_service_module.MemoryAutomationService()
fake_platform_io_manager = _FakePlatformIOManager()
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
return {
"rebuilt": 0,
"items": [],
"failures": [],
"sources": list(sources),
}
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
monkeypatch.setattr(
memory_service_module,
"a_memorix_host_service",
_KernelBackedRuntimeManager(kernel),
)
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: (
BotChatSession(
session_id="test-session",
platform="qq",
user_id="target-user",
group_id=None,
)
if stream_id == "test-session"
else None
),
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_enabled",
True,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_message_threshold",
2,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"chat_summary_writeback_context_length",
10,
raising=False,
)
monkeypatch.setattr(
memory_flow_service_module.global_config.memory,
"person_fact_writeback_enabled",
False,
raising=False,
)
await kernel.initialize()
try:
incoming_message = _build_incoming_message(
session_id="test-session",
user_id="target-user",
text="我最近买了一条绿色围巾。",
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
)
with database_module.get_db_session() as session:
session.add(incoming_message.to_db_instance())
sent_message = await send_service.text_to_stream_with_message(
text="好的,我会记住你最近买了绿色围巾。",
stream_id="test-session",
storage_message=True,
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_platform_io_manager.ensure_calls == 1
assert count_messages(session_id="test-session") == 2
paragraphs = await _wait_until(
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
description="等待聊天摘要写回到 A_memorix",
)
assert captured_prompts
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
finally:
await service.shutdown()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass

View File

@@ -455,10 +455,15 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), []
return "continue", _build_chat_response("直接进入 planner。", []), [], []
async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse:
del tool_definitions
async def _fake_planner(
self,
*,
injected_user_messages: list[str] | None = None,
tool_definitions: list[dict[str, Any]] | None = None,
) -> ChatResponse:
del injected_user_messages, tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
import pickle
from pathlib import Path
import pytest
try:
from src.A_memorix.core.storage.graph_store import GraphStore
except SystemExit as exc:
GraphStore = None # type: ignore[assignment]
IMPORT_ERROR = f"config initialization exited during import: {exc}"
else:
IMPORT_ERROR = None
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
def _build_empty_graph_metadata() -> dict:
return {
"nodes": [],
"node_to_idx": {},
"node_attrs": {},
"matrix_format": "csr",
"total_nodes_added": 0,
"total_edges_added": 0,
"total_nodes_deleted": 0,
"total_edges_deleted": 0,
"edge_hash_map": {},
}
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
matrix_path = data_dir / "graph_adjacency.npz"
assert matrix_path.exists()
store.clear()
store.save()
assert not matrix_path.exists()
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
with metadata_path.open("wb") as handle:
pickle.dump(_build_empty_graph_metadata(), handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.num_nodes == 0
assert reloaded.num_edges == 0
assert reloaded.get_nodes() == []

View File

@@ -33,34 +33,3 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
"enable_jargon_learning": False,
},
]
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
payload = {
"visual": {
"multimodal_replyer": True,
}
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "visual.multimodal_replyer_moved_to_visual.replyer_mode" in result.reason
assert result.data["visual"]["replyer_mode"] == "multimodal"
assert "multimodal_replyer" not in result.data["visual"]
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None:
payload = {
"chat": {
"replyer_generator_type": "legacy",
},
"visual": {},
}
result = try_migrate_legacy_bot_config_dict(payload)
assert result.migrated is True
assert "chat.replyer_generator_type_moved_to_visual.replyer_mode" in result.reason
assert result.data["visual"]["replyer_mode"] == "text"
assert "replyer_generator_type" not in result.data["chat"]

View File

@@ -38,6 +38,78 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
assert person.person_id == "qq:123"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:5"
assert payload["chat_id"] == "session-1"
assert payload["text"] == ""
assert payload["metadata"]["generate_from_chat"] is True
assert payload["metadata"]["context_length"] == 7
assert payload["metadata"]["trigger"] == "message_threshold"
assert payload["user_id"] == "user-1"
assert payload["group_id"] == "group-1"
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=6,
chat_summary_writeback_context_length=9,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
@pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = []
@@ -52,14 +124,67 @@ async def test_memory_automation_service_auto_starts_and_delegates():
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("sent", "session-1"),
("summary", "session-1"),
("shutdown", "summary"),
("shutdown", "fact"),
]
@pytest.mark.asyncio
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
events: list[tuple[str, str]] = []
class FakeFactWriteback:
async def start(self):
events.append(("start", "fact"))
async def enqueue(self, message):
events.append(("sent", message.session_id))
async def shutdown(self):
events.append(("shutdown", "fact"))
class FakeChatSummaryWriteback:
async def start(self):
events.append(("start", "summary"))
async def enqueue(self, message):
events.append(("summary", message.session_id))
async def shutdown(self):
events.append(("shutdown", "summary"))
service = memory_flow_module.MemoryAutomationService()
service.fact_writeback = FakeFactWriteback()
service.chat_summary_writeback = FakeChatSummaryWriteback()
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
await service.shutdown()
assert events == [
("start", "fact"),
("start", "summary"),
("shutdown", "summary"),
("shutdown", "fact"),
]

View File

@@ -1,5 +1,6 @@
"""发送服务回归测试。"""
import sys
from types import SimpleNamespace
from typing import Any, Dict, List
@@ -182,6 +183,75 @@ async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pyt
assert stored_messages[0].message_id == "real-message-id"
@pytest.mark.asyncio
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
monkeypatch: pytest.MonkeyPatch,
) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
memory_events: List[str] = []
history_events: List[tuple[str, str]] = []
class FakeMemoryAutomationService:
async def on_message_sent(self, message: Any) -> None:
memory_events.append(str(message.message_id))
class FakeRuntime:
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
history_events.append((str(message.message_id), source_kind))
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
monkeypatch.setitem(
sys.modules,
"src.services.memory_flow_service",
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
)
monkeypatch.setitem(
sys.modules,
"src.chat.heart_flow.heartflow_manager",
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
)
sent_message = await send_service.text_to_stream_with_message(
text="你好",
stream_id="test-session",
sync_to_maisaka_history=True,
maisaka_source_kind="guided_reply",
)
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_manager.ensure_calls == 1
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
assert memory_events == ["real-message-id"]
assert history_events == [("real-message-id", "guided_reply")]
@pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager(

View File

@@ -802,6 +802,7 @@ class SDKMemoryKernel:
chat_id: str,
context_length: Optional[int] = None,
include_personality: Optional[bool] = None,
time_end: Optional[float] = None,
) -> Dict[str, Any]:
await self.initialize()
assert self.summary_importer
@@ -809,6 +810,7 @@ class SDKMemoryKernel:
stream_id=str(chat_id or "").strip(),
context_length=context_length,
include_personality=include_personality,
time_end=time_end,
)
if success:
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
@@ -851,6 +853,7 @@ class SDKMemoryKernel:
chat_id=chat_id,
context_length=self._optional_int(summary_meta.get("context_length")),
include_personality=summary_meta.get("include_personality"),
time_end=time_end,
)
result.setdefault("external_id", external_id)
result.setdefault("chat_id", chat_id)

View File

@@ -1190,11 +1190,14 @@ class GraphStore:
data_dir.mkdir(parents=True, exist_ok=True)
# 保存邻接矩阵
matrix_path = data_dir / "graph_adjacency.npz"
if self._adjacency is not None:
matrix_path = data_dir / "graph_adjacency.npz"
with atomic_write(matrix_path, "wb") as f:
save_npz(f, self._adjacency)
logger.debug(f"保存邻接矩阵: {matrix_path}")
elif matrix_path.exists():
matrix_path.unlink()
logger.debug(f"删除陈旧邻接矩阵: {matrix_path}")
# 保存元数据
metadata = {
@@ -1288,9 +1291,20 @@ class GraphStore:
if self._adjacency is not None:
adj_n = self._adjacency.shape[0]
current_n = len(self._nodes)
if current_n > adj_n:
if current_n == 0:
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
self._adjacency = None
elif current_n > adj_n:
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
self._expand_adjacency_matrix(current_n - adj_n)
elif current_n < adj_n:
logger.warning(
f"检测到过期邻接矩阵: 节点数={current_n}, 矩阵大小={adj_n}. 正在重置邻接矩阵..."
)
if self.matrix_format == "csc":
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
else:
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
self._adjacency_dirty = True
logger.info(
@@ -1445,4 +1459,3 @@ class GraphStore:
self._adjacency_dirty = True
logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边")
return count

View File

@@ -5,12 +5,13 @@
导入到 A_memorix 的存储组件中。
"""
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import json
import re
import time
import traceback
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
from src.common.logger import get_logger
from src.services import llm_service as llm_api
@@ -222,7 +223,8 @@ class SummaryImporter:
self,
stream_id: str,
context_length: Optional[int] = None,
include_personality: Optional[bool] = None
include_personality: Optional[bool] = None,
time_end: Optional[float] = None,
) -> Tuple[bool, str]:
"""
从指定的聊天流中提取记录并执行总结导入
@@ -231,6 +233,7 @@ class SummaryImporter:
stream_id: 聊天流 ID
context_length: 总结的历史消息条数
include_personality: 是否包含人设
time_end: 用于截取聊天记录的时间上界(闭区间)
Returns:
Tuple[bool, str]: (是否成功, 结果消息)
@@ -248,12 +251,13 @@ class SummaryImporter:
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
# 2. 获取历史消息
# 获取当前时间之前的消息
now = time.time()
messages = message_api.get_messages_before_time_in_chat(
query_time_end = time.time() if time_end is None else float(time_end)
messages = message_api.get_messages_by_time_in_chat(
chat_id=stream_id,
timestamp=now,
limit=context_length
start_time=0.0,
end_time=query_time_end,
limit=context_length,
limit_mode="latest",
)
if not messages:

View File

@@ -282,6 +282,10 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
personality = _as_dict(data.get("personality"))
visual = _as_dict(data.get("visual"))
if visual is None and personality is not None and "visual_style" in personality:
visual = {}
data["visual"] = visual
if visual is not None and personality is not None and "visual_style" in personality:
if "visual_style" not in visual:
visual["visual_style"] = personality["visual_style"]
@@ -289,15 +293,6 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
migrated_any = True
reasons.append("personality.visual_style_moved_to_visual.visual_style")
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
multimodal_planner = visual.pop("multimodal_planner")
if isinstance(multimodal_planner, bool):
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
migrated_any = True
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
else:
visual["multimodal_planner"] = multimodal_planner
memory = _as_dict(data.get("memory"))
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
migrated_any = True

View File

@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
__ui_label__ = "视觉"
__ui_icon__ = "image"
multimodal_planner: bool = Field(
default=True,
planner_mode: Literal["text", "multimodal", "auto"] = Field(
default="text",
json_schema_extra={
"x-widget": "switch",
"x-widget": "select",
"x-icon": "image",
},
)
"""是否直接输入图片"""
"""Planner 视觉模式text 仅文本multimodal 强制多模态auto 按模型能力自动选择"""
multimodal_replyer: bool = Field(
default=False,
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
default="text",
json_schema_extra={
"x-widget": "switch",
"x-widget": "select",
"x-icon": "git-branch",
},
)
"""是否启用 Maisaka 多模态 replyer 生成器"""
"""Replyer 视觉模式text 仅文本multimodal 强制多模态auto 按模型能力自动选择"""
visual_style: str = Field(
default="请用中文描述这张图片的内容。如果有文字请把文字描述概括出来请留意其主题直观感受输出为一段平文本最多30字请注意不要分点就输出一段文本",
@@ -424,6 +424,36 @@ class MemoryConfig(ConfigBase):
)
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
chat_summary_writeback_enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "scroll-text",
},
)
"""是否在 Maisaka 聊天过程中按消息窗口自动写回聊天摘要到长期记忆"""
chat_summary_writeback_message_threshold: int = Field(
default=12,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "messages-square",
},
)
"""自动写回聊天摘要的消息窗口阈值"""
chat_summary_writeback_context_length: int = Field(
default=50,
ge=1,
le=500,
json_schema_extra={
"x-widget": "input",
"x-icon": "rows-3",
},
)
"""自动写回聊天摘要时,从聊天流中回看的消息条数"""
feedback_correction_enabled: bool = Field(
default=False,
json_schema_extra={

View File

@@ -1,16 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List, Optional
import asyncio
import json
from typing import Any, List, Optional
import time
from json_repair import repair_json
from src.chat.utils.utils import is_bot_self
from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
from src.services.memory_service import memory_service
from src.services.llm_service import LLMServiceClient
logger = get_logger("memory_flow_service")
@@ -210,27 +215,192 @@ class PersonFactWritebackService:
return False
@dataclass
class ChatSummaryWritebackState:
last_trigger_message_count: int = 0
last_trigger_time: float = 0.0
class ChatSummaryWritebackService:
def __init__(self) -> None:
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
self._worker_task: Optional[asyncio.Task] = None
self._stopping = False
self._states: dict[str, ChatSummaryWritebackState] = {}
async def start(self) -> None:
if self._worker_task is not None and not self._worker_task.done():
return
self._stopping = False
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
async def shutdown(self) -> None:
self._stopping = True
worker = self._worker_task
self._worker_task = None
if worker is None:
return
worker.cancel()
try:
await worker
except asyncio.CancelledError:
pass
except Exception as exc:
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
async def enqueue(self, message: Any) -> None:
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
return
if self._stopping:
return
try:
self._queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning("聊天摘要写回队列已满,跳过本次触发")
async def _worker_loop(self) -> None:
try:
while not self._stopping:
message = await self._queue.get()
try:
await self._handle_message(message)
except Exception as exc:
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
finally:
self._queue.task_done()
except asyncio.CancelledError:
raise
async def _handle_message(self, message: Any) -> None:
session_id = self._resolve_session_id(message)
if not session_id:
return
total_message_count = count_messages(session_id=session_id)
if total_message_count <= 0:
return
threshold = self._message_threshold()
state = self._states.setdefault(session_id, ChatSummaryWritebackState())
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
if pending_message_count < threshold:
return
context_length = self._context_length()
message_time = self._extract_message_timestamp(message)
result = await memory_service.ingest_summary(
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
chat_id=session_id,
text="",
participants=[],
time_end=message_time,
metadata={
"generate_from_chat": True,
"context_length": context_length,
"writeback_source": "memory_flow_service",
"trigger": "message_threshold",
"trigger_message_count": total_message_count,
},
respect_filter=True,
user_id=self._extract_session_user_id(message),
group_id=self._extract_session_group_id(message),
)
if not getattr(result, "success", False):
logger.warning(
"聊天摘要自动写回失败: session_id=%s detail=%s",
session_id,
getattr(result, "detail", ""),
)
return
state.last_trigger_message_count = total_message_count
state.last_trigger_time = time.time()
logger.info(
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
session_id,
"message_threshold",
total_message_count,
context_length,
getattr(result, "detail", ""),
)
@staticmethod
def _resolve_session_id(message: Any) -> str:
return str(
getattr(message, "session_id", "")
or getattr(getattr(message, "session", None), "session_id", "")
or ""
).strip()
@staticmethod
def _extract_session_user_id(message: Any) -> str:
return str(
getattr(getattr(message, "session", None), "user_id", "")
or getattr(message, "user_id", "")
or ""
).strip()
@staticmethod
def _extract_session_group_id(message: Any) -> str:
return str(
getattr(getattr(message, "session", None), "group_id", "")
or getattr(message, "group_id", "")
or ""
).strip()
@staticmethod
def _extract_message_timestamp(message: Any) -> float | None:
raw_timestamp = getattr(message, "timestamp", None)
if isinstance(raw_timestamp, datetime):
return raw_timestamp.timestamp()
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
try:
return float(raw_timestamp.timestamp())
except Exception:
return None
if isinstance(raw_timestamp, (int, float)):
return float(raw_timestamp)
return None
@staticmethod
def _message_threshold() -> int:
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
@staticmethod
def _context_length() -> int:
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
class MemoryAutomationService:
def __init__(self) -> None:
self.fact_writeback = PersonFactWritebackService()
self.chat_summary_writeback = ChatSummaryWritebackService()
self._started = False
async def start(self) -> None:
if self._started:
return
await self.fact_writeback.start()
await self.chat_summary_writeback.start()
self._started = True
async def shutdown(self) -> None:
if not self._started:
return
await self.chat_summary_writeback.shutdown()
await self.fact_writeback.shutdown()
self._started = False
async def on_incoming_message(self, message: Any) -> None:
del message
if not self._started:
await self.start()
async def on_message_sent(self, message: Any) -> None:
if not self._started:
await self.start()
await self.fact_writeback.enqueue(message)
await self.chat_summary_writeback.enqueue(message)
memory_automation_service = MemoryAutomationService()