This commit is contained in:
SengokuCola
2025-11-13 19:00:59 +08:00
46 changed files with 1000 additions and 1041 deletions

View File

@@ -12,17 +12,13 @@ logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
@@ -33,7 +29,7 @@ class MemoryRetrievalTool:
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
@@ -44,10 +40,10 @@ class MemoryRetrievalTool:
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
@@ -97,10 +93,10 @@ class MemoryRetrievalTool:
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
@@ -108,22 +104,22 @@ class MemoryRetrievalToolRegistry:
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()]
@@ -145,13 +141,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
@@ -165,4 +158,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry