448 lines
14 KiB
Python
448 lines
14 KiB
Python
"""统一工具抽象。
|
||
|
||
该模块定义主程序内部统一使用的工具声明、调用与执行结果模型,
|
||
用于收敛插件 Tool、兼容旧 Action、MaiSaka 内置 Tool 与 MCP Tool。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from copy import deepcopy
|
||
from dataclasses import dataclass, field
|
||
import json
|
||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable
|
||
|
||
from src.common.logger import get_logger
|
||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||
|
||
logger = get_logger("core.tooling")
|
||
|
||
|
||
def _normalize_schema_type(raw_type: Any) -> str:
|
||
"""将原始 Schema 类型值规范化为可读字符串。
|
||
|
||
Args:
|
||
raw_type: 原始类型值。
|
||
|
||
Returns:
|
||
str: 规范化后的类型名称。
|
||
"""
|
||
|
||
normalized_type = str(raw_type or "").strip().lower()
|
||
if not normalized_type:
|
||
return "string"
|
||
if normalized_type == "number":
|
||
return "number"
|
||
if normalized_type == "integer":
|
||
return "integer"
|
||
if normalized_type == "boolean":
|
||
return "boolean"
|
||
if normalized_type == "array":
|
||
return "array"
|
||
if normalized_type == "object":
|
||
return "object"
|
||
return normalized_type
|
||
|
||
|
||
def build_tool_detailed_description(
|
||
parameters_schema: Optional[Dict[str, Any]],
|
||
fallback_description: str = "",
|
||
) -> str:
|
||
"""根据参数 Schema 构建工具详细描述。
|
||
|
||
Args:
|
||
parameters_schema: 工具参数对象级 Schema。
|
||
fallback_description: 无法从 Schema 解析时使用的兜底说明。
|
||
|
||
Returns:
|
||
str: 生成后的详细描述文本。
|
||
"""
|
||
|
||
if not parameters_schema:
|
||
return fallback_description.strip()
|
||
|
||
properties = parameters_schema.get("properties")
|
||
if not isinstance(properties, dict) or not properties:
|
||
return fallback_description.strip()
|
||
|
||
required_names = {
|
||
str(name).strip()
|
||
for name in parameters_schema.get("required", [])
|
||
if str(name).strip()
|
||
}
|
||
|
||
lines = ["参数说明:"]
|
||
for parameter_name, parameter_schema in properties.items():
|
||
if not isinstance(parameter_schema, dict):
|
||
continue
|
||
|
||
normalized_name = str(parameter_name).strip()
|
||
parameter_type = _normalize_schema_type(parameter_schema.get("type"))
|
||
required_text = "必填" if normalized_name in required_names else "可选"
|
||
parameter_description = str(parameter_schema.get("description", "") or "").strip() or "无额外说明"
|
||
line = f"- {normalized_name}:{parameter_type},{required_text}。{parameter_description}"
|
||
|
||
if isinstance(parameter_schema.get("enum"), list) and parameter_schema["enum"]:
|
||
enum_values = "、".join(str(item) for item in parameter_schema["enum"])
|
||
line += f" 可选值:{enum_values}。"
|
||
|
||
if "default" in parameter_schema:
|
||
line += f" 默认值:{parameter_schema['default']}。"
|
||
|
||
lines.append(line)
|
||
|
||
if len(lines) == 1:
|
||
return fallback_description.strip()
|
||
|
||
if fallback_description.strip():
|
||
lines.append("")
|
||
lines.append(fallback_description.strip())
|
||
return "\n".join(lines).strip()
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolIcon:
|
||
"""统一工具图标信息。"""
|
||
|
||
src: str
|
||
mime_type: str = ""
|
||
sizes: list[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolAnnotation:
|
||
"""统一工具注解信息。"""
|
||
|
||
audience: list[str] = field(default_factory=list)
|
||
priority: float | None = None
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolContentItem:
|
||
"""统一工具内容项。"""
|
||
|
||
content_type: Literal["text", "image", "audio", "resource_link", "resource", "binary", "unknown"]
|
||
text: str = ""
|
||
data: str = ""
|
||
mime_type: str = ""
|
||
uri: str = ""
|
||
name: str = ""
|
||
description: str = ""
|
||
annotation: ToolAnnotation | None = None
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
def build_history_text(self) -> str:
|
||
"""生成适合写入历史消息的文本摘要。
|
||
|
||
Returns:
|
||
str: 当前内容项对应的历史摘要文本。
|
||
"""
|
||
|
||
if self.content_type == "text" and self.text.strip():
|
||
return self.text.strip()
|
||
if self.content_type == "image":
|
||
return f"[图片内容 {self.mime_type or 'unknown'}]"
|
||
if self.content_type == "audio":
|
||
return f"[音频内容 {self.mime_type or 'unknown'}]"
|
||
if self.content_type == "resource_link":
|
||
label = self.name or self.uri or "资源链接"
|
||
return f"[资源链接] {label}"
|
||
if self.content_type == "resource":
|
||
if self.text.strip():
|
||
return self.text.strip()
|
||
label = self.name or self.uri or "嵌入资源"
|
||
return f"[嵌入资源] {label}"
|
||
if self.content_type == "binary":
|
||
return f"[二进制内容 {self.mime_type or 'unknown'}]"
|
||
return f"[{self.content_type} 内容]"
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolSpec:
|
||
"""统一工具声明。"""
|
||
|
||
name: str
|
||
brief_description: str
|
||
detailed_description: str = ""
|
||
title: str = ""
|
||
parameters_schema: Dict[str, Any] | None = None
|
||
output_schema: Dict[str, Any] | None = None
|
||
provider_name: str = ""
|
||
provider_type: str = ""
|
||
enabled: bool = True
|
||
icons: list[ToolIcon] = field(default_factory=list)
|
||
annotation: ToolAnnotation | None = None
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
def build_llm_description(self) -> str:
|
||
"""构建供 LLM 使用的描述文本。
|
||
|
||
Returns:
|
||
str: 合并后的单段工具描述。
|
||
"""
|
||
|
||
return self.brief_description.strip()
|
||
|
||
def to_llm_definition(self) -> ToolDefinitionInput:
|
||
"""转换为统一的 LLM 工具定义。
|
||
|
||
Returns:
|
||
ToolDefinitionInput: 可直接交给模型层的工具定义。
|
||
"""
|
||
|
||
definition: Dict[str, Any] = {
|
||
"name": self.name,
|
||
"description": self.build_llm_description(),
|
||
}
|
||
if self.parameters_schema is not None:
|
||
definition["parameters_schema"] = deepcopy(self.parameters_schema)
|
||
return definition
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolInvocation:
|
||
"""统一工具调用请求。"""
|
||
|
||
tool_name: str
|
||
arguments: Dict[str, Any] = field(default_factory=dict)
|
||
call_id: str = ""
|
||
session_id: str = ""
|
||
stream_id: str = ""
|
||
reasoning: str = ""
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolExecutionContext:
|
||
"""统一工具执行上下文。"""
|
||
|
||
session_id: str = ""
|
||
stream_id: str = ""
|
||
reasoning: str = ""
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolAvailabilityContext:
|
||
"""工具暴露可用性判断上下文。"""
|
||
|
||
session_id: str = ""
|
||
stream_id: str = ""
|
||
is_group_chat: bool | None = None
|
||
group_id: str = ""
|
||
user_id: str = ""
|
||
platform: str = ""
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ToolExecutionResult:
|
||
"""统一工具执行结果。"""
|
||
|
||
tool_name: str
|
||
success: bool
|
||
content: str = ""
|
||
error_message: str = ""
|
||
structured_content: Any = None
|
||
content_items: list[ToolContentItem] = field(default_factory=list)
|
||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
def get_history_content(self) -> str:
|
||
"""获取适合写入对话历史的结果文本。
|
||
|
||
Returns:
|
||
str: 优先使用文本内容,其次使用错误信息。
|
||
"""
|
||
|
||
if self.content.strip():
|
||
return self.content.strip()
|
||
if self.content_items:
|
||
parts = [item.build_history_text() for item in self.content_items if item.build_history_text().strip()]
|
||
if parts:
|
||
return "\n".join(parts).strip()
|
||
if self.structured_content is not None:
|
||
if isinstance(self.structured_content, str):
|
||
return self.structured_content.strip()
|
||
try:
|
||
return json.dumps(self.structured_content, ensure_ascii=False)
|
||
except (TypeError, ValueError):
|
||
return str(self.structured_content).strip()
|
||
return self.error_message.strip()
|
||
|
||
|
||
@runtime_checkable
|
||
class ToolProvider(Protocol):
|
||
"""统一工具提供者协议。"""
|
||
|
||
provider_name: str
|
||
provider_type: str
|
||
|
||
async def list_tools(
|
||
self,
|
||
context: Optional[ToolAvailabilityContext] = None,
|
||
) -> list[ToolSpec]:
|
||
"""列出当前 Provider 暴露的全部工具。"""
|
||
...
|
||
|
||
async def invoke(
|
||
self,
|
||
invocation: ToolInvocation,
|
||
context: Optional[ToolExecutionContext] = None,
|
||
) -> ToolExecutionResult:
|
||
"""执行指定工具调用。"""
|
||
...
|
||
|
||
async def close(self) -> None:
|
||
"""释放 Provider 资源。"""
|
||
...
|
||
|
||
|
||
class ToolRegistry:
|
||
"""统一工具注册表。"""
|
||
|
||
def __init__(self) -> None:
|
||
"""初始化统一工具注册表。"""
|
||
|
||
self._providers: list[ToolProvider] = []
|
||
|
||
def register_provider(self, provider: ToolProvider) -> None:
|
||
"""注册一个工具提供者。
|
||
|
||
Args:
|
||
provider: 待注册的工具提供者。
|
||
"""
|
||
|
||
self._providers = [item for item in self._providers if item.provider_name != provider.provider_name]
|
||
self._providers.append(provider)
|
||
|
||
def unregister_provider(self, provider_name: str) -> None:
|
||
"""注销指定名称的工具提供者。
|
||
|
||
Args:
|
||
provider_name: 待移除的 Provider 名称。
|
||
"""
|
||
|
||
self._providers = [item for item in self._providers if item.provider_name != provider_name]
|
||
|
||
async def list_tools(
|
||
self,
|
||
context: Optional[ToolAvailabilityContext] = None,
|
||
) -> list[ToolSpec]:
|
||
"""按 Provider 顺序列出全部去重后的工具。
|
||
|
||
Returns:
|
||
list[ToolSpec]: 去重后的工具列表。
|
||
"""
|
||
|
||
collected_specs: list[ToolSpec] = []
|
||
seen_names: set[str] = set()
|
||
|
||
for provider in self._providers:
|
||
provider_specs = await provider.list_tools(context)
|
||
for spec in provider_specs:
|
||
if not spec.enabled:
|
||
continue
|
||
if spec.name in seen_names:
|
||
logger.warning(
|
||
f"检测到重复工具名 {spec.name},保留先注册的工具,跳过 provider={provider.provider_name}"
|
||
)
|
||
continue
|
||
seen_names.add(spec.name)
|
||
collected_specs.append(spec)
|
||
return collected_specs
|
||
|
||
async def get_tool_spec(
|
||
self,
|
||
tool_name: str,
|
||
context: Optional[ToolAvailabilityContext] = None,
|
||
) -> Optional[ToolSpec]:
|
||
"""查询指定工具声明。
|
||
|
||
Args:
|
||
tool_name: 工具名称。
|
||
|
||
Returns:
|
||
Optional[ToolSpec]: 匹配到的工具声明。
|
||
"""
|
||
|
||
for spec in await self.list_tools(context):
|
||
if spec.name == tool_name:
|
||
return spec
|
||
return None
|
||
|
||
async def has_tool(
|
||
self,
|
||
tool_name: str,
|
||
context: Optional[ToolAvailabilityContext] = None,
|
||
) -> bool:
|
||
"""判断指定工具是否存在。
|
||
|
||
Args:
|
||
tool_name: 工具名称。
|
||
|
||
Returns:
|
||
bool: 是否存在。
|
||
"""
|
||
|
||
return await self.get_tool_spec(tool_name, context) is not None
|
||
|
||
async def get_llm_definitions(
|
||
self,
|
||
context: Optional[ToolAvailabilityContext] = None,
|
||
) -> list[ToolDefinitionInput]:
|
||
"""获取供 LLM 使用的工具定义列表。
|
||
|
||
Returns:
|
||
list[ToolDefinitionInput]: 统一工具定义列表。
|
||
"""
|
||
|
||
return [spec.to_llm_definition() for spec in await self.list_tools(context)]
|
||
|
||
async def invoke(
|
||
self,
|
||
invocation: ToolInvocation,
|
||
context: Optional[ToolExecutionContext] = None,
|
||
) -> ToolExecutionResult:
|
||
"""执行一次工具调用。
|
||
|
||
Args:
|
||
invocation: 工具调用请求。
|
||
context: 执行上下文。
|
||
|
||
Returns:
|
||
ToolExecutionResult: 工具执行结果。
|
||
"""
|
||
|
||
for provider in self._providers:
|
||
provider_specs = await provider.list_tools()
|
||
if any(spec.name == invocation.tool_name and spec.enabled for spec in provider_specs):
|
||
try:
|
||
return await provider.invoke(invocation, context)
|
||
except Exception as exc:
|
||
logger.exception(
|
||
"工具调用异常: tool=%s provider=%s",
|
||
invocation.tool_name,
|
||
getattr(provider, "provider_name", ""),
|
||
)
|
||
error_message = str(exc).strip()
|
||
if error_message:
|
||
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}: {error_message}"
|
||
else:
|
||
error_message = f"工具 {invocation.tool_name} 调用失败:{exc.__class__.__name__}"
|
||
return ToolExecutionResult(
|
||
tool_name=invocation.tool_name,
|
||
success=False,
|
||
error_message=error_message,
|
||
)
|
||
|
||
return ToolExecutionResult(
|
||
tool_name=invocation.tool_name,
|
||
success=False,
|
||
error_message=f"未找到工具:{invocation.tool_name}",
|
||
)
|
||
|
||
async def close(self) -> None:
|
||
"""关闭全部 Provider。"""
|
||
|
||
for provider in self._providers:
|
||
await provider.close()
|