feat:lpmm可选接入memory agent,将memory agent改为标准工具格式,修改llm_utils以兼容

This commit is contained in:
SengokuCola
2025-11-13 18:55:37 +08:00
parent e52a81e90b
commit f2819be5e9
18 changed files with 868 additions and 432 deletions

View File

@@ -3,8 +3,9 @@
提供统一的工具注册和管理接口
"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from typing import List, Dict, Any, Optional, Callable, Awaitable, Tuple
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType
logger = get_logger("memory_retrieval_tools")
@@ -50,6 +51,48 @@ class MemoryRetrievalTool:
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取工具定义用于LLM function calling
Returns:
Dict[str, Any]: 工具定义字典格式与BaseTool一致
格式: {"name": str, "description": str, "parameters": List[Tuple]}
"""
# 转换参数格式为元组列表格式与BaseTool一致
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
param_tuples = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"int": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
# 构建参数元组
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
return tool_def
class MemoryRetrievalToolRegistry:
@@ -60,6 +103,9 @@ class MemoryRetrievalToolRegistry:
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
@@ -79,11 +125,19 @@ class MemoryRetrievalToolRegistry:
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt"""
"""获取所有动作类型的列表用于prompt(已废弃,保留用于兼容)"""
action_types = [tool.name for tool in self.tools.values()]
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
"""
return [tool.get_tool_definition() for tool in self.tools.values()]
# 全局工具注册器实例