feat:lpmm可选接入memory agent,将memory agent改为标准工具格式,修改llm_utils以兼容

This commit is contained in:
SengokuCola
2025-11-13 18:55:37 +08:00
parent e52a81e90b
commit f2819be5e9
18 changed files with 868 additions and 432 deletions

View File

@@ -11,6 +11,7 @@ retrieval_tools/
├── tool_utils.py # 工具函数库(共用函数)
├── query_jargon.py # 查询jargon工具
├── query_chat_history.py # 查询聊天历史工具
├── query_lpmm_knowledge.py # 查询LPMM知识库工具
└── README.md # 本文件
```
@@ -145,6 +146,11 @@ result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:0
- `time_point` (可选) - 时间点格式YYYY-MM-DD HH:MM:SS用于查询某个时间点附近发生了什么与time_range二选一
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'与time_point二选一
### query_lpmm_knowledge
从LPMM知识库中检索与关键词相关的知识内容
- 参数:
- `query` (必填) - 查询的关键词或问题描述
## 注意事项
- 所有工具函数必须是异步函数(`async def`

View File

@@ -13,13 +13,17 @@ from .tool_registry import (
# 导入所有工具的注册函数
from .query_jargon import register_tool as register_query_jargon
from .query_chat_history import register_tool as register_query_chat_history
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_jargon()
register_query_chat_history()
if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge()
__all__ = [
"MemoryRetrievalTool",

View File

@@ -9,7 +9,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.chat.utils.utils import parse_keywords_string
from .tool_registry import register_memory_retrieval_tool
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
from ..memory_utils import parse_datetime_to_timestamp, parse_time_range
logger = get_logger("memory_retrieval_tools")
@@ -17,7 +17,8 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_range: Optional[str] = None
time_range: Optional[str] = None,
fuzzy: bool = True
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
@@ -27,6 +28,9 @@ async def query_chat_history(
time_range: 时间范围或时间点,格式:
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
fuzzy: 是否使用模糊匹配模式默认True
- True: 模糊匹配只要包含任意一个关键词即匹配OR关系
- False: 全匹配必须包含所有关键词才匹配AND关系
Returns:
str: 查询结果
@@ -62,9 +66,6 @@ async def query_chat_history(
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
if not records:
return "未找到相关聊天记录概述"
# 如果有关键词,进一步过滤
if keyword:
# 解析多个关键词(支持空格、逗号等分隔符)
@@ -96,24 +97,48 @@ async def query_chat_history(
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 检查是否包含任意一个关键词OR关系
# 根据匹配模式检查关键词
matched = False
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
matched = True
break
if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
matched = True
break
else:
# 全匹配必须包含所有关键词才匹配AND关系
matched = True
for kw in keywords_lower:
kw_matched = (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list))
if not kw_matched:
matched = False
break
if matched:
filtered_records.append(record)
if not filtered_records:
keywords_str = "".join(keywords_list)
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
if time_range:
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
else:
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
records = filtered_records
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
if not records:
if time_range:
return "未找到指定时间范围内的聊天记录概述"
else:
return "未找到相关聊天记录概述"
# 对即将返回的记录增加使用计数
records_to_use = records[:3]
@@ -168,12 +193,12 @@ def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_chat_history",
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件或根据关键词搜索消息概述",
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
"required": False
},
{
@@ -181,6 +206,12 @@ def register_tool():
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
},
{
"name": "fuzzy",
"type": "boolean",
"description": "是否使用模糊匹配模式默认True。True表示模糊匹配只要包含任意一个关键词即匹配OR关系False表示全匹配必须包含所有关键词才匹配AND关系",
"required": False
}
],
execute_func=query_chat_history

View File

@@ -0,0 +1,65 @@
"""
通过LPMM知识库查询信息 - 工具实现
"""
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_lpmm_knowledge(query: str) -> str:
"""在LPMM知识库中查询相关信息
Args:
query: 查询关键词
Returns:
str: 查询结果
"""
try:
content = str(query).strip()
if not content:
return "查询关键词为空"
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用")
return "LPMM知识库未启用"
qa_manager = get_qa_manager()
if qa_manager is None:
logger.debug("LPMM知识库未初始化跳过查询")
return "LPMM知识库未初始化"
knowledge_info = await qa_manager.get_knowledge(content)
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
if knowledge_info:
return f"你从LPMM知识库中找到以下信息\n{knowledge_info}"
return f"在LPMM知识库中未找到与“{content}”相关的信息"
except Exception as e:
logger.error(f"LPMM知识库查询失败: {e}")
return f"LPMM知识库查询失败{str(e)}"
def register_tool():
"""注册LPMM知识库查询工具"""
register_memory_retrieval_tool(
name="lpmm_search_knowledge",
description="从LPMM知识库中搜索相关信息适用于需要知识支持的场景。",
parameters=[
{
"name": "query",
"type": "string",
"description": "需要查询的关键词或问题",
"required": True,
}
],
execute_func=query_lpmm_knowledge,
)

View File

@@ -3,8 +3,9 @@
提供统一的工具注册和管理接口
"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from typing import List, Dict, Any, Optional, Callable, Awaitable, Tuple
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType
logger = get_logger("memory_retrieval_tools")
@@ -50,6 +51,48 @@ class MemoryRetrievalTool:
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取工具定义用于LLM function calling
Returns:
Dict[str, Any]: 工具定义字典格式与BaseTool一致
格式: {"name": str, "description": str, "parameters": List[Tuple]}
"""
# 转换参数格式为元组列表格式与BaseTool一致
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
param_tuples = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"int": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
# 构建参数元组
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
return tool_def
class MemoryRetrievalToolRegistry:
@@ -60,6 +103,9 @@ class MemoryRetrievalToolRegistry:
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
@@ -79,11 +125,19 @@ class MemoryRetrievalToolRegistry:
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt"""
"""获取所有动作类型的列表用于prompt(已废弃,保留用于兼容)"""
action_types = [tool.name for tool in self.tools.values()]
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
"""
return [tool.get_tool_definition() for tool in self.tools.values()]
# 全局工具注册器实例

View File

@@ -1,64 +0,0 @@
"""
工具函数库
包含所有工具共用的工具函数
"""
from datetime import datetime
from typing import Tuple
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e:
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串,返回开始和结束时间戳
Args:
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp