feat: 添加记忆自动化钩子与回写
在接收和发送消息时注册记忆自动化,并重构人物记忆回写逻辑以使用 memory_service.ingest_text。主要改动如下: 在接收消息时调用 memory_automation_service.on_incoming_message(bot 侧),在发送消息时调用 on_message_sent(send_service 侧),并加入安全的错误处理。 在 person_info 中,用 memory_service.ingest_text 替换手动操作 person.memory_points 的方式;新增 resolve_person_id_for_memory 辅助方法,并为回写计算一个 external_id 指纹。 扩展插件运行时的记忆搜索能力,使其支持 mode、chat_id、person_id、user_id、group_id、时间范围以及 respect_filter 选项。 改进 find_messages 的数据库会话处理,改为使用单一 session,并修复排序和过滤逻辑。 从 KnowledgeFetcher 中移除未使用的 LLMRequest 导入和初始化。 更新术语解释器(jargon explainer)的导入路径,使用新的模块位置。 更新 .gitignore 例外规则,允许特定的 pytest 数据文件被纳入版本控制。 文档小调整:明确人物事实提取规则(将直接使用的 “you” 改写为第三人称)。
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -1,4 +1,10 @@
|
|||||||
data/
|
data/
|
||||||
|
!pytests/A_memorix_test/data/
|
||||||
|
!pytests/A_memorix_test/data/benchmarks/
|
||||||
|
!pytests/A_memorix_test/data/benchmarks/long_novel_memory_benchmark.json
|
||||||
|
!pytests/A_memorix_test/data/real_dialogues/
|
||||||
|
!pytests/A_memorix_test/data/real_dialogues/private_alice_weekend.json
|
||||||
|
pytests/A_memorix_test/data/benchmarks/results/
|
||||||
data1/
|
data1/
|
||||||
mongodb/
|
mongodb/
|
||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
|
|||||||
@@ -3,10 +3,6 @@ from typing import Any, Dict, List, Tuple
|
|||||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
|
||||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
|
||||||
from src.config.config import model_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.person_info.person_info import resolve_person_id_for_memory
|
from src.person_info.person_info import resolve_person_id_for_memory
|
||||||
from src.services.memory_service import memory_service
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
@@ -17,7 +13,6 @@ class KnowledgeFetcher:
|
|||||||
"""知识调取器"""
|
"""知识调取器"""
|
||||||
|
|
||||||
def __init__(self, private_name: str, stream_id: str):
|
def __init__(self, private_name: str, stream_id: str):
|
||||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
|
||||||
self.private_name = private_name
|
self.private_name = private_name
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
|
|
||||||
|
|||||||
@@ -325,6 +325,13 @@ class ChatBot:
|
|||||||
scope=scope,
|
scope=scope,
|
||||||
) # 确保会话存在
|
) # 确保会话存在
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.services.memory_flow_service import memory_automation_service
|
||||||
|
|
||||||
|
await memory_automation_service.on_incoming_message(message)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}")
|
||||||
|
|
||||||
# message.update_chat_stream(chat)
|
# message.update_chat_stream(chat)
|
||||||
|
|
||||||
# 命令处理 - 使用新插件系统检查并处理命令。
|
# 命令处理 - 使用新插件系统检查并处理命令。
|
||||||
|
|||||||
@@ -189,39 +189,37 @@ def find_messages(
|
|||||||
conditions.append(Messages.is_command == False) # noqa: E712
|
conditions.append(Messages.is_command == False) # noqa: E712
|
||||||
|
|
||||||
statement = select(Messages).where(*conditions)
|
statement = select(Messages).where(*conditions)
|
||||||
if limit > 0:
|
with get_db_session(auto_commit=False) as session:
|
||||||
if limit_mode == "earliest":
|
if limit > 0:
|
||||||
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
if limit_mode == "earliest":
|
||||||
with get_db_session() as session:
|
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
||||||
results = list(session.exec(statement).all())
|
results = list(session.exec(statement).all())
|
||||||
|
else:
|
||||||
|
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
||||||
|
results = list(session.exec(statement).all())
|
||||||
|
results = list(reversed(results))
|
||||||
else:
|
else:
|
||||||
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
if sort:
|
||||||
with get_db_session() as session:
|
order_terms: list[Any] = []
|
||||||
results = list(session.exec(statement).all())
|
for field_name, direction in sort:
|
||||||
results = list(reversed(results))
|
sort_field = _resolve_field(field_name)
|
||||||
else:
|
if sort_field is None:
|
||||||
if sort:
|
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||||
order_terms: list[Any] = []
|
continue
|
||||||
for field_name, direction in sort:
|
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
||||||
sort_field = _resolve_field(field_name)
|
if order_terms:
|
||||||
if sort_field is None:
|
statement = statement.order_by(*order_terms)
|
||||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
|
||||||
continue
|
|
||||||
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
|
||||||
if order_terms:
|
|
||||||
statement = statement.order_by(*order_terms)
|
|
||||||
with get_db_session() as session:
|
|
||||||
results = list(session.exec(statement).all())
|
results = list(session.exec(statement).all())
|
||||||
|
|
||||||
if filter_intercept_message_level is not None:
|
if filter_intercept_message_level is not None:
|
||||||
filtered_results = []
|
filtered_results = []
|
||||||
for msg in results:
|
for msg in results:
|
||||||
config = _parse_additional_config(msg)
|
config = _parse_additional_config(msg)
|
||||||
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
||||||
filtered_results.append(msg)
|
filtered_results.append(msg)
|
||||||
results = filtered_results
|
results = filtered_results
|
||||||
|
|
||||||
return [_message_to_instance(msg) for msg in results]
|
return [_message_to_instance(msg) for msg in results]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = (
|
log_message = (
|
||||||
"使用 SQLModel 查找消息失败 "
|
"使用 SQLModel 查找消息失败 "
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
|||||||
from src.services.llm_service import LLMServiceClient
|
from src.services.llm_service import LLMServiceClient
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.learners.jargon_miner_old import search_jargon
|
from src.learners.jargon_explainer import search_jargon
|
||||||
from src.learners.learner_utils_old import (
|
from src.learners.learner_utils_old import (
|
||||||
is_bot_message,
|
is_bot_message,
|
||||||
contains_bot_self_name,
|
contains_bot_self_name,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from src.common.database.database import get_db_session
|
|||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.services.memory_service import memory_service
|
||||||
from src.services.llm_service import LLMServiceClient
|
from src.services.llm_service import LLMServiceClient
|
||||||
|
|
||||||
|
|
||||||
@@ -66,15 +67,45 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
|||||||
def get_person_id_by_person_name(person_name: str) -> str:
|
def get_person_id_by_person_name(person_name: str) -> str:
|
||||||
"""根据用户名获取用户ID"""
|
"""根据用户名获取用户ID"""
|
||||||
try:
|
try:
|
||||||
with get_db_session() as session:
|
with get_db_session(auto_commit=False) as session:
|
||||||
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
|
statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == person_name).limit(1)
|
||||||
record = session.exec(statement).first()
|
person_id = session.exec(statement).first()
|
||||||
return record.person_id if record else ""
|
return str(person_id) if person_id else ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_person_id_for_memory(
|
||||||
|
*,
|
||||||
|
person_name: str = "",
|
||||||
|
platform: str = "",
|
||||||
|
user_id: Union[int, str, None] = None,
|
||||||
|
strict_known: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""解析长期记忆检索/写入使用的人物 ID。
|
||||||
|
|
||||||
|
解析顺序:
|
||||||
|
1. 优先按 `person_name` 映射数据库中的 `person_id`
|
||||||
|
2. 回退到 `platform + user_id` 生成稳定 `person_id`
|
||||||
|
3. 若 `strict_known=True`,则要求该 `person_id` 已被认识
|
||||||
|
"""
|
||||||
|
clean_name = str(person_name or "").strip()
|
||||||
|
if clean_name:
|
||||||
|
if by_name := get_person_id_by_person_name(clean_name):
|
||||||
|
return by_name
|
||||||
|
|
||||||
|
clean_platform = str(platform or "").strip()
|
||||||
|
clean_user_id = str(user_id or "").strip()
|
||||||
|
if clean_platform and clean_user_id:
|
||||||
|
candidate = get_person_id(clean_platform, clean_user_id)
|
||||||
|
if strict_known and not is_person_known(person_id=candidate):
|
||||||
|
return ""
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def is_person_known(
|
def is_person_known(
|
||||||
person_id: Optional[str] = None,
|
person_id: Optional[str] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
@@ -800,75 +831,83 @@ person_info_manager = PersonInfoManager()
|
|||||||
|
|
||||||
|
|
||||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||||
"""将人物信息存入person_info的memory_points
|
"""将人物事实写入长期记忆系统。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
person_name: 人物名称
|
person_name: 人物名称
|
||||||
memory_content: 记忆内容
|
memory_content: 记忆内容
|
||||||
chat_id: 聊天ID
|
chat_id: 聊天ID
|
||||||
"""
|
"""
|
||||||
|
clean_content = str(memory_content or "").strip()
|
||||||
|
if not clean_content:
|
||||||
|
logger.debug("人物事实写回跳过:memory_content 为空")
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_chat_id = str(chat_id or "").strip()
|
||||||
|
if not clean_chat_id:
|
||||||
|
logger.warning("人物事实写回失败:chat_id 为空")
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_person_name = str(person_name or "").strip()
|
||||||
try:
|
try:
|
||||||
# 从 chat_id 获取 session
|
# 从 chat_id 获取 session
|
||||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
session = _chat_manager.get_session_by_session_id(clean_chat_id)
|
||||||
if not session:
|
if not session:
|
||||||
logger.warning(f"无法获取session for chat_id: {chat_id}")
|
logger.warning(f"无法获取session for chat_id: {clean_chat_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
platform = session.platform
|
session_platform = str(getattr(session, "platform", "") or "").strip()
|
||||||
|
session_user_id = str(getattr(session, "user_id", "") or "").strip()
|
||||||
# 尝试从person_name查找person_id
|
session_group_id = str(getattr(session, "group_id", "") or "").strip()
|
||||||
# 首先尝试通过person_name查找
|
|
||||||
person_id = get_person_id_by_person_name(person_name)
|
|
||||||
|
|
||||||
|
person_id = resolve_person_id_for_memory(
|
||||||
|
person_name=clean_person_name,
|
||||||
|
platform=session_platform,
|
||||||
|
user_id=session_user_id,
|
||||||
|
)
|
||||||
if not person_id:
|
if not person_id:
|
||||||
# 如果通过person_name找不到,尝试从 session 获取 user_id
|
logger.warning(f"无法确定person_id for person_name: {clean_person_name}, chat_id: {clean_chat_id}")
|
||||||
if platform and session.user_id:
|
|
||||||
user_id = session.user_id
|
|
||||||
person_id = get_person_id(platform, user_id)
|
|
||||||
else:
|
|
||||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 创建或获取Person对象
|
|
||||||
person = Person(person_id=person_id)
|
|
||||||
|
|
||||||
if not person.is_known:
|
|
||||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
person = Person(person_id=person_id)
|
||||||
category = "其他" # 默认分类,可以根据需要调整
|
if not person.is_known:
|
||||||
|
logger.warning(f"用户 {clean_person_name or person_id} (person_id: {person_id}) 尚未认识,跳过写回")
|
||||||
|
return
|
||||||
|
|
||||||
# 记忆点格式:category:content:weight
|
participant_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or "").strip()
|
||||||
weight = "1.0" # 默认权重
|
if not participant_name:
|
||||||
memory_point = f"{category}:{memory_content}:{weight}"
|
participant_name = clean_person_name or person_id
|
||||||
|
|
||||||
# 添加到memory_points
|
payload_fingerprint = hashlib.md5(f"{person_id}|{clean_chat_id}|{clean_content}".encode()).hexdigest()
|
||||||
if not person.memory_points:
|
external_id = f"person_fact:{person_id}:{payload_fingerprint}"
|
||||||
person.memory_points = []
|
|
||||||
|
|
||||||
# 检查是否已存在相似的记忆点(避免重复)
|
result = await memory_service.ingest_text(
|
||||||
is_duplicate = False
|
external_id=external_id,
|
||||||
for existing_point in person.memory_points:
|
source_type="person_fact",
|
||||||
if existing_point and isinstance(existing_point, str):
|
text=clean_content,
|
||||||
parts = existing_point.split(":", 2)
|
chat_id=clean_chat_id,
|
||||||
if len(parts) >= 2:
|
person_ids=[person_id],
|
||||||
existing_content = parts[1].strip()
|
participants=[participant_name],
|
||||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
tags=["person_fact"],
|
||||||
if (
|
metadata={
|
||||||
existing_content == memory_content
|
"person_id": person_id,
|
||||||
or memory_content in existing_content
|
"person_name": participant_name,
|
||||||
or existing_content in memory_content
|
"writeback_source": "memory_flow_service",
|
||||||
):
|
},
|
||||||
is_duplicate = True
|
respect_filter=True,
|
||||||
break
|
user_id=session_user_id,
|
||||||
|
group_id=session_group_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not is_duplicate:
|
if getattr(result, "success", False):
|
||||||
person.memory_points.append(memory_point)
|
logger.info(
|
||||||
person.sync_to_database()
|
f"成功写回人物事实到长期记忆: person={participant_name} person_id={person_id} chat_id={clean_chat_id}"
|
||||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
logger.warning(
|
||||||
|
f"人物事实写回长期记忆失败: person={participant_name} person_id={person_id} "
|
||||||
|
f"chat_id={clean_chat_id} detail={getattr(result, 'detail', '')}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储人物记忆失败: {e}")
|
logger.error(f"存储人物记忆失败: {e}")
|
||||||
|
|||||||
@@ -671,10 +671,30 @@ class RuntimeDataCapabilityMixin:
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
limit_value = 5
|
limit_value = 5
|
||||||
|
|
||||||
|
mode = str(args.get("mode", "search") or "search").strip() or "search"
|
||||||
|
chat_id = str(args.get("chat_id", "") or "").strip()
|
||||||
|
person_id = str(args.get("person_id", "") or "").strip()
|
||||||
|
user_id = str(args.get("user_id", "") or "").strip()
|
||||||
|
group_id = str(args.get("group_id", "") or "").strip()
|
||||||
|
respect_filter = bool(args.get("respect_filter", True))
|
||||||
|
time_start = args.get("time_start")
|
||||||
|
time_end = args.get("time_end")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.services.memory_service import memory_service
|
from src.services.memory_service import memory_service
|
||||||
|
|
||||||
result = await memory_service.search(query, limit=limit_value)
|
result = await memory_service.search(
|
||||||
|
query,
|
||||||
|
limit=limit_value,
|
||||||
|
mode=mode,
|
||||||
|
chat_id=chat_id,
|
||||||
|
person_id=person_id,
|
||||||
|
time_start=time_start,
|
||||||
|
time_end=time_end,
|
||||||
|
respect_filter=respect_filter,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
if not result.success:
|
if not result.success:
|
||||||
return {"success": False, "error": result.error or "长期记忆检索失败"}
|
return {"success": False, "error": result.error or "长期记忆检索失败"}
|
||||||
knowledge_info = result.to_text(limit=limit_value)
|
knowledge_info = result.to_text(limit=limit_value)
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ class PersonFactWritebackService:
|
|||||||
1. 明确是关于目标人物本人的信息。
|
1. 明确是关于目标人物本人的信息。
|
||||||
2. 具有相对稳定性,可以作为长期记忆保存。
|
2. 具有相对稳定性,可以作为长期记忆保存。
|
||||||
3. 用简洁中文陈述句表达。
|
3. 用简洁中文陈述句表达。
|
||||||
|
4. 如果回复是在直接对目标人物说话,出现“你/你的/你自己”时,默认都指目标人物,请先改写成关于目标人物的第三人称事实再输出。
|
||||||
|
|
||||||
不要提取:
|
不要提取:
|
||||||
- 机器人的情绪、计划、临时动作、客套话
|
- 机器人的情绪、计划、临时动作、客套话
|
||||||
|
|||||||
@@ -434,6 +434,21 @@ def _store_sent_message(message: SessionMessage) -> None:
|
|||||||
MessageUtils.store_message_to_db(message)
|
MessageUtils.store_message_to_db(message)
|
||||||
|
|
||||||
|
|
||||||
|
async def _notify_memory_automation_on_message_sent(message: SessionMessage) -> None:
|
||||||
|
"""在发送成功后通知长期记忆自动化服务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 已成功发送的内部消息对象。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.services.memory_flow_service import memory_automation_service
|
||||||
|
|
||||||
|
await memory_automation_service.on_message_sent(message)
|
||||||
|
except Exception as exc:
|
||||||
|
session_id = message.session_id or "unknown-session"
|
||||||
|
logger.warning(f"[{session_id}] 长期记忆人物事实写回注册失败: {exc}")
|
||||||
|
|
||||||
|
|
||||||
def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
|
def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None:
|
||||||
"""输出 Platform IO 批量发送失败详情。
|
"""输出 Platform IO 批量发送失败详情。
|
||||||
|
|
||||||
@@ -503,6 +518,7 @@ async def _send_via_platform_io(
|
|||||||
if delivery_batch.has_success:
|
if delivery_batch.has_success:
|
||||||
if storage_message:
|
if storage_message:
|
||||||
_store_sent_message(message)
|
_store_sent_message(message)
|
||||||
|
await _notify_memory_automation_on_message_sent(message)
|
||||||
if show_log:
|
if show_log:
|
||||||
successful_driver_ids = [
|
successful_driver_ids = [
|
||||||
receipt.driver_id or "unknown"
|
receipt.driver_id or "unknown"
|
||||||
|
|||||||
Reference in New Issue
Block a user