Ruff Format

This commit is contained in:
DrSmoothl
2026-02-21 16:24:24 +08:00
parent 2cb512120b
commit eaef7f0e98
82 changed files with 1881 additions and 1900 deletions

View File

@@ -1,2 +1 @@
"""Core helpers for MCP Bridge Plugin."""

View File

@@ -167,4 +167,3 @@ def legacy_servers_list_to_claude_config(servers_list_json: str) -> str:
if not mcp_servers:
return ""
return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -22,21 +22,24 @@ from typing import Any, Dict, List, Optional, Tuple
try:
from src.common.logger import get_logger
logger = get_logger("mcp_tool_chain")
except ImportError:
import logging
logger = logging.getLogger("mcp_tool_chain")
@dataclass
class ToolChainStep:
"""工具链步骤"""
tool_name: str # 要调用的工具名(如 mcp_server_tool
args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换
output_key: str = "" # 输出存储的键名,供后续步骤引用
description: str = "" # 步骤描述
optional: bool = False # 是否可选(失败时继续执行)
def to_dict(self) -> Dict[str, Any]:
return {
"tool_name": self.tool_name,
@@ -45,7 +48,7 @@ class ToolChainStep:
"description": self.description,
"optional": self.optional,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep":
return cls(
@@ -60,12 +63,13 @@ class ToolChainStep:
@dataclass
class ToolChainDefinition:
"""工具链定义"""
name: str # 工具链名称(将作为组合工具的名称)
description: str # 工具链描述(供 LLM 理解)
steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤
input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述}
enabled: bool = True # 是否启用
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
@@ -74,7 +78,7 @@ class ToolChainDefinition:
"input_params": self.input_params,
"enabled": self.enabled,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition":
steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])]
@@ -90,12 +94,13 @@ class ToolChainDefinition:
@dataclass
class ChainExecutionResult:
"""工具链执行结果"""
success: bool
final_output: str # 最终输出(最后一个步骤的结果)
step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果
error: str = ""
total_duration_ms: float = 0.0
def to_summary(self) -> str:
"""生成执行摘要"""
lines = []
@@ -103,7 +108,7 @@ class ChainExecutionResult:
status = "" if step.get("success") else ""
tool = step.get("tool_name", "unknown")
duration = step.get("duration_ms", 0)
lines.append(f"{status} 步骤{i+1}: {tool} ({duration:.0f}ms)")
lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)")
if not step.get("success") and step.get("error"):
lines.append(f" 错误: {step['error'][:50]}")
return "\n".join(lines)
@@ -111,49 +116,49 @@ class ChainExecutionResult:
class ToolChainExecutor:
"""工具链执行器"""
# 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev}
VAR_PATTERN = re.compile(r'\$\{([^}]+)\}')
VAR_PATTERN = re.compile(r"\$\{([^}]+)\}")
def __init__(self, mcp_manager):
self._mcp_manager = mcp_manager
def _resolve_tool_key(self, tool_name: str) -> Optional[str]:
"""解析工具名,返回有效的 tool_key
支持:
- 直接使用 tool_key如 mcp_server_tool
- 使用注册后的工具名(会自动转换 - 和 . 为 _
"""
all_tools = self._mcp_manager.all_tools
# 直接匹配
if tool_name in all_tools:
return tool_name
# 尝试转换后匹配(用户可能使用了注册后的名称)
normalized = tool_name.replace("-", "_").replace(".", "_")
if normalized in all_tools:
return normalized
# 尝试查找包含该名称的工具
for key in all_tools.keys():
if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"):
return key
return None
async def execute(
self,
chain: ToolChainDefinition,
input_args: Dict[str, Any],
) -> ChainExecutionResult:
"""执行工具链
Args:
chain: 工具链定义
input_args: 用户输入的参数
Returns:
ChainExecutionResult: 执行结果
"""
@@ -164,15 +169,15 @@ class ToolChainExecutor:
"step": {}, # 各步骤输出,按 output_key 存储
"prev": "", # 上一步的输出
}
final_output = ""
# 验证必需的输入参数
missing_params = []
for param_name in chain.input_params.keys():
if param_name not in context["input"]:
missing_params.append(param_name)
if missing_params:
return ChainExecutionResult(
success=False,
@@ -180,7 +185,7 @@ class ToolChainExecutor:
error=f"缺少必需参数: {', '.join(missing_params)}",
total_duration_ms=(time.time() - start_time) * 1000,
)
for i, step in enumerate(chain.steps):
step_start = time.time()
step_result = {
@@ -191,96 +196,96 @@ class ToolChainExecutor:
"error": "",
"duration_ms": 0,
}
try:
# 替换参数中的变量
resolved_args = self._resolve_args(step.args_template, context)
step_result["resolved_args"] = resolved_args
# 解析工具名
tool_key = self._resolve_tool_key(step.tool_name)
if not tool_key:
step_result["error"] = f"工具 {step.tool_name} 不存在"
logger.warning(f"工具链步骤 {i+1}: 工具 {step.tool_name} 不存在")
logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在")
if not step.optional:
step_results.append(step_result)
return ChainExecutionResult(
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1}: 工具 {step.tool_name} 不存在",
error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在",
total_duration_ms=(time.time() - start_time) * 1000,
)
step_results.append(step_result)
continue
logger.debug(f"工具链步骤 {i+1}: 调用 {tool_key},参数: {resolved_args}")
logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}")
# 调用工具
result = await self._mcp_manager.call_tool(tool_key, resolved_args)
step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration
if result.success:
step_result["success"] = True
# 确保 content 不为 None
content = result.content if result.content is not None else ""
step_result["output"] = content
# 更新上下文
context["prev"] = content
if step.output_key:
context["step"][step.output_key] = content
final_output = content
content_preview = content[:100] if content else "(空)"
logger.debug(f"工具链步骤 {i+1} 成功: {content_preview}...")
logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...")
else:
step_result["error"] = result.error or "未知错误"
logger.warning(f"工具链步骤 {i+1} 失败: {result.error}")
logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}")
if not step.optional:
step_results.append(step_result)
return ChainExecutionResult(
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 失败: {result.error}",
error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}",
total_duration_ms=(time.time() - start_time) * 1000,
)
except Exception as e:
step_duration = (time.time() - step_start) * 1000
step_result["duration_ms"] = step_duration
step_result["error"] = str(e)
logger.error(f"工具链步骤 {i+1} 异常: {e}")
logger.error(f"工具链步骤 {i + 1} 异常: {e}")
if not step.optional:
step_results.append(step_result)
return ChainExecutionResult(
success=False,
final_output="",
step_results=step_results,
error=f"步骤 {i+1} ({step.tool_name}) 异常: {e}",
error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}",
total_duration_ms=(time.time() - start_time) * 1000,
)
step_results.append(step_result)
total_duration = (time.time() - start_time) * 1000
return ChainExecutionResult(
success=True,
final_output=final_output,
step_results=step_results,
total_duration_ms=total_duration,
)
def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
"""解析参数模板,替换变量
支持的变量格式:
- ${input.param_name}: 用户输入的参数
- ${step.output_key}: 某个步骤的输出
@@ -288,50 +293,48 @@ class ToolChainExecutor:
- ${prev.field}: 上一步输出JSON的某个字段
"""
resolved = {}
for key, value in args_template.items():
if isinstance(value, str):
resolved[key] = self._substitute_vars(value, context)
elif isinstance(value, dict):
resolved[key] = self._resolve_args(value, context)
elif isinstance(value, list):
resolved[key] = [
self._substitute_vars(v, context) if isinstance(v, str) else v
for v in value
]
resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value]
else:
resolved[key] = value
return resolved
def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str:
"""替换字符串中的变量"""
def replacer(match):
var_path = match.group(1)
return self._get_var_value(var_path, context)
return self.VAR_PATTERN.sub(replacer, template)
def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str:
"""获取变量值
Args:
var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id"
context: 上下文
"""
parts = self._parse_var_path(var_path)
if not parts:
return ""
# 获取根对象
root = parts[0]
if root not in context:
logger.warning(f"变量 {var_path} 的根 '{root}' 不存在")
return ""
value = context[root]
# 遍历路径
for part in parts[1:]:
if isinstance(value, str):
@@ -349,7 +352,7 @@ class ToolChainExecutor:
value = ""
else:
value = ""
# 确保返回字符串
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False)
@@ -448,39 +451,39 @@ class ToolChainExecutor:
class ToolChainManager:
"""工具链管理器"""
_instance: Optional["ToolChainManager"] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._chains: Dict[str, ToolChainDefinition] = {}
self._executor: Optional[ToolChainExecutor] = None
def set_executor(self, mcp_manager) -> None:
"""设置执行器"""
self._executor = ToolChainExecutor(mcp_manager)
def add_chain(self, chain: ToolChainDefinition) -> bool:
"""添加工具链"""
if not chain.name:
logger.error("工具链名称不能为空")
return False
if chain.name in self._chains:
logger.warning(f"工具链 {chain.name} 已存在,将被覆盖")
self._chains[chain.name] = chain
logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)")
return True
def remove_chain(self, name: str) -> bool:
"""移除工具链"""
if name in self._chains:
@@ -488,19 +491,19 @@ class ToolChainManager:
logger.info(f"已移除工具链: {name}")
return True
return False
def get_chain(self, name: str) -> Optional[ToolChainDefinition]:
"""获取工具链"""
return self._chains.get(name)
def get_all_chains(self) -> Dict[str, ToolChainDefinition]:
"""获取所有工具链"""
return self._chains.copy()
def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]:
"""获取所有启用的工具链"""
return {name: chain for name, chain in self._chains.items() if chain.enabled}
async def execute_chain(
self,
chain_name: str,
@@ -514,64 +517,64 @@ class ToolChainManager:
final_output="",
error=f"工具链 {chain_name} 不存在",
)
if not chain.enabled:
return ChainExecutionResult(
success=False,
final_output="",
error=f"工具链 {chain_name} 已禁用",
)
if not self._executor:
return ChainExecutionResult(
success=False,
final_output="",
error="工具链执行器未初始化",
)
return await self._executor.execute(chain, input_args)
def load_from_json(self, json_str: str) -> Tuple[int, List[str]]:
"""从 JSON 字符串加载工具链配置
Returns:
(成功加载数量, 错误列表)
"""
errors = []
loaded = 0
try:
data = json.loads(json_str) if json_str.strip() else []
except json.JSONDecodeError as e:
return 0, [f"JSON 解析失败: {e}"]
if not isinstance(data, list):
data = [data]
for i, item in enumerate(data):
try:
chain = ToolChainDefinition.from_dict(item)
if not chain.name:
errors.append(f"{i+1} 个工具链缺少名称")
errors.append(f"{i + 1} 个工具链缺少名称")
continue
if not chain.steps:
errors.append(f"工具链 {chain.name} 没有步骤")
continue
self.add_chain(chain)
loaded += 1
except Exception as e:
errors.append(f"{i+1} 个工具链解析失败: {e}")
errors.append(f"{i + 1} 个工具链解析失败: {e}")
return loaded, errors
def export_to_json(self, pretty: bool = True) -> str:
"""导出所有工具链为 JSON"""
chains_data = [chain.to_dict() for chain in self._chains.values()]
if pretty:
return json.dumps(chains_data, ensure_ascii=False, indent=2)
return json.dumps(chains_data, ensure_ascii=False)
def clear(self) -> None:
"""清空所有工具链"""
self._chains.clear()