fix:完善Maisaka记忆写回链路
补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -371,3 +371,4 @@ packages/
|
||||
.claude/
|
||||
.omc/
|
||||
/.venv312
|
||||
/src/A_memorix/algorithm_redesign
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Session, create_engine
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
IMPORT_ERROR: str | None = None
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
from src.A_memorix.core.utils import summary_importer as summary_importer_module
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.message_repository import count_messages
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.services import memory_flow_service as memory_flow_service_module
|
||||
from src.services import memory_service as memory_service_module
|
||||
from src.services import send_service
|
||||
except SystemExit as exc:
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
kernel_module = None # type: ignore[assignment]
|
||||
SDKMemoryKernel = None # type: ignore[assignment]
|
||||
summary_importer_module = None # type: ignore[assignment]
|
||||
BotChatSession = None # type: ignore[assignment]
|
||||
SessionMessage = None # type: ignore[assignment]
|
||||
MessageInfo = None # type: ignore[assignment]
|
||||
UserInfo = None # type: ignore[assignment]
|
||||
MessageSequence = None # type: ignore[assignment]
|
||||
TextComponent = None # type: ignore[assignment]
|
||||
database_module = None # type: ignore[assignment]
|
||||
create_database_migration_bootstrapper = None # type: ignore[assignment]
|
||||
count_messages = None # type: ignore[assignment]
|
||||
TaskConfig = None # type: ignore[assignment]
|
||||
memory_flow_service_module = None # type: ignore[assignment]
|
||||
memory_service_module = None # type: ignore[assignment]
|
||||
send_service = None # type: ignore[assignment]
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 8) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any) -> np.ndarray:
|
||||
def _encode_one(raw: Any) -> np.ndarray:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
|
||||
class _KernelBackedRuntimeManager:
|
||||
def __init__(self, kernel: SDKMemoryKernel) -> None:
|
||||
self.kernel = kernel
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None,
|
||||
*,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
handler = getattr(self.kernel, component_name)
|
||||
result = handler(**payload)
|
||||
return await result if inspect.isawaitable(result) else result
|
||||
|
||||
|
||||
class _NoopRuntimeManager:
|
||||
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
|
||||
del hook_name
|
||||
return SimpleNamespace(aborted=False, kwargs=kwargs)
|
||||
|
||||
|
||||
class _FakePlatformIOManager:
|
||||
def __init__(self) -> None:
|
||||
self.ensure_calls = 0
|
||||
|
||||
async def ensure_send_pipeline_ready(self) -> None:
|
||||
self.ensure_calls += 1
|
||||
|
||||
def build_route_key_from_message(self, message: Any) -> Any:
|
||||
del message
|
||||
return SimpleNamespace(platform="qq")
|
||||
|
||||
async def send_message(self, message: Any, route_key: Any, metadata: Dict[str, Any]) -> Any:
|
||||
del message, metadata
|
||||
return SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
external_message_id="real-message-id",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
failed_receipts=[],
|
||||
route_key=route_key,
|
||||
)
|
||||
|
||||
|
||||
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
db_dir = (tmp_path / "main_db").resolve()
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
db_file = db_dir / "MaiBot.db"
|
||||
database_url = f"sqlite:///{db_file}"
|
||||
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
engine = create_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
session_local = sessionmaker(
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
|
||||
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
|
||||
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
|
||||
monkeypatch.setattr(database_module, "engine", engine, raising=False)
|
||||
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
|
||||
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
|
||||
|
||||
|
||||
def _build_incoming_message(
|
||||
*,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
text: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> SessionMessage:
|
||||
message = SessionMessage(
|
||||
message_id="incoming-message-id",
|
||||
timestamp=timestamp or datetime.now(),
|
||||
platform="qq",
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=UserInfo(
|
||||
user_id=user_id,
|
||||
user_nickname="测试用户",
|
||||
user_cardname="测试用户",
|
||||
),
|
||||
additional_config={},
|
||||
)
|
||||
message.raw_message = MessageSequence(components=[TextComponent(text=text)])
|
||||
message.session_id = session_id
|
||||
message.reply_to = None
|
||||
message.is_mentioned = False
|
||||
message.is_at = False
|
||||
message.is_emoji = False
|
||||
message.is_picture = False
|
||||
message.is_command = False
|
||||
message.is_notify = False
|
||||
message.processed_plain_text = text
|
||||
message.display_message = text
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
async def _wait_until(
|
||||
predicate: Callable[[], Any],
|
||||
*,
|
||||
timeout_seconds: float = 10.0,
|
||||
interval_seconds: float = 0.05,
|
||||
description: str,
|
||||
) -> Any:
|
||||
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
value = predicate()
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
if value:
|
||||
return value
|
||||
await asyncio.sleep(interval_seconds)
|
||||
raise AssertionError(f"等待超时: {description}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_triggers_real_chat_summary_writeback(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
_install_temp_main_database(monkeypatch, tmp_path)
|
||||
|
||||
fake_embedding_manager = _FakeEmbeddingManager()
|
||||
captured_prompts: List[str] = []
|
||||
fixed_send_timestamp = 1_777_000_000.0
|
||||
|
||||
async def _fake_runtime_self_check(**kwargs: Any) -> Dict[str, Any]:
|
||||
del kwargs
|
||||
return {
|
||||
"ok": True,
|
||||
"message": "ok",
|
||||
"configured_dimension": fake_embedding_manager.default_dimension,
|
||||
"requested_dimension": fake_embedding_manager.default_dimension,
|
||||
"vector_store_dimension": fake_embedding_manager.default_dimension,
|
||||
"detected_dimension": fake_embedding_manager.default_dimension,
|
||||
"encoded_dimension": fake_embedding_manager.default_dimension,
|
||||
"elapsed_ms": 0.0,
|
||||
"sample_text": "test",
|
||||
"checked_at": datetime.now().timestamp(),
|
||||
}
|
||||
|
||||
async def _fake_generate(request: Any) -> Any:
|
||||
captured_prompts.append(str(getattr(request, "prompt", "") or ""))
|
||||
return SimpleNamespace(
|
||||
success=True,
|
||||
completion=SimpleNamespace(
|
||||
response=json.dumps(
|
||||
{
|
||||
"summary": "这段对话记录了用户提到自己买了绿色围巾,机器人表示会记住这件事。",
|
||||
"entities": ["绿色围巾"],
|
||||
"relations": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: fake_embedding_manager,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
kernel_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module,
|
||||
"run_embedding_runtime_self_check",
|
||||
_fake_runtime_self_check,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"get_available_models",
|
||||
lambda: {"utils": TaskConfig(model_list=["fake-summary-model"])},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"resolve_task_name_from_model_config",
|
||||
lambda model_config: "utils",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
summary_importer_module.llm_api,
|
||||
"generate",
|
||||
_fake_generate,
|
||||
)
|
||||
monkeypatch.setattr(send_service.time, "time", lambda: fixed_send_timestamp)
|
||||
monkeypatch.setattr(summary_importer_module.time, "time", lambda: fixed_send_timestamp)
|
||||
|
||||
kernel = SDKMemoryKernel(
|
||||
plugin_root=tmp_path / "plugin_root",
|
||||
config={
|
||||
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
|
||||
"advanced": {"enable_auto_save": False},
|
||||
"embedding": {"dimension": fake_embedding_manager.default_dimension},
|
||||
"memory": {"base_decay_interval_hours": 24},
|
||||
"person_profile": {"refresh_interval_minutes": 5},
|
||||
"summarization": {"model_name": ["utils"]},
|
||||
},
|
||||
)
|
||||
|
||||
service = memory_flow_service_module.MemoryAutomationService()
|
||||
fake_platform_io_manager = _FakePlatformIOManager()
|
||||
|
||||
async def _fake_rebuild_episodes_for_sources(sources: List[str]) -> Dict[str, Any]:
|
||||
return {
|
||||
"rebuilt": 0,
|
||||
"items": [],
|
||||
"failures": [],
|
||||
"sources": list(sources),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(kernel, "rebuild_episodes_for_sources", _fake_rebuild_episodes_for_sources)
|
||||
monkeypatch.setattr(
|
||||
memory_service_module,
|
||||
"a_memorix_host_service",
|
||||
_KernelBackedRuntimeManager(kernel),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_service_module, "memory_automation_service", service)
|
||||
monkeypatch.setattr(send_service, "_get_runtime_manager", lambda: _NoopRuntimeManager())
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_platform_io_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: (
|
||||
BotChatSession(
|
||||
session_id="test-session",
|
||||
platform="qq",
|
||||
user_id="target-user",
|
||||
group_id=None,
|
||||
)
|
||||
if stream_id == "test-session"
|
||||
else None
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"chat_summary_writeback_enabled",
|
||||
True,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"chat_summary_writeback_message_threshold",
|
||||
2,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"chat_summary_writeback_context_length",
|
||||
10,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_flow_service_module.global_config.memory,
|
||||
"person_fact_writeback_enabled",
|
||||
False,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
await kernel.initialize()
|
||||
|
||||
try:
|
||||
incoming_message = _build_incoming_message(
|
||||
session_id="test-session",
|
||||
user_id="target-user",
|
||||
text="我最近买了一条绿色围巾。",
|
||||
timestamp=datetime.fromtimestamp(fixed_send_timestamp) - timedelta(seconds=1),
|
||||
)
|
||||
with database_module.get_db_session() as session:
|
||||
session.add(incoming_message.to_db_instance())
|
||||
|
||||
sent_message = await send_service.text_to_stream_with_message(
|
||||
text="好的,我会记住你最近买了绿色围巾。",
|
||||
stream_id="test-session",
|
||||
storage_message=True,
|
||||
)
|
||||
|
||||
assert sent_message is not None
|
||||
assert sent_message.message_id == "real-message-id"
|
||||
assert fake_platform_io_manager.ensure_calls == 1
|
||||
assert count_messages(session_id="test-session") == 2
|
||||
|
||||
paragraphs = await _wait_until(
|
||||
lambda: kernel.metadata_store.get_paragraphs_by_source("chat_summary:test-session"),
|
||||
description="等待聊天摘要写回到 A_memorix",
|
||||
)
|
||||
|
||||
assert captured_prompts
|
||||
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
|
||||
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
|
||||
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
|
||||
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
|
||||
finally:
|
||||
await service.shutdown()
|
||||
await kernel.shutdown()
|
||||
try:
|
||||
database_module.engine.dispose()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -455,10 +455,15 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
|
||||
async def _fake_timing_gate(self, anchor_message: Any):
|
||||
del self, anchor_message
|
||||
return "continue", _build_chat_response("直接进入 planner。", []), []
|
||||
return "continue", _build_chat_response("直接进入 planner。", []), [], []
|
||||
|
||||
async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse:
|
||||
del tool_definitions
|
||||
async def _fake_planner(
|
||||
self,
|
||||
*,
|
||||
injected_user_messages: list[str] | None = None,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> ChatResponse:
|
||||
del injected_user_messages, tool_definitions
|
||||
latest_message = self._runtime.message_cache[-1]
|
||||
latest_text = str(latest_message.processed_plain_text or "")
|
||||
planner_calls.append(latest_text)
|
||||
|
||||
64
pytests/A_memorix_test/test_graph_store_persistence.py
Normal file
64
pytests/A_memorix_test/test_graph_store_persistence.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from src.A_memorix.core.storage.graph_store import GraphStore
|
||||
except SystemExit as exc:
|
||||
GraphStore = None # type: ignore[assignment]
|
||||
IMPORT_ERROR = f"config initialization exited during import: {exc}"
|
||||
else:
|
||||
IMPORT_ERROR = None
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
|
||||
|
||||
|
||||
def _build_empty_graph_metadata() -> dict:
|
||||
return {
|
||||
"nodes": [],
|
||||
"node_to_idx": {},
|
||||
"node_attrs": {},
|
||||
"matrix_format": "csr",
|
||||
"total_nodes_added": 0,
|
||||
"total_edges_added": 0,
|
||||
"total_nodes_deleted": 0,
|
||||
"total_edges_deleted": 0,
|
||||
"edge_hash_map": {},
|
||||
}
|
||||
|
||||
|
||||
def test_graph_store_clear_save_removes_stale_adjacency(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
assert matrix_path.exists()
|
||||
|
||||
store.clear()
|
||||
store.save()
|
||||
|
||||
assert not matrix_path.exists()
|
||||
|
||||
|
||||
def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path: Path) -> None:
|
||||
data_dir = tmp_path / "graph_data"
|
||||
store = GraphStore(data_dir=data_dir)
|
||||
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||
store.save()
|
||||
|
||||
metadata_path = data_dir / "graph_metadata.pkl"
|
||||
with metadata_path.open("wb") as handle:
|
||||
pickle.dump(_build_empty_graph_metadata(), handle)
|
||||
|
||||
reloaded = GraphStore(data_dir=data_dir)
|
||||
reloaded.load()
|
||||
|
||||
assert reloaded.num_nodes == 0
|
||||
assert reloaded.num_edges == 0
|
||||
assert reloaded.get_nodes() == []
|
||||
@@ -33,34 +33,3 @@ def test_legacy_learning_list_with_numeric_fourth_column_is_migrated():
|
||||
"enable_jargon_learning": False,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_visual_multimodal_replyer_is_migrated_to_replyer_mode() -> None:
|
||||
payload = {
|
||||
"visual": {
|
||||
"multimodal_replyer": True,
|
||||
}
|
||||
}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert "visual.multimodal_replyer_moved_to_visual.replyer_mode" in result.reason
|
||||
assert result.data["visual"]["replyer_mode"] == "multimodal"
|
||||
assert "multimodal_replyer" not in result.data["visual"]
|
||||
|
||||
|
||||
def test_chat_replyer_generator_type_is_migrated_to_replyer_mode() -> None:
|
||||
payload = {
|
||||
"chat": {
|
||||
"replyer_generator_type": "legacy",
|
||||
},
|
||||
"visual": {},
|
||||
}
|
||||
|
||||
result = try_migrate_legacy_bot_config_dict(payload)
|
||||
|
||||
assert result.migrated is True
|
||||
assert "chat.replyer_generator_type_moved_to_visual.replyer_mode" in result.reason
|
||||
assert result.data["visual"]["replyer_mode"] == "text"
|
||||
assert "replyer_generator_type" not in result.data["chat"]
|
||||
|
||||
@@ -38,6 +38,78 @@ def test_person_fact_resolve_target_person_for_private_chat(monkeypatch):
|
||||
assert person.person_id == "qq:123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_triggers_when_threshold_reached(monkeypatch):
|
||||
events: list[tuple[str, object]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=3,
|
||||
chat_summary_writeback_context_length=7,
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
events.append(("ingest_summary", kwargs))
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert len(events) == 1
|
||||
_, payload = events[0]
|
||||
assert payload["external_id"] == "chat_auto_summary:session-1:5"
|
||||
assert payload["chat_id"] == "session-1"
|
||||
assert payload["text"] == ""
|
||||
assert payload["metadata"]["generate_from_chat"] is True
|
||||
assert payload["metadata"]["context_length"] == 7
|
||||
assert payload["metadata"]["trigger"] == "message_threshold"
|
||||
assert payload["user_id"] == "user-1"
|
||||
assert payload["group_id"] == "group-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(monkeypatch):
|
||||
called = False
|
||||
|
||||
monkeypatch.setattr(
|
||||
memory_flow_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
memory=SimpleNamespace(
|
||||
chat_summary_writeback_enabled=True,
|
||||
chat_summary_writeback_message_threshold=6,
|
||||
chat_summary_writeback_context_length=9,
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||
|
||||
async def fake_ingest_summary(**kwargs):
|
||||
nonlocal called
|
||||
called = True
|
||||
return SimpleNamespace(success=True, detail="ok")
|
||||
|
||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||
|
||||
service = memory_flow_module.ChatSummaryWritebackService()
|
||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||
|
||||
await service._handle_message(message)
|
||||
|
||||
assert called is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
events: list[tuple[str, str]] = []
|
||||
@@ -52,14 +124,67 @@ async def test_memory_automation_service_auto_starts_and_delegates():
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
class FakeChatSummaryWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "summary"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("summary", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "summary"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
||||
|
||||
await service.on_message_sent(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("start", "summary"),
|
||||
("sent", "session-1"),
|
||||
("summary", "session-1"),
|
||||
("shutdown", "summary"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_automation_service_on_incoming_message_auto_starts_only():
|
||||
events: list[tuple[str, str]] = []
|
||||
|
||||
class FakeFactWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "fact"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("sent", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "fact"))
|
||||
|
||||
class FakeChatSummaryWriteback:
|
||||
async def start(self):
|
||||
events.append(("start", "summary"))
|
||||
|
||||
async def enqueue(self, message):
|
||||
events.append(("summary", message.session_id))
|
||||
|
||||
async def shutdown(self):
|
||||
events.append(("shutdown", "summary"))
|
||||
|
||||
service = memory_flow_module.MemoryAutomationService()
|
||||
service.fact_writeback = FakeFactWriteback()
|
||||
service.chat_summary_writeback = FakeChatSummaryWriteback()
|
||||
|
||||
await service.on_incoming_message(SimpleNamespace(session_id="session-1"))
|
||||
await service.shutdown()
|
||||
|
||||
assert events == [
|
||||
("start", "fact"),
|
||||
("start", "summary"),
|
||||
("shutdown", "summary"),
|
||||
("shutdown", "fact"),
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""发送服务回归测试。"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -182,6 +183,75 @@ async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pyt
|
||||
assert stored_messages[0].message_id == "real-message-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_with_message_triggers_memory_and_syncs_maisaka_history(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
delivery_batch=SimpleNamespace(
|
||||
has_success=True,
|
||||
sent_receipts=[
|
||||
SimpleNamespace(
|
||||
driver_id="plugin.qq.sender",
|
||||
external_message_id="real-message-id",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
failed_receipts=[],
|
||||
route_key=SimpleNamespace(platform="qq"),
|
||||
)
|
||||
)
|
||||
stored_messages: List[Any] = []
|
||||
memory_events: List[str] = []
|
||||
history_events: List[tuple[str, str]] = []
|
||||
|
||||
class FakeMemoryAutomationService:
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
memory_events.append(str(message.message_id))
|
||||
|
||||
class FakeRuntime:
|
||||
def append_sent_message_to_chat_history(self, message: Any, *, source_kind: str = "guided_reply") -> None:
|
||||
history_events.append((str(message.message_id), source_kind))
|
||||
|
||||
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
|
||||
monkeypatch.setattr(
|
||||
send_service._chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
send_service.MessageUtils,
|
||||
"store_message_to_db",
|
||||
lambda message: stored_messages.append(message),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"src.services.memory_flow_service",
|
||||
SimpleNamespace(memory_automation_service=FakeMemoryAutomationService()),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"src.chat.heart_flow.heartflow_manager",
|
||||
SimpleNamespace(heartflow_manager=SimpleNamespace(heartflow_chat_list={"test-session": FakeRuntime()})),
|
||||
)
|
||||
|
||||
sent_message = await send_service.text_to_stream_with_message(
|
||||
text="你好",
|
||||
stream_id="test-session",
|
||||
sync_to_maisaka_history=True,
|
||||
maisaka_source_kind="guided_reply",
|
||||
)
|
||||
|
||||
assert sent_message is not None
|
||||
assert sent_message.message_id == "real-message-id"
|
||||
assert fake_manager.ensure_calls == 1
|
||||
assert len(stored_messages) == 1
|
||||
assert stored_messages[0].message_id == "real-message-id"
|
||||
assert memory_events == ["real-message-id"]
|
||||
assert history_events == [("real-message-id", "guided_reply")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
fake_manager = _FakePlatformIOManager(
|
||||
|
||||
@@ -802,6 +802,7 @@ class SDKMemoryKernel:
|
||||
chat_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None,
|
||||
time_end: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
await self.initialize()
|
||||
assert self.summary_importer
|
||||
@@ -809,6 +810,7 @@ class SDKMemoryKernel:
|
||||
stream_id=str(chat_id or "").strip(),
|
||||
context_length=context_length,
|
||||
include_personality=include_personality,
|
||||
time_end=time_end,
|
||||
)
|
||||
if success:
|
||||
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
|
||||
@@ -851,6 +853,7 @@ class SDKMemoryKernel:
|
||||
chat_id=chat_id,
|
||||
context_length=self._optional_int(summary_meta.get("context_length")),
|
||||
include_personality=summary_meta.get("include_personality"),
|
||||
time_end=time_end,
|
||||
)
|
||||
result.setdefault("external_id", external_id)
|
||||
result.setdefault("chat_id", chat_id)
|
||||
|
||||
@@ -1190,11 +1190,14 @@ class GraphStore:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存邻接矩阵
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
if self._adjacency is not None:
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
with atomic_write(matrix_path, "wb") as f:
|
||||
save_npz(f, self._adjacency)
|
||||
logger.debug(f"保存邻接矩阵: {matrix_path}")
|
||||
elif matrix_path.exists():
|
||||
matrix_path.unlink()
|
||||
logger.debug(f"删除陈旧邻接矩阵: {matrix_path}")
|
||||
|
||||
# 保存元数据
|
||||
metadata = {
|
||||
@@ -1288,9 +1291,20 @@ class GraphStore:
|
||||
if self._adjacency is not None:
|
||||
adj_n = self._adjacency.shape[0]
|
||||
current_n = len(self._nodes)
|
||||
if current_n > adj_n:
|
||||
if current_n == 0:
|
||||
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
|
||||
self._adjacency = None
|
||||
elif current_n > adj_n:
|
||||
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
|
||||
self._expand_adjacency_matrix(current_n - adj_n)
|
||||
elif current_n < adj_n:
|
||||
logger.warning(
|
||||
f"检测到过期邻接矩阵: 节点数={current_n}, 矩阵大小={adj_n}. 正在重置邻接矩阵..."
|
||||
)
|
||||
if self.matrix_format == "csc":
|
||||
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
|
||||
else:
|
||||
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
|
||||
|
||||
self._adjacency_dirty = True
|
||||
logger.info(
|
||||
@@ -1445,4 +1459,3 @@ class GraphStore:
|
||||
self._adjacency_dirty = True
|
||||
logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边")
|
||||
return count
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
@@ -222,7 +223,8 @@ class SummaryImporter:
|
||||
self,
|
||||
stream_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None
|
||||
include_personality: Optional[bool] = None,
|
||||
time_end: Optional[float] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从指定的聊天流中提取记录并执行总结导入
|
||||
@@ -231,6 +233,7 @@ class SummaryImporter:
|
||||
stream_id: 聊天流 ID
|
||||
context_length: 总结的历史消息条数
|
||||
include_personality: 是否包含人设
|
||||
time_end: 用于截取聊天记录的时间上界(闭区间)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 结果消息)
|
||||
@@ -248,12 +251,13 @@ class SummaryImporter:
|
||||
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
|
||||
|
||||
# 2. 获取历史消息
|
||||
# 获取当前时间之前的消息
|
||||
now = time.time()
|
||||
messages = message_api.get_messages_before_time_in_chat(
|
||||
query_time_end = time.time() if time_end is None else float(time_end)
|
||||
messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=now,
|
||||
limit=context_length
|
||||
start_time=0.0,
|
||||
end_time=query_time_end,
|
||||
limit=context_length,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
if not messages:
|
||||
|
||||
@@ -282,6 +282,10 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
|
||||
personality = _as_dict(data.get("personality"))
|
||||
visual = _as_dict(data.get("visual"))
|
||||
if visual is None and personality is not None and "visual_style" in personality:
|
||||
visual = {}
|
||||
data["visual"] = visual
|
||||
|
||||
if visual is not None and personality is not None and "visual_style" in personality:
|
||||
if "visual_style" not in visual:
|
||||
visual["visual_style"] = personality["visual_style"]
|
||||
@@ -289,15 +293,6 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("personality.visual_style_moved_to_visual.visual_style")
|
||||
|
||||
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
|
||||
multimodal_planner = visual.pop("multimodal_planner")
|
||||
if isinstance(multimodal_planner, bool):
|
||||
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
|
||||
migrated_any = True
|
||||
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
|
||||
else:
|
||||
visual["multimodal_planner"] = multimodal_planner
|
||||
|
||||
memory = _as_dict(data.get("memory"))
|
||||
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
|
||||
@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
|
||||
__ui_label__ = "视觉"
|
||||
__ui_icon__ = "image"
|
||||
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="text",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
"""Planner 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||
|
||||
multimodal_replyer: bool = Field(
|
||||
default=False,
|
||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="text",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 多模态 replyer 生成器"""
|
||||
"""Replyer 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
@@ -424,6 +424,36 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||
|
||||
chat_summary_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "scroll-text",
|
||||
},
|
||||
)
|
||||
"""是否在 Maisaka 聊天过程中按消息窗口自动写回聊天摘要到长期记忆"""
|
||||
|
||||
chat_summary_writeback_message_threshold: int = Field(
|
||||
default=12,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "messages-square",
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要的消息窗口阈值"""
|
||||
|
||||
chat_summary_writeback_context_length: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=500,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "rows-3",
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要时,从聊天流中回看的消息条数"""
|
||||
|
||||
feedback_correction_enabled: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, List, Optional
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||
from src.services.memory_service import memory_service
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("memory_flow_service")
|
||||
@@ -210,27 +215,192 @@ class PersonFactWritebackService:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSummaryWritebackState:
|
||||
last_trigger_message_count: int = 0
|
||||
last_trigger_time: float = 0.0
|
||||
|
||||
|
||||
class ChatSummaryWritebackService:
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._stopping = False
|
||||
self._states: dict[str, ChatSummaryWritebackState] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._worker_task is not None and not self._worker_task.done():
|
||||
return
|
||||
self._stopping = False
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._stopping = True
|
||||
worker = self._worker_task
|
||||
self._worker_task = None
|
||||
if worker is None:
|
||||
return
|
||||
worker.cancel()
|
||||
try:
|
||||
await worker
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
|
||||
return
|
||||
if self._stopping:
|
||||
return
|
||||
try:
|
||||
self._queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("聊天摘要写回队列已满,跳过本次触发")
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
try:
|
||||
while not self._stopping:
|
||||
message = await self._queue.get()
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
async def _handle_message(self, message: Any) -> None:
|
||||
session_id = self._resolve_session_id(message)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
total_message_count = count_messages(session_id=session_id)
|
||||
if total_message_count <= 0:
|
||||
return
|
||||
|
||||
threshold = self._message_threshold()
|
||||
state = self._states.setdefault(session_id, ChatSummaryWritebackState())
|
||||
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
|
||||
if pending_message_count < threshold:
|
||||
return
|
||||
|
||||
context_length = self._context_length()
|
||||
message_time = self._extract_message_timestamp(message)
|
||||
result = await memory_service.ingest_summary(
|
||||
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
|
||||
chat_id=session_id,
|
||||
text="",
|
||||
participants=[],
|
||||
time_end=message_time,
|
||||
metadata={
|
||||
"generate_from_chat": True,
|
||||
"context_length": context_length,
|
||||
"writeback_source": "memory_flow_service",
|
||||
"trigger": "message_threshold",
|
||||
"trigger_message_count": total_message_count,
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=self._extract_session_user_id(message),
|
||||
group_id=self._extract_session_group_id(message),
|
||||
)
|
||||
if not getattr(result, "success", False):
|
||||
logger.warning(
|
||||
"聊天摘要自动写回失败: session_id=%s detail=%s",
|
||||
session_id,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
return
|
||||
|
||||
state.last_trigger_message_count = total_message_count
|
||||
state.last_trigger_time = time.time()
|
||||
logger.info(
|
||||
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
|
||||
session_id,
|
||||
"message_threshold",
|
||||
total_message_count,
|
||||
context_length,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(message, "session_id", "")
|
||||
or getattr(getattr(message, "session", None), "session_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_user_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "user_id", "")
|
||||
or getattr(message, "user_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_group_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "group_id", "")
|
||||
or getattr(message, "group_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_timestamp(message: Any) -> float | None:
|
||||
raw_timestamp = getattr(message, "timestamp", None)
|
||||
if isinstance(raw_timestamp, datetime):
|
||||
return raw_timestamp.timestamp()
|
||||
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
|
||||
try:
|
||||
return float(raw_timestamp.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(raw_timestamp, (int, float)):
|
||||
return float(raw_timestamp)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _message_threshold() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
|
||||
|
||||
@staticmethod
|
||||
def _context_length() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
|
||||
|
||||
|
||||
class MemoryAutomationService:
|
||||
def __init__(self) -> None:
|
||||
self.fact_writeback = PersonFactWritebackService()
|
||||
self.chat_summary_writeback = ChatSummaryWritebackService()
|
||||
self._started = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
await self.fact_writeback.start()
|
||||
await self.chat_summary_writeback.start()
|
||||
self._started = True
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
await self.chat_summary_writeback.shutdown()
|
||||
await self.fact_writeback.shutdown()
|
||||
self._started = False
|
||||
|
||||
async def on_incoming_message(self, message: Any) -> None:
|
||||
del message
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.fact_writeback.enqueue(message)
|
||||
await self.chat_summary_writeback.enqueue(message)
|
||||
|
||||
|
||||
memory_automation_service = MemoryAutomationService()
|
||||
|
||||
Reference in New Issue
Block a user