fix:完善Maisaka记忆写回链路
补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, List, Optional
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||
from src.services.memory_service import memory_service
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("memory_flow_service")
|
||||
@@ -210,27 +215,192 @@ class PersonFactWritebackService:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSummaryWritebackState:
|
||||
last_trigger_message_count: int = 0
|
||||
last_trigger_time: float = 0.0
|
||||
|
||||
|
||||
class ChatSummaryWritebackService:
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._stopping = False
|
||||
self._states: dict[str, ChatSummaryWritebackState] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._worker_task is not None and not self._worker_task.done():
|
||||
return
|
||||
self._stopping = False
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._stopping = True
|
||||
worker = self._worker_task
|
||||
self._worker_task = None
|
||||
if worker is None:
|
||||
return
|
||||
worker.cancel()
|
||||
try:
|
||||
await worker
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
|
||||
return
|
||||
if self._stopping:
|
||||
return
|
||||
try:
|
||||
self._queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("聊天摘要写回队列已满,跳过本次触发")
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
try:
|
||||
while not self._stopping:
|
||||
message = await self._queue.get()
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
async def _handle_message(self, message: Any) -> None:
|
||||
session_id = self._resolve_session_id(message)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
total_message_count = count_messages(session_id=session_id)
|
||||
if total_message_count <= 0:
|
||||
return
|
||||
|
||||
threshold = self._message_threshold()
|
||||
state = self._states.setdefault(session_id, ChatSummaryWritebackState())
|
||||
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
|
||||
if pending_message_count < threshold:
|
||||
return
|
||||
|
||||
context_length = self._context_length()
|
||||
message_time = self._extract_message_timestamp(message)
|
||||
result = await memory_service.ingest_summary(
|
||||
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
|
||||
chat_id=session_id,
|
||||
text="",
|
||||
participants=[],
|
||||
time_end=message_time,
|
||||
metadata={
|
||||
"generate_from_chat": True,
|
||||
"context_length": context_length,
|
||||
"writeback_source": "memory_flow_service",
|
||||
"trigger": "message_threshold",
|
||||
"trigger_message_count": total_message_count,
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=self._extract_session_user_id(message),
|
||||
group_id=self._extract_session_group_id(message),
|
||||
)
|
||||
if not getattr(result, "success", False):
|
||||
logger.warning(
|
||||
"聊天摘要自动写回失败: session_id=%s detail=%s",
|
||||
session_id,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
return
|
||||
|
||||
state.last_trigger_message_count = total_message_count
|
||||
state.last_trigger_time = time.time()
|
||||
logger.info(
|
||||
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
|
||||
session_id,
|
||||
"message_threshold",
|
||||
total_message_count,
|
||||
context_length,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(message, "session_id", "")
|
||||
or getattr(getattr(message, "session", None), "session_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_user_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "user_id", "")
|
||||
or getattr(message, "user_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_group_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "group_id", "")
|
||||
or getattr(message, "group_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_timestamp(message: Any) -> float | None:
|
||||
raw_timestamp = getattr(message, "timestamp", None)
|
||||
if isinstance(raw_timestamp, datetime):
|
||||
return raw_timestamp.timestamp()
|
||||
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
|
||||
try:
|
||||
return float(raw_timestamp.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(raw_timestamp, (int, float)):
|
||||
return float(raw_timestamp)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _message_threshold() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
|
||||
|
||||
@staticmethod
|
||||
def _context_length() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
|
||||
|
||||
|
||||
class MemoryAutomationService:
|
||||
def __init__(self) -> None:
|
||||
self.fact_writeback = PersonFactWritebackService()
|
||||
self.chat_summary_writeback = ChatSummaryWritebackService()
|
||||
self._started = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
await self.fact_writeback.start()
|
||||
await self.chat_summary_writeback.start()
|
||||
self._started = True
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
await self.chat_summary_writeback.shutdown()
|
||||
await self.fact_writeback.shutdown()
|
||||
self._started = False
|
||||
|
||||
async def on_incoming_message(self, message: Any) -> None:
|
||||
del message
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.fact_writeback.enqueue(message)
|
||||
await self.chat_summary_writeback.enqueue(message)
|
||||
|
||||
|
||||
memory_automation_service = MemoryAutomationService()
|
||||
|
||||
Reference in New Issue
Block a user