feat: Introduce unified tooling system for plugins and MCP
- Added a new `tooling` module to define a unified model for tool declarations, invocations, and execution results, facilitating compatibility between plugins, legacy actions, and MCP tools. - Implemented `ToolProvider` interface for various tool providers including built-in tools, MCP tools, and plugin runtime tools. - Enhanced `MCPManager` and `MCPConnection` to support unified tool invocation and execution results. - Updated `ComponentRegistry` and related classes to accommodate the new tool specifications and descriptions. - Refactored existing components to utilize the new tooling system, ensuring backward compatibility with legacy actions. - Improved error handling and logging for tool invocations across different providers.
This commit is contained in:
335
src/core/tooling.py
Normal file
335
src/core/tooling.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""统一工具抽象。
|
||||
|
||||
该模块定义主程序内部统一使用的工具声明、调用与执行结果模型,
|
||||
用于收敛插件 Tool、兼容旧 Action、MaiSaka 内置 Tool 与 MCP Tool。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Protocol, runtime_checkable
|
||||
import json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||
|
||||
logger = get_logger("core.tooling")
|
||||
|
||||
|
||||
def _normalize_schema_type(raw_type: Any) -> str:
|
||||
"""将原始 Schema 类型值规范化为可读字符串。
|
||||
|
||||
Args:
|
||||
raw_type: 原始类型值。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的类型名称。
|
||||
"""
|
||||
|
||||
normalized_type = str(raw_type or "").strip().lower()
|
||||
if not normalized_type:
|
||||
return "string"
|
||||
if normalized_type == "number":
|
||||
return "number"
|
||||
if normalized_type == "integer":
|
||||
return "integer"
|
||||
if normalized_type == "boolean":
|
||||
return "boolean"
|
||||
if normalized_type == "array":
|
||||
return "array"
|
||||
if normalized_type == "object":
|
||||
return "object"
|
||||
return normalized_type
|
||||
|
||||
|
||||
def build_tool_detailed_description(
|
||||
parameters_schema: Optional[Dict[str, Any]],
|
||||
fallback_description: str = "",
|
||||
) -> str:
|
||||
"""根据参数 Schema 构建工具详细描述。
|
||||
|
||||
Args:
|
||||
parameters_schema: 工具参数对象级 Schema。
|
||||
fallback_description: 无法从 Schema 解析时使用的兜底说明。
|
||||
|
||||
Returns:
|
||||
str: 生成后的详细描述文本。
|
||||
"""
|
||||
|
||||
if not parameters_schema:
|
||||
return fallback_description.strip()
|
||||
|
||||
properties = parameters_schema.get("properties")
|
||||
if not isinstance(properties, dict) or not properties:
|
||||
return fallback_description.strip()
|
||||
|
||||
required_names = {
|
||||
str(name).strip()
|
||||
for name in parameters_schema.get("required", [])
|
||||
if str(name).strip()
|
||||
}
|
||||
|
||||
lines = ["参数说明:"]
|
||||
for parameter_name, parameter_schema in properties.items():
|
||||
if not isinstance(parameter_schema, dict):
|
||||
continue
|
||||
|
||||
normalized_name = str(parameter_name).strip()
|
||||
parameter_type = _normalize_schema_type(parameter_schema.get("type"))
|
||||
required_text = "必填" if normalized_name in required_names else "可选"
|
||||
parameter_description = str(parameter_schema.get("description", "") or "").strip() or "无额外说明"
|
||||
line = f"- {normalized_name}:{parameter_type},{required_text}。{parameter_description}"
|
||||
|
||||
if isinstance(parameter_schema.get("enum"), list) and parameter_schema["enum"]:
|
||||
enum_values = "、".join(str(item) for item in parameter_schema["enum"])
|
||||
line += f" 可选值:{enum_values}。"
|
||||
|
||||
if "default" in parameter_schema:
|
||||
line += f" 默认值:{parameter_schema['default']}。"
|
||||
|
||||
lines.append(line)
|
||||
|
||||
if len(lines) == 1:
|
||||
return fallback_description.strip()
|
||||
|
||||
if fallback_description.strip():
|
||||
lines.append("")
|
||||
lines.append(fallback_description.strip())
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolSpec:
|
||||
"""统一工具声明。"""
|
||||
|
||||
name: str
|
||||
brief_description: str
|
||||
detailed_description: str = ""
|
||||
parameters_schema: Dict[str, Any] | None = None
|
||||
provider_name: str = ""
|
||||
provider_type: str = ""
|
||||
enabled: bool = True
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def build_llm_description(self) -> str:
|
||||
"""构建供 LLM 使用的描述文本。
|
||||
|
||||
Returns:
|
||||
str: 合并后的单段工具描述。
|
||||
"""
|
||||
|
||||
parts = [self.brief_description.strip()]
|
||||
if self.detailed_description.strip():
|
||||
parts.append(self.detailed_description.strip())
|
||||
return "\n\n".join(part for part in parts if part).strip()
|
||||
|
||||
def to_llm_definition(self) -> ToolDefinitionInput:
|
||||
"""转换为统一的 LLM 工具定义。
|
||||
|
||||
Returns:
|
||||
ToolDefinitionInput: 可直接交给模型层的工具定义。
|
||||
"""
|
||||
|
||||
definition: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.build_llm_description(),
|
||||
}
|
||||
if self.parameters_schema is not None:
|
||||
definition["parameters_schema"] = deepcopy(self.parameters_schema)
|
||||
return definition
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolInvocation:
|
||||
"""统一工具调用请求。"""
|
||||
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any] = field(default_factory=dict)
|
||||
call_id: str = ""
|
||||
session_id: str = ""
|
||||
stream_id: str = ""
|
||||
reasoning: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolExecutionContext:
|
||||
"""统一工具执行上下文。"""
|
||||
|
||||
session_id: str = ""
|
||||
stream_id: str = ""
|
||||
reasoning: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolExecutionResult:
|
||||
"""统一工具执行结果。"""
|
||||
|
||||
tool_name: str
|
||||
success: bool
|
||||
content: str = ""
|
||||
error_message: str = ""
|
||||
structured_content: Any = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_history_content(self) -> str:
|
||||
"""获取适合写入对话历史的结果文本。
|
||||
|
||||
Returns:
|
||||
str: 优先使用文本内容,其次使用错误信息。
|
||||
"""
|
||||
|
||||
if self.content.strip():
|
||||
return self.content.strip()
|
||||
if self.structured_content is not None:
|
||||
if isinstance(self.structured_content, str):
|
||||
return self.structured_content.strip()
|
||||
try:
|
||||
return json.dumps(self.structured_content, ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
return str(self.structured_content).strip()
|
||||
return self.error_message.strip()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolProvider(Protocol):
|
||||
"""统一工具提供者协议。"""
|
||||
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""列出当前 Provider 暴露的全部工具。"""
|
||||
...
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行指定工具调用。"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""释放 Provider 资源。"""
|
||||
...
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""统一工具注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._providers: list[ToolProvider] = []
|
||||
|
||||
def register_provider(self, provider: ToolProvider) -> None:
|
||||
"""注册一个工具提供者。
|
||||
|
||||
Args:
|
||||
provider: 待注册的工具提供者。
|
||||
"""
|
||||
|
||||
self._providers = [item for item in self._providers if item.provider_name != provider.provider_name]
|
||||
self._providers.append(provider)
|
||||
|
||||
def unregister_provider(self, provider_name: str) -> None:
|
||||
"""注销指定名称的工具提供者。
|
||||
|
||||
Args:
|
||||
provider_name: 待移除的 Provider 名称。
|
||||
"""
|
||||
|
||||
self._providers = [item for item in self._providers if item.provider_name != provider_name]
|
||||
|
||||
async def list_tools(self) -> list[ToolSpec]:
|
||||
"""按 Provider 顺序列出全部去重后的工具。
|
||||
|
||||
Returns:
|
||||
list[ToolSpec]: 去重后的工具列表。
|
||||
"""
|
||||
|
||||
collected_specs: list[ToolSpec] = []
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
for spec in provider_specs:
|
||||
if not spec.enabled:
|
||||
continue
|
||||
if spec.name in seen_names:
|
||||
logger.warning(
|
||||
f"检测到重复工具名 {spec.name},保留先注册的工具,跳过 provider={provider.provider_name}"
|
||||
)
|
||||
continue
|
||||
seen_names.add(spec.name)
|
||||
collected_specs.append(spec)
|
||||
return collected_specs
|
||||
|
||||
async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]:
|
||||
"""查询指定工具声明。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
|
||||
Returns:
|
||||
Optional[ToolSpec]: 匹配到的工具声明。
|
||||
"""
|
||||
|
||||
for spec in await self.list_tools():
|
||||
if spec.name == tool_name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
async def has_tool(self, tool_name: str) -> bool:
|
||||
"""判断指定工具是否存在。
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称。
|
||||
|
||||
Returns:
|
||||
bool: 是否存在。
|
||||
"""
|
||||
|
||||
return await self.get_tool_spec(tool_name) is not None
|
||||
|
||||
async def get_llm_definitions(self) -> list[ToolDefinitionInput]:
|
||||
"""获取供 LLM 使用的工具定义列表。
|
||||
|
||||
Returns:
|
||||
list[ToolDefinitionInput]: 统一工具定义列表。
|
||||
"""
|
||||
|
||||
return [spec.to_llm_definition() for spec in await self.list_tools()]
|
||||
|
||||
async def invoke(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
context: Optional[ToolExecutionContext] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""执行一次工具调用。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用请求。
|
||||
context: 执行上下文。
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult: 工具执行结果。
|
||||
"""
|
||||
|
||||
for provider in self._providers:
|
||||
provider_specs = await provider.list_tools()
|
||||
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
|
||||
return await provider.invoke(invocation, context)
|
||||
|
||||
return ToolExecutionResult(
|
||||
tool_name=invocation.tool_name,
|
||||
success=False,
|
||||
error_message=f"未找到工具:{invocation.tool_name}",
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭全部 Provider。"""
|
||||
|
||||
for provider in self._providers:
|
||||
await provider.close()
|
||||
Reference in New Issue
Block a user