fix:完善Maisaka记忆写回链路

补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。
This commit is contained in:
A-Dawn
2026-04-16 19:04:08 +08:00
parent 7ed5630583
commit 459927e7c0
13 changed files with 918 additions and 65 deletions

View File

@@ -1,16 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any, List, Optional
import asyncio
import json
from typing import Any, List, Optional
import time
from json_repair import repair_json
from src.chat.utils.utils import is_bot_self
from src.common.logger import get_logger
from src.common.message_repository import find_messages
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
from src.services.memory_service import memory_service
from src.services.llm_service import LLMServiceClient
logger = get_logger("memory_flow_service")
@@ -210,27 +215,192 @@ class PersonFactWritebackService:
return False
@dataclass
class ChatSummaryWritebackState:
last_trigger_message_count: int = 0
last_trigger_time: float = 0.0
class ChatSummaryWritebackService:
def __init__(self) -> None:
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
self._worker_task: Optional[asyncio.Task] = None
self._stopping = False
self._states: dict[str, ChatSummaryWritebackState] = {}
async def start(self) -> None:
if self._worker_task is not None and not self._worker_task.done():
return
self._stopping = False
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
async def shutdown(self) -> None:
self._stopping = True
worker = self._worker_task
self._worker_task = None
if worker is None:
return
worker.cancel()
try:
await worker
except asyncio.CancelledError:
pass
except Exception as exc:
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
async def enqueue(self, message: Any) -> None:
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
return
if self._stopping:
return
try:
self._queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning("聊天摘要写回队列已满,跳过本次触发")
async def _worker_loop(self) -> None:
try:
while not self._stopping:
message = await self._queue.get()
try:
await self._handle_message(message)
except Exception as exc:
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
finally:
self._queue.task_done()
except asyncio.CancelledError:
raise
async def _handle_message(self, message: Any) -> None:
session_id = self._resolve_session_id(message)
if not session_id:
return
total_message_count = count_messages(session_id=session_id)
if total_message_count <= 0:
return
threshold = self._message_threshold()
state = self._states.setdefault(session_id, ChatSummaryWritebackState())
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
if pending_message_count < threshold:
return
context_length = self._context_length()
message_time = self._extract_message_timestamp(message)
result = await memory_service.ingest_summary(
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
chat_id=session_id,
text="",
participants=[],
time_end=message_time,
metadata={
"generate_from_chat": True,
"context_length": context_length,
"writeback_source": "memory_flow_service",
"trigger": "message_threshold",
"trigger_message_count": total_message_count,
},
respect_filter=True,
user_id=self._extract_session_user_id(message),
group_id=self._extract_session_group_id(message),
)
if not getattr(result, "success", False):
logger.warning(
"聊天摘要自动写回失败: session_id=%s detail=%s",
session_id,
getattr(result, "detail", ""),
)
return
state.last_trigger_message_count = total_message_count
state.last_trigger_time = time.time()
logger.info(
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
session_id,
"message_threshold",
total_message_count,
context_length,
getattr(result, "detail", ""),
)
@staticmethod
def _resolve_session_id(message: Any) -> str:
return str(
getattr(message, "session_id", "")
or getattr(getattr(message, "session", None), "session_id", "")
or ""
).strip()
@staticmethod
def _extract_session_user_id(message: Any) -> str:
return str(
getattr(getattr(message, "session", None), "user_id", "")
or getattr(message, "user_id", "")
or ""
).strip()
@staticmethod
def _extract_session_group_id(message: Any) -> str:
return str(
getattr(getattr(message, "session", None), "group_id", "")
or getattr(message, "group_id", "")
or ""
).strip()
@staticmethod
def _extract_message_timestamp(message: Any) -> float | None:
raw_timestamp = getattr(message, "timestamp", None)
if isinstance(raw_timestamp, datetime):
return raw_timestamp.timestamp()
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
try:
return float(raw_timestamp.timestamp())
except Exception:
return None
if isinstance(raw_timestamp, (int, float)):
return float(raw_timestamp)
return None
@staticmethod
def _message_threshold() -> int:
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
@staticmethod
def _context_length() -> int:
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
class MemoryAutomationService:
def __init__(self) -> None:
self.fact_writeback = PersonFactWritebackService()
self.chat_summary_writeback = ChatSummaryWritebackService()
self._started = False
async def start(self) -> None:
if self._started:
return
await self.fact_writeback.start()
await self.chat_summary_writeback.start()
self._started = True
async def shutdown(self) -> None:
if not self._started:
return
await self.chat_summary_writeback.shutdown()
await self.fact_writeback.shutdown()
self._started = False
async def on_incoming_message(self, message: Any) -> None:
del message
if not self._started:
await self.start()
async def on_message_sent(self, message: Any) -> None:
if not self._started:
await self.start()
await self.fact_writeback.enqueue(message)
await self.chat_summary_writeback.enqueue(message)
memory_automation_service = MemoryAutomationService()