feat:新增记忆测试、检索工具与服务
新增完整的长期记忆支持及测试:引入中文记忆检索提示词、query_long_term_memory 检索工具、记忆服务与记忆流程服务,以及 WebUI 的记忆路由。新增大规模测试套件(包括单元测试与基准/在线测试),覆盖聊天历史摘要、知识获取器、事件(episode)生成、写回机制以及用户画像检索等功能。 更新多个模块以集成记忆检索能力(包括 knowledge fetcher、chat summarizer、memory_retrieval、person_info、config/legacy 迁移以及 WebUI 路由),并移除遗留的 lpmm 知识模块。这些变更完成了记忆运行时的接入,同时为基准测试提供嵌入适配器的 mock,并支持新测试与工具所需的导入与 episode 处理流程。
This commit is contained in:
@@ -7,7 +7,7 @@ from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.bw_learner.jargon_miner_old import search_jargon
|
||||
from src.bw_learner.jargon_explainer import search_jargon
|
||||
from src.bw_learner.learner_utils_old import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
|
||||
@@ -196,6 +196,32 @@ def contains_bot_self_name(content: str) -> bool:
|
||||
return any(name in target for name in candidates)
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身。"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
known_accounts = {
|
||||
str(getattr(bot_config, "qq_account", "") or "").strip(),
|
||||
str(getattr(bot_config, "telegram_account", "") or "").strip(),
|
||||
}
|
||||
|
||||
for platform in getattr(bot_config, "platforms", []) or []:
|
||||
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
|
||||
if account:
|
||||
known_accounts.add(account)
|
||||
|
||||
return user_id in {account for account in known_accounts if account}
|
||||
|
||||
|
||||
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
# """
|
||||
# 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
|
||||
@@ -55,7 +55,7 @@ class Conversation:
|
||||
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
|
||||
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
|
||||
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
|
||||
self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id)
|
||||
self.waiter = Waiter(self.stream_id, self.private_name)
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge import qa_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import resolve_person_id_for_memory
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
|
||||
@@ -13,11 +16,39 @@ logger = get_logger("knowledge_fetcher")
|
||||
class KnowledgeFetcher:
|
||||
"""知识调取器"""
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
def __init__(self, private_name: str, stream_id: str):
|
||||
self.llm = LLMRequest(model_set=model_config.model_task_config.utils)
|
||||
self.private_name = private_name
|
||||
self.stream_id = stream_id
|
||||
|
||||
def _lpmm_get_knowledge(self, query: str) -> str:
|
||||
def _resolve_private_memory_context(self) -> Dict[str, str]:
|
||||
session = _chat_manager.get_session_by_session_id(self.stream_id)
|
||||
if session is None:
|
||||
return {"chat_id": self.stream_id}
|
||||
|
||||
group_id = str(getattr(session, "group_id", "") or "").strip()
|
||||
user_id = str(getattr(session, "user_id", "") or "").strip()
|
||||
platform = str(getattr(session, "platform", "") or "").strip()
|
||||
|
||||
person_id = ""
|
||||
if not group_id:
|
||||
try:
|
||||
person_id = resolve_person_id_for_memory(
|
||||
person_name=self.private_name,
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[私聊][{self.private_name}]解析人物ID失败: {exc}")
|
||||
|
||||
return {
|
||||
"chat_id": self.stream_id,
|
||||
"person_id": person_id,
|
||||
"user_id": user_id,
|
||||
"group_id": group_id,
|
||||
}
|
||||
|
||||
async def _memory_get_knowledge(self, query: str) -> str:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
@@ -27,13 +58,32 @@ class KnowledgeFetcher:
|
||||
str: 构造好的,带相关度的知识
|
||||
"""
|
||||
|
||||
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
|
||||
logger.debug(f"[私聊][{self.private_name}]正在从长期记忆中获取知识")
|
||||
try:
|
||||
knowledge_info = qa_manager.get_knowledge(query)
|
||||
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
|
||||
return knowledge_info
|
||||
context = self._resolve_private_memory_context()
|
||||
search_kwargs = {
|
||||
"limit": 5,
|
||||
"mode": "search",
|
||||
"chat_id": context.get("chat_id", ""),
|
||||
"person_id": context.get("person_id", ""),
|
||||
"user_id": context.get("user_id", ""),
|
||||
"group_id": context.get("group_id", ""),
|
||||
"respect_filter": True,
|
||||
}
|
||||
result = await memory_service.search(query, **search_kwargs)
|
||||
if not result.filtered and not result.hits and search_kwargs["person_id"]:
|
||||
fallback_kwargs = dict(search_kwargs)
|
||||
fallback_kwargs["person_id"] = ""
|
||||
logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆")
|
||||
result = await memory_service.search(query, **fallback_kwargs)
|
||||
knowledge_info = result.to_text(limit=5)
|
||||
if result.filtered:
|
||||
logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过")
|
||||
else:
|
||||
logger.debug(f"[私聊][{self.private_name}]长期记忆查询结果: {knowledge_info[:150]}")
|
||||
return knowledge_info or "未找到匹配的知识"
|
||||
except Exception as e:
|
||||
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
|
||||
logger.error(f"[私聊][{self.private_name}]长期记忆搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
|
||||
@@ -72,7 +122,7 @@ class KnowledgeFetcher:
|
||||
# sources_text = ",".join(sources)
|
||||
|
||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||
knowledge_text += self._lpmm_get_knowledge(query)
|
||||
knowledge_text += await self._memory_get_knowledge(query)
|
||||
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
|
||||
|
||||
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"她们",
|
||||
"它们",
|
||||
]
|
||||
|
||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
|
||||
|
||||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def get_qa_manager():
|
||||
return qa_manager
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
logger.info("正在初始化Mai-LPMM")
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager(
|
||||
max_workers=global_config.lpmm_knowledge.max_embedding_workers,
|
||||
chunk_size=global_config.lpmm_knowledge.embedding_chunk_size,
|
||||
)
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
logger.info("正在从文件加载KG")
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
# 使用与EmbeddingStore中一致的命名空间格式
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
global qa_manager
|
||||
# 问答系统(用于知识库)
|
||||
qa_manager = QAManager(
|
||||
embed_manager,
|
||||
kg_manager,
|
||||
)
|
||||
|
||||
# # 记忆激活(用于记忆库)
|
||||
# global inspire_manager
|
||||
# inspire_manager = MemoryActiveManager(
|
||||
# embed_manager,
|
||||
# llm_client_list[global_config["embedding"]["provider"]],
|
||||
# )
|
||||
else:
|
||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
||||
# 创建空的占位符对象,避免导入错误
|
||||
@@ -1,380 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import List, Callable, Any
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import get_qa_manager, lpmm_start_up
|
||||
|
||||
logger = get_logger("LPMM-Plugin-API")
|
||||
|
||||
|
||||
class LPMMOperations:
|
||||
"""
|
||||
LPMM 内部操作接口。
|
||||
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
在线程池中执行可取消的同步操作。
|
||||
当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。
|
||||
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
|
||||
|
||||
Args:
|
||||
func: 要执行的同步函数
|
||||
*args: 函数的位置参数
|
||||
**kwargs: 函数的关键字参数
|
||||
|
||||
Returns:
|
||||
函数的返回值
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: 当任务被取消时
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
# 在线程池中执行,当协程被取消时会立即响应
|
||||
# 虽然线程池中的操作可能仍在运行,但协程不会阻塞
|
||||
return await loop.run_in_executor(None, func, *args, **kwargs)
|
||||
|
||||
async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]:
|
||||
"""获取并确保 LPMM 管理器已初始化"""
|
||||
qa_mgr = get_qa_manager()
|
||||
if qa_mgr is None:
|
||||
# 如果全局没初始化,尝试初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
|
||||
|
||||
lpmm_start_up()
|
||||
qa_mgr = get_qa_manager()
|
||||
|
||||
if qa_mgr is None:
|
||||
raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。")
|
||||
|
||||
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
|
||||
|
||||
async def add_content(self, text: str, auto_split: bool = True) -> dict:
|
||||
"""
|
||||
向知识库添加新内容。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
auto_split: 是否自动按双换行符分割段落。
|
||||
- True: 自动分割(默认),支持多段文本(用双换行分隔)
|
||||
- False: 不分割,将整个文本作为完整一段处理
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
# 1. 分段处理
|
||||
if auto_split:
|
||||
# 自动按双换行符分割
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
else:
|
||||
# 不分割,作为完整一段
|
||||
text_stripped = text.strip()
|
||||
if not text_stripped:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
paragraphs = [text_stripped]
|
||||
|
||||
if not paragraphs:
|
||||
return {"status": "error", "message": "文本内容为空"}
|
||||
|
||||
# 2. 实体与三元组抽取 (内部调用大模型)
|
||||
from src.chat.knowledge.ie_process import IEProcess
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm_ner = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
ie_process = IEProcess(llm_ner, llm_rdf)
|
||||
|
||||
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
|
||||
extracted_docs = await ie_process.process_paragraphs(paragraphs)
|
||||
|
||||
# 3. 构造并导入数据
|
||||
# 这里我们手动实现导入逻辑,不依赖外部脚本
|
||||
# a. 准备段落
|
||||
raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs}
|
||||
# b. 准备三元组
|
||||
triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs}
|
||||
|
||||
# 向量化并入库
|
||||
# 注意:此处模仿 import_openie.py 的核心逻辑
|
||||
# 1. 先进行去重检查,只处理新段落
|
||||
# store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本
|
||||
new_raw_paragraphs = {}
|
||||
new_triple_list_data = {}
|
||||
|
||||
for pg_hash, passage in raw_paragraphs.items():
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_mgr.stored_pg_hashes:
|
||||
new_raw_paragraphs[pg_hash] = passage
|
||||
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
|
||||
|
||||
if not new_raw_paragraphs:
|
||||
return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"}
|
||||
|
||||
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
|
||||
# store_new_data_set 会自动处理嵌入生成和存储
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data)
|
||||
|
||||
# 3. 构建知识图谱(只需要三元组数据和embedding_manager)
|
||||
await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr)
|
||||
|
||||
# 4. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": len(new_raw_paragraphs),
|
||||
"message": f"成功导入 {len(new_raw_paragraphs)} 条知识",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 导入操作被用户中断")
|
||||
return {"status": "cancelled", "message": "导入操作已被用户中断"}
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search(self, query: str, top_k: int = 3) -> List[str]:
|
||||
"""
|
||||
检索知识库。
|
||||
|
||||
Args:
|
||||
query: 查询问题。
|
||||
top_k: 返回最相关的条目数。
|
||||
|
||||
Returns:
|
||||
List[str]: 相关文段列表。
|
||||
"""
|
||||
try:
|
||||
_, _, qa_mgr = await self._get_managers()
|
||||
# 直接调用 QAManager 的检索接口
|
||||
knowledge = qa_mgr.get_knowledge(query, top_k=top_k)
|
||||
# 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回
|
||||
return [knowledge] if knowledge else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 检索知识失败: {e}")
|
||||
return []
|
||||
|
||||
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
|
||||
"""
|
||||
根据关键词或完整文段删除知识库内容。
|
||||
|
||||
Args:
|
||||
keyword: 匹配关键词或完整文段。
|
||||
exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
# 1. 查找匹配的段落
|
||||
to_delete_keys = []
|
||||
to_delete_hashes = []
|
||||
|
||||
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
|
||||
if exact_match:
|
||||
# 完整文段匹配
|
||||
if item.str.strip() == keyword.strip():
|
||||
to_delete_keys.append(key)
|
||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
||||
else:
|
||||
# 关键词模糊匹配
|
||||
if keyword in item.str:
|
||||
to_delete_keys.append(key)
|
||||
to_delete_hashes.append(key.replace("paragraph-", "", 1))
|
||||
|
||||
if not to_delete_keys:
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
|
||||
|
||||
# 2. 执行删除
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
# a. 从向量库删除
|
||||
deleted_count, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
|
||||
# b. 从知识图谱删除
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
# 3. 持久化
|
||||
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
match_type = "完整文段" if exact_match else "关键词"
|
||||
return {
|
||||
"status": "success",
|
||||
"deleted_count": deleted_count,
|
||||
"message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)",
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 删除操作被用户中断")
|
||||
return {"status": "cancelled", "message": "删除操作已被用户中断"}
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def clear_all(self) -> dict:
|
||||
"""
|
||||
清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success/error", "message": "描述", "stats": {...}}
|
||||
"""
|
||||
try:
|
||||
embed_mgr, kg_mgr, _ = await self._get_managers()
|
||||
|
||||
# 记录清空前的统计信息
|
||||
before_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
||||
"relations": len(embed_mgr.relation_embedding_store.store),
|
||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
|
||||
|
||||
# 1. 清空所有向量库
|
||||
# 获取所有keys
|
||||
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
|
||||
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
|
||||
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
|
||||
|
||||
# 删除所有段落向量
|
||||
para_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.paragraphs_embedding_store.delete_items, para_keys
|
||||
)
|
||||
embed_mgr.stored_pg_hashes.clear()
|
||||
|
||||
# 删除所有实体向量
|
||||
if ent_keys:
|
||||
ent_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.entities_embedding_store.delete_items, ent_keys
|
||||
)
|
||||
else:
|
||||
ent_deleted = 0
|
||||
|
||||
# 删除所有关系向量
|
||||
if rel_keys:
|
||||
rel_deleted, _ = await self._run_cancellable_executor(
|
||||
embed_mgr.relation_embedding_store.delete_items, rel_keys
|
||||
)
|
||||
else:
|
||||
rel_deleted = 0
|
||||
|
||||
# 2. 清空所有 embedding store 的索引和映射
|
||||
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
|
||||
def _clear_embedding_indices():
|
||||
# 清空段落索引
|
||||
embed_mgr.paragraphs_embedding_store.faiss_index = None
|
||||
embed_mgr.paragraphs_embedding_store.idx2hash = None
|
||||
embed_mgr.paragraphs_embedding_store.dirty = False
|
||||
# 删除旧的索引文件
|
||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path):
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
|
||||
|
||||
# 清空实体索引
|
||||
embed_mgr.entities_embedding_store.faiss_index = None
|
||||
embed_mgr.entities_embedding_store.idx2hash = None
|
||||
embed_mgr.entities_embedding_store.dirty = False
|
||||
# 删除旧的索引文件
|
||||
if os.path.exists(embed_mgr.entities_embedding_store.index_file_path):
|
||||
os.remove(embed_mgr.entities_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
|
||||
|
||||
# 清空关系索引
|
||||
embed_mgr.relation_embedding_store.faiss_index = None
|
||||
embed_mgr.relation_embedding_store.idx2hash = None
|
||||
embed_mgr.relation_embedding_store.dirty = False
|
||||
# 删除旧的索引文件
|
||||
if os.path.exists(embed_mgr.relation_embedding_store.index_file_path):
|
||||
os.remove(embed_mgr.relation_embedding_store.index_file_path)
|
||||
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
|
||||
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
|
||||
|
||||
await self._run_cancellable_executor(_clear_embedding_indices)
|
||||
|
||||
# 3. 清空知识图谱
|
||||
# 获取所有段落hash
|
||||
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
|
||||
if all_pg_hashes:
|
||||
# 删除所有段落节点(这会自动清理相关的边和孤立实体)
|
||||
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
|
||||
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
|
||||
delete_func = partial(
|
||||
kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True
|
||||
)
|
||||
await self._run_cancellable_executor(delete_func)
|
||||
|
||||
# 完全清空KG:创建新的空图(无论是否有段落hash都要执行)
|
||||
from quick_algo import di_graph
|
||||
|
||||
kg_mgr.graph = di_graph.DiGraph()
|
||||
kg_mgr.stored_paragraph_hashes.clear()
|
||||
kg_mgr.ent_appear_cnt.clear()
|
||||
|
||||
# 4. 保存所有数据(此时所有store都是空的,索引也是None)
|
||||
# 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的
|
||||
await self._run_cancellable_executor(embed_mgr.save_to_file)
|
||||
await self._run_cancellable_executor(kg_mgr.save_to_file)
|
||||
|
||||
after_stats = {
|
||||
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
|
||||
"entities": len(embed_mgr.entities_embedding_store.store),
|
||||
"relations": len(embed_mgr.relation_embedding_store.store),
|
||||
"kg_nodes": len(kg_mgr.graph.get_node_list()),
|
||||
"kg_edges": len(kg_mgr.graph.get_edge_list()),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
|
||||
"stats": {
|
||||
"before": before_stats,
|
||||
"after": after_stats,
|
||||
},
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("[Plugin API] 清空操作被用户中断")
|
||||
return {"status": "cancelled", "message": "清空操作已被用户中断"}
|
||||
except Exception as e:
|
||||
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# 内部使用的单例
|
||||
lpmm_ops = LPMMOperations()
|
||||
@@ -360,6 +360,12 @@ class ChatBot:
|
||||
user_id = user_info.user_id
|
||||
group_id = group_info.group_id if group_info else None
|
||||
_ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在
|
||||
try:
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
await memory_automation_service.on_incoming_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}")
|
||||
|
||||
# message.update_chat_stream(chat)
|
||||
|
||||
|
||||
@@ -383,6 +383,13 @@ class UniversalMessageSender:
|
||||
with get_db_session() as db_session:
|
||||
db_session.add(message.to_db_instance())
|
||||
|
||||
try:
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
await memory_automation_service.on_message_sent(message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[{chat_id}] 长期记忆人物事实写回注册失败: {exc}")
|
||||
|
||||
return sent_msg
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
import importlib
|
||||
import random
|
||||
import re
|
||||
|
||||
@@ -36,6 +35,7 @@ from src.services import llm_service as llm_api
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
from src.memory_system.retrieval_tools import get_tool_registry
|
||||
from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
@@ -1164,29 +1164,14 @@ class DefaultReplyer:
|
||||
async def get_prompt_info(self, message: str, sender: str, target: str):
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
try:
|
||||
knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge")
|
||||
except ImportError:
|
||||
logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None)
|
||||
search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory")
|
||||
if search_knowledge_tool is None:
|
||||
logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容")
|
||||
logger.debug("长期记忆检索工具未注册,跳过获取知识内容")
|
||||
return ""
|
||||
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 从LPMM知识库获取知识
|
||||
logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
try:
|
||||
# 检查LPMM知识库是否启用
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用,跳过获取知识库内容")
|
||||
return ""
|
||||
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
return ""
|
||||
|
||||
template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge")
|
||||
template_prompt = prompt_manager.get_prompt("memory_get_knowledge")
|
||||
template_prompt.add_context("bot_name", global_config.bot.nickname)
|
||||
template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
template_prompt.add_context("chat_history", message)
|
||||
@@ -1202,24 +1187,31 @@ class DefaultReplyer:
|
||||
# logger.info(f"工具调用提示词: {prompt}")
|
||||
# logger.info(f"工具调用: {tool_calls}")
|
||||
|
||||
if tool_calls:
|
||||
result = await self.tool_executor.execute_tool_call(tool_calls[0])
|
||||
end_time = time.time()
|
||||
if not result or not result.get("content"):
|
||||
logger.debug("从LPMM知识库获取知识失败,返回空知识...")
|
||||
return ""
|
||||
found_knowledge_from_lpmm = result.get("content", "")
|
||||
logger.info(
|
||||
f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
|
||||
)
|
||||
related_info += found_knowledge_from_lpmm
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
|
||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||
else:
|
||||
logger.debug("模型认为不需要使用LPMM知识库")
|
||||
if not tool_calls:
|
||||
logger.debug("模型认为不需要使用长期记忆")
|
||||
return ""
|
||||
|
||||
related_chunks: List[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.func_name != "search_long_term_memory":
|
||||
continue
|
||||
tool_args = dict(tool_call.args or {})
|
||||
tool_args.setdefault("chat_id", self.chat_stream.session_id)
|
||||
result_text = await search_knowledge_tool.execute(**tool_args)
|
||||
if result_text and "未找到" not in result_text:
|
||||
related_chunks.append(result_text)
|
||||
|
||||
if not related_chunks:
|
||||
logger.debug("长期记忆未返回有效信息")
|
||||
return ""
|
||||
|
||||
related_info = "\n".join(related_chunks)
|
||||
end_time = time.time()
|
||||
logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
|
||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
return ""
|
||||
|
||||
@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.1.0"
|
||||
CONFIG_VERSION: str = "8.1.1"
|
||||
MODEL_CONFIG_VERSION: str = "1.12.0"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -94,6 +94,11 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
["", "enable", "enable", "enable"],
|
||||
["qq:1919810:group", "enable", "enable", "enable"],
|
||||
]
|
||||
兼容旧旧格式:
|
||||
learning_list = [
|
||||
["qq:1919810:group", "enable", "enable", "0.5"],
|
||||
["", "disable", "disable", "0.1"],
|
||||
]
|
||||
新:
|
||||
[[expression.learning_list]]
|
||||
platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true
|
||||
@@ -117,6 +122,16 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool:
|
||||
use_expression = _parse_enable_disable(r[1])
|
||||
enable_learning = _parse_enable_disable(r[2])
|
||||
enable_jargon_learning = _parse_enable_disable(r[3])
|
||||
if enable_jargon_learning is None:
|
||||
# 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值,
|
||||
# 当前 schema 已没有对应字段。这里按保守策略兼容迁移:
|
||||
# 丢弃旧数值,并将 enable_jargon_learning 置为 False。
|
||||
try:
|
||||
float(str(r[3]))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
else:
|
||||
enable_jargon_learning = False
|
||||
if use_expression is None or enable_learning is None or enable_jargon_learning is None:
|
||||
return False
|
||||
|
||||
|
||||
@@ -416,6 +416,24 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
|
||||
|
||||
long_term_auto_summary_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "book-open",
|
||||
},
|
||||
)
|
||||
"""是否自动启动聊天总结并导入长期记忆"""
|
||||
|
||||
person_fact_writeback_enabled: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "user-round-pen",
|
||||
},
|
||||
)
|
||||
"""是否在发送回复后自动提取并写回人物事实到长期记忆"""
|
||||
|
||||
chat_history_topic_check_message_threshold: int = Field(
|
||||
default=80,
|
||||
ge=1,
|
||||
|
||||
10
src/main.py
10
src/main.py
@@ -6,7 +6,6 @@ import time
|
||||
|
||||
from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.knowledge import lpmm_start_up
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
@@ -19,6 +18,7 @@ from src.config.config import config_manager, global_config
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.memory_flow_service import memory_automation_service
|
||||
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
@@ -88,9 +88,6 @@ class MainSystem:
|
||||
# start_api_server()
|
||||
# logger.info("API服务器启动成功")
|
||||
|
||||
# 启动LPMM
|
||||
lpmm_start_up()
|
||||
|
||||
# 启动插件运行时(内置插件 + 第三方插件双子进程)
|
||||
await get_plugin_runtime_manager().start()
|
||||
|
||||
@@ -103,6 +100,7 @@ class MainSystem:
|
||||
asyncio.create_task(chat_manager.regularly_save_sessions())
|
||||
|
||||
logger.info(t("startup.chat_manager_initialized"))
|
||||
await memory_automation_service.start()
|
||||
|
||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||
|
||||
@@ -164,6 +162,10 @@ async def main():
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
finally:
|
||||
await memory_automation_service.shutdown()
|
||||
await get_plugin_runtime_manager().bridge_event("on_stop")
|
||||
await get_plugin_runtime_manager().stop()
|
||||
await async_task_manager.stop_and_wait_all_tasks()
|
||||
await config_manager.stop_file_watcher()
|
||||
|
||||
|
||||
|
||||
@@ -931,12 +931,14 @@ class ChatHistorySummarizer:
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
|
||||
|
||||
# 同时导入到LPMM知识库
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
await self._import_to_lpmm_knowledge(
|
||||
if saved_record and saved_record.get("id") is not None:
|
||||
await self._import_to_long_term_memory(
|
||||
record_id=int(saved_record["id"]),
|
||||
theme=theme,
|
||||
summary=summary,
|
||||
participants=participants,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
original_text=original_text,
|
||||
)
|
||||
|
||||
@@ -947,76 +949,131 @@ class ChatHistorySummarizer:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def _import_to_lpmm_knowledge(
|
||||
async def _import_to_long_term_memory(
|
||||
self,
|
||||
record_id: int,
|
||||
theme: str,
|
||||
summary: str,
|
||||
participants: List[str],
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
original_text: str,
|
||||
):
|
||||
"""
|
||||
将聊天历史总结导入到LPMM知识库
|
||||
将聊天历史总结导入到统一长期记忆
|
||||
|
||||
Args:
|
||||
record_id: chat_history 主键
|
||||
theme: 话题主题
|
||||
summary: 概括内容
|
||||
participants: 参与者列表
|
||||
start_time: 开始时间
|
||||
end_time: 结束时间
|
||||
original_text: 原始文本(可能很长,需要截断)
|
||||
"""
|
||||
try:
|
||||
from src.chat.knowledge.lpmm_ops import lpmm_ops
|
||||
from src.services.memory_service import memory_service
|
||||
session = _chat_manager.get_session_by_session_id(self.session_id)
|
||||
session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else ""
|
||||
session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else ""
|
||||
|
||||
# 构造要导入的文本内容
|
||||
# 格式:主题 + 概括 + 参与者信息 + 原始内容摘要
|
||||
# 注意:使用单换行符连接,确保整个内容作为一段导入,不被LPMM分段
|
||||
content_parts = []
|
||||
|
||||
# 1. 话题主题
|
||||
# if theme:
|
||||
# content_parts.append(f"话题:{theme}")
|
||||
|
||||
# 2. 概括内容
|
||||
if theme:
|
||||
content_parts.append(f"主题:{theme}")
|
||||
if summary:
|
||||
content_parts.append(f"概括:{summary}")
|
||||
|
||||
# 3. 参与者信息
|
||||
if participants:
|
||||
participants_text = "、".join(participants)
|
||||
content_parts.append(f"参与者:{participants_text}")
|
||||
|
||||
# 4. 原始文本摘要(如果原始文本太长,只取前500字)
|
||||
# if original_text:
|
||||
# # 截断原始文本,避免过长
|
||||
# max_original_length = 500
|
||||
# if len(original_text) > max_original_length:
|
||||
# truncated_text = original_text[:max_original_length] + "..."
|
||||
# content_parts.append(f"原始内容摘要:{truncated_text}")
|
||||
# else:
|
||||
# content_parts.append(f"原始内容:{original_text}")
|
||||
|
||||
# 将所有部分合并为一个完整段落(使用单换行符,避免被LPMM分段)
|
||||
# LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段
|
||||
content_to_import = "\n".join(content_parts)
|
||||
|
||||
if not content_to_import.strip():
|
||||
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库")
|
||||
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底")
|
||||
await self._fallback_import_to_long_term_memory(
|
||||
record_id=record_id,
|
||||
theme=theme,
|
||||
participants=participants,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
original_text=original_text,
|
||||
)
|
||||
return
|
||||
|
||||
# 调用lpmm_ops导入
|
||||
result = await lpmm_ops.add_content(text=content_to_import, auto_split=False)
|
||||
|
||||
if result["status"] == "success":
|
||||
logger.info(
|
||||
f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}"
|
||||
)
|
||||
result = await memory_service.ingest_summary(
|
||||
external_id=f"chat_history:{record_id}",
|
||||
chat_id=self.session_id,
|
||||
text=content_to_import,
|
||||
participants=participants,
|
||||
time_start=start_time,
|
||||
time_end=end_time,
|
||||
tags=[theme] if theme else [],
|
||||
metadata={"theme": theme, "original_text_length": len(original_text or "")},
|
||||
respect_filter=True,
|
||||
user_id=session_user_id,
|
||||
group_id=session_group_id,
|
||||
)
|
||||
if result.success:
|
||||
if result.detail == "chat_filtered":
|
||||
logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}"
|
||||
logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}")
|
||||
await self._fallback_import_to_long_term_memory(
|
||||
record_id=record_id,
|
||||
theme=theme,
|
||||
participants=participants,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
original_text=original_text,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 导入失败不应该影响数据库存储,只记录错误
|
||||
logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True)
|
||||
logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True)
|
||||
|
||||
async def _fallback_import_to_long_term_memory(
|
||||
self,
|
||||
*,
|
||||
record_id: int,
|
||||
theme: str,
|
||||
participants: List[str],
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
original_text: str,
|
||||
) -> None:
|
||||
try:
|
||||
from src.services.memory_service import memory_service
|
||||
session = _chat_manager.get_session_by_session_id(self.session_id)
|
||||
session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else ""
|
||||
session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else ""
|
||||
|
||||
result = await memory_service.ingest_summary(
|
||||
external_id=f"chat_history:{record_id}",
|
||||
chat_id=self.session_id,
|
||||
text="",
|
||||
participants=participants,
|
||||
time_start=start_time,
|
||||
time_end=end_time,
|
||||
tags=[theme] if theme else [],
|
||||
metadata={
|
||||
"theme": theme,
|
||||
"original_text_length": len(original_text or ""),
|
||||
"generate_from_chat": True,
|
||||
"context_length": global_config.memory.chat_history_topic_check_message_threshold,
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=session_user_id,
|
||||
group_id=session_group_id,
|
||||
)
|
||||
if result.success:
|
||||
if result.detail == "chat_filtered":
|
||||
logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}")
|
||||
except Exception as exc:
|
||||
logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True)
|
||||
|
||||
async def start(self):
|
||||
"""启动后台定期检查循环"""
|
||||
|
||||
@@ -237,8 +237,8 @@ async def _react_agent_solve_question(
|
||||
if first_head_prompt is None:
|
||||
# 第一次构建,使用初始的collected_info(即initial_info)
|
||||
initial_collected_info = initial_info or ""
|
||||
# 使用 LPMM 知识库检索 prompt
|
||||
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_lpmm")
|
||||
# 使用统一长期记忆检索 prompt
|
||||
first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory")
|
||||
first_head_prompt_template.add_context("bot_name", bot_name)
|
||||
first_head_prompt_template.add_context("time_now", time_now)
|
||||
first_head_prompt_template.add_context("chat_history", chat_history)
|
||||
|
||||
@@ -10,21 +10,17 @@ from .tool_registry import (
|
||||
get_tool_registry,
|
||||
)
|
||||
|
||||
# 导入所有工具的注册函数
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .return_information import register_tool as register_return_information
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
# 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。
|
||||
from .query_long_term_memory import register_tool as register_long_term_memory
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .return_information import register_tool as register_return_information
|
||||
|
||||
register_query_words()
|
||||
register_return_information()
|
||||
|
||||
# LPMM知识库检索工具
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
register_lpmm_knowledge()
|
||||
register_long_term_memory()
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
304
src/memory_system/retrieval_tools/query_long_term_memory.py
Normal file
304
src/memory_system/retrieval_tools/query_long_term_memory.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""通过统一长期记忆服务查询信息。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from calendar import monthrange
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Iterable, Literal, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service
|
||||
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"}
|
||||
_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$")
|
||||
_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
|
||||
_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$")
|
||||
_TIME_EXPRESSION_HELP = (
|
||||
"请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、"
|
||||
"2026/03/18、2026/03/18 09:30。"
|
||||
)
|
||||
|
||||
|
||||
def _format_query_datetime(dt: datetime) -> str:
|
||||
return dt.strftime("%Y/%m/%d %H:%M")
|
||||
|
||||
|
||||
def _resolve_time_expression(
|
||||
expression: str,
|
||||
*,
|
||||
now: datetime | None = None,
|
||||
) -> Tuple[float, float, str, str]:
|
||||
clean = str(expression or "").strip()
|
||||
if not clean:
|
||||
raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}")
|
||||
|
||||
current = now or datetime.now()
|
||||
day_start = current.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if clean == "今天":
|
||||
start = day_start
|
||||
end = day_start.replace(hour=23, minute=59)
|
||||
elif clean == "昨天":
|
||||
start = day_start - timedelta(days=1)
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif clean == "前天":
|
||||
start = day_start - timedelta(days=2)
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif clean == "本周":
|
||||
start = day_start - timedelta(days=day_start.weekday())
|
||||
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||
elif clean == "上周":
|
||||
this_week_start = day_start - timedelta(days=day_start.weekday())
|
||||
start = this_week_start - timedelta(days=7)
|
||||
end = start + timedelta(days=6, hours=23, minutes=59)
|
||||
elif clean == "本月":
|
||||
start = day_start.replace(day=1)
|
||||
last_day = monthrange(start.year, start.month)[1]
|
||||
end = start.replace(day=last_day, hour=23, minute=59)
|
||||
elif clean == "上月":
|
||||
year = day_start.year
|
||||
month = day_start.month - 1
|
||||
if month == 0:
|
||||
year -= 1
|
||||
month = 12
|
||||
start = day_start.replace(year=year, month=month, day=1)
|
||||
last_day = monthrange(year, month)[1]
|
||||
end = start.replace(day=last_day, hour=23, minute=59)
|
||||
else:
|
||||
relative_match = _RELATIVE_DAYS_RE.fullmatch(clean)
|
||||
if relative_match:
|
||||
days = max(1, int(relative_match.group(1)))
|
||||
start = day_start - timedelta(days=max(0, days - 1))
|
||||
end = day_start.replace(hour=23, minute=59)
|
||||
elif _DATE_RE.fullmatch(clean):
|
||||
start = datetime.strptime(clean, "%Y/%m/%d")
|
||||
end = start.replace(hour=23, minute=59)
|
||||
elif _MINUTE_RE.fullmatch(clean):
|
||||
start = datetime.strptime(clean, "%Y/%m/%d %H:%M")
|
||||
end = start
|
||||
else:
|
||||
raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}")
|
||||
|
||||
return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end)
|
||||
|
||||
|
||||
def _extract_time_label(metadata: dict) -> str:
|
||||
if not isinstance(metadata, dict):
|
||||
return ""
|
||||
start = metadata.get("event_time_start")
|
||||
end = metadata.get("event_time_end")
|
||||
event_time = metadata.get("event_time")
|
||||
|
||||
def _fmt(value: object) -> str:
|
||||
if value in {None, ""}:
|
||||
return ""
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M")
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
start_text = _fmt(start or event_time)
|
||||
end_text = _fmt(end)
|
||||
if start_text and end_text:
|
||||
return f"{start_text} - {end_text}"
|
||||
return start_text or end_text
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = 160) -> str:
|
||||
compact = str(text or "").strip().replace("\n", " ")
|
||||
if len(compact) <= limit:
|
||||
return compact
|
||||
return compact[:limit] + "..."
|
||||
|
||||
|
||||
def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
time_label = _extract_time_label(item.metadata) if include_time else ""
|
||||
prefix = f"[{time_label}] " if time_label else ""
|
||||
lines.append(f"{index}. {prefix}{_truncate(item.content)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||
title = str(item.title or "").strip() or "未命名事件"
|
||||
summary = _truncate(item.content, limit=180)
|
||||
participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()]
|
||||
keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()]
|
||||
extras = []
|
||||
if participants:
|
||||
extras.append(f"参与者:{'、'.join(participants[:4])}")
|
||||
if keywords:
|
||||
extras.append(f"关键词:{'、'.join(keywords[:6])}")
|
||||
time_label = _extract_time_label(metadata)
|
||||
if time_label:
|
||||
extras.append(f"时间:{time_label}")
|
||||
suffix = f"({';'.join(extras)})" if extras else ""
|
||||
lines.append(f"{index}. 事件《{title}》:{summary}{suffix}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str:
|
||||
lines = []
|
||||
for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1):
|
||||
metadata = item.metadata if isinstance(item.metadata, dict) else {}
|
||||
source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()]
|
||||
branch_text = f"[{','.join(source_branches)}]" if source_branches else ""
|
||||
item_type = str(item.hit_type or "").strip().lower() or "memory"
|
||||
if item_type == "episode":
|
||||
title = str(item.title or "").strip() or "未命名事件"
|
||||
lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}")
|
||||
else:
|
||||
lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_tool_result(
|
||||
*,
|
||||
result: MemorySearchResult,
|
||||
mode: Literal["search", "time", "episode", "aggregate"],
|
||||
limit: int,
|
||||
query: str,
|
||||
time_range_text: str = "",
|
||||
) -> str:
|
||||
if not result.hits:
|
||||
if mode == "time":
|
||||
return f"在指定时间范围内未找到相关的长期记忆{time_range_text}"
|
||||
if mode == "episode":
|
||||
return f"未找到与“{query}”相关的事件或情节记忆"
|
||||
if mode == "aggregate":
|
||||
return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}"
|
||||
return f"在长期记忆中未找到与“{query}”相关的信息"
|
||||
|
||||
if mode == "episode":
|
||||
text = _format_episode_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆的事件/情节中找到以下信息:\n{text}"
|
||||
|
||||
if mode == "aggregate":
|
||||
text = _format_aggregate_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆中综合找到了以下线索:\n{text}"
|
||||
|
||||
if mode == "time":
|
||||
text = _format_search_lines(result.hits, limit=limit, include_time=True)
|
||||
return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}"
|
||||
|
||||
text = _format_search_lines(result.hits, limit=limit)
|
||||
return f"你从长期记忆中找到以下信息:\n{text}"
|
||||
|
||||
|
||||
async def query_long_term_memory(
|
||||
query: str = "",
|
||||
limit: int = 5,
|
||||
chat_id: str = "",
|
||||
person_id: str = "",
|
||||
mode: str = "search",
|
||||
time_expression: str = "",
|
||||
) -> str:
|
||||
content = str(query or "").strip()
|
||||
safe_limit = max(1, int(limit or 5))
|
||||
normalized_mode = str(mode or "search").strip().lower() or "search"
|
||||
if normalized_mode not in _SUPPORTED_MODES:
|
||||
return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。"
|
||||
|
||||
if normalized_mode == "search" and not content:
|
||||
return "查询关键词为空,请提供你想查找的长期记忆内容。"
|
||||
if normalized_mode == "time" and not str(time_expression or "").strip():
|
||||
return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}"
|
||||
if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip():
|
||||
return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。"
|
||||
|
||||
time_start = None
|
||||
time_end = None
|
||||
time_range_text = ""
|
||||
if str(time_expression or "").strip():
|
||||
try:
|
||||
time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression)
|
||||
except ValueError as exc:
|
||||
return str(exc)
|
||||
time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})"
|
||||
|
||||
backend_mode = "hybrid" if normalized_mode == "search" else normalized_mode
|
||||
|
||||
try:
|
||||
result = await memory_service.search(
|
||||
content,
|
||||
limit=safe_limit,
|
||||
mode=backend_mode,
|
||||
chat_id=str(chat_id or "").strip(),
|
||||
person_id=str(person_id or "").strip(),
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
)
|
||||
text = _format_tool_result(
|
||||
result=result,
|
||||
mode=normalized_mode, # type: ignore[arg-type]
|
||||
limit=safe_limit,
|
||||
query=content,
|
||||
time_range_text=time_range_text,
|
||||
)
|
||||
logger.debug(f"长期记忆查询结果({normalized_mode}): {text}")
|
||||
return text
|
||||
except Exception as exc:
|
||||
logger.error(f"长期记忆查询失败: {exc}")
|
||||
return f"长期记忆查询失败:{exc}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
register_memory_retrieval_tool(
|
||||
name="search_long_term_memory",
|
||||
description=(
|
||||
"从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、"
|
||||
"episode(按事件/情节检索)、aggregate(综合检索)四种模式。"
|
||||
),
|
||||
parameters=[
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "mode",
|
||||
"type": "string",
|
||||
"description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。",
|
||||
"required": False,
|
||||
"enum": ["search", "time", "episode", "aggregate"],
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"type": "integer",
|
||||
"description": "希望返回的相关知识条数,默认为5",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "chat_id",
|
||||
"type": "string",
|
||||
"description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "person_id",
|
||||
"type": "string",
|
||||
"description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。",
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_expression",
|
||||
"type": "string",
|
||||
"description": (
|
||||
"时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、"
|
||||
"最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。"
|
||||
),
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_long_term_memory,
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""
|
||||
通过LPMM知识库查询信息 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import get_qa_manager
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
|
||||
"""在LPMM知识库中查询相关信息
|
||||
|
||||
Args:
|
||||
query: 查询关键词
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(query).strip()
|
||||
if not content:
|
||||
return "查询关键词为空"
|
||||
|
||||
try:
|
||||
limit_value = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
limit_value = 5
|
||||
limit_value = max(1, limit_value)
|
||||
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用")
|
||||
return "LPMM知识库未启用"
|
||||
|
||||
qa_manager = get_qa_manager()
|
||||
if qa_manager is None:
|
||||
logger.debug("LPMM知识库未初始化,跳过查询")
|
||||
return "LPMM知识库未初始化"
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
|
||||
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info:
|
||||
return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}"
|
||||
|
||||
return f"在LPMM知识库中未找到与“{content}”相关的信息"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LPMM知识库查询失败: {e}")
|
||||
return f"LPMM知识库查询失败:{str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册LPMM知识库查询工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="lpmm_search_knowledge",
|
||||
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
|
||||
parameters=[
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "需要查询的问题,使用一句疑问句提问,例如:什么是AI?",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"type": "integer",
|
||||
"description": "希望返回的相关知识条数,默认为5",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_lpmm_knowledge,
|
||||
)
|
||||
@@ -6,9 +6,10 @@ import random
|
||||
import math
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional, Dict
|
||||
from typing import Union, Optional, Dict, List
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -17,6 +18,7 @@ from src.common.database.database_model import PersonInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
|
||||
logger = get_logger("person_info")
|
||||
@@ -37,16 +39,60 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
|
||||
def get_person_id_by_person_name(person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
clean_name = str(person_name or "").strip()
|
||||
if not clean_name:
|
||||
return ""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
|
||||
statement = (
|
||||
select(PersonInfo)
|
||||
.where(
|
||||
or_(
|
||||
col(PersonInfo.person_name) == clean_name,
|
||||
col(PersonInfo.user_nickname) == clean_name,
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
if record and record.person_id:
|
||||
return record.person_id
|
||||
|
||||
statement = (
|
||||
select(PersonInfo)
|
||||
.where(PersonInfo.group_cardname.contains(clean_name))
|
||||
.limit(1)
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||
logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def resolve_person_id_for_memory(
|
||||
*,
|
||||
person_name: str = "",
|
||||
platform: str = "",
|
||||
user_id: Optional[Union[int, str]] = None,
|
||||
) -> str:
|
||||
"""统一人物记忆链路中的 person_id 解析。
|
||||
|
||||
优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。
|
||||
"""
|
||||
name_token = str(person_name or "").strip()
|
||||
if name_token:
|
||||
resolved = get_person_id_by_person_name(name_token)
|
||||
if resolved:
|
||||
return resolved
|
||||
|
||||
platform_token = str(platform or "").strip()
|
||||
user_token = str(user_id or "").strip()
|
||||
if platform_token and user_token:
|
||||
return get_person_id(platform_token, user_token)
|
||||
return ""
|
||||
|
||||
|
||||
def is_person_known(
|
||||
person_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
@@ -537,79 +583,79 @@ class Person:
|
||||
async def build_relationship(self, chat_content: str = "", info_type=""):
|
||||
if not self.is_known:
|
||||
return ""
|
||||
# 构建points文本
|
||||
|
||||
nickname_str = ""
|
||||
if self.person_name != self.nickname:
|
||||
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
||||
|
||||
relation_info = ""
|
||||
async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]:
|
||||
clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()]
|
||||
if not clean_traits:
|
||||
return []
|
||||
if not query_text:
|
||||
return clean_traits[:limit]
|
||||
|
||||
points_text = ""
|
||||
category_list = self.get_all_category()
|
||||
numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1))
|
||||
prompt = f"""当前关注内容:
|
||||
{query_text}
|
||||
|
||||
if chat_content:
|
||||
prompt = f"""当前聊天内容:
|
||||
{chat_content}
|
||||
候选人物信息:
|
||||
{numbered_traits}
|
||||
|
||||
分类列表:
|
||||
{category_list}
|
||||
**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
|
||||
例如:
|
||||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。
|
||||
例如:
|
||||
<1><3>
|
||||
如果都不相关,请输出<none>"""
|
||||
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
if "none" not in category_list:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 2)
|
||||
if random_memory:
|
||||
random_memory_str = "\n".join(
|
||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
||||
)
|
||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
||||
break
|
||||
elif info_type:
|
||||
prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。
|
||||
try:
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
selected_traits: List[str] = []
|
||||
for raw_index in extract_categories_from_response(response):
|
||||
if raw_index == "none":
|
||||
return []
|
||||
try:
|
||||
trait_index = int(raw_index) - 1
|
||||
except ValueError:
|
||||
continue
|
||||
if 0 <= trait_index < len(clean_traits):
|
||||
trait = clean_traits[trait_index]
|
||||
if trait not in selected_traits:
|
||||
selected_traits.append(trait)
|
||||
if selected_traits:
|
||||
return selected_traits[:limit]
|
||||
except Exception as e:
|
||||
logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}")
|
||||
|
||||
现有信息类别列表:
|
||||
{category_list}
|
||||
**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹:
|
||||
例如:
|
||||
<分类1><分类2><分类3>......
|
||||
如果没有相关的分类,请输出<none>"""
|
||||
response, _ = await relation_selection_model.generate_response_async(prompt)
|
||||
# print(prompt)
|
||||
# print(response)
|
||||
category_list = extract_categories_from_response(response)
|
||||
if "none" not in category_list:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 3)
|
||||
if random_memory:
|
||||
random_memory_str = "\n".join(
|
||||
[get_memory_content_from_memory(memory) for memory in random_memory]
|
||||
)
|
||||
points_text = f"有关 {category} 的内容:{random_memory_str}"
|
||||
break
|
||||
else:
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category, 1)[0]
|
||||
if random_memory:
|
||||
points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}"
|
||||
break
|
||||
return clean_traits[:limit]
|
||||
|
||||
profile = await memory_service.get_person_profile(self.person_id, limit=8)
|
||||
relation_parts: List[str] = []
|
||||
if profile.summary.strip():
|
||||
relation_parts.append(profile.summary.strip())
|
||||
|
||||
query_text = str(chat_content or info_type or "").strip()
|
||||
selected_traits = await _select_traits(query_text, profile.traits, limit=3)
|
||||
if not selected_traits and not query_text:
|
||||
selected_traits = [trait for trait in profile.traits if trait][:2]
|
||||
|
||||
for trait in selected_traits:
|
||||
clean_trait = str(trait).strip()
|
||||
if clean_trait and clean_trait not in relation_parts:
|
||||
relation_parts.append(clean_trait)
|
||||
|
||||
for evidence in profile.evidence:
|
||||
content = str(evidence.get("content", "") or "").strip()
|
||||
if content and content not in relation_parts:
|
||||
relation_parts.append(content)
|
||||
if len(relation_parts) >= 4:
|
||||
break
|
||||
|
||||
points_info = ""
|
||||
if points_text:
|
||||
points_info = f"你还记得有关{self.person_name}的内容:{points_text}"
|
||||
if relation_parts:
|
||||
points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}"
|
||||
|
||||
if not (nickname_str or points_info):
|
||||
return ""
|
||||
relation_info = f"{self.person_name}:{nickname_str}{points_info}"
|
||||
|
||||
return relation_info
|
||||
return f"{self.person_name}:{nickname_str}{points_info}"
|
||||
|
||||
|
||||
class PersonInfoManager:
|
||||
@@ -776,7 +822,7 @@ person_info_manager = PersonInfoManager()
|
||||
|
||||
|
||||
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
|
||||
"""将人物信息存入person_info的memory_points
|
||||
"""将人物事实写入统一长期记忆
|
||||
|
||||
Args:
|
||||
person_name: 人物名称
|
||||
@@ -784,6 +830,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
try:
|
||||
content = str(memory_content or "").strip()
|
||||
if not content:
|
||||
logger.debug("人物记忆内容为空,跳过写入")
|
||||
return
|
||||
|
||||
# 从 chat_id 获取 session
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
@@ -794,16 +845,14 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
|
||||
# 尝试从person_name查找person_id
|
||||
# 首先尝试通过person_name查找
|
||||
person_id = get_person_id_by_person_name(person_name)
|
||||
|
||||
person_id = resolve_person_id_for_memory(
|
||||
person_name=person_name,
|
||||
platform=platform,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
if not person_id:
|
||||
# 如果通过person_name找不到,尝试从 session 获取 user_id
|
||||
if platform and session.user_id:
|
||||
user_id = session.user_id
|
||||
person_id = get_person_id(platform, user_id)
|
||||
else:
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
|
||||
return
|
||||
|
||||
# 创建或获取Person对象
|
||||
person = Person(person_id=person_id)
|
||||
@@ -812,39 +861,34 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
|
||||
return
|
||||
|
||||
# 确定记忆分类(可以根据memory_content判断,这里使用通用分类)
|
||||
category = "其他" # 默认分类,可以根据需要调整
|
||||
memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16]
|
||||
result = await memory_service.ingest_text(
|
||||
external_id=f"person_fact:{person_id}:{memory_hash}",
|
||||
source_type="person_fact",
|
||||
text=content,
|
||||
chat_id=chat_id,
|
||||
person_ids=[person_id],
|
||||
participants=[person.person_name or person_name],
|
||||
timestamp=time.time(),
|
||||
tags=["person_fact"],
|
||||
metadata={
|
||||
"person_id": person_id,
|
||||
"person_name": person.person_name or person_name,
|
||||
"platform": platform,
|
||||
"source": "person_info.store_person_memory_from_answer",
|
||||
},
|
||||
respect_filter=True,
|
||||
user_id=str(session.user_id or "").strip(),
|
||||
group_id=str(session.group_id or "").strip(),
|
||||
)
|
||||
|
||||
# 记忆点格式:category:content:weight
|
||||
weight = "1.0" # 默认权重
|
||||
memory_point = f"{category}:{memory_content}:{weight}"
|
||||
|
||||
# 添加到memory_points
|
||||
if not person.memory_points:
|
||||
person.memory_points = []
|
||||
|
||||
# 检查是否已存在相似的记忆点(避免重复)
|
||||
is_duplicate = False
|
||||
for existing_point in person.memory_points:
|
||||
if existing_point and isinstance(existing_point, str):
|
||||
parts = existing_point.split(":", 2)
|
||||
if len(parts) >= 2:
|
||||
existing_content = parts[1].strip()
|
||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||
if (
|
||||
existing_content == memory_content
|
||||
or memory_content in existing_content
|
||||
or existing_content in memory_content
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
person.memory_points.append(memory_point)
|
||||
person.sync_to_database()
|
||||
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
|
||||
if result.success:
|
||||
if result.detail == "chat_filtered":
|
||||
logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})")
|
||||
else:
|
||||
logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})")
|
||||
else:
|
||||
logger.debug(f"记忆点已存在,跳过: {memory_point}")
|
||||
logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储人物记忆失败: {e}")
|
||||
|
||||
@@ -672,12 +672,10 @@ class RuntimeDataCapabilityMixin:
|
||||
limit_value = 5
|
||||
|
||||
try:
|
||||
from src.chat.knowledge import qa_manager
|
||||
from src.services.memory_service import memory_service
|
||||
|
||||
if qa_manager is None:
|
||||
return {"success": True, "content": "LPMM知识库已禁用"}
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
|
||||
result = await memory_service.search(query, limit=limit_value)
|
||||
knowledge_info = result.to_text(limit=limit_value)
|
||||
content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识"
|
||||
return {"success": True, "content": content}
|
||||
except Exception as e:
|
||||
|
||||
275
src/services/memory_flow_service.py
Normal file
275
src/services/memory_flow_service.py
Normal file
@@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.memory_system.chat_history_summarizer import ChatHistorySummarizer
|
||||
from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer
|
||||
|
||||
logger = get_logger("memory_flow_service")
|
||||
|
||||
|
||||
class LongTermMemorySessionManager:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._summarizers: Dict[str, ChatHistorySummarizer] = {}
|
||||
|
||||
async def on_message(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)):
|
||||
return
|
||||
session_id = str(getattr(message, "session_id", "") or "").strip()
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
created = False
|
||||
async with self._lock:
|
||||
summarizer = self._summarizers.get(session_id)
|
||||
if summarizer is None:
|
||||
summarizer = ChatHistorySummarizer(session_id=session_id)
|
||||
self._summarizers[session_id] = summarizer
|
||||
created = True
|
||||
if created:
|
||||
await summarizer.start()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
async with self._lock:
|
||||
items = list(self._summarizers.items())
|
||||
self._summarizers.clear()
|
||||
for session_id, summarizer in items:
|
||||
try:
|
||||
await summarizer.stop()
|
||||
except Exception as exc:
|
||||
logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc)
|
||||
|
||||
|
||||
class PersonFactWritebackService:
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256)
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self._stopping = False
|
||||
self._extractor = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="person_fact_writeback",
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._worker_task is not None and not self._worker_task.done():
|
||||
return
|
||||
self._stopping = False
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_person_fact_writeback")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self._stopping = True
|
||||
worker = self._worker_task
|
||||
self._worker_task = None
|
||||
if worker is None:
|
||||
return
|
||||
worker.cancel()
|
||||
try:
|
||||
await worker
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭人物事实写回 worker 失败: %s", exc)
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(getattr(global_config.memory, "person_fact_writeback_enabled", True)):
|
||||
return
|
||||
if self._stopping:
|
||||
return
|
||||
try:
|
||||
self._queue.put_nowait(message)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("人物事实写回队列已满,跳过本次回复")
|
||||
|
||||
async def _worker_loop(self) -> None:
|
||||
try:
|
||||
while not self._stopping:
|
||||
message = await self._queue.get()
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("人物事实写回处理失败: %s", exc, exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
||||
async def _handle_message(self, message: Any) -> None:
|
||||
reply_text = str(getattr(message, "processed_plain_text", "") or "").strip()
|
||||
if not reply_text:
|
||||
return
|
||||
if self._looks_ephemeral(reply_text):
|
||||
return
|
||||
|
||||
target_person = self._resolve_target_person(message)
|
||||
if target_person is None or not target_person.is_known:
|
||||
return
|
||||
|
||||
facts = await self._extract_facts(target_person, reply_text)
|
||||
if not facts:
|
||||
return
|
||||
|
||||
session_id = str(
|
||||
getattr(message, "session_id", "")
|
||||
or getattr(getattr(message, "session", None), "session_id", "")
|
||||
or ""
|
||||
).strip()
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip()
|
||||
if not person_name:
|
||||
return
|
||||
|
||||
for fact in facts:
|
||||
await store_person_memory_from_answer(person_name, fact, session_id)
|
||||
|
||||
def _resolve_target_person(self, message: Any) -> Optional[Person]:
|
||||
session = getattr(message, "session", None)
|
||||
session_platform = str(getattr(session, "platform", "") or getattr(message, "platform", "") or "").strip()
|
||||
session_user_id = str(getattr(session, "user_id", "") or "").strip()
|
||||
group_id = str(getattr(session, "group_id", "") or "").strip()
|
||||
|
||||
if session_platform and session_user_id and not group_id:
|
||||
if is_bot_self(session_platform, session_user_id):
|
||||
return None
|
||||
person_id = get_person_id(session_platform, session_user_id)
|
||||
person = Person(person_id=person_id)
|
||||
return person if person.is_known else None
|
||||
|
||||
reply_to = str(getattr(message, "reply_to", "") or "").strip()
|
||||
if not reply_to:
|
||||
return None
|
||||
try:
|
||||
replies = find_messages(message_id=reply_to, limit=1)
|
||||
except Exception as exc:
|
||||
logger.debug("查询 reply_to 目标失败: %s", exc)
|
||||
return None
|
||||
if not replies:
|
||||
return None
|
||||
reply_message = replies[0]
|
||||
reply_platform = str(getattr(reply_message, "platform", "") or session_platform or "").strip()
|
||||
reply_user_info = getattr(getattr(reply_message, "message_info", None), "user_info", None)
|
||||
reply_user_id = str(getattr(reply_user_info, "user_id", "") or "").strip()
|
||||
if not reply_platform or not reply_user_id or is_bot_self(reply_platform, reply_user_id):
|
||||
return None
|
||||
person_id = get_person_id(reply_platform, reply_user_id)
|
||||
person = Person(person_id=person_id)
|
||||
return person if person.is_known else None
|
||||
|
||||
async def _extract_facts(self, person: Person, reply_text: str) -> List[str]:
|
||||
person_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or person.person_id)
|
||||
prompt = f"""你要从一条机器人刚刚发送的回复中,提取“关于{person_name}的稳定事实”。
|
||||
|
||||
目标人物:{person_name}
|
||||
机器人回复:
|
||||
{reply_text}
|
||||
|
||||
请只提取满足以下条件的事实:
|
||||
1. 明确是关于目标人物本人的信息。
|
||||
2. 具有相对稳定性,可以作为长期记忆保存。
|
||||
3. 用简洁中文陈述句表达。
|
||||
|
||||
不要提取:
|
||||
- 机器人的情绪、计划、临时动作、客套话
|
||||
- 只适用于当前时刻的短期安排
|
||||
- 不确定、猜测、反问
|
||||
- 与目标人物无关的信息
|
||||
|
||||
严格输出 JSON 数组,例如:
|
||||
["他喜欢深夜打游戏", "他养了一只猫"]
|
||||
如果没有可写入的事实,输出 []"""
|
||||
try:
|
||||
response, _ = await self._extractor.generate_response_async(prompt)
|
||||
except Exception as exc:
|
||||
logger.debug("人物事实提取模型调用失败: %s", exc)
|
||||
return []
|
||||
return self._parse_fact_list(response)
|
||||
|
||||
@staticmethod
|
||||
def _parse_fact_list(raw: str) -> List[str]:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
repaired = repair_json(text)
|
||||
payload = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
except Exception:
|
||||
payload = None
|
||||
if not isinstance(payload, list):
|
||||
return []
|
||||
|
||||
items: List[str] = []
|
||||
seen = set()
|
||||
for item in payload:
|
||||
fact = str(item or "").strip().strip("- ")
|
||||
if not fact or len(fact) < 4:
|
||||
continue
|
||||
if fact in seen:
|
||||
continue
|
||||
seen.add(fact)
|
||||
items.append(fact)
|
||||
return items[:5]
|
||||
|
||||
@staticmethod
|
||||
def _looks_ephemeral(text: str) -> bool:
|
||||
content = str(text or "").strip()
|
||||
if not content:
|
||||
return True
|
||||
ephemeral_markers = (
|
||||
"哈哈",
|
||||
"好的",
|
||||
"收到",
|
||||
"嗯嗯",
|
||||
"晚安",
|
||||
"早安",
|
||||
"拜拜",
|
||||
"谢谢",
|
||||
"在吗",
|
||||
"?",
|
||||
)
|
||||
if len(content) <= 8 and any(marker in content for marker in ephemeral_markers):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MemoryAutomationService:
|
||||
def __init__(self) -> None:
|
||||
self.session_manager = LongTermMemorySessionManager()
|
||||
self.fact_writeback = PersonFactWritebackService()
|
||||
self._started = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._started:
|
||||
return
|
||||
await self.fact_writeback.start()
|
||||
self._started = True
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._started:
|
||||
return
|
||||
await self.session_manager.shutdown()
|
||||
await self.fact_writeback.shutdown()
|
||||
self._started = False
|
||||
|
||||
async def on_incoming_message(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.session_manager.on_message(message)
|
||||
|
||||
async def on_message_sent(self, message: Any) -> None:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
await self.fact_writeback.enqueue(message)
|
||||
|
||||
|
||||
memory_automation_service = MemoryAutomationService()
|
||||
428
src/services/memory_service.py
Normal file
428
src/services/memory_service.py
Normal file
@@ -0,0 +1,428 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
|
||||
logger = get_logger("memory_service")
|
||||
|
||||
PLUGIN_ID = "A_Memorix"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryHit:
|
||||
content: str
|
||||
score: float = 0.0
|
||||
hit_type: str = ""
|
||||
source: str = ""
|
||||
hash_value: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
episode_id: str = ""
|
||||
title: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"content": self.content,
|
||||
"score": self.score,
|
||||
"type": self.hit_type,
|
||||
"source": self.source,
|
||||
"hash": self.hash_value,
|
||||
"metadata": self.metadata,
|
||||
"episode_id": self.episode_id,
|
||||
"title": self.title,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySearchResult:
|
||||
summary: str = ""
|
||||
hits: List[MemoryHit] = field(default_factory=list)
|
||||
filtered: bool = False
|
||||
|
||||
def to_text(self, limit: int = 5) -> str:
|
||||
if not self.hits:
|
||||
return ""
|
||||
lines = []
|
||||
for index, item in enumerate(self.hits[: max(1, int(limit))], start=1):
|
||||
content = item.content.strip().replace("\n", " ")
|
||||
if len(content) > 160:
|
||||
content = content[:160] + "..."
|
||||
lines.append(f"{index}. {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"summary": self.summary,
|
||||
"hits": [item.to_dict() for item in self.hits],
|
||||
"filtered": self.filtered,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryWriteResult:
|
||||
success: bool
|
||||
stored_ids: List[str] = field(default_factory=list)
|
||||
skipped_ids: List[str] = field(default_factory=list)
|
||||
detail: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"success": self.success,
|
||||
"stored_ids": self.stored_ids,
|
||||
"skipped_ids": self.skipped_ids,
|
||||
"detail": self.detail,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PersonProfileResult:
|
||||
summary: str = ""
|
||||
traits: List[str] = field(default_factory=list)
|
||||
evidence: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"summary": self.summary, "traits": self.traits, "evidence": self.evidence}
|
||||
|
||||
|
||||
class MemoryService:
|
||||
async def _invoke(self, component_name: str, args: Optional[Dict[str, Any]] = None, *, timeout_ms: int = 30000) -> Any:
|
||||
runtime = get_plugin_runtime_manager()
|
||||
if not runtime.is_running:
|
||||
raise RuntimeError("plugin_runtime 未启动")
|
||||
return await runtime.invoke_plugin(
|
||||
method="plugin.invoke_tool",
|
||||
plugin_id=PLUGIN_ID,
|
||||
component_name=component_name,
|
||||
args=args or {},
|
||||
timeout_ms=max(1000, int(timeout_ms or 30000)),
|
||||
)
|
||||
|
||||
async def _invoke_admin(
|
||||
self,
|
||||
component_name: str,
|
||||
*,
|
||||
action: str,
|
||||
timeout_ms: int = 30000,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
payload = await self._invoke(component_name, {"action": action, **kwargs}, timeout_ms=timeout_ms)
|
||||
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_write_result(payload: Any) -> MemoryWriteResult:
|
||||
if not isinstance(payload, dict):
|
||||
return MemoryWriteResult(success=False, detail="invalid_payload")
|
||||
stored_ids = [str(item) for item in (payload.get("stored_ids") or []) if str(item).strip()]
|
||||
skipped_ids = [str(item) for item in (payload.get("skipped_ids") or []) if str(item).strip()]
|
||||
detail = str(payload.get("detail") or payload.get("reason") or "")
|
||||
if stored_ids or skipped_ids:
|
||||
success = True
|
||||
elif "success" in payload:
|
||||
success = bool(payload.get("success"))
|
||||
else:
|
||||
success = not bool(detail)
|
||||
return MemoryWriteResult(
|
||||
success=success,
|
||||
stored_ids=stored_ids,
|
||||
skipped_ids=skipped_ids,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_search_result(payload: Any) -> MemorySearchResult:
|
||||
if not isinstance(payload, dict):
|
||||
return MemorySearchResult()
|
||||
hits: List[MemoryHit] = []
|
||||
for item in payload.get("hits", []) or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
metadata = item.get("metadata", {}) or {}
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
if "source_branches" in item and "source_branches" not in metadata:
|
||||
metadata["source_branches"] = item.get("source_branches") or []
|
||||
if "rank" in item and "rank" not in metadata:
|
||||
metadata["rank"] = item.get("rank")
|
||||
hits.append(
|
||||
MemoryHit(
|
||||
content=str(item.get("content", "") or ""),
|
||||
score=float(item.get("score", 0.0) or 0.0),
|
||||
hit_type=str(item.get("type", "") or ""),
|
||||
source=str(item.get("source", "") or ""),
|
||||
hash_value=str(item.get("hash", "") or ""),
|
||||
metadata=metadata,
|
||||
episode_id=str(item.get("episode_id", "") or ""),
|
||||
title=str(item.get("title", "") or ""),
|
||||
)
|
||||
)
|
||||
return MemorySearchResult(
|
||||
summary=str(payload.get("summary", "") or ""),
|
||||
hits=hits,
|
||||
filtered=bool(payload.get("filtered", False)),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_profile_result(payload: Any) -> PersonProfileResult:
|
||||
if not isinstance(payload, dict):
|
||||
return PersonProfileResult()
|
||||
return PersonProfileResult(
|
||||
summary=str(payload.get("summary", "") or ""),
|
||||
traits=[str(item) for item in (payload.get("traits") or []) if str(item).strip()],
|
||||
evidence=[item for item in (payload.get("evidence") or []) if isinstance(item, dict)],
|
||||
)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
limit: int = 5,
|
||||
mode: str = "hybrid",
|
||||
chat_id: str = "",
|
||||
person_id: str = "",
|
||||
time_start: str | float | None = None,
|
||||
time_end: str | float | None = None,
|
||||
respect_filter: bool = True,
|
||||
user_id: str = "",
|
||||
group_id: str = "",
|
||||
) -> MemorySearchResult:
|
||||
clean_query = str(query or "").strip()
|
||||
normalized_time_start = None if time_start in {None, ""} else time_start
|
||||
normalized_time_end = None if time_end in {None, ""} else time_end
|
||||
if not clean_query and normalized_time_start is None and normalized_time_end is None:
|
||||
return MemorySearchResult()
|
||||
try:
|
||||
payload = await self._invoke(
|
||||
"search_memory",
|
||||
{
|
||||
"query": clean_query,
|
||||
"limit": max(1, int(limit)),
|
||||
"mode": mode,
|
||||
"chat_id": chat_id,
|
||||
"person_id": person_id,
|
||||
"time_start": normalized_time_start,
|
||||
"time_end": normalized_time_end,
|
||||
"respect_filter": bool(respect_filter),
|
||||
"user_id": str(user_id or "").strip(),
|
||||
"group_id": str(group_id or "").strip(),
|
||||
},
|
||||
)
|
||||
return self._coerce_search_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("长期记忆搜索失败: %s", exc)
|
||||
return MemorySearchResult()
|
||||
|
||||
async def ingest_summary(
|
||||
self,
|
||||
*,
|
||||
external_id: str,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
participants: Optional[List[str]] = None,
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
respect_filter: bool = True,
|
||||
user_id: str = "",
|
||||
group_id: str = "",
|
||||
) -> MemoryWriteResult:
|
||||
try:
|
||||
payload = await self._invoke(
|
||||
"ingest_summary",
|
||||
{
|
||||
"external_id": external_id,
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
"participants": participants or [],
|
||||
"time_start": time_start,
|
||||
"time_end": time_end,
|
||||
"tags": tags or [],
|
||||
"metadata": metadata or {},
|
||||
"respect_filter": bool(respect_filter),
|
||||
"user_id": str(user_id or "").strip(),
|
||||
"group_id": str(group_id or "").strip(),
|
||||
},
|
||||
)
|
||||
return self._coerce_write_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("长期记忆写入摘要失败: %s", exc)
|
||||
return MemoryWriteResult(success=False, detail=str(exc))
|
||||
|
||||
async def ingest_text(
|
||||
self,
|
||||
*,
|
||||
external_id: str,
|
||||
source_type: str,
|
||||
text: str,
|
||||
chat_id: str = "",
|
||||
person_ids: Optional[List[str]] = None,
|
||||
participants: Optional[List[str]] = None,
|
||||
timestamp: float | None = None,
|
||||
time_start: float | None = None,
|
||||
time_end: float | None = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
entities: Optional[List[str]] = None,
|
||||
relations: Optional[List[Dict[str, Any]]] = None,
|
||||
respect_filter: bool = True,
|
||||
user_id: str = "",
|
||||
group_id: str = "",
|
||||
) -> MemoryWriteResult:
|
||||
try:
|
||||
payload = await self._invoke(
|
||||
"ingest_text",
|
||||
{
|
||||
"external_id": external_id,
|
||||
"source_type": source_type,
|
||||
"text": text,
|
||||
"chat_id": chat_id,
|
||||
"person_ids": person_ids or [],
|
||||
"participants": participants or [],
|
||||
"timestamp": timestamp,
|
||||
"time_start": time_start,
|
||||
"time_end": time_end,
|
||||
"tags": tags or [],
|
||||
"metadata": metadata or {},
|
||||
"entities": entities or [],
|
||||
"relations": relations or [],
|
||||
"respect_filter": bool(respect_filter),
|
||||
"user_id": str(user_id or "").strip(),
|
||||
"group_id": str(group_id or "").strip(),
|
||||
},
|
||||
)
|
||||
return self._coerce_write_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("长期记忆写入文本失败: %s", exc)
|
||||
return MemoryWriteResult(success=False, detail=str(exc))
|
||||
|
||||
async def get_person_profile(self, person_id: str, *, chat_id: str = "", limit: int = 10) -> PersonProfileResult:
|
||||
clean_person_id = str(person_id or "").strip()
|
||||
if not clean_person_id:
|
||||
return PersonProfileResult()
|
||||
try:
|
||||
payload = await self._invoke(
|
||||
"get_person_profile",
|
||||
{"person_id": clean_person_id, "chat_id": chat_id, "limit": max(1, int(limit))},
|
||||
)
|
||||
return self._coerce_profile_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("获取人物画像失败: %s", exc)
|
||||
return PersonProfileResult()
|
||||
|
||||
async def maintain_memory(
|
||||
self,
|
||||
*,
|
||||
action: str,
|
||||
target: str = "",
|
||||
hours: float | None = None,
|
||||
reason: str = "",
|
||||
limit: int = 50,
|
||||
) -> MemoryWriteResult:
|
||||
try:
|
||||
payload = await self._invoke(
|
||||
"maintain_memory",
|
||||
{"action": action, "target": target, "hours": hours, "reason": reason, "limit": limit},
|
||||
)
|
||||
if not isinstance(payload, dict):
|
||||
return MemoryWriteResult(success=False, detail="invalid_payload")
|
||||
return MemoryWriteResult(success=bool(payload.get("success")), detail=str(payload.get("detail", "") or ""))
|
||||
except Exception as exc:
|
||||
logger.warning("记忆维护失败: %s", exc)
|
||||
return MemoryWriteResult(success=False, detail=str(exc))
|
||||
|
||||
async def memory_stats(self) -> Dict[str, Any]:
|
||||
try:
|
||||
payload = await self._invoke("memory_stats", {})
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
except Exception as exc:
|
||||
logger.warning("获取记忆统计失败: %s", exc)
|
||||
return {}
|
||||
|
||||
async def graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_graph_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("图谱管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_source_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("来源管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_episode_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Episode 管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_profile_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("画像管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("运行时管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def import_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_import_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("导入管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def tuning_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_tuning_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("调优管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def v5_admin(self, *, action: str, timeout_ms: int = 30000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_v5_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("V5 记忆管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def delete_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_delete_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("删除管理调用失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def get_recycle_bin(self, *, limit: int = 50) -> Dict[str, Any]:
|
||||
try:
|
||||
payload = await self._invoke("maintain_memory", {"action": "recycle_bin", "limit": max(1, int(limit or 50))})
|
||||
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
|
||||
except Exception as exc:
|
||||
logger.warning("获取回收站失败: %s", exc)
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def restore_memory(self, *, target: str) -> MemoryWriteResult:
|
||||
return await self.maintain_memory(action="restore", target=target)
|
||||
|
||||
async def reinforce_memory(self, *, target: str) -> MemoryWriteResult:
|
||||
return await self.maintain_memory(action="reinforce", target=target)
|
||||
|
||||
async def freeze_memory(self, *, target: str) -> MemoryWriteResult:
|
||||
return await self.maintain_memory(action="freeze", target=target)
|
||||
|
||||
async def protect_memory(self, *, target: str, hours: float | None = None) -> MemoryWriteResult:
|
||||
return await self.maintain_memory(action="protect", target=target, hours=hours)
|
||||
|
||||
|
||||
memory_service = MemoryService()
|
||||
@@ -17,14 +17,14 @@ def get_all_routers() -> List[APIRouter]:
|
||||
from src.webui.api.planner import router as planner_router
|
||||
from src.webui.api.replier import router as replier_router
|
||||
from src.webui.routers.chat import router as chat_router
|
||||
from src.webui.routers.knowledge import router as knowledge_router
|
||||
from src.webui.routers.memory import compat_router as memory_compat_router
|
||||
from src.webui.routers.websocket.logs import router as logs_router
|
||||
from src.webui.routes import router as main_router
|
||||
|
||||
return [
|
||||
main_router,
|
||||
memory_compat_router,
|
||||
logs_router,
|
||||
knowledge_router,
|
||||
chat_router,
|
||||
planner_router,
|
||||
replier_router,
|
||||
|
||||
1395
src/webui/routers/memory.py
Normal file
1395
src/webui/routers/memory.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -16,6 +16,7 @@ from src.webui.routers.config import router as config_router
|
||||
from src.webui.routers.emoji import router as emoji_router
|
||||
from src.webui.routers.expression import router as expression_router
|
||||
from src.webui.routers.jargon import router as jargon_router
|
||||
from src.webui.routers.memory import router as memory_router
|
||||
from src.webui.routers.model import router as model_router
|
||||
from src.webui.routers.person import router as person_router
|
||||
from src.webui.routers.plugin import get_progress_router
|
||||
@@ -49,6 +50,8 @@ router.include_router(get_progress_router())
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册长期记忆管理路由
|
||||
router.include_router(memory_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user