"""统一工具抽象。 该模块定义主程序内部统一使用的工具声明、调用与执行结果模型, 用于收敛插件 Tool、兼容旧 Action、MaiSaka 内置 Tool 与 MCP Tool。 """ from __future__ import annotations from copy import deepcopy from dataclasses import dataclass, field import json from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable from src.common.logger import get_logger from src.llm_models.payload_content.tool_option import ToolDefinitionInput logger = get_logger("core.tooling") def _normalize_schema_type(raw_type: Any) -> str: """将原始 Schema 类型值规范化为可读字符串。 Args: raw_type: 原始类型值。 Returns: str: 规范化后的类型名称。 """ normalized_type = str(raw_type or "").strip().lower() if not normalized_type: return "string" if normalized_type == "number": return "number" if normalized_type == "integer": return "integer" if normalized_type == "boolean": return "boolean" if normalized_type == "array": return "array" if normalized_type == "object": return "object" return normalized_type def build_tool_detailed_description( parameters_schema: Optional[Dict[str, Any]], fallback_description: str = "", ) -> str: """根据参数 Schema 构建工具详细描述。 Args: parameters_schema: 工具参数对象级 Schema。 fallback_description: 无法从 Schema 解析时使用的兜底说明。 Returns: str: 生成后的详细描述文本。 """ if not parameters_schema: return fallback_description.strip() properties = parameters_schema.get("properties") if not isinstance(properties, dict) or not properties: return fallback_description.strip() required_names = { str(name).strip() for name in parameters_schema.get("required", []) if str(name).strip() } lines = ["参数说明:"] for parameter_name, parameter_schema in properties.items(): if not isinstance(parameter_schema, dict): continue normalized_name = str(parameter_name).strip() parameter_type = _normalize_schema_type(parameter_schema.get("type")) required_text = "必填" if normalized_name in required_names else "可选" parameter_description = str(parameter_schema.get("description", "") or "").strip() or "无额外说明" line = f"- {normalized_name}:{parameter_type},{required_text}。{parameter_description}" if isinstance(parameter_schema.get("enum"), list) and parameter_schema["enum"]: enum_values = "、".join(str(item) for item in parameter_schema["enum"]) line += f" 可选值:{enum_values}。" if "default" in parameter_schema: line += f" 默认值:{parameter_schema['default']}。" lines.append(line) if len(lines) == 1: return fallback_description.strip() if fallback_description.strip(): lines.append("") lines.append(fallback_description.strip()) return "\n".join(lines).strip() @dataclass(slots=True) class ToolIcon: """统一工具图标信息。""" src: str mime_type: str = "" sizes: list[str] = field(default_factory=list) @dataclass(slots=True) class ToolAnnotation: """统一工具注解信息。""" audience: list[str] = field(default_factory=list) priority: float | None = None metadata: Dict[str, Any] = field(default_factory=dict) @dataclass(slots=True) class ToolContentItem: """统一工具内容项。""" content_type: Literal["text", "image", "audio", "resource_link", "resource", "binary", "unknown"] text: str = "" data: str = "" mime_type: str = "" uri: str = "" name: str = "" description: str = "" annotation: ToolAnnotation | None = None metadata: Dict[str, Any] = field(default_factory=dict) def build_history_text(self) -> str: """生成适合写入历史消息的文本摘要。 Returns: str: 当前内容项对应的历史摘要文本。 """ if self.content_type == "text" and self.text.strip(): return self.text.strip() if self.content_type == "image": return f"[图片内容 {self.mime_type or 'unknown'}]" if self.content_type == "audio": return f"[音频内容 {self.mime_type or 'unknown'}]" if self.content_type == "resource_link": label = self.name or self.uri or "资源链接" return f"[资源链接] {label}" if self.content_type == "resource": if self.text.strip(): return self.text.strip() label = self.name or self.uri or "嵌入资源" return f"[嵌入资源] {label}" if self.content_type == "binary": return f"[二进制内容 {self.mime_type or 'unknown'}]" return f"[{self.content_type} 内容]" @dataclass(slots=True) class ToolSpec: """统一工具声明。""" name: str brief_description: str detailed_description: str = "" title: str = "" parameters_schema: Dict[str, Any] | None = None output_schema: Dict[str, Any] | None = None provider_name: str = "" provider_type: str = "" enabled: bool = True icons: list[ToolIcon] = field(default_factory=list) annotation: ToolAnnotation | None = None metadata: Dict[str, Any] = field(default_factory=dict) def build_llm_description(self) -> str: """构建供 LLM 使用的描述文本。 Returns: str: 合并后的单段工具描述。 """ return self.brief_description.strip() def to_llm_definition(self) -> ToolDefinitionInput: """转换为统一的 LLM 工具定义。 Returns: ToolDefinitionInput: 可直接交给模型层的工具定义。 """ definition: Dict[str, Any] = { "name": self.name, "description": self.build_llm_description(), } if self.parameters_schema is not None: definition["parameters_schema"] = deepcopy(self.parameters_schema) return definition @dataclass(slots=True) class ToolInvocation: """统一工具调用请求。""" tool_name: str arguments: Dict[str, Any] = field(default_factory=dict) call_id: str = "" session_id: str = "" stream_id: str = "" reasoning: str = "" metadata: Dict[str, Any] = field(default_factory=dict) @dataclass(slots=True) class ToolExecutionContext: """统一工具执行上下文。""" session_id: str = "" stream_id: str = "" reasoning: str = "" metadata: Dict[str, Any] = field(default_factory=dict) @dataclass(slots=True) class ToolAvailabilityContext: """工具暴露可用性判断上下文。""" session_id: str = "" stream_id: str = "" is_group_chat: bool | None = None group_id: str = "" user_id: str = "" platform: str = "" @dataclass(slots=True) class ToolExecutionResult: """统一工具执行结果。""" tool_name: str success: bool content: str = "" error_message: str = "" structured_content: Any = None content_items: list[ToolContentItem] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) def get_history_content(self) -> str: """获取适合写入对话历史的结果文本。 Returns: str: 优先使用文本内容,其次使用错误信息。 """ if self.content.strip(): return self.content.strip() if self.content_items: parts = [item.build_history_text() for item in self.content_items if item.build_history_text().strip()] if parts: return "\n".join(parts).strip() if self.structured_content is not None: if isinstance(self.structured_content, str): return self.structured_content.strip() try: return json.dumps(self.structured_content, ensure_ascii=False) except (TypeError, ValueError): return str(self.structured_content).strip() return self.error_message.strip() @runtime_checkable class ToolProvider(Protocol): """统一工具提供者协议。""" provider_name: str provider_type: str async def list_tools( self, context: Optional[ToolAvailabilityContext] = None, ) -> list[ToolSpec]: """列出当前 Provider 暴露的全部工具。""" ... async def invoke( self, invocation: ToolInvocation, context: Optional[ToolExecutionContext] = None, ) -> ToolExecutionResult: """执行指定工具调用。""" ... async def close(self) -> None: """释放 Provider 资源。""" ... class ToolRegistry: """统一工具注册表。""" def __init__(self) -> None: """初始化统一工具注册表。""" self._providers: list[ToolProvider] = [] def register_provider(self, provider: ToolProvider) -> None: """注册一个工具提供者。 Args: provider: 待注册的工具提供者。 """ self._providers = [item for item in self._providers if item.provider_name != provider.provider_name] self._providers.append(provider) def unregister_provider(self, provider_name: str) -> None: """注销指定名称的工具提供者。 Args: provider_name: 待移除的 Provider 名称。 """ self._providers = [item for item in self._providers if item.provider_name != provider_name] async def list_tools( self, context: Optional[ToolAvailabilityContext] = None, ) -> list[ToolSpec]: """按 Provider 顺序列出全部去重后的工具。 Returns: list[ToolSpec]: 去重后的工具列表。 """ collected_specs: list[ToolSpec] = [] seen_names: set[str] = set() for provider in self._providers: provider_specs = await provider.list_tools(context) for spec in provider_specs: if not spec.enabled: continue if spec.name in seen_names: logger.warning( f"检测到重复工具名 {spec.name},保留先注册的工具,跳过 provider={provider.provider_name}" ) continue seen_names.add(spec.name) collected_specs.append(spec) return collected_specs async def get_tool_spec( self, tool_name: str, context: Optional[ToolAvailabilityContext] = None, ) -> Optional[ToolSpec]: """查询指定工具声明。 Args: tool_name: 工具名称。 Returns: Optional[ToolSpec]: 匹配到的工具声明。 """ for spec in await self.list_tools(context): if spec.name == tool_name: return spec return None async def has_tool( self, tool_name: str, context: Optional[ToolAvailabilityContext] = None, ) -> bool: """判断指定工具是否存在。 Args: tool_name: 工具名称。 Returns: bool: 是否存在。 """ return await self.get_tool_spec(tool_name, context) is not None async def get_llm_definitions( self, context: Optional[ToolAvailabilityContext] = None, ) -> list[ToolDefinitionInput]: """获取供 LLM 使用的工具定义列表。 Returns: list[ToolDefinitionInput]: 统一工具定义列表。 """ return [spec.to_llm_definition() for spec in await self.list_tools(context)] async def invoke( self, invocation: ToolInvocation, context: Optional[ToolExecutionContext] = None, ) -> ToolExecutionResult: """执行一次工具调用。 Args: invocation: 工具调用请求。 context: 执行上下文。 Returns: ToolExecutionResult: 工具执行结果。 """ for provider in self._providers: provider_specs = await provider.list_tools() if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs): try: return await provider.invoke(invocation, context) except Exception as exc: logger.exception( "工具调用异常: tool=%s provider=%s", invocation.tool_name, getattr(provider, "provider_name", ""), ) error_message = str(exc).strip() if error_message: error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}: {error_message}" else: error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}" return ToolExecutionResult( tool_name=invocation.tool_name, success=False, error_message=error_message, ) return ToolExecutionResult( tool_name=invocation.tool_name, success=False, error_message=f"未找到工具:{invocation.tool_name}", ) async def close(self) -> None: """关闭全部 Provider。""" for provider in self._providers: await provider.close()