fix:完善Maisaka记忆写回链路
补齐聊天摘要自动写回、发送后同步与图存储清理逻辑,对齐 visual 新配置字段并补充相关回归测试,同时忽略 algorithm_redesign 设计目录。
This commit is contained in:
@@ -802,6 +802,7 @@ class SDKMemoryKernel:
|
||||
chat_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None,
|
||||
time_end: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
await self.initialize()
|
||||
assert self.summary_importer
|
||||
@@ -809,6 +810,7 @@ class SDKMemoryKernel:
|
||||
stream_id=str(chat_id or "").strip(),
|
||||
context_length=context_length,
|
||||
include_personality=include_personality,
|
||||
time_end=time_end,
|
||||
)
|
||||
if success:
|
||||
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
|
||||
@@ -851,6 +853,7 @@ class SDKMemoryKernel:
|
||||
chat_id=chat_id,
|
||||
context_length=self._optional_int(summary_meta.get("context_length")),
|
||||
include_personality=summary_meta.get("include_personality"),
|
||||
time_end=time_end,
|
||||
)
|
||||
result.setdefault("external_id", external_id)
|
||||
result.setdefault("chat_id", chat_id)
|
||||
|
||||
@@ -1190,11 +1190,14 @@ class GraphStore:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存邻接矩阵
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
if self._adjacency is not None:
|
||||
matrix_path = data_dir / "graph_adjacency.npz"
|
||||
with atomic_write(matrix_path, "wb") as f:
|
||||
save_npz(f, self._adjacency)
|
||||
logger.debug(f"保存邻接矩阵: {matrix_path}")
|
||||
elif matrix_path.exists():
|
||||
matrix_path.unlink()
|
||||
logger.debug(f"删除陈旧邻接矩阵: {matrix_path}")
|
||||
|
||||
# 保存元数据
|
||||
metadata = {
|
||||
@@ -1288,9 +1291,20 @@ class GraphStore:
|
||||
if self._adjacency is not None:
|
||||
adj_n = self._adjacency.shape[0]
|
||||
current_n = len(self._nodes)
|
||||
if current_n > adj_n:
|
||||
if current_n == 0:
|
||||
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
|
||||
self._adjacency = None
|
||||
elif current_n > adj_n:
|
||||
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
|
||||
self._expand_adjacency_matrix(current_n - adj_n)
|
||||
elif current_n < adj_n:
|
||||
logger.warning(
|
||||
f"检测到过期邻接矩阵: 节点数={current_n}, 矩阵大小={adj_n}. 正在重置邻接矩阵..."
|
||||
)
|
||||
if self.matrix_format == "csc":
|
||||
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
|
||||
else:
|
||||
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
|
||||
|
||||
self._adjacency_dirty = True
|
||||
logger.info(
|
||||
@@ -1445,4 +1459,3 @@ class GraphStore:
|
||||
self._adjacency_dirty = True
|
||||
logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边")
|
||||
return count
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@
|
||||
导入到 A_memorix 的存储组件中。
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services import llm_service as llm_api
|
||||
@@ -222,7 +223,8 @@ class SummaryImporter:
|
||||
self,
|
||||
stream_id: str,
|
||||
context_length: Optional[int] = None,
|
||||
include_personality: Optional[bool] = None
|
||||
include_personality: Optional[bool] = None,
|
||||
time_end: Optional[float] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
从指定的聊天流中提取记录并执行总结导入
|
||||
@@ -231,6 +233,7 @@ class SummaryImporter:
|
||||
stream_id: 聊天流 ID
|
||||
context_length: 总结的历史消息条数
|
||||
include_personality: 是否包含人设
|
||||
time_end: 用于截取聊天记录的时间上界(闭区间)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 结果消息)
|
||||
@@ -248,12 +251,13 @@ class SummaryImporter:
|
||||
include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True)
|
||||
|
||||
# 2. 获取历史消息
|
||||
# 获取当前时间之前的消息
|
||||
now = time.time()
|
||||
messages = message_api.get_messages_before_time_in_chat(
|
||||
query_time_end = time.time() if time_end is None else float(time_end)
|
||||
messages = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=now,
|
||||
limit=context_length
|
||||
start_time=0.0,
|
||||
end_time=query_time_end,
|
||||
limit=context_length,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
if not messages:
|
||||
|
||||
@@ -282,6 +282,10 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
|
||||
personality = _as_dict(data.get("personality"))
|
||||
visual = _as_dict(data.get("visual"))
|
||||
if visual is None and personality is not None and "visual_style" in personality:
|
||||
visual = {}
|
||||
data["visual"] = visual
|
||||
|
||||
if visual is not None and personality is not None and "visual_style" in personality:
|
||||
if "visual_style" not in visual:
|
||||
visual["visual_style"] = personality["visual_style"]
|
||||
@@ -289,15 +293,6 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
migrated_any = True
|
||||
reasons.append("personality.visual_style_moved_to_visual.visual_style")
|
||||
|
||||
if visual is not None and "multimodal_planner" in visual and "planner_mode" not in visual:
|
||||
multimodal_planner = visual.pop("multimodal_planner")
|
||||
if isinstance(multimodal_planner, bool):
|
||||
visual["planner_mode"] = "multimodal" if multimodal_planner else "text"
|
||||
migrated_any = True
|
||||
reasons.append("visual.multimodal_planner_moved_to_visual.planner_mode")
|
||||
else:
|
||||
visual["multimodal_planner"] = multimodal_planner
|
||||
|
||||
memory = _as_dict(data.get("memory"))
|
||||
if memory is not None and _migrate_target_item_list(memory, "global_memory_blacklist"):
|
||||
migrated_any = True
|
||||
|
||||
@@ -145,23 +145,23 @@ class VisualConfig(ConfigBase):
|
||||
__ui_label__ = "视觉"
|
||||
__ui_icon__ = "image"
|
||||
|
||||
multimodal_planner: bool = Field(
|
||||
default=True,
|
||||
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="text",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""是否直接输入图片"""
|
||||
"""Planner 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||
|
||||
multimodal_replyer: bool = Field(
|
||||
default=False,
|
||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||
default="text",
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
},
|
||||
)
|
||||
"""是否启用 Maisaka 多模态 replyer 生成器"""
|
||||
"""Replyer 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||
|
||||
visual_style: str = Field(
|
||||
default="请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多30字,请注意不要分点,就输出一段文本",
|
||||
@@ -424,6 +424,36 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||
|
||||
chat_summary_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "scroll-text",
|
||||
},
|
||||
)
|
||||
"""是否在 Maisaka 聊天过程中按消息窗口自动写回聊天摘要到长期记忆"""
|
||||
|
||||
chat_summary_writeback_message_threshold: int = Field(
|
||||
default=12,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "messages-square",
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要的消息窗口阈值"""
|
||||
|
||||
chat_summary_writeback_context_length: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=500,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "rows-3",
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要时,从聊天流中回看的消息条数"""
|
||||
|
||||
feedback_correction_enabled: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, List, Optional
|
||||
import time
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||
from src.services.memory_service import memory_service
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("memory_flow_service")
|
||||
@@ -210,27 +215,192 @@ class PersonFactWritebackService:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatSummaryWritebackState:
|
||||
last_trigger_message_count: int = 0
|
||||
last_trigger_time: float = 0.0
|
||||
|
||||
|
||||
class ChatSummaryWritebackService:
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._stopping = False
|
||||
self._states: dict[str, ChatSummaryWritebackState] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._worker_task is not None and not self._worker_task.done():
|
||||
return
|
||||
self._stopping = False
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_chat_summary_writeback")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._stopping = True
|
||||
worker = self._worker_task
|
||||
self._worker_task = None
|
||||
if worker is None:
|
||||
return
|
||||
worker.cancel()
|
||||
try:
|
||||
await worker
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "chat_summary_writeback_enabled", True)):
|
||||
return
|
||||
if self._stopping:
|
||||
return
|
||||
try:
|
||||
self._queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("聊天摘要写回队列已满,跳过本次触发")
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
try:
|
||||
while not self._stopping:
|
||||
message = await self._queue.get()
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
async def _handle_message(self, message: Any) -> None:
|
||||
session_id = self._resolve_session_id(message)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
total_message_count = count_messages(session_id=session_id)
|
||||
if total_message_count <= 0:
|
||||
return
|
||||
|
||||
threshold = self._message_threshold()
|
||||
state = self._states.setdefault(session_id, ChatSummaryWritebackState())
|
||||
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
|
||||
if pending_message_count < threshold:
|
||||
return
|
||||
|
||||
context_length = self._context_length()
|
||||
message_time = self._extract_message_timestamp(message)
|
||||
result = await memory_service.ingest_summary(
|
||||
external_id=f"chat_auto_summary:{session_id}:{total_message_count}",
|
||||
chat_id=session_id,
|
||||
text="",
|
||||
participants=[],
|
||||
time_end=message_time,
|
||||
metadata={
|
||||
"generate_from_chat": True,
|
||||
"context_length": context_length,
|
||||
"writeback_source": "memory_flow_service",
|
||||
"trigger": "message_threshold",
|
||||
"trigger_message_count": total_message_count,
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=self._extract_session_user_id(message),
|
||||
group_id=self._extract_session_group_id(message),
|
||||
)
|
||||
if not getattr(result, "success", False):
|
||||
logger.warning(
|
||||
"聊天摘要自动写回失败: session_id=%s detail=%s",
|
||||
session_id,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
return
|
||||
|
||||
state.last_trigger_message_count = total_message_count
|
||||
state.last_trigger_time = time.time()
|
||||
logger.info(
|
||||
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
|
||||
session_id,
|
||||
"message_threshold",
|
||||
total_message_count,
|
||||
context_length,
|
||||
getattr(result, "detail", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(message, "session_id", "")
|
||||
or getattr(getattr(message, "session", None), "session_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_user_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "user_id", "")
|
||||
or getattr(message, "user_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_session_group_id(message: Any) -> str:
|
||||
return str(
|
||||
getattr(getattr(message, "session", None), "group_id", "")
|
||||
or getattr(message, "group_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_timestamp(message: Any) -> float | None:
|
||||
raw_timestamp = getattr(message, "timestamp", None)
|
||||
if isinstance(raw_timestamp, datetime):
|
||||
return raw_timestamp.timestamp()
|
||||
if hasattr(raw_timestamp, "timestamp") and callable(raw_timestamp.timestamp):
|
||||
try:
|
||||
return float(raw_timestamp.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(raw_timestamp, (int, float)):
|
||||
return float(raw_timestamp)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _message_threshold() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_message_threshold", 12) or 12))
|
||||
|
||||
@staticmethod
|
||||
def _context_length() -> int:
|
||||
return max(1, int(getattr(global_config.memory, "chat_summary_writeback_context_length", 50) or 50))
|
||||
|
||||
|
||||
class MemoryAutomationService:
|
||||
def __init__(self) -> None:
|
||||
self.fact_writeback = PersonFactWritebackService()
|
||||
self.chat_summary_writeback = ChatSummaryWritebackService()
|
||||
self._started = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
await self.fact_writeback.start()
|
||||
await self.chat_summary_writeback.start()
|
||||
self._started = True
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
await self.chat_summary_writeback.shutdown()
|
||||
await self.fact_writeback.shutdown()
|
||||
self._started = False
|
||||
|
||||
async def on_incoming_message(self, message: Any) -> None:
|
||||
del message
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.fact_writeback.enqueue(message)
|
||||
await self.chat_summary_writeback.enqueue(message)
|
||||
|
||||
|
||||
memory_automation_service = MemoryAutomationService()
|
||||
|
||||
Reference in New Issue
Block a user