Files
mai-bot/src/core/tooling.py
DrSmoothl dc2bf02a42 feat: Introduce unified tooling system for plugins and MCP
- Added a new `tooling` module to define a unified model for tool declarations, invocations, and execution results, facilitating compatibility between plugins, legacy actions, and MCP tools.
- Implemented `ToolProvider` interface for various tool providers including built-in tools, MCP tools, and plugin runtime tools.
- Enhanced `MCPManager` and `MCPConnection` to support unified tool invocation and execution results.
- Updated `ComponentRegistry` and related classes to accommodate the new tool specifications and descriptions.
- Refactored existing components to utilize the new tooling system, ensuring backward compatibility with legacy actions.
- Improved error handling and logging for tool invocations across different providers.
2026-03-30 23:11:56 +08:00

336 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""统一工具抽象。
该模块定义主程序内部统一使用的工具声明、调用与执行结果模型,
用于收敛插件 Tool、兼容旧 Action、MaiSaka 内置 Tool 与 MCP Tool。
"""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Protocol, runtime_checkable
import json
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
logger = get_logger("core.tooling")
def _normalize_schema_type(raw_type: Any) -> str:
"""将原始 Schema 类型值规范化为可读字符串。
Args:
raw_type: 原始类型值。
Returns:
str: 规范化后的类型名称。
"""
normalized_type = str(raw_type or "").strip().lower()
if not normalized_type:
return "string"
if normalized_type == "number":
return "number"
if normalized_type == "integer":
return "integer"
if normalized_type == "boolean":
return "boolean"
if normalized_type == "array":
return "array"
if normalized_type == "object":
return "object"
return normalized_type
def build_tool_detailed_description(
parameters_schema: Optional[Dict[str, Any]],
fallback_description: str = "",
) -> str:
"""根据参数 Schema 构建工具详细描述。
Args:
parameters_schema: 工具参数对象级 Schema。
fallback_description: 无法从 Schema 解析时使用的兜底说明。
Returns:
str: 生成后的详细描述文本。
"""
if not parameters_schema:
return fallback_description.strip()
properties = parameters_schema.get("properties")
if not isinstance(properties, dict) or not properties:
return fallback_description.strip()
required_names = {
str(name).strip()
for name in parameters_schema.get("required", [])
if str(name).strip()
}
lines = ["参数说明:"]
for parameter_name, parameter_schema in properties.items():
if not isinstance(parameter_schema, dict):
continue
normalized_name = str(parameter_name).strip()
parameter_type = _normalize_schema_type(parameter_schema.get("type"))
required_text = "必填" if normalized_name in required_names else "可选"
parameter_description = str(parameter_schema.get("description", "") or "").strip() or "无额外说明"
line = f"- {normalized_name}{parameter_type}{required_text}{parameter_description}"
if isinstance(parameter_schema.get("enum"), list) and parameter_schema["enum"]:
enum_values = "".join(str(item) for item in parameter_schema["enum"])
line += f" 可选值:{enum_values}"
if "default" in parameter_schema:
line += f" 默认值:{parameter_schema['default']}"
lines.append(line)
if len(lines) == 1:
return fallback_description.strip()
if fallback_description.strip():
lines.append("")
lines.append(fallback_description.strip())
return "\n".join(lines).strip()
@dataclass(slots=True)
class ToolSpec:
"""统一工具声明。"""
name: str
brief_description: str
detailed_description: str = ""
parameters_schema: Dict[str, Any] | None = None
provider_name: str = ""
provider_type: str = ""
enabled: bool = True
metadata: Dict[str, Any] = field(default_factory=dict)
def build_llm_description(self) -> str:
"""构建供 LLM 使用的描述文本。
Returns:
str: 合并后的单段工具描述。
"""
parts = [self.brief_description.strip()]
if self.detailed_description.strip():
parts.append(self.detailed_description.strip())
return "\n\n".join(part for part in parts if part).strip()
def to_llm_definition(self) -> ToolDefinitionInput:
"""转换为统一的 LLM 工具定义。
Returns:
ToolDefinitionInput: 可直接交给模型层的工具定义。
"""
definition: Dict[str, Any] = {
"name": self.name,
"description": self.build_llm_description(),
}
if self.parameters_schema is not None:
definition["parameters_schema"] = deepcopy(self.parameters_schema)
return definition
@dataclass(slots=True)
class ToolInvocation:
"""统一工具调用请求。"""
tool_name: str
arguments: Dict[str, Any] = field(default_factory=dict)
call_id: str = ""
session_id: str = ""
stream_id: str = ""
reasoning: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class ToolExecutionContext:
"""统一工具执行上下文。"""
session_id: str = ""
stream_id: str = ""
reasoning: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class ToolExecutionResult:
"""统一工具执行结果。"""
tool_name: str
success: bool
content: str = ""
error_message: str = ""
structured_content: Any = None
metadata: Dict[str, Any] = field(default_factory=dict)
def get_history_content(self) -> str:
"""获取适合写入对话历史的结果文本。
Returns:
str: 优先使用文本内容,其次使用错误信息。
"""
if self.content.strip():
return self.content.strip()
if self.structured_content is not None:
if isinstance(self.structured_content, str):
return self.structured_content.strip()
try:
return json.dumps(self.structured_content, ensure_ascii=False)
except (TypeError, ValueError):
return str(self.structured_content).strip()
return self.error_message.strip()
@runtime_checkable
class ToolProvider(Protocol):
"""统一工具提供者协议。"""
provider_name: str
provider_type: str
async def list_tools(self) -> list[ToolSpec]:
"""列出当前 Provider 暴露的全部工具。"""
...
async def invoke(
self,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行指定工具调用。"""
...
async def close(self) -> None:
"""释放 Provider 资源。"""
...
class ToolRegistry:
"""统一工具注册表。"""
def __init__(self) -> None:
self._providers: list[ToolProvider] = []
def register_provider(self, provider: ToolProvider) -> None:
"""注册一个工具提供者。
Args:
provider: 待注册的工具提供者。
"""
self._providers = [item for item in self._providers if item.provider_name != provider.provider_name]
self._providers.append(provider)
def unregister_provider(self, provider_name: str) -> None:
"""注销指定名称的工具提供者。
Args:
provider_name: 待移除的 Provider 名称。
"""
self._providers = [item for item in self._providers if item.provider_name != provider_name]
async def list_tools(self) -> list[ToolSpec]:
"""按 Provider 顺序列出全部去重后的工具。
Returns:
list[ToolSpec]: 去重后的工具列表。
"""
collected_specs: list[ToolSpec] = []
seen_names: set[str] = set()
for provider in self._providers:
provider_specs = await provider.list_tools()
for spec in provider_specs:
if not spec.enabled:
continue
if spec.name in seen_names:
logger.warning(
f"检测到重复工具名 {spec.name},保留先注册的工具,跳过 provider={provider.provider_name}"
)
continue
seen_names.add(spec.name)
collected_specs.append(spec)
return collected_specs
async def get_tool_spec(self, tool_name: str) -> Optional[ToolSpec]:
"""查询指定工具声明。
Args:
tool_name: 工具名称。
Returns:
Optional[ToolSpec]: 匹配到的工具声明。
"""
for spec in await self.list_tools():
if spec.name == tool_name:
return spec
return None
async def has_tool(self, tool_name: str) -> bool:
"""判断指定工具是否存在。
Args:
tool_name: 工具名称。
Returns:
bool: 是否存在。
"""
return await self.get_tool_spec(tool_name) is not None
async def get_llm_definitions(self) -> list[ToolDefinitionInput]:
"""获取供 LLM 使用的工具定义列表。
Returns:
list[ToolDefinitionInput]: 统一工具定义列表。
"""
return [spec.to_llm_definition() for spec in await self.list_tools()]
async def invoke(
self,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行一次工具调用。
Args:
invocation: 工具调用请求。
context: 执行上下文。
Returns:
ToolExecutionResult: 工具执行结果。
"""
for provider in self._providers:
provider_specs = await provider.list_tools()
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
return await provider.invoke(invocation, context)
return ToolExecutionResult(
tool_name=invocation.tool_name,
success=False,
error_message=f"未找到工具:{invocation.tool_name}",
)
async def close(self) -> None:
"""关闭全部 Provider。"""
for provider in self._providers:
await provider.close()