🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -10,41 +10,39 @@ logger = get_module_logger("base_tool")
|
||||
# 工具注册表
|
||||
TOOL_REGISTRY = {}
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> Dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.parameters
|
||||
}
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
message_txt: 原始消息文本
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
@@ -53,17 +51,17 @@ class BaseTool:
|
||||
|
||||
def register_tool(tool_class: Type[BaseTool]):
|
||||
"""注册工具到全局注册表
|
||||
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
"""
|
||||
if not issubclass(tool_class, BaseTool):
|
||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||
|
||||
|
||||
tool_name = tool_class.name
|
||||
if not tool_name:
|
||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||
|
||||
|
||||
TOOL_REGISTRY[tool_name] = tool_class
|
||||
logger.info(f"已注册工具: {tool_name}")
|
||||
|
||||
@@ -73,27 +71,27 @@ def discover_tools():
|
||||
# 获取当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_name = os.path.basename(current_dir)
|
||||
|
||||
|
||||
# 遍历包中的所有模块
|
||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||
# 跳过当前模块和__pycache__
|
||||
if module_name == "base_tool" or module_name.startswith("__"):
|
||||
continue
|
||||
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(f"src.do_tool.{package_name}.{module_name}")
|
||||
|
||||
|
||||
# 查找模块中的工具类
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
register_tool(obj)
|
||||
|
||||
|
||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||
|
||||
|
||||
def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict]: 工具定义列表
|
||||
"""
|
||||
@@ -102,14 +100,14 @@ def get_all_tool_definitions() -> List[Dict[str, Any]]:
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取指定名称的工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||
"""
|
||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
return tool_class()
|
||||
|
||||
Reference in New Issue
Block a user