diff --git a/plugins/MaiBot_MCPBridgePlugin/.gitignore b/plugins/MaiBot_MCPBridgePlugin/.gitignore deleted file mode 100644 index ebef83b0..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/.gitignore +++ /dev/null @@ -1,30 +0,0 @@ -# 运行时配置(包含用户敏感信息) -config.toml - -# 备份文件 -*.backup.* -*.bak - -# 日志 -logs/ -*.log -*.jsonl - -# Python 缓存 -__pycache__/ -*.py[cod] -*$py.class -*.so - -# 本地测试脚本(仓库不提交) -test_*.py - -# IDE -.idea/ -.vscode/ -*.swp -*.swo - -# 系统文件 -.DS_Store -Thumbs.db diff --git a/plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md b/plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md deleted file mode 100644 index 0c3feb46..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md +++ /dev/null @@ -1,24 +0,0 @@ -# Changelog - -本文件记录 `MaiBot_MCPBridgePlugin` 的用户可感知变更。 - -## 2.0.0 - -- 配置入口统一:MCP 服务器仅使用 Claude Desktop `mcpServers` JSON(`servers.claude_config_json`) -- 兼容迁移:自动识别旧版 `servers.list` 并迁移为 `mcpServers`(需在 WebUI 保存一次固化) -- 保持功能不变:保留 Workflow(硬流程/工具链)与 ReAct(软流程)双轨制能力 -- 精简实现:移除旧的 WebUI 导入导出/快速添加服务器实现与 `tomlkit` 依赖 -- 易用性:完善 Workflow 变量替换(支持数组下标与 bracket 写法),并优化 WebUI 配置区顺序 - -## 1.9.0 - -- 双轨制架构:ReAct(软流程)+ Workflow(硬流程/工具链) - -## 1.8.0 - -- Workflow(工具链):多工具顺序执行、变量替换、自定义 Workflow 并注册为组合工具 - -## 1.7.0 - -- 断路器模式、状态刷新、工具搜索等易用性增强 - diff --git a/plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md b/plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md deleted file mode 100644 index 7299fe13..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/DEVELOPMENT.md +++ /dev/null @@ -1,356 +0,0 @@ -# MCP 桥接插件开发文档 - -本文档面向开发者,介绍插件的架构设计、核心模块和扩展方式。 - -## 架构概览 - -``` -MaiBot_MCPBridgePlugin/ -├── plugin.py # 主插件文件,包含所有核心逻辑 -├── mcp_client.py # MCP 客户端封装 -├── tool_chain.py # 工具链(Workflow)模块 -├── core/ -│ └── claude_config.py # Claude Desktop mcpServers 解析/迁移 -├── config.toml # 运行时配置 -└── _manifest.json # 插件元数据 -``` - -## 核心模块 - -### 1. MCP 客户端 (`mcp_client.py`) - -封装了与 MCP 服务器的通信逻辑。 - -```python -from .mcp_client import mcp_manager, MCPServerConfig, TransportType - -# 添加服务器 -config = MCPServerConfig( - name="my-server", - transport=TransportType.STREAMABLE_HTTP, - url="https://mcp.example.com/mcp" -) -await mcp_manager.add_server(config) - -# 调用工具 -result = await mcp_manager.call_tool("server_tool_name", {"param": "value"}) -if result.success: - print(result.content) -``` - -**支持的传输类型:** -- `STDIO`: 本地进程通信 -- `SSE`: Server-Sent Events -- `HTTP`: HTTP 请求 -- `STREAMABLE_HTTP`: 流式 HTTP(推荐) - -### 2. 工具注册系统 - -MCP 工具通过动态类创建注册到 MaiBot: - -```python -# 创建工具代理类 -class MCPToolProxy(BaseTool): - name = "mcp_server_tool" - description = "工具描述" - parameters = [("param", ToolParamType.STRING, "参数描述", True, None)] - available_for_llm = True - - async def execute(self, function_args): - result = await mcp_manager.call_tool(self._mcp_tool_key, function_args) - return {"name": self.name, "content": result.content} -``` - -### 3. 工具链模块 (`tool_chain.py`) - -实现 Workflow 硬流程,支持多工具顺序执行。 - -```python -from .tool_chain import ToolChainDefinition, ToolChainStep, tool_chain_manager - -# 定义工具链 -chain = ToolChainDefinition( - name="search_and_detail", - description="搜索并获取详情", - input_params={"query": "搜索关键词"}, - steps=[ - ToolChainStep( - tool_name="mcp_server_search", - args_template={"keyword": "${input.query}"}, - output_key="search_result" - ), - ToolChainStep( - tool_name="mcp_server_detail", - args_template={"id": "${prev}"} - ) - ] -) - -# 注册并执行 -tool_chain_manager.add_chain(chain) -result = await tool_chain_manager.execute_chain("search_and_detail", {"query": "test"}) -``` - -**变量替换语法:** -- `${input.参数名}`: 用户输入 -- `${step.输出键}`: 指定步骤的输出 -- `${prev}`: 上一步输出 -- `${prev.字段}`: 上一步输出(JSON)的字段 -- `${step.geo.return.0.location}` / `${step.geo.return[0].location}`: 数组下标访问 -- `${step.geo['return'][0]['location']}`: bracket 写法(最通用) - -## 双轨制架构 - -### ReAct 软流程 - -将 MCP 工具注册到 MaiBot 的记忆检索 ReAct 系统,LLM 自主决策调用。 - -```python -def _register_tools_to_react(self) -> int: - from src.memory_system.retrieval_tools import register_memory_retrieval_tool - - def make_execute_func(tool_key: str): - async def execute_func(**kwargs) -> str: - result = await mcp_manager.call_tool(tool_key, kwargs) - return result.content if result.success else f"失败: {result.error}" - return execute_func - - register_memory_retrieval_tool( - name="mcp_tool_name", - description="工具描述", - parameters=[{"name": "param", "type": "string", "required": True}], - execute_func=make_execute_func("tool_key") - ) -``` - -### Workflow 硬流程 - -用户预定义的固定执行流程,注册为组合工具。 - -```python -def _register_tool_chains(self) -> None: - from src.plugin_system.core.component_registry import component_registry - - for chain_name, chain in tool_chain_manager.get_enabled_chains().items(): - info, tool_class = tool_chain_registry.register_chain(chain) - info.plugin_name = self.plugin_name - component_registry.register_component(info, tool_class) -``` - -## 配置系统 - -### MCP 服务器配置(Claude Desktop 规范) - -插件只接受 Claude Desktop 的 `mcpServers` JSON(见 `core/claude_config.py`)。配置入口统一为: - -- WebUI/配置文件:`[servers].claude_config_json` -- 命令:`/mcp import`(合并 `mcpServers`)与 `/mcp export`(导出当前 `mcpServers`) - -兼容迁移: -- 若检测到旧版 `servers.list`,会自动迁移为 `servers.claude_config_json`(仅迁移到内存配置,需 WebUI 保存一次固化)。 - -### WebUI 配置 Schema - -使用 `ConfigField` 定义 WebUI 配置项: - -```python -config_schema = { - "section_name": { - "field_name": ConfigField( - type=str, # 类型: str, bool, int, float - default="default_value", # 默认值 - description="字段描述", - label="显示标签", - input_type="textarea", # 输入类型: text, textarea, password - rows=5, # textarea 行数 - disabled=True, # 只读 - choices=["a", "b"], # 下拉选项 - hint="提示信息", - order=1, # 排序 - ), - }, -} -``` - -### 配置读取 - -```python -# 在组件中读取配置 -value = self.get_config("section.key", default="fallback") - -# 在插件类中读取 -value = self.config.get("section", {}).get("key", "default") -``` - -## 事件处理 - -### 启动事件 - -```python -class MCPStartupHandler(BaseEventHandler): - event_type = EventType.ON_START - handler_name = "mcp_startup" - - async def execute(self, message): - global _plugin_instance - if _plugin_instance: - await _plugin_instance._async_connect_servers() - return (True, True, None, None, None) -``` - -### 停止事件 - -```python -class MCPStopHandler(BaseEventHandler): - event_type = EventType.ON_STOP - handler_name = "mcp_stop" - - async def execute(self, message): - await mcp_manager.shutdown() - return (True, True, None, None, None) -``` - -## 命令系统 - -```python -class MCPStatusCommand(BaseCommand): - command_name = "mcp_status" - command_pattern = r"^/mcp(?:\s+(?P\S+))?(?:\s+(?P.+))?$" - - async def execute(self) -> Tuple[bool, str, bool]: - action = self.matched_groups.get("action", "") - arg = self.matched_groups.get("arg", "") - - if action == "tools": - await self.send_text("工具列表...") - elif action == "reconnect": - await self._handle_reconnect(arg) - - return (True, None, True) # (成功, 消息, 拦截) -``` - -## 高级功能 - -### 调用追踪 - -```python -from plugin import tool_call_tracer, ToolCallRecord - -# 记录调用 -record = ToolCallRecord( - call_id="xxx", - timestamp=time.time(), - tool_name="tool", - server_name="server", - arguments={"key": "value"}, - success=True, - duration_ms=100.0 -) -tool_call_tracer.record(record) - -# 查询记录 -recent = tool_call_tracer.get_recent(10) -by_tool = tool_call_tracer.get_by_tool("tool_name") -``` - -### 调用缓存 - -```python -from plugin import tool_call_cache - -# 配置缓存 -tool_call_cache.configure( - enabled=True, - ttl=300, # 秒 - max_entries=200, - exclude_tools="mcp_*_time_*" # 排除模式 -) - -# 使用缓存 -cached = tool_call_cache.get("tool_name", {"param": "value"}) -if cached is None: - result = await call_tool(...) - tool_call_cache.set("tool_name", {"param": "value"}, result) -``` - -### 权限控制 - -```python -from plugin import permission_checker - -# 配置权限 -permission_checker.configure( - enabled=True, - default_mode="allow_all", # 或 "deny_all" - rules_json='[{"tool": "mcp_*_delete_*", "denied": ["qq:123:group"]}]', - quick_deny_groups="123456789", - quick_allow_users="111111111" -) - -# 检查权限 -allowed = permission_checker.check( - tool_name="mcp_server_delete", - chat_id="123456", - user_id="789", - is_group=True -) -``` - -### 断路器模式 - -MCP 客户端内置断路器,故障服务器快速失败: - -- 连续失败 N 次后熔断 -- 熔断期间直接返回错误 -- 定期尝试恢复 - -## 扩展开发 - -### 添加新的传输类型 - -1. 在 `mcp_client.py` 中添加 `TransportType` 枚举值 -2. 实现对应的连接逻辑 -3. 更新 `_create_transport()` 方法 - -### 添加新的工具类型 - -1. 继承 `BaseTool` 创建新类 -2. 在 `get_plugin_components()` 中注册 -3. 实现 `execute()` 方法 - -### 添加新的命令 - -1. 在 `MCPStatusCommand.execute()` 中添加新的 action 分支 -2. 或创建新的 `BaseCommand` 子类 - -## 调试技巧 - -### 日志级别 - -```python -from src.common.logger import get_logger -logger = get_logger("mcp_bridge_plugin") - -logger.debug("详细调试信息") -logger.info("一般信息") -logger.warning("警告") -logger.error("错误") -``` - -### 常用调试命令 - -```bash -/mcp # 查看状态 -/mcp tools # 查看工具列表 -/mcp trace # 查看调用记录 -/mcp cache # 查看缓存状态 -/mcp chain # 查看工具链 -``` - -## 更新日志 - -见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md` - -## 开发约定 - -- 本仓库不提交测试脚本/临时复现文件;如需本地验证,可自行在工作区创建未跟踪文件(建议放到 `.local/` 并加入 `.gitignore`)。 diff --git a/plugins/MaiBot_MCPBridgePlugin/README.md b/plugins/MaiBot_MCPBridgePlugin/README.md deleted file mode 100644 index 61aca8f5..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/README.md +++ /dev/null @@ -1,357 +0,0 @@ -# MCP 桥接插件 - -将 [MCP (Model Context Protocol)](https://modelcontextprotocol.io/) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。 - -image - -## 🚀 快速开始 - -### 1. 安装 - -```bash -# 克隆到 MaiBot 插件目录 -cd /path/to/MaiBot/plugins -git clone https://github.com/CharTyr/MaiBot_MCPBridgePlugin.git MCPBridgePlugin - -# 安装依赖 -pip install mcp - -# 复制配置文件 -cd MCPBridgePlugin -cp config.example.toml config.toml -``` - -### 2. 添加服务器 - -编辑 `config.toml`,在 `[servers]` 的 `claude_config_json` 中填写 Claude Desktop 的 `mcpServers` JSON: - -```toml -[servers] -claude_config_json = ''' -{ - "mcpServers": { - "time": { "transport": "streamable_http", "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" }, - "my-server": { "transport": "streamable_http", "url": "https://mcp.xxx.com/mcp", "headers": { "Authorization": "Bearer 你的密钥" } }, - "fetch": { "command": "uvx", "args": ["mcp-server-fetch"] } - } -} -''' -``` - -### 3. 启动 - -重启 MaiBot,或发送 `/mcp reconnect` - ---- - -## 📚 去哪找 MCP 服务器? - -| 平台 | 说明 | -|------|------| -| [mcp.modelscope.cn](https://mcp.modelscope.cn/) | 魔搭 ModelScope,免费推荐 | -| [smithery.ai](https://smithery.ai/) | MCP 服务器注册中心 | -| [github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) | 官方服务器列表 | - ---- - -## 💡 常用命令 - -| 命令 | 说明 | -|------|------| -| `/mcp` | 查看连接状态 | -| `/mcp tools` | 查看可用工具 | -| `/mcp reconnect` | 重连服务器 | -| `/mcp trace` | 查看调用记录 | -| `/mcp cache` | 查看缓存状态 | -| `/mcp perm` | 查看权限配置 | -| `/mcp import ` | 🆕 导入 Claude Desktop 配置 | -| `/mcp export` | 🆕 导出配置 | -| `/mcp search <关键词>` | 🆕 搜索工具 | -| `/mcp chain` | 🆕 查看工具链 | -| `/mcp chain <名称>` | 🆕 查看工具链详情 | -| `/mcp chain test <名称> <参数>` | 🆕 测试执行工具链 | - ---- - -## ✨ 功能特性 - -### 核心功能 -- 🔌 多服务器同时连接 -- 📡 支持 stdio / SSE / HTTP / Streamable HTTP -- 🔄 自动重试、心跳检测、断线重连 -- 🖥️ WebUI 完整配置支持 - -### 双轨制架构 -- 🔄 **ReAct(软流程)**:LLM 自主决策,多轮动态调用 MCP 工具(适合探索式场景) -- 🔗 **Workflow(硬流程/工具链)**:用户预定义步骤顺序与参数传递(适合可控可复用场景) - -### 高级功能 -- 📦 Resources 支持(实验性) -- 📝 Prompts 支持(实验性) -- 🔄 结果后处理(LLM 摘要提炼) -- 🔍 调用追踪 / 🗄️ 调用缓存 / 🔐 权限控制 / 🚫 工具禁用 - -### 更新日志 -- 见 `plugins/MaiBot_MCPBridgePlugin/CHANGELOG.md` - ---- - -## ⚙️ 配置说明 - -### 服务器配置 - -```json -{ - "mcpServers": { - "server_name": { - "transport": "streamable_http", - "url": "https://..." - } - } -} -``` - -| 字段 | 说明 | -|------|------| -| `mcpServers.` | 服务器名称(唯一) | -| `enabled` | 是否启用(可选,默认 true) | -| `transport` | `stdio` / `sse` / `http` / `streamable_http` | -| `url` | 远程服务器地址 | -| `headers` | 🆕 鉴权头(如 `{"Authorization": "Bearer xxx"}`) | -| `command` / `args` | 本地服务器启动命令 | - -### 权限控制 - -**快捷配置(推荐):** -```toml -[permissions] -perm_enabled = true -quick_deny_groups = "123456789" # 禁用的群号 -quick_allow_users = "111111111" # 管理员白名单 -``` - -**高级规则:** -```json -[{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}] -``` - -### 工具禁用 - -```toml -[tools] -disabled_tools = ''' -mcp_filesystem_delete_file -mcp_filesystem_write_file -''' -``` - -### 调用缓存 - -```toml -[settings] -cache_enabled = true -cache_ttl = 300 -cache_exclude_tools = "mcp_*_time_*" -``` - ---- - -## ❓ 常见问题 - -**Q: 工具没有注册?** -- 检查 `enabled = true` -- 检查 MaiBot 日志错误信息 -- 确认 `pip install mcp` - -**Q: JSON 格式报错?** -- 多行 JSON 用 `'''` 三引号包裹 -- 使用英文双引号 `"` - -**Q: 如何手动重连?** -- `/mcp reconnect` 或 `/mcp reconnect 服务器名` - ---- - -## 📥 配置导入导出(Claude mcpServers) - -### 从 Claude Desktop 导入 - -如果你已有 Claude Desktop 的 MCP 配置,可以直接导入: - -``` -/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]},"fetch":{"command":"uvx","args":["mcp-server-fetch"]}}} -``` - -支持的格式: -- Claude Desktop 格式(`mcpServers` 对象) -- 兼容旧版:MaiBot servers 列表数组(将自动迁移为 `mcpServers`) - -### 导出配置 - -``` -/mcp export # 导出为 Claude Desktop 格式(默认) -/mcp export claude # 导出为 Claude Desktop 格式 -``` - -### 注意事项 -- 导入时会自动跳过同名服务器 -- 导入后需要发送 `/mcp reconnect` 使配置生效 -- 支持 stdio、sse、http、streamable_http 全部传输类型 - ---- - -## 🔗 Workflow(硬流程/工具链) - -工具链允许你将多个 MCP 工具按顺序执行,后续工具可以使用前序工具的输出作为输入。 - -### 1 分钟上手(推荐 WebUI) -1. 先完成 MCP 服务器配置并 `/mcp reconnect` -2. 发送 `/mcp tools`,复制你要用的工具名 -3. 打开 WebUI → 「Workflow(硬流程/工具链)」→ 用“快速添加”表单填入: - - 名称/描述 - - 输入参数(每行 `参数名=描述`) - - 执行步骤(每行 `工具名|参数JSON|输出键`) -4. 在“确认添加”中输入 `ADD` 并保存 - -### 快速添加工具链(推荐) - -在 WebUI 的「工具链」配置区,使用表单快速添加: - -1. **名称**: 填写工具链名称(英文,如 `search_and_detail`) -2. **描述**: 填写工具链用途(供 LLM 理解何时使用) -3. **输入参数**: 每行一个,格式 `参数名=描述` - ``` - query=搜索关键词 - max_results=最大结果数 - ``` -4. **执行步骤**: 每行一个,格式 `工具名|参数JSON|输出键` - ``` - mcp_server_search|{"keyword":"${input.query}"}|search_result - mcp_server_detail|{"id":"${prev}"}| - ``` -5. **确认添加**: 输入 `ADD` 并保存 - -### JSON 配置方式 - -也可以直接在「工具链列表」中编写 JSON: - -```json -[ - { - "name": "search_and_detail", - "description": "先搜索模组,再获取详情", - "input_params": { - "query": "搜索关键词" - }, - "steps": [ - { - "tool_name": "mcp_mcmod_search_mod", - "args_template": {"keyword": "${input.query}", "limit": 1}, - "output_key": "search_result", - "description": "搜索模组" - }, - { - "tool_name": "mcp_mcmod_get_mod_detail", - "args_template": {"mod_id": "${prev}"}, - "description": "获取详情" - } - ] - } -] -``` - -### 变量替换 - -| 变量格式 | 说明 | -|---------|------| -| `${input.参数名}` | 用户输入的参数 | -| `${step.输出键}` | 某个步骤的输出(通过 `output_key` 指定) | -| `${prev}` | 上一步的输出 | -| `${prev.字段}` | 上一步输出(JSON)的某个字段 | -| `${step.geo.return.0.location}` | 数组下标访问(dot) | -| `${step.geo.return[0].location}` | 数组下标访问([]) | -| `${step.geo['return'][0]['location']}` | bracket 写法(最通用) | - -### 工具链字段说明 - -| 字段 | 说明 | -|------|------| -| `name` | 工具链名称,将生成 `chain_xxx` 工具 | -| `description` | 描述,供 LLM 理解何时使用 | -| `input_params` | 输入参数定义 `{参数名: 描述}` | -| `steps` | 执行步骤数组 | -| `steps[].tool_name` | 要调用的工具名 | -| `steps[].args_template` | 参数模板,支持变量替换 | -| `steps[].output_key` | 输出存储键名(可选) | -| `steps[].optional` | 是否可选,失败时继续执行(默认 false) | - -### 命令 - -```bash -/mcp chain # 查看所有工具链 -/mcp chain list # 列出工具链 -/mcp chain <名称> # 查看详情 -/mcp chain test <名称> {"query": "JEI"} # 测试执行 -/mcp chain reload # 重新加载配置 -``` - ---- - -## 🔄 双轨制架构 - -MCP 桥接插件支持两种工具调用模式,可根据场景选择: - -### ReAct 软流程 - -LLM 自主决策的多轮工具调用模式,适合复杂、不确定的场景。 - -**工作原理:** -1. 用户提问 → LLM 分析需要什么信息 -2. LLM 选择调用工具 → 获取结果 -3. LLM 观察结果 → 决定是否需要更多信息 -4. 重复 2-3 直到信息足够 → 生成最终回答 - -**启用方式:** -在 WebUI「ReAct (软流程)」配置区启用,MCP 工具将自动注册到 MaiBot 的记忆检索 ReAct 系统。 - -**适用场景:** -- 复杂问题需要多步推理 -- 不确定需要调用哪些工具 -- 需要根据中间结果动态调整 - -### Workflow 硬流程 - -用户预定义的工作流,固定执行顺序,适合可靠、可控的场景。 - -**工作原理:** -1. 用户定义步骤顺序和参数传递 -2. 按顺序执行每个步骤 -3. 后续步骤可使用前序步骤的输出 -4. 返回最终结果 - -**适用场景:** -- 流程固定、可预测 -- 需要可靠、可重复的执行 -- 希望精确控制工具调用顺序 - -### 对比 - -| 特性 | ReAct 软流程 | Workflow 硬流程 | -|------|-------------|----------------| -| 决策者 | LLM 自主决策 | 用户预定义 | -| 灵活性 | 高,动态调整 | 低,固定流程 | -| 可预测性 | 低 | 高 | -| 适用场景 | 复杂、探索性任务 | 固定、重复性任务 | -| 配置方式 | 启用即可 | 需要定义步骤 | - ---- - -## 📋 依赖 - -- MaiBot >= 0.11.6 -- Python >= 3.10 -- mcp >= 1.0.0 - -## 📄 许可证 - -AGPL-3.0 diff --git a/plugins/MaiBot_MCPBridgePlugin/__init__.py b/plugins/MaiBot_MCPBridgePlugin/__init__.py deleted file mode 100644 index 80e2ae47..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -""" -MCP 桥接插件 -将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot - -v1.1.0 新增功能: -- 心跳检测和自动重连 -- 调用统计(次数、成功率、耗时) -- 更好的错误处理 - -v1.2.0 新增功能: -- Resources 支持(资源读取) -- Prompts 支持(提示模板) -""" - -from .plugin import MCPBridgePlugin, mcp_tool_registry, MCPStartupHandler, MCPStopHandler -from .mcp_client import ( - mcp_manager, - MCPClientManager, - MCPServerConfig, - TransportType, - MCPCallResult, - MCPToolInfo, - MCPResourceInfo, - MCPPromptInfo, - ToolCallStats, - ServerStats, -) - -__all__ = [ - "MCPBridgePlugin", - "mcp_tool_registry", - "mcp_manager", - "MCPClientManager", - "MCPServerConfig", - "TransportType", - "MCPCallResult", - "MCPToolInfo", - "MCPResourceInfo", - "MCPPromptInfo", - "ToolCallStats", - "ServerStats", - "MCPStartupHandler", - "MCPStopHandler", -] diff --git a/plugins/MaiBot_MCPBridgePlugin/_manifest.json b/plugins/MaiBot_MCPBridgePlugin/_manifest.json deleted file mode 100644 index d2e08ab4..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/_manifest.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "manifest_version": 2, - "version": "2.0.0", - "name": "MCP桥接插件", - "description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot,使麦麦能够调用外部 MCP 工具。", - "author": { - "name": "CharTyr", - "url": "https://github.com/CharTyr" - }, - "license": "AGPL-3.0", - "urls": { - "repository": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "homepage": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "documentation": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin", - "issues": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin/issues" - }, - "host_application": { - "min_version": "0.11.6", - "max_version": "1.0.0" - }, - "sdk": { - "min_version": "2.0.0", - "max_version": "2.99.99" - }, - "dependencies": [ - { - "type": "python_package", - "name": "mcp", - "version_spec": ">=0.0.0" - } - ], - "capabilities": [ - "send.text" - ], - "i18n": { - "default_locale": "zh-CN", - "supported_locales": [ - "zh-CN" - ] - }, - "id": "chartyr.mcpbridge-plugin" -} diff --git a/plugins/MaiBot_MCPBridgePlugin/config.example.toml b/plugins/MaiBot_MCPBridgePlugin/config.example.toml deleted file mode 100644 index 4edac27a..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/config.example.toml +++ /dev/null @@ -1,309 +0,0 @@ -# MCP桥接插件 - 配置文件示例 -# 将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot -# -# 使用方法:复制此文件为 config.toml,然后根据需要修改配置 -# -# ============================================================ -# 🎯 快速开始(三步) -# ============================================================ -# 1. 在下方 [servers] 添加 MCP 服务器配置 -# 2. 将 enabled 改为 true 启用服务器 -# 3. 重启 MaiBot 或发送 /mcp reconnect -# -# ============================================================ -# 📚 去哪找 MCP 服务器? -# ============================================================ -# -# 【远程服务(推荐新手)】 -# - ModelScope: https://mcp.modelscope.cn/ (免费,推荐) -# - Smithery: https://smithery.ai/ -# - Glama: https://glama.ai/mcp/servers -# -# 【本地服务(需要 npx 或 uvx)】 -# - 官方列表: https://github.com/modelcontextprotocol/servers -# -# ============================================================ - -# ============================================================ -# 🔌 MCP 服务器配置 -# ============================================================ -# -# ⚠️ 重要:配置格式(Claude Desktop 规范) -# ──────────────────────────────────────────────────────────── -# 统一使用 Claude Desktop 的 mcpServers JSON。 -# -# claude_config_json 的内容应为 JSON 对象: -# { -# "mcpServers": { -# "server_name": { ...server config... }, -# "another": { ... } -# } -# } -# -# 每个服务器支持字段: -# transport - 传输方式: "stdio" / "sse" / "http" / "streamable_http"(可选) -# url - 服务器地址(sse/http/streamable_http 模式) -# command - 启动命令(stdio 模式,如 "npx" / "uvx") -# args - 命令参数数组(stdio 模式) -# env - 环境变量对象(stdio 模式,可选) -# headers - 鉴权头(可选,如 {"Authorization": "Bearer xxx"}) -# enabled - 是否启用(可选,默认 true) -# post_process - 服务器级别后处理配置(可选) -# -# ============================================================ - -[servers] -claude_config_json = ''' -{ - "mcpServers": { - "time-mcp-server": { - "enabled": false, - "transport": "streamable_http", - "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" - }, - "my-auth-server": { - "enabled": false, - "transport": "streamable_http", - "url": "https://mcp.api-inference.modelscope.net/xxxxxx/mcp", - "headers": { - "Authorization": "Bearer ms-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" - } - }, - "fetch-local": { - "enabled": false, - "command": "uvx", - "args": ["mcp-server-fetch"] - } - } -} -''' - -# ============================================================ -# 插件基本信息 -# ============================================================ -[plugin] -name = "mcp_bridge_plugin" -version = "2.0.0" -config_version = "2.0.0" -enabled = false # 默认禁用,在 WebUI 中启用 - -# ============================================================ -# Workflow(硬流程/工具链) -# ============================================================ -# -# 作用:把多个工具按顺序执行;后续步骤可引用前序输出。 -# -# ✅ 推荐配置方式:WebUI「Workflow(硬流程/工具链)」里用“快速添加”表单。 -# ✅ 也可以直接写 chains_list(JSON 数组)。 -# -# 变量替换: -# ${input.xxx} - 用户输入 -# ${step.} - 指定步骤输出(需设置 output_key) -# ${prev} - 上一步输出 -# ${prev.字段} - 上一步输出(JSON)的字段 -# ${step.geo.return.0.location} - 数组/下标访问(dot) -# ${step.geo.return[0].location} - 数组/下标访问([]) -# ${step.geo['return'][0]['location']} - bracket 写法 -# -# ============================================================ - -[tool_chains] -chains_enabled = true - -chains_list = ''' -[ - { - "name": "search_and_detail", - "description": "先搜索,再根据结果获取详情", - "input_params": { "query": "搜索关键词" }, - "steps": [ - { "tool_name": "把这里替换成你的搜索工具名", "args_template": { "keyword": "${input.query}" }, "output_key": "search" }, - { "tool_name": "把这里替换成你的详情工具名", "args_template": { "id": "${prev}" } } - ] - } -] -''' - -# ============================================================ -# ReAct(软流程) -# ============================================================ -# -# 作用:把 MCP 工具注册到 MaiBot 的 ReAct 系统,LLM 可自主多轮调用。 -# -# 注意:ReAct 适合“探索式/不确定”场景;Workflow 适合“固定/可控”场景。 -# -# ============================================================ - -[react] -react_enabled = false -filter_mode = "whitelist" # whitelist / blacklist -tool_filter = "" # 每行一个工具名,支持通配符 * - -# ============================================================ -# 全局设置(高级设置建议保持默认) -# ============================================================ -[settings] -# 🏷️ 工具前缀 - 用于区分 MCP 工具和原生工具 -tool_prefix = "mcp" - -# ⏱️ 连接超时(秒) -connect_timeout = 30.0 - -# ⏱️ 调用超时(秒) -call_timeout = 60.0 - -# 🔄 自动连接 - 启动时自动连接所有已启用的服务器 -auto_connect = true - -# 🔁 重试次数 - 连接失败时的重试次数 -retry_attempts = 3 - -# ⏳ 重试间隔(秒) -retry_interval = 5.0 - -# 💓 心跳检测 - 定期检测服务器连接状态 -heartbeat_enabled = true - -# 💓 心跳间隔(秒)- 建议 30-120 秒 -heartbeat_interval = 60.0 - -# 🔄 自动重连 - 检测到断开时自动尝试重连 -auto_reconnect = true - -# 🔄 最大重连次数 - 连续重连失败后暂停重连 -max_reconnect_attempts = 3 - -# ============================================================ -# 高级功能(实验性) -# ============================================================ -# 📦 启用 Resources - 允许读取 MCP 服务器提供的资源 -enable_resources = false - -# 📝 启用 Prompts - 允许使用 MCP 服务器提供的提示模板 -enable_prompts = false - -# ============================================================ -# 结果后处理功能 -# ============================================================ -# 当 MCP 工具返回的内容过长时,使用 LLM 对结果进行摘要提炼 - -# 🔄 启用结果后处理 -post_process_enabled = false - -# 📏 后处理阈值(字符数)- 结果长度超过此值才触发后处理 -post_process_threshold = 500 - -# 🔢 后处理输出限制 - LLM 摘要输出的最大 token 数 -post_process_max_tokens = 500 - -# 🤖 后处理模型(可选)- 留空则使用 utils 模型组 -post_process_model = "" - -# 🧠 后处理提示词模板 -post_process_prompt = '''用户问题:{query} - -工具返回内容: -{result} - -请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:''' - -# ============================================================ -# 调用链路追踪 -# ============================================================ -# 记录工具调用详情,便于调试和分析 - -# 🔍 启用调用追踪 -trace_enabled = true - -# 📊 追踪记录上限 - 内存中保留的最大记录数 -trace_max_records = 50 - -# 📝 追踪日志文件 - 是否将追踪记录写入日志文件 -# 启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl -trace_log_enabled = false - -# ============================================================ -# 工具调用缓存 -# ============================================================ -# 缓存相同参数的调用结果,减少重复请求 - -# 🗄️ 启用调用缓存 -cache_enabled = false - -# ⏱️ 缓存有效期(秒) -cache_ttl = 300 - -# 📦 最大缓存条目 - 超出后 LRU 淘汰 -cache_max_entries = 200 - -# 🚫 缓存排除列表 - 即不缓存的工具(每行一个,支持通配符 *) -# 时间类、随机类工具建议排除 -cache_exclude_tools = ''' -mcp_*_time_* -mcp_*_random_* -''' - -# ============================================================ -# 工具管理 -# ============================================================ -[tools] -# 📋 工具清单(只读)- 启动后自动生成 -tool_list = "(启动后自动生成)" - -# 🚫 禁用工具列表 - 要禁用的工具名(每行一个) -# 从上方工具清单复制工具名,禁用后该工具不会被 LLM 调用 -# 示例: -# disabled_tools = ''' -# mcp_filesystem_delete_file -# mcp_filesystem_write_file -# ''' -disabled_tools = "" - -# ============================================================ -# 权限控制 -# ============================================================ -[permissions] -# 🔐 启用权限控制 - 按群/用户限制工具使用 -perm_enabled = false - -# 📋 默认模式 -# allow_all: 未配置规则的工具默认允许 -# deny_all: 未配置规则的工具默认禁止 -perm_default_mode = "allow_all" - -# ──────────────────────────────────────────────────────────── -# 🚀 快捷配置(推荐新手使用) -# ──────────────────────────────────────────────────────────── - -# 🚫 禁用群列表 - 这些群无法使用任何 MCP 工具(每行一个群号) -# 示例: -# quick_deny_groups = ''' -# 123456789 -# 987654321 -# ''' -quick_deny_groups = "" - -# ✅ 管理员白名单 - 这些用户始终可以使用所有工具(每行一个QQ号) -# 示例: -# quick_allow_users = ''' -# 111111111 -# ''' -quick_allow_users = "" - -# ──────────────────────────────────────────────────────────── -# 📜 高级权限规则(可选,针对特定工具配置) -# ──────────────────────────────────────────────────────────── -# 格式: qq:ID:group/private/user,工具名支持通配符 * -# 示例: -# perm_rules = ''' -# [ -# {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} -# ] -# ''' -perm_rules = "[]" - -# ============================================================ -# 状态显示(只读) -# ============================================================ -[status] -connection_status = "未初始化" diff --git a/plugins/MaiBot_MCPBridgePlugin/core/__init__.py b/plugins/MaiBot_MCPBridgePlugin/core/__init__.py deleted file mode 100644 index d5656a8e..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Core helpers for MCP Bridge Plugin.""" diff --git a/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py b/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py deleted file mode 100644 index f2a6f011..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/core/claude_config.py +++ /dev/null @@ -1,169 +0,0 @@ -import json -from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional - - -class ClaudeConfigError(ValueError): - pass - - -Transport = Literal["stdio", "sse", "http", "streamable_http"] - - -@dataclass(frozen=True) -class ClaudeMcpServer: - name: str - transport: Transport - command: str = "" - args: List[str] = field(default_factory=list) - env: Dict[str, str] = field(default_factory=dict) - url: str = "" - headers: Dict[str, str] = field(default_factory=dict) - enabled: bool = True - - -def _normalize_transport(value: Optional[str]) -> Transport: - if not value: - return "streamable_http" - v = value.strip().lower().replace("-", "_") - if v in ("streamable_http", "streamablehttp", "streamable"): - return "streamable_http" - if v in ("http",): - return "http" - if v in ("sse",): - return "sse" - if v in ("stdio",): - return "stdio" - raise ClaudeConfigError(f"unsupported transport: {value}") - - -def _coerce_str_list(value: Any, field_name: str) -> List[str]: - if value is None: - return [] - if isinstance(value, list): - return [str(v) for v in value] - raise ClaudeConfigError(f"{field_name} must be a list") - - -def _coerce_str_dict(value: Any, field_name: str) -> Dict[str, str]: - if value is None: - return {} - if isinstance(value, dict): - return {str(k): str(v) for k, v in value.items()} - raise ClaudeConfigError(f"{field_name} must be an object") - - -def parse_claude_mcp_config(config_json: str) -> List[ClaudeMcpServer]: - """Parse Claude Desktop style MCP config JSON. - - Supported: - - Full object: {"mcpServers": {...}} - - Direct mapping: {...} treated as mcpServers - """ - text = (config_json or "").strip() - if not text: - return [] - - try: - data = json.loads(text) - except json.JSONDecodeError as e: - raise ClaudeConfigError(f"invalid JSON: {e}") from e - - if not isinstance(data, dict): - raise ClaudeConfigError("config must be a JSON object") - - servers_obj = data.get("mcpServers", data) - if not isinstance(servers_obj, dict): - raise ClaudeConfigError("mcpServers must be an object") - - servers: List[ClaudeMcpServer] = [] - for name, raw in servers_obj.items(): - if not isinstance(name, str) or not name.strip(): - raise ClaudeConfigError("server name must be a non-empty string") - if not isinstance(raw, dict): - raise ClaudeConfigError(f"server '{name}' must be an object") - - enabled = bool(raw.get("enabled", True)) - command = str(raw.get("command", "") or "") - url = str(raw.get("url", "") or "") - args = _coerce_str_list(raw.get("args"), "args") - env = _coerce_str_dict(raw.get("env"), "env") - headers = _coerce_str_dict(raw.get("headers"), "headers") - - transport_hint = raw.get("transport", raw.get("type")) - - if command: - transport: Transport = "stdio" - elif url: - try: - transport = _normalize_transport(str(transport_hint) if transport_hint is not None else None) - except ClaudeConfigError: - transport = "streamable_http" - else: - raise ClaudeConfigError(f"server '{name}' must have either 'command' or 'url'") - - servers.append( - ClaudeMcpServer( - name=name, - transport=transport, - command=command, - args=args, - env=env, - url=url, - headers=headers, - enabled=enabled, - ) - ) - - return servers - - -def legacy_servers_list_to_claude_config(servers_list_json: str) -> str: - """Convert legacy v1.x servers list (JSON array) to Claude mcpServers JSON. - - Legacy item schema: - {"name","enabled","transport","url","headers","command","args","env"} - """ - text = (servers_list_json or "").strip() - if not text: - return "" - try: - data = json.loads(text) - except json.JSONDecodeError: - return "" - if isinstance(data, dict): - data = [data] - if not isinstance(data, list): - return "" - - mcp_servers: Dict[str, Any] = {} - for item in data: - if not isinstance(item, dict): - continue - name = str(item.get("name", "") or "").strip() - if not name: - continue - enabled = bool(item.get("enabled", True)) - transport = str(item.get("transport", "") or "").strip().lower().replace("-", "_") - - if transport == "stdio" or item.get("command"): - entry: Dict[str, Any] = { - "enabled": enabled, - "command": item.get("command", "") or "", - "args": item.get("args", []) or [], - } - if item.get("env"): - entry["env"] = item.get("env") - mcp_servers[name] = entry - continue - - entry = {"enabled": enabled, "url": item.get("url", "") or ""} - if item.get("headers"): - entry["headers"] = item.get("headers") - if transport: - entry["transport"] = transport - mcp_servers[name] = entry - - if not mcp_servers: - return "" - return json.dumps({"mcpServers": mcp_servers}, ensure_ascii=False, indent=2) diff --git a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py b/plugins/MaiBot_MCPBridgePlugin/mcp_client.py deleted file mode 100644 index de5abab2..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/mcp_client.py +++ /dev/null @@ -1,1485 +0,0 @@ -""" -MCP 客户端封装模块 -负责与 MCP 服务器建立连接、获取工具列表、执行工具调用 - -v1.7.0 稳定性优化: -- 断路器模式:连续失败 5 次后熔断,60 秒后试探恢复 -- 熔断期间快速失败,避免等待超时 -- 连接成功时自动重置断路器 - -v1.5.2 性能优化: -- 智能心跳间隔:根据服务器稳定性动态调整心跳频率 -- 稳定服务器逐渐增加间隔(最高 3x),减少不必要的检测 -- 断开的服务器使用较短间隔快速重连 - -v1.1.0 新增功能: -- 调用统计(次数、成功率、耗时) -- 心跳检测 -- 自动重连 -- 更好的错误处理 - -v1.2.0 新增功能: -- Resources 支持(资源读取) -- Prompts 支持(提示模板) -- 新增配置项: enable_resources, enable_prompts -""" - -import asyncio -import time -import logging -from typing import Any, Dict, List, Optional, Tuple -from dataclasses import dataclass, field -from enum import Enum - -# 尝试导入 MaiBot 的 logger,如果失败则使用标准 logging -try: - from src.common.logger import get_logger - - logger = get_logger("mcp_client") -except ImportError: - # Fallback: 使用标准 logging - logger = logging.getLogger("mcp_client") - if not logger.handlers: - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter("[%(levelname)s] %(name)s: %(message)s")) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - - -class TransportType(Enum): - """MCP 传输类型""" - - STDIO = "stdio" # 本地进程通信 - SSE = "sse" # Server-Sent Events (旧版 HTTP) - HTTP = "http" # HTTP Streamable (新版,推荐) - STREAMABLE_HTTP = "streamable_http" # HTTP Streamable 的别名 - - -@dataclass -class MCPToolInfo: - """MCP 工具信息""" - - name: str - description: str - input_schema: Dict[str, Any] - server_name: str - - -@dataclass -class MCPResourceInfo: - """MCP 资源信息""" - - uri: str - name: str - description: str - mime_type: Optional[str] - server_name: str - - -@dataclass -class MCPPromptInfo: - """MCP 提示模板信息""" - - name: str - description: str - arguments: List[Dict[str, Any]] # [{name, description, required}] - server_name: str - - -@dataclass -class MCPServerConfig: - """MCP 服务器配置""" - - name: str - enabled: bool = True - transport: TransportType = TransportType.STDIO - # stdio 配置 - command: str = "" - args: List[str] = field(default_factory=list) - env: Dict[str, str] = field(default_factory=dict) - # http/sse 配置 - url: str = "" - headers: Dict[str, str] = field(default_factory=dict) # v1.4.2: 鉴权头支持 - - -@dataclass -class MCPCallResult: - """MCP 工具调用结果""" - - success: bool - content: Any - error: Optional[str] = None - duration_ms: float = 0.0 # 调用耗时(毫秒) - circuit_broken: bool = False # v1.7.0: 是否被断路器拦截 - - -class CircuitState(Enum): - """断路器状态""" - - CLOSED = "closed" # 正常状态,允许请求 - OPEN = "open" # 熔断状态,拒绝请求 - HALF_OPEN = "half_open" # 半开状态,允许少量试探请求 - - -@dataclass -class CircuitBreaker: - """v1.7.0: 断路器 - 防止对故障服务器持续请求 - - 状态转换: - - CLOSED -> OPEN: 连续失败次数达到阈值 - - OPEN -> HALF_OPEN: 熔断时间到期 - - HALF_OPEN -> CLOSED: 试探请求成功 - - HALF_OPEN -> OPEN: 试探请求失败 - """ - - # 配置 - failure_threshold: int = 5 # 连续失败多少次后熔断 - recovery_timeout: float = 60.0 # 熔断后多久尝试恢复(秒) - half_open_max_calls: int = 1 # 半开状态最多允许几次试探调用 - - # 状态 - state: CircuitState = field(default=CircuitState.CLOSED) - failure_count: int = 0 - success_count: int = 0 - last_failure_time: float = 0.0 - last_state_change: float = field(default_factory=time.time) - half_open_calls: int = 0 - - def can_execute(self) -> Tuple[bool, Optional[str]]: - """检查是否允许执行请求 - - Returns: - (是否允许, 拒绝原因) - """ - current_time = time.time() - - if self.state == CircuitState.CLOSED: - return True, None - - if self.state == CircuitState.OPEN: - # 检查是否到了恢复时间 - time_since_failure = current_time - self.last_failure_time - if time_since_failure >= self.recovery_timeout: - # 转换到半开状态 - self._transition_to(CircuitState.HALF_OPEN) - return True, None - else: - remaining = self.recovery_timeout - time_since_failure - return False, f"断路器熔断中,{remaining:.0f}秒后重试" - - if self.state == CircuitState.HALF_OPEN: - # 半开状态,检查是否还有试探配额 - if self.half_open_calls < self.half_open_max_calls: - return True, None - else: - return False, "断路器半开状态,等待试探结果" - - return True, None - - def record_success(self) -> None: - """记录成功调用""" - self.success_count += 1 - - if self.state == CircuitState.HALF_OPEN: - # 半开状态下成功,恢复到关闭状态 - self._transition_to(CircuitState.CLOSED) - logger.info("断路器恢复正常(试探成功)") - elif self.state == CircuitState.CLOSED: - # 正常状态下成功,重置失败计数 - self.failure_count = 0 - - def record_failure(self) -> None: - """记录失败调用""" - self.failure_count += 1 - self.last_failure_time = time.time() - - if self.state == CircuitState.HALF_OPEN: - # 半开状态下失败,重新熔断 - self._transition_to(CircuitState.OPEN) - logger.warning("断路器重新熔断(试探失败)") - elif self.state == CircuitState.CLOSED: - # 检查是否达到熔断阈值 - if self.failure_count >= self.failure_threshold: - self._transition_to(CircuitState.OPEN) - logger.warning(f"断路器熔断(连续失败 {self.failure_count} 次)") - - def _transition_to(self, new_state: CircuitState) -> None: - """状态转换""" - old_state = self.state - self.state = new_state - self.last_state_change = time.time() - - if new_state == CircuitState.CLOSED: - self.failure_count = 0 - self.half_open_calls = 0 - elif new_state == CircuitState.HALF_OPEN: - self.half_open_calls = 0 - - logger.debug(f"断路器状态: {old_state.value} -> {new_state.value}") - - def reset(self) -> None: - """重置断路器""" - self.state = CircuitState.CLOSED - self.failure_count = 0 - self.success_count = 0 - self.half_open_calls = 0 - self.last_state_change = time.time() - - def get_status(self) -> Dict[str, Any]: - """获取断路器状态""" - return { - "state": self.state.value, - "failure_count": self.failure_count, - "success_count": self.success_count, - "failure_threshold": self.failure_threshold, - "recovery_timeout": self.recovery_timeout, - "time_since_last_failure": time.time() - self.last_failure_time if self.last_failure_time > 0 else None, - } - - -@dataclass -class ToolCallStats: - """工具调用统计""" - - tool_key: str - total_calls: int = 0 - success_calls: int = 0 - failed_calls: int = 0 - total_duration_ms: float = 0.0 - last_call_time: Optional[float] = None - last_error: Optional[str] = None - - @property - def success_rate(self) -> float: - """成功率(0-100)""" - if self.total_calls == 0: - return 0.0 - return (self.success_calls / self.total_calls) * 100 - - @property - def avg_duration_ms(self) -> float: - """平均耗时(毫秒)""" - if self.success_calls == 0: - return 0.0 - return self.total_duration_ms / self.success_calls - - def record_call(self, success: bool, duration_ms: float, error: Optional[str] = None) -> None: - """记录一次调用""" - self.total_calls += 1 - self.last_call_time = time.time() - if success: - self.success_calls += 1 - self.total_duration_ms += duration_ms - else: - self.failed_calls += 1 - self.last_error = error - - def to_dict(self) -> Dict[str, Any]: - """转换为字典""" - return { - "tool_key": self.tool_key, - "total_calls": self.total_calls, - "success_calls": self.success_calls, - "failed_calls": self.failed_calls, - "success_rate": round(self.success_rate, 2), - "avg_duration_ms": round(self.avg_duration_ms, 2), - "last_call_time": self.last_call_time, - "last_error": self.last_error, - } - - -@dataclass -class ServerStats: - """服务器统计""" - - server_name: str - connect_count: int = 0 # 连接次数 - disconnect_count: int = 0 # 断开次数 - reconnect_count: int = 0 # 重连次数 - last_connect_time: Optional[float] = None - last_disconnect_time: Optional[float] = None - last_heartbeat_time: Optional[float] = None - consecutive_failures: int = 0 # 连续失败次数 - - def record_connect(self) -> None: - self.connect_count += 1 - self.last_connect_time = time.time() - self.consecutive_failures = 0 - - def record_disconnect(self) -> None: - self.disconnect_count += 1 - self.last_disconnect_time = time.time() - - def record_reconnect(self) -> None: - self.reconnect_count += 1 - self.consecutive_failures = 0 - - def record_failure(self) -> None: - self.consecutive_failures += 1 - - def record_heartbeat(self) -> None: - self.last_heartbeat_time = time.time() - - def to_dict(self) -> Dict[str, Any]: - return { - "server_name": self.server_name, - "connect_count": self.connect_count, - "disconnect_count": self.disconnect_count, - "reconnect_count": self.reconnect_count, - "last_connect_time": self.last_connect_time, - "last_disconnect_time": self.last_disconnect_time, - "last_heartbeat_time": self.last_heartbeat_time, - "consecutive_failures": self.consecutive_failures, - } - - -class MCPClientSession: - """MCP 客户端会话,管理与单个 MCP 服务器的连接""" - - def __init__(self, config: MCPServerConfig, call_timeout: float = 60.0): - self.config = config - self.call_timeout = call_timeout - self._session = None - self._read_stream = None - self._write_stream = None - self._process: Optional[asyncio.subprocess.Process] = None - self._tools: List[MCPToolInfo] = [] - self._resources: List[MCPResourceInfo] = [] # v1.2.0: Resources 支持 - self._prompts: List[MCPPromptInfo] = [] # v1.2.0: Prompts 支持 - self._connected = False - self._lock = asyncio.Lock() - - # 功能支持标记(服务器可能不支持某些功能) - self._supports_resources: bool = False - self._supports_prompts: bool = False - - # 统计信息 - self.stats = ServerStats(server_name=config.name) - self._tool_stats: Dict[str, ToolCallStats] = {} - - # v1.7.0: 断路器 - self._circuit_breaker = CircuitBreaker() - - @property - def is_connected(self) -> bool: - return self._connected - - @property - def tools(self) -> List[MCPToolInfo]: - return self._tools.copy() - - @property - def resources(self) -> List[MCPResourceInfo]: - """v1.2.0: 获取资源列表""" - return self._resources.copy() - - @property - def prompts(self) -> List[MCPPromptInfo]: - """v1.2.0: 获取提示模板列表""" - return self._prompts.copy() - - @property - def supports_resources(self) -> bool: - """v1.2.0: 服务器是否支持 Resources""" - return self._supports_resources - - @property - def supports_prompts(self) -> bool: - """v1.2.0: 服务器是否支持 Prompts""" - return self._supports_prompts - - @property - def server_name(self) -> str: - return self.config.name - - def get_tool_stats(self, tool_name: str) -> Optional[ToolCallStats]: - """获取工具统计""" - return self._tool_stats.get(tool_name) - - def get_circuit_breaker_status(self) -> Dict[str, Any]: - """v1.7.0: 获取断路器状态""" - return self._circuit_breaker.get_status() - - def reset_circuit_breaker(self) -> None: - """v1.7.0: 重置断路器""" - self._circuit_breaker.reset() - logger.info(f"[{self.server_name}] 断路器已重置") - - def get_all_tool_stats(self) -> Dict[str, ToolCallStats]: - """获取所有工具统计""" - return self._tool_stats.copy() - - async def connect(self) -> bool: - """连接到 MCP 服务器""" - async with self._lock: - if self._connected: - return True - - try: - success = False - if self.config.transport == TransportType.STDIO: - success = await self._connect_stdio() - elif self.config.transport == TransportType.SSE: - success = await self._connect_sse() - elif self.config.transport in (TransportType.HTTP, TransportType.STREAMABLE_HTTP): - success = await self._connect_http() - else: - logger.error(f"[{self.server_name}] 不支持的传输类型: {self.config.transport}") - return False - - if success: - self.stats.record_connect() - # v1.7.0: 连接成功时重置断路器 - self._circuit_breaker.reset() - else: - self.stats.record_failure() - return success - - except Exception as e: - logger.error(f"[{self.server_name}] 连接失败: {e}") - self._connected = False - self.stats.record_failure() - return False - - async def _connect_stdio(self) -> bool: - """通过 stdio 连接 MCP 服务器""" - try: - try: - from mcp import ClientSession, StdioServerParameters - from mcp.client.stdio import stdio_client - except ImportError: - logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") - return False - - server_params = StdioServerParameters( - command=self.config.command, args=self.config.args, env=self.config.env if self.config.env else None - ) - - self._stdio_context = stdio_client(server_params) - self._read_stream, self._write_stream = await self._stdio_context.__aenter__() - - self._session_context = ClientSession(self._read_stream, self._write_stream) - self._session = await self._session_context.__aenter__() - - await self._session.initialize() - await self._fetch_tools() - - self._connected = True - logger.info(f"[{self.server_name}] stdio 连接成功,发现 {len(self._tools)} 个工具") - return True - - except Exception as e: - logger.error(f"[{self.server_name}] stdio 连接失败: {e}") - await self._cleanup() - return False - - async def _connect_sse(self) -> bool: - """通过 SSE 连接 MCP 服务器""" - try: - try: - from mcp import ClientSession - from mcp.client.sse import sse_client - except ImportError: - logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") - return False - - if not self.config.url: - logger.error(f"[{self.server_name}] SSE 传输需要配置 url") - return False - - logger.debug(f"[{self.server_name}] 正在连接 SSE MCP 服务器: {self.config.url}") - - # v1.4.2: 支持 headers 鉴权 - sse_kwargs = { - "url": self.config.url, - "timeout": 60.0, - "sse_read_timeout": 300.0, - } - if self.config.headers: - sse_kwargs["headers"] = self.config.headers - - self._sse_context = sse_client(**sse_kwargs) - self._read_stream, self._write_stream = await self._sse_context.__aenter__() - - self._session_context = ClientSession(self._read_stream, self._write_stream) - self._session = await self._session_context.__aenter__() - - await self._session.initialize() - await self._fetch_tools() - - self._connected = True - logger.info(f"[{self.server_name}] SSE 连接成功,发现 {len(self._tools)} 个工具") - return True - - except Exception as e: - logger.error(f"[{self.server_name}] SSE 连接失败: {e}") - import traceback - - logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") - await self._cleanup() - return False - - async def _connect_http(self) -> bool: - """通过 HTTP Streamable 连接 MCP 服务器""" - try: - try: - from mcp import ClientSession - from mcp.client.streamable_http import streamablehttp_client - except ImportError: - logger.error(f"[{self.server_name}] 未安装 mcp 库,请运行: pip install mcp") - return False - - if not self.config.url: - logger.error(f"[{self.server_name}] HTTP 传输需要配置 url") - return False - - logger.debug(f"[{self.server_name}] 正在连接 HTTP MCP 服务器: {self.config.url}") - - # v1.4.2: 支持 headers 鉴权 - http_kwargs = { - "url": self.config.url, - "timeout": 60.0, - "sse_read_timeout": 300.0, - } - if self.config.headers: - http_kwargs["headers"] = self.config.headers - - self._http_context = streamablehttp_client(**http_kwargs) - self._read_stream, self._write_stream, self._get_session_id = await self._http_context.__aenter__() - - self._session_context = ClientSession(self._read_stream, self._write_stream) - self._session = await self._session_context.__aenter__() - - await self._session.initialize() - await self._fetch_tools() - - self._connected = True - logger.info(f"[{self.server_name}] HTTP 连接成功,发现 {len(self._tools)} 个工具") - return True - - except Exception as e: - logger.error(f"[{self.server_name}] HTTP 连接失败: {e}") - import traceback - - logger.debug(f"[{self.server_name}] 详细错误: {traceback.format_exc()}") - await self._cleanup() - return False - - async def _fetch_tools(self) -> None: - """获取 MCP 服务器的工具列表""" - if not self._session: - return - - try: - result = await self._session.list_tools() - self._tools = [] - - for tool in result.tools: - tool_info = MCPToolInfo( - name=tool.name, - description=tool.description or f"MCP tool: {tool.name}", - input_schema=tool.inputSchema if hasattr(tool, "inputSchema") else {}, - server_name=self.server_name, - ) - self._tools.append(tool_info) - # 初始化工具统计 - if tool.name not in self._tool_stats: - self._tool_stats[tool.name] = ToolCallStats(tool_key=tool.name) - logger.debug(f"[{self.server_name}] 发现工具: {tool.name}") - - except Exception as e: - logger.error(f"[{self.server_name}] 获取工具列表失败: {e}") - self._tools = [] - - async def fetch_resources(self) -> bool: - """v1.2.0: 获取 MCP 服务器的资源列表 - - Returns: - bool: 是否成功获取(服务器不支持时返回 False) - """ - if not self._session: - return False - - try: - result = await asyncio.wait_for(self._session.list_resources(), timeout=self.call_timeout) - self._resources = [] - - for resource in result.resources: - resource_info = MCPResourceInfo( - uri=str(resource.uri), - name=resource.name or str(resource.uri), - description=resource.description or "", - mime_type=resource.mimeType if hasattr(resource, "mimeType") else None, - server_name=self.server_name, - ) - self._resources.append(resource_info) - logger.debug(f"[{self.server_name}] 发现资源: {resource_info.uri}") - - self._supports_resources = True - logger.info(f"[{self.server_name}] 获取到 {len(self._resources)} 个资源") - return True - - except Exception as e: - # 服务器可能不支持 resources,这不是错误 - error_str = str(e).lower() - if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str: - logger.debug(f"[{self.server_name}] 服务器不支持 Resources 功能") - else: - logger.warning(f"[{self.server_name}] 获取资源列表失败: {e}") - self._supports_resources = False - self._resources = [] - return False - - async def fetch_prompts(self) -> bool: - """v1.2.0: 获取 MCP 服务器的提示模板列表 - - Returns: - bool: 是否成功获取(服务器不支持时返回 False) - """ - if not self._session: - return False - - try: - result = await asyncio.wait_for(self._session.list_prompts(), timeout=self.call_timeout) - self._prompts = [] - - for prompt in result.prompts: - # 解析参数 - arguments = [] - if hasattr(prompt, "arguments") and prompt.arguments: - for arg in prompt.arguments: - arguments.append( - { - "name": arg.name, - "description": arg.description or "", - "required": arg.required if hasattr(arg, "required") else False, - } - ) - - prompt_info = MCPPromptInfo( - name=prompt.name, - description=prompt.description or f"MCP prompt: {prompt.name}", - arguments=arguments, - server_name=self.server_name, - ) - self._prompts.append(prompt_info) - logger.debug(f"[{self.server_name}] 发现提示模板: {prompt.name}") - - self._supports_prompts = True - logger.info(f"[{self.server_name}] 获取到 {len(self._prompts)} 个提示模板") - return True - - except Exception as e: - # 服务器可能不支持 prompts,这不是错误 - error_str = str(e).lower() - if "not supported" in error_str or "not implemented" in error_str or "method not found" in error_str: - logger.debug(f"[{self.server_name}] 服务器不支持 Prompts 功能") - else: - logger.warning(f"[{self.server_name}] 获取提示模板列表失败: {e}") - self._supports_prompts = False - self._prompts = [] - return False - - async def read_resource(self, uri: str) -> MCPCallResult: - """v1.2.0: 读取指定资源的内容 - - Args: - uri: 资源 URI - - Returns: - MCPCallResult: 包含资源内容的结果 - """ - start_time = time.time() - - if not self._connected or not self._session: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") - - if not self._supports_resources: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Resources 功能") - - try: - result = await asyncio.wait_for(self._session.read_resource(uri), timeout=self.call_timeout) - - duration_ms = (time.time() - start_time) * 1000 - - # 处理返回内容 - content_parts = [] - for content in result.contents: - if hasattr(content, "text"): - content_parts.append(content.text) - elif hasattr(content, "blob"): - # 二进制数据,返回 base64 或提示 - import base64 - - blob_data = content.blob - if len(blob_data) < 10000: # 小于 10KB 返回 base64 - content_parts.append(f"[base64]{base64.b64encode(blob_data).decode()}") - else: - content_parts.append(f"[二进制数据: {len(blob_data)} bytes]") - else: - content_parts.append(str(content)) - - return MCPCallResult( - success=True, content="\n".join(content_parts) if content_parts else "", duration_ms=duration_ms - ) - - except asyncio.TimeoutError: - duration_ms = (time.time() - start_time) * 1000 - return MCPCallResult( - success=False, content=None, error=f"读取资源超时({self.call_timeout}秒)", duration_ms=duration_ms - ) - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - logger.error(f"[{self.server_name}] 读取资源 {uri} 失败: {e}") - return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) - - async def get_prompt(self, name: str, arguments: Optional[Dict[str, str]] = None) -> MCPCallResult: - """v1.2.0: 获取提示模板的内容 - - Args: - name: 提示模板名称 - arguments: 模板参数 - - Returns: - MCPCallResult: 包含提示内容的结果 - """ - start_time = time.time() - - if not self._connected or not self._session: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 未连接") - - if not self._supports_prompts: - return MCPCallResult(success=False, content=None, error=f"服务器 {self.server_name} 不支持 Prompts 功能") - - try: - result = await asyncio.wait_for( - self._session.get_prompt(name, arguments=arguments or {}), timeout=self.call_timeout - ) - - duration_ms = (time.time() - start_time) * 1000 - - # 处理返回的消息 - messages = [] - for msg in result.messages: - role = msg.role if hasattr(msg, "role") else "unknown" - content_text = "" - if hasattr(msg, "content"): - if hasattr(msg.content, "text"): - content_text = msg.content.text - elif isinstance(msg.content, str): - content_text = msg.content - else: - content_text = str(msg.content) - messages.append(f"[{role}]: {content_text}") - - return MCPCallResult( - success=True, content="\n\n".join(messages) if messages else "", duration_ms=duration_ms - ) - - except asyncio.TimeoutError: - duration_ms = (time.time() - start_time) * 1000 - return MCPCallResult( - success=False, content=None, error=f"获取提示模板超时({self.call_timeout}秒)", duration_ms=duration_ms - ) - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - logger.error(f"[{self.server_name}] 获取提示模板 {name} 失败: {e}") - return MCPCallResult(success=False, content=None, error=str(e), duration_ms=duration_ms) - - async def check_health(self) -> bool: - """检查连接健康状态(心跳检测) - - 通过调用 list_tools 来验证连接是否正常 - """ - if not self._connected or not self._session: - return False - - try: - # 使用 list_tools 作为心跳检测 - await asyncio.wait_for(self._session.list_tools(), timeout=10.0) - self.stats.record_heartbeat() - return True - except Exception as e: - logger.warning(f"[{self.server_name}] 心跳检测失败: {e}") - # 标记为断开 - self._connected = False - self.stats.record_disconnect() - return False - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> MCPCallResult: - """调用 MCP 工具""" - start_time = time.time() - - # v1.7.0: 断路器检查 - can_execute, reject_reason = self._circuit_breaker.can_execute() - if not can_execute: - return MCPCallResult(success=False, content=None, error=f"⚡ {reject_reason}", circuit_broken=True) - - # 半开状态下增加试探计数 - if self._circuit_breaker.state == CircuitState.HALF_OPEN: - self._circuit_breaker.half_open_calls += 1 - - if not self._connected or not self._session: - error_msg = f"服务器 {self.server_name} 未连接" - # 记录失败 - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(False, 0, error_msg) - self._circuit_breaker.record_failure() - return MCPCallResult(success=False, content=None, error=error_msg) - - try: - result = await asyncio.wait_for( - self._session.call_tool(tool_name, arguments=arguments), timeout=self.call_timeout - ) - - duration_ms = (time.time() - start_time) * 1000 - - # 处理返回内容 - content_parts = [] - for content in result.content: - if hasattr(content, "text"): - content_parts.append(content.text) - elif hasattr(content, "data"): - content_parts.append(f"[二进制数据: {len(content.data)} bytes]") - else: - content_parts.append(str(content)) - - # 记录成功 - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(True, duration_ms) - - # v1.7.0: 断路器记录成功 - self._circuit_breaker.record_success() - - return MCPCallResult( - success=True, - content="\n".join(content_parts) if content_parts else "执行成功(无返回内容)", - duration_ms=duration_ms, - ) - - except asyncio.TimeoutError: - duration_ms = (time.time() - start_time) * 1000 - error_msg = f"工具调用超时({self.call_timeout}秒)" - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(False, duration_ms, error_msg) - # v1.7.0: 断路器记录失败 - self._circuit_breaker.record_failure() - return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms) - - except Exception as e: - duration_ms = (time.time() - start_time) * 1000 - error_msg = str(e) - logger.error(f"[{self.server_name}] 调用工具 {tool_name} 失败: {e}") - if tool_name in self._tool_stats: - self._tool_stats[tool_name].record_call(False, duration_ms, error_msg) - # v1.7.0: 断路器记录失败 - self._circuit_breaker.record_failure() - # 检查是否是连接问题 - if "connection" in error_msg.lower() or "closed" in error_msg.lower(): - self._connected = False - self.stats.record_disconnect() - return MCPCallResult(success=False, content=None, error=error_msg, duration_ms=duration_ms) - - async def disconnect(self) -> None: - """断开连接""" - async with self._lock: - if self._connected: - self.stats.record_disconnect() - await self._cleanup() - - async def _cleanup(self) -> None: - """清理资源""" - self._connected = False - self._tools = [] - self._resources = [] # v1.2.0 - self._prompts = [] # v1.2.0 - self._supports_resources = False # v1.2.0 - self._supports_prompts = False # v1.2.0 - - try: - if hasattr(self, "_session_context") and self._session_context: - await self._session_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭会话时出错: {e}") - - try: - if hasattr(self, "_stdio_context") and self._stdio_context: - await self._stdio_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭 stdio 连接时出错: {e}") - - try: - if hasattr(self, "_http_context") and self._http_context: - await self._http_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭 HTTP 连接时出错: {e}") - - try: - if hasattr(self, "_sse_context") and self._sse_context: - await self._sse_context.__aexit__(None, None, None) - except Exception as e: - logger.debug(f"[{self.server_name}] 关闭 SSE 连接时出错: {e}") - - self._session = None - self._session_context = None - self._stdio_context = None - self._http_context = None - self._sse_context = None - self._read_stream = None - self._write_stream = None - - logger.debug(f"[{self.server_name}] 连接已关闭") - - -class MCPClientManager: - """MCP 客户端管理器,管理多个 MCP 服务器连接 - - 功能: - - 管理多个 MCP 服务器连接 - - 心跳检测和自动重连 - - 调用统计 - """ - - _instance: Optional["MCPClientManager"] = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - self._initialized = True - self._clients: Dict[str, MCPClientSession] = {} - self._all_tools: Dict[str, Tuple[MCPToolInfo, MCPClientSession]] = {} - self._all_resources: Dict[str, Tuple[MCPResourceInfo, MCPClientSession]] = {} # v1.2.0 - self._all_prompts: Dict[str, Tuple[MCPPromptInfo, MCPClientSession]] = {} # v1.2.0 - self._settings: Dict[str, Any] = {} - self._lock = asyncio.Lock() - - # 心跳检测任务 - self._heartbeat_task: Optional[asyncio.Task] = None - self._heartbeat_running = False - - # 状态变化回调 - self._on_status_change: Optional[callable] = None - - # 全局统计 - self._global_stats = { - "total_tool_calls": 0, - "successful_calls": 0, - "failed_calls": 0, - "start_time": time.time(), - } - - def configure(self, settings: Dict[str, Any]) -> None: - """配置管理器""" - self._settings = settings - - def set_status_change_callback(self, callback: callable) -> None: - """设置状态变化回调函数""" - self._on_status_change = callback - - def _notify_status_change(self) -> None: - """通知状态变化""" - if self._on_status_change: - try: - self._on_status_change() - except Exception as e: - logger.debug(f"状态变化回调出错: {e}") - - @property - def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]: - """获取所有已注册的工具""" - return self._all_tools.copy() - - @property - def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]: - """v1.2.0: 获取所有已注册的资源""" - return self._all_resources.copy() - - @property - def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]: - """v1.2.0: 获取所有已注册的提示模板""" - return self._all_prompts.copy() - - @property - def connected_servers(self) -> List[str]: - """获取已连接的服务器列表""" - return [name for name, client in self._clients.items() if client.is_connected] - - @property - def disconnected_servers(self) -> List[str]: - """获取已断开的服务器列表""" - return [name for name, client in self._clients.items() if not client.is_connected and client.config.enabled] - - async def add_server(self, config: MCPServerConfig) -> bool: - """添加并连接 MCP 服务器""" - async with self._lock: - if config.name in self._clients: - logger.warning(f"服务器 {config.name} 已存在") - return False - - call_timeout = self._settings.get("call_timeout", 60.0) - client = MCPClientSession(config, call_timeout) - self._clients[config.name] = client - - if not config.enabled: - logger.info(f"服务器 {config.name} 已添加但未启用") - return True - - # 尝试连接 - retry_attempts = self._settings.get("retry_attempts", 3) - retry_interval = self._settings.get("retry_interval", 5.0) - - for attempt in range(1, retry_attempts + 1): - if await client.connect(): - self._register_tools(client) - return True - - if attempt < retry_attempts: - logger.warning( - f"服务器 {config.name} 连接失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})" - ) - await asyncio.sleep(retry_interval) - - logger.error(f"服务器 {config.name} 连接失败,已达最大重试次数 ({retry_attempts})") - # 连接失败,但保留在 _clients 中以便后续重连 - return False - - def _register_tools(self, client: MCPClientSession) -> None: - """注册客户端的工具""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - - for tool in client.tools: - if tool.name.startswith(f"{tool_prefix}_{client.server_name}_"): - tool_key = tool.name - else: - tool_key = f"{tool_prefix}_{client.server_name}_{tool.name}" - self._all_tools[tool_key] = (tool, client) - logger.debug(f"注册 MCP 工具: {tool_key}") - - def _unregister_tools(self, server_name: str) -> List[str]: - """注销服务器的工具,返回被注销的工具键列表""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - prefix = f"{tool_prefix}_{server_name}_" - - keys_to_remove = [k for k in self._all_tools.keys() if k.startswith(prefix)] - for key in keys_to_remove: - del self._all_tools[key] - logger.debug(f"注销 MCP 工具: {key}") - return keys_to_remove - - def _register_resources(self, client: MCPClientSession) -> None: - """v1.2.0: 注册客户端的资源""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - - for resource in client.resources: - # 资源键格式: mcp_{server}_{uri_safe_name} - # 将 URI 转换为安全的键名 - safe_uri = resource.uri.replace("://", "_").replace("/", "_").replace(".", "_") - resource_key = f"{tool_prefix}_{client.server_name}_res_{safe_uri}" - self._all_resources[resource_key] = (resource, client) - logger.debug(f"注册 MCP 资源: {resource_key}") - - def _unregister_resources(self, server_name: str) -> List[str]: - """v1.2.0: 注销服务器的资源""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - prefix = f"{tool_prefix}_{server_name}_res_" - - keys_to_remove = [k for k in self._all_resources.keys() if k.startswith(prefix)] - for key in keys_to_remove: - del self._all_resources[key] - logger.debug(f"注销 MCP 资源: {key}") - return keys_to_remove - - def _register_prompts(self, client: MCPClientSession) -> None: - """v1.2.0: 注册客户端的提示模板""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - - for prompt in client.prompts: - prompt_key = f"{tool_prefix}_{client.server_name}_prompt_{prompt.name}" - self._all_prompts[prompt_key] = (prompt, client) - logger.debug(f"注册 MCP 提示模板: {prompt_key}") - - def _unregister_prompts(self, server_name: str) -> List[str]: - """v1.2.0: 注销服务器的提示模板""" - tool_prefix = self._settings.get("tool_prefix", "mcp") - prefix = f"{tool_prefix}_{server_name}_prompt_" - - keys_to_remove = [k for k in self._all_prompts.keys() if k.startswith(prefix)] - for key in keys_to_remove: - del self._all_prompts[key] - logger.debug(f"注销 MCP 提示模板: {key}") - return keys_to_remove - - async def remove_server(self, server_name: str) -> bool: - """移除 MCP 服务器""" - async with self._lock: - if server_name not in self._clients: - return False - - client = self._clients[server_name] - await client.disconnect() - self._unregister_tools(server_name) - self._unregister_resources(server_name) # v1.2.0 - self._unregister_prompts(server_name) # v1.2.0 - del self._clients[server_name] - - logger.info(f"服务器 {server_name} 已移除") - return True - - async def reconnect_server(self, server_name: str) -> bool: - """重新连接服务器""" - if server_name not in self._clients: - return False - - client = self._clients[server_name] - - async with self._lock: - self._unregister_tools(server_name) - self._unregister_resources(server_name) # v1.2.0 - self._unregister_prompts(server_name) # v1.2.0 - await client.disconnect() - - # 尝试重连 - retry_attempts = self._settings.get("retry_attempts", 3) - retry_interval = self._settings.get("retry_interval", 5.0) - - for attempt in range(1, retry_attempts + 1): - if await client.connect(): - async with self._lock: - self._register_tools(client) - # v1.2.0: 重连后也尝试获取 resources 和 prompts - if self._settings.get("enable_resources", False): - await client.fetch_resources() - self._register_resources(client) - if self._settings.get("enable_prompts", False): - await client.fetch_prompts() - self._register_prompts(client) - client.stats.record_reconnect() - logger.info(f"服务器 {server_name} 重连成功") - return True - - if attempt < retry_attempts: - logger.warning(f"服务器 {server_name} 重连失败,{retry_interval}秒后重试 ({attempt}/{retry_attempts})") - await asyncio.sleep(retry_interval) - - logger.error(f"服务器 {server_name} 重连失败") - return False - - async def call_tool(self, tool_key: str, arguments: Dict[str, Any]) -> MCPCallResult: - """调用 MCP 工具""" - if tool_key not in self._all_tools: - return MCPCallResult(success=False, content=None, error=f"工具 {tool_key} 不存在") - - tool_info, client = self._all_tools[tool_key] - - # 更新全局统计 - self._global_stats["total_tool_calls"] += 1 - - result = await client.call_tool(tool_info.name, arguments) - - if result.success: - self._global_stats["successful_calls"] += 1 - else: - self._global_stats["failed_calls"] += 1 - - return result - - async def fetch_resources_for_server(self, server_name: str) -> bool: - """v1.2.0: 获取指定服务器的资源列表""" - if server_name not in self._clients: - return False - - client = self._clients[server_name] - if not client.is_connected: - return False - - success = await client.fetch_resources() - if success: - async with self._lock: - self._register_resources(client) - return success - - async def fetch_prompts_for_server(self, server_name: str) -> bool: - """v1.2.0: 获取指定服务器的提示模板列表""" - if server_name not in self._clients: - return False - - client = self._clients[server_name] - if not client.is_connected: - return False - - success = await client.fetch_prompts() - if success: - async with self._lock: - self._register_prompts(client) - return success - - async def read_resource(self, uri: str, server_name: Optional[str] = None) -> MCPCallResult: - """v1.2.0: 读取资源内容 - - Args: - uri: 资源 URI - server_name: 指定服务器名称(可选,不指定则自动查找) - """ - # 如果指定了服务器 - if server_name: - if server_name not in self._clients: - return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") - client = self._clients[server_name] - return await client.read_resource(uri) - - # 自动查找拥有该资源的服务器 - for _resource_key, (resource_info, client) in self._all_resources.items(): - if resource_info.uri == uri: - return await client.read_resource(uri) - - # 尝试在所有支持 resources 的服务器上查找 - for client in self._clients.values(): - if client.is_connected and client.supports_resources: - result = await client.read_resource(uri) - if result.success: - return result - - return MCPCallResult(success=False, content=None, error=f"未找到资源: {uri}") - - async def get_prompt( - self, name: str, arguments: Optional[Dict[str, str]] = None, server_name: Optional[str] = None - ) -> MCPCallResult: - """v1.2.0: 获取提示模板内容 - - Args: - name: 提示模板名称 - arguments: 模板参数 - server_name: 指定服务器名称(可选) - """ - # 如果指定了服务器 - if server_name: - if server_name not in self._clients: - return MCPCallResult(success=False, content=None, error=f"服务器 {server_name} 不存在") - client = self._clients[server_name] - return await client.get_prompt(name, arguments) - - # 自动查找拥有该提示模板的服务器 - for _prompt_key, (prompt_info, client) in self._all_prompts.items(): - if prompt_info.name == name: - return await client.get_prompt(name, arguments) - - return MCPCallResult(success=False, content=None, error=f"未找到提示模板: {name}") - - # ==================== 心跳检测 ==================== - - async def start_heartbeat(self) -> None: - """启动心跳检测任务""" - if self._heartbeat_running: - logger.warning("心跳检测任务已在运行") - return - - heartbeat_enabled = self._settings.get("heartbeat_enabled", True) - if not heartbeat_enabled: - logger.info("心跳检测已禁用") - return - - self._heartbeat_running = True - self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) - logger.info("心跳检测任务已启动") - - async def stop_heartbeat(self) -> None: - """停止心跳检测任务""" - self._heartbeat_running = False - if self._heartbeat_task: - self._heartbeat_task.cancel() - try: - await self._heartbeat_task - except asyncio.CancelledError: - pass - self._heartbeat_task = None - logger.info("心跳检测任务已停止") - - async def _heartbeat_loop(self) -> None: - """心跳检测循环(v1.5.2: 智能心跳间隔)""" - base_interval = self._settings.get("heartbeat_interval", 60.0) - auto_reconnect = self._settings.get("auto_reconnect", True) - max_reconnect_attempts = self._settings.get("max_reconnect_attempts", 3) - - # v1.5.2: 智能心跳配置 - adaptive_enabled = self._settings.get("heartbeat_adaptive", True) - max_multiplier = self._settings.get("heartbeat_max_multiplier", 3.0) - - # 每个服务器独立的心跳间隔(根据稳定性动态调整) - server_intervals: Dict[str, float] = {} - min_interval = max(base_interval * 0.5, 30.0) # 最小间隔 - max_interval = base_interval * max_multiplier # 最大间隔 - - mode_str = "智能" if adaptive_enabled else "固定" - logger.info(f"心跳检测循环启动,{mode_str}模式,基准间隔: {base_interval}秒") - - while self._heartbeat_running: - try: - # 使用最小的服务器间隔作为循环间隔 - current_interval = min(server_intervals.values()) if server_intervals else base_interval - current_interval = max(current_interval, min_interval) - - await asyncio.sleep(current_interval) - - if not self._heartbeat_running: - break - - current_time = time.time() - - # 检查所有已启用的服务器 - for server_name, client in list(self._clients.items()): - if not client.config.enabled: - continue - - # 初始化服务器间隔 - if server_name not in server_intervals: - server_intervals[server_name] = base_interval - - # 检查是否到达该服务器的心跳时间 - last_heartbeat = client.stats.last_heartbeat_time or 0 - if current_time - last_heartbeat < server_intervals[server_name] * 0.9: - continue # 还没到心跳时间 - - if client.is_connected: - # 检查健康状态 - healthy = await client.check_health() - if healthy: - # v1.5.2: 智能心跳 - 稳定服务器逐渐增加间隔 - if adaptive_enabled and client.stats.consecutive_failures == 0: - new_interval = min(server_intervals[server_name] * 1.2, max_interval) - if new_interval != server_intervals[server_name]: - server_intervals[server_name] = new_interval - logger.debug(f"[{server_name}] 稳定,心跳间隔调整为 {new_interval:.0f}s") - else: - logger.warning(f"[{server_name}] 心跳检测失败,连接可能已断开") - # 失败后重置为基准间隔 - if adaptive_enabled: - server_intervals[server_name] = base_interval - self._notify_status_change() - if auto_reconnect: - await self._try_reconnect(server_name, max_reconnect_attempts) - else: - # 服务器未连接,尝试重连 - if adaptive_enabled: - # 智能心跳:断开的服务器使用较短间隔 - server_intervals[server_name] = min_interval - if auto_reconnect and client.stats.consecutive_failures < max_reconnect_attempts: - logger.info(f"[{server_name}] 检测到断开,尝试重连...") - await self._try_reconnect(server_name, max_reconnect_attempts) - elif client.stats.consecutive_failures >= max_reconnect_attempts: - if adaptive_enabled: - # 达到最大重连次数,降低检测频率 - server_intervals[server_name] = max_interval - logger.debug(f"[{server_name}] 已达最大重连次数,降低检测频率") - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"心跳检测循环出错: {e}") - await asyncio.sleep(5) - - async def _try_reconnect(self, server_name: str, max_attempts: int) -> bool: - """尝试重连服务器""" - client = self._clients.get(server_name) - if not client: - return False - - if client.stats.consecutive_failures >= max_attempts: - logger.warning(f"[{server_name}] 连续失败次数已达上限 ({max_attempts}),暂停重连") - return False - - logger.info(f"[{server_name}] 尝试重连 (失败次数: {client.stats.consecutive_failures}/{max_attempts})") - - success = await self.reconnect_server(server_name) - if not success: - client.stats.record_failure() - - self._notify_status_change() # 重连后更新状态 - return success - - # ==================== 统计和状态 ==================== - - def get_tool_stats(self, tool_key: str) -> Optional[Dict[str, Any]]: - """获取指定工具的统计信息""" - if tool_key not in self._all_tools: - return None - - tool_info, client = self._all_tools[tool_key] - stats = client.get_tool_stats(tool_info.name) - return stats.to_dict() if stats else None - - def get_all_stats(self) -> Dict[str, Any]: - """获取所有统计信息""" - server_stats = {} - tool_stats = {} - - for server_name, client in self._clients.items(): - server_stats[server_name] = client.stats.to_dict() - for tool_name, stats in client.get_all_tool_stats().items(): - full_key = f"{self._settings.get('tool_prefix', 'mcp')}_{server_name}_{tool_name}" - tool_stats[full_key] = stats.to_dict() - - uptime = time.time() - self._global_stats["start_time"] - - return { - "global": { - **self._global_stats, - "uptime_seconds": round(uptime, 2), - "calls_per_minute": round(self._global_stats["total_tool_calls"] / (uptime / 60), 2) - if uptime > 0 - else 0, - }, - "servers": server_stats, - "tools": tool_stats, - } - - async def shutdown(self) -> None: - """关闭所有连接""" - # 停止心跳检测 - await self.stop_heartbeat() - - async with self._lock: - for client in self._clients.values(): - await client.disconnect() - self._clients.clear() - self._all_tools.clear() - self._all_resources.clear() # v1.2.0 - self._all_prompts.clear() # v1.2.0 - logger.info("MCP 客户端管理器已关闭") - - def get_status(self) -> Dict[str, Any]: - """获取状态信息""" - return { - "total_servers": len(self._clients), - "connected_servers": len(self.connected_servers), - "disconnected_servers": len(self.disconnected_servers), - "total_tools": len(self._all_tools), - "total_resources": len(self._all_resources), # v1.2.0 - "total_prompts": len(self._all_prompts), # v1.2.0 - "heartbeat_running": self._heartbeat_running, - "servers": { - name: { - "connected": client.is_connected, - "enabled": client.config.enabled, - "tools_count": len(client.tools), - "resources_count": len(client.resources), # v1.2.0 - "prompts_count": len(client.prompts), # v1.2.0 - "supports_resources": client.supports_resources, # v1.2.0 - "supports_prompts": client.supports_prompts, # v1.2.0 - "transport": client.config.transport.value, - "consecutive_failures": client.stats.consecutive_failures, - "circuit_breaker": client.get_circuit_breaker_status(), # v1.7.0 - } - for name, client in self._clients.items() - }, - "global_stats": self._global_stats, - } - - -# 全局单例 -mcp_manager = MCPClientManager() diff --git a/plugins/MaiBot_MCPBridgePlugin/plugin.py b/plugins/MaiBot_MCPBridgePlugin/plugin.py deleted file mode 100644 index 1d965e25..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/plugin.py +++ /dev/null @@ -1,3733 +0,0 @@ -""" -MCP 桥接插件 v2.0.0 -将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot - -v2.0.0 配置与架构精简(功能保持不变): -- MCP 服务器配置统一为 Claude Desktop 的 mcpServers JSON(WebUI / config.toml 同一入口) -- 兼容迁移:检测到旧版 servers.list 时自动迁移为 mcpServers(仅迁移,避免多入口混淆) -- 移除 WebUI 导入导出/快速添加服务器的旧实现(避免 tomlkit 依赖与格式混乱) - -v1.9.0 双轨制架构: -- 软流程 (ReAct): LLM 自主决策,动态多轮调用 MCP 工具,灵活应对复杂场景 -- 硬流程 (Workflow): 用户预定义的工作流,固定执行顺序,可靠可控 -- 工具链重命名为 Workflow,更清晰地表达其"预定义流程"的本质 -- 命令更新:/mcp workflow 替代 /mcp chain - -v1.8.1 工具链易用性优化: -- 快速添加工具链:WebUI 表单式配置,无需手写 JSON -- 工具链模板:提供常用工具链配置模板参考 -- 使用指南:内置变量语法和命令说明 -- 状态显示优化:详细展示工具链步骤和参数信息 - -v1.8.0 工具链支持: -- 工具链:将多个工具按顺序执行,后续工具可使用前序工具的输出 -- 自定义工具链:在 WebUI 配置工具链,自动注册为组合工具供 LLM 调用 -- 变量替换:支持 ${input.参数}、${step.输出键}、${prev} 变量 -- 工具链命令:/mcp chain 查看、测试、管理工具链 - -v1.7.0 稳定性与易用性优化: -- 断路器模式:故障服务器快速失败,避免拖慢整体响应 -- 状态实时刷新:WebUI 每 10 秒自动更新连接状态 -- 断路器状态显示:在状态面板显示熔断/试探状态 - -v1.6.0 配置导入导出: -- 新增 /mcp import 命令,支持从 Claude Desktop 格式导入配置 -- 新增 /mcp export 命令,导出为 Claude Desktop (mcpServers) 格式 -- 支持 stdio、sse、http、streamable_http 全部传输类型 -- 自动跳过同名服务器,防止重复导入 - -v1.5.4 易用性优化: -- 新增 MCP 服务器获取快捷入口(魔搭、Smithery、Glama 等) -- 优化快速入门指南,提供配置示例 -- 帮助新用户快速上手 MCP - -v1.5.3 配置优化: -- 新增智能心跳 WebUI 配置项:启用开关、最大间隔倍数 -- 支持在 WebUI 中开启/关闭智能心跳功能 - -v1.5.2 性能优化: -- 智能心跳间隔:根据服务器稳定性动态调整心跳频率 -- 稳定服务器逐渐增加间隔,减少不必要的网络请求 -- 断开的服务器使用较短间隔快速重连 - -v1.5.1 易用性优化(v2.0.0 起已移除): -- 「快速添加服务器」表单式配置(已统一为 Claude mcpServers JSON,避免多入口混淆) - -v1.5.0 性能优化: -- 服务器并行连接:多个服务器同时连接,大幅减少启动时间 -- 连接耗时统计:日志显示并行连接总耗时 - -v1.4.4 修复: -- 修复首次生成默认配置文件时多行字符串导致 TOML 解析失败的问题 -- 简化 config_schema 默认值,避免主程序 json.dumps 产生无效 TOML - -v1.4.3 修复: -- 修复 WebUI 保存配置后多行字符串格式错误导致配置文件无法读取的问题 -- 清理未使用的导入 - -v1.4.0 新增功能: -- 工具禁用管理 -- 调用链路追踪 -- 工具调用缓存 -- 工具权限控制 -""" - -import asyncio -import fnmatch -import hashlib -import json -import re -import time -import uuid -from collections import OrderedDict, deque -from dataclasses import asdict, dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type - -from src.common.logger import get_logger -from src.plugin_system import ( - BasePlugin, - register_plugin, - BaseTool, - BaseCommand, - ComponentInfo, - ConfigField, - ToolParamType, -) -from src.plugin_system.base.config_types import section_meta -from src.plugin_system.base.component_types import ToolInfo, ComponentType, EventType -from src.plugin_system.base.base_events_handler import BaseEventHandler - -from .mcp_client import ( - MCPServerConfig, - MCPToolInfo, - MCPResourceInfo, - MCPPromptInfo, - TransportType, - mcp_manager, -) -from .core.claude_config import ( - ClaudeConfigError, - legacy_servers_list_to_claude_config, - parse_claude_mcp_config, -) -from .tool_chain import ( - ToolChainDefinition, - tool_chain_manager, -) - -logger = get_logger("mcp_bridge_plugin") - - -# ============================================================================ -# v1.4.0: 调用链路追踪 -# ============================================================================ - - -@dataclass -class ToolCallRecord: - """工具调用记录""" - - call_id: str - timestamp: float - tool_name: str - server_name: str - chat_id: str = "" - user_id: str = "" - user_query: str = "" - arguments: Dict = field(default_factory=dict) - raw_result: str = "" - processed_result: str = "" - duration_ms: float = 0.0 - success: bool = True - error: str = "" - post_processed: bool = False - cache_hit: bool = False - - -class ToolCallTracer: - """工具调用追踪器""" - - def __init__(self, max_records: int = 100): - self._records: deque[ToolCallRecord] = deque(maxlen=max_records) - self._enabled: bool = True - self._log_enabled: bool = False - self._log_path: Optional[Path] = None - - def configure(self, enabled: bool, max_records: int, log_enabled: bool, log_path: Optional[Path] = None) -> None: - """配置追踪器""" - self._enabled = enabled - self._records = deque(self._records, maxlen=max_records) - self._log_enabled = log_enabled - self._log_path = log_path - - def record(self, record: ToolCallRecord) -> None: - """添加调用记录""" - if not self._enabled: - return - - self._records.append(record) - - if self._log_enabled and self._log_path: - self._write_to_log(record) - - def get_recent(self, n: int = 10) -> List[ToolCallRecord]: - """获取最近 N 条记录""" - return list(self._records)[-n:] - - def get_by_tool(self, tool_name: str) -> List[ToolCallRecord]: - """按工具名筛选记录""" - return [r for r in self._records if r.tool_name == tool_name] - - def get_by_server(self, server_name: str) -> List[ToolCallRecord]: - """按服务器名筛选记录""" - return [r for r in self._records if r.server_name == server_name] - - def clear(self) -> None: - """清空记录""" - self._records.clear() - - def _write_to_log(self, record: ToolCallRecord) -> None: - """写入 JSONL 日志文件""" - try: - if self._log_path: - self._log_path.parent.mkdir(parents=True, exist_ok=True) - with open(self._log_path, "a", encoding="utf-8") as f: - f.write(json.dumps(asdict(record), ensure_ascii=False) + "\n") - except Exception as e: - logger.warning(f"写入追踪日志失败: {e}") - - @property - def total_records(self) -> int: - return len(self._records) - - -# 全局追踪器实例 -tool_call_tracer = ToolCallTracer() - - -# ============================================================================ -# v1.4.0: 工具调用缓存 -# ============================================================================ - - -@dataclass -class CacheEntry: - """缓存条目""" - - tool_name: str - args_hash: str - result: str - created_at: float - expires_at: float - hit_count: int = 0 - - -class ToolCallCache: - """工具调用缓存(LRU)""" - - def __init__(self, max_entries: int = 200, ttl: int = 300): - self._cache: OrderedDict[str, CacheEntry] = OrderedDict() - self._max_entries = max_entries - self._ttl = ttl - self._enabled = False - self._exclude_patterns: List[str] = [] - self._stats = {"hits": 0, "misses": 0} - - def configure(self, enabled: bool, ttl: int, max_entries: int, exclude_tools: str) -> None: - """配置缓存""" - self._enabled = enabled - self._ttl = ttl - self._max_entries = max_entries - self._exclude_patterns = [p.strip() for p in exclude_tools.strip().split("\n") if p.strip()] - - def get(self, tool_name: str, args: Dict) -> Optional[str]: - """获取缓存""" - if not self._enabled: - return None - - if self._is_excluded(tool_name): - return None - - key = self._generate_key(tool_name, args) - - if key not in self._cache: - self._stats["misses"] += 1 - return None - - entry = self._cache[key] - - # 检查是否过期 - if time.time() > entry.expires_at: - del self._cache[key] - self._stats["misses"] += 1 - return None - - # LRU: 移到末尾 - self._cache.move_to_end(key) - entry.hit_count += 1 - self._stats["hits"] += 1 - - return entry.result - - def set(self, tool_name: str, args: Dict, result: str) -> None: - """设置缓存""" - if not self._enabled: - return - - if self._is_excluded(tool_name): - return - - key = self._generate_key(tool_name, args) - now = time.time() - - entry = CacheEntry( - tool_name=tool_name, - args_hash=key, - result=result, - created_at=now, - expires_at=now + self._ttl, - ) - - # 如果已存在,更新 - if key in self._cache: - self._cache[key] = entry - self._cache.move_to_end(key) - else: - # 检查容量 - self._evict_if_needed() - self._cache[key] = entry - - def clear(self) -> None: - """清空缓存""" - self._cache.clear() - self._stats = {"hits": 0, "misses": 0} - - def _generate_key(self, tool_name: str, args: Dict) -> str: - """生成缓存键""" - args_str = json.dumps(args, sort_keys=True, ensure_ascii=False) - content = f"{tool_name}:{args_str}" - return hashlib.md5(content.encode()).hexdigest() - - def _is_excluded(self, tool_name: str) -> bool: - """检查是否在排除列表中""" - for pattern in self._exclude_patterns: - if fnmatch.fnmatch(tool_name, pattern): - return True - return False - - def _evict_if_needed(self) -> None: - """必要时淘汰条目""" - # 先清理过期的 - now = time.time() - expired_keys = [k for k, v in self._cache.items() if now > v.expires_at] - for k in expired_keys: - del self._cache[k] - - # LRU 淘汰 - while len(self._cache) >= self._max_entries: - self._cache.popitem(last=False) - - def get_stats(self) -> Dict[str, Any]: - """获取缓存统计""" - total = self._stats["hits"] + self._stats["misses"] - hit_rate = (self._stats["hits"] / total * 100) if total > 0 else 0 - return { - "enabled": self._enabled, - "entries": len(self._cache), - "max_entries": self._max_entries, - "ttl": self._ttl, - "hits": self._stats["hits"], - "misses": self._stats["misses"], - "hit_rate": f"{hit_rate:.1f}%", - } - - -# 全局缓存实例 -tool_call_cache = ToolCallCache() - - -# ============================================================================ -# v1.4.0: 工具权限控制 -# ============================================================================ - - -class PermissionChecker: - """工具权限检查器""" - - def __init__(self): - self._enabled = False - self._default_mode = "allow_all" # allow_all 或 deny_all - self._rules: List[Dict] = [] - self._quick_deny_groups: set = set() - self._quick_allow_users: set = set() - - def configure( - self, - enabled: bool, - default_mode: str, - rules_json: str, - quick_deny_groups: str = "", - quick_allow_users: str = "", - ) -> None: - """配置权限检查器""" - self._enabled = enabled - self._default_mode = default_mode if default_mode in ("allow_all", "deny_all") else "allow_all" - - # 解析快捷配置 - self._quick_deny_groups = {g.strip() for g in quick_deny_groups.strip().split("\n") if g.strip()} - self._quick_allow_users = {u.strip() for u in quick_allow_users.strip().split("\n") if u.strip()} - - try: - self._rules = json.loads(rules_json) if rules_json.strip() else [] - except json.JSONDecodeError as e: - logger.warning(f"权限规则 JSON 解析失败: {e}") - self._rules = [] - - def check(self, tool_name: str, chat_id: str, user_id: str, is_group: bool) -> bool: - """检查权限 - - Args: - tool_name: 工具名称 - chat_id: 聊天 ID(群号或私聊 ID) - user_id: 用户 ID - is_group: 是否为群聊 - - Returns: - True 表示允许,False 表示拒绝 - """ - if not self._enabled: - return True - - # 快捷配置优先级最高 - # 1. 管理员白名单(始终允许) - if user_id and user_id in self._quick_allow_users: - return True - - # 2. 禁用群列表(始终拒绝) - if is_group and chat_id and chat_id in self._quick_deny_groups: - return False - - # 查找匹配的规则 - for rule in self._rules: - tool_pattern = rule.get("tool", "") - if not self._match_tool(tool_pattern, tool_name): - continue - - # 找到匹配的规则 - mode = rule.get("mode", "") - allowed = rule.get("allowed", []) - denied = rule.get("denied", []) - - # 构建当前上下文的 ID 列表 - context_ids = self._build_context_ids(chat_id, user_id, is_group) - - # 检查 denied 列表(优先级最高) - if denied: - for ctx_id in context_ids: - if self._match_id_list(denied, ctx_id): - return False - - # 检查 allowed 列表 - if allowed: - for ctx_id in context_ids: - if self._match_id_list(allowed, ctx_id): - return True - # 如果是 whitelist 模式且不在 allowed 中,拒绝 - if mode == "whitelist": - return False - - # 规则匹配但没有明确允许/拒绝,继续检查下一条规则 - - # 没有匹配的规则,使用默认模式 - return self._default_mode == "allow_all" - - def _match_tool(self, pattern: str, tool_name: str) -> bool: - """工具名通配符匹配""" - if not pattern: - return False - return fnmatch.fnmatch(tool_name, pattern) - - def _build_context_ids(self, chat_id: str, user_id: str, is_group: bool) -> List[str]: - """构建上下文 ID 列表""" - ids = [] - - # 用户级别(任何场景生效) - if user_id: - ids.append(f"qq:{user_id}:user") - - # 场景级别 - if is_group and chat_id: - ids.append(f"qq:{chat_id}:group") - elif chat_id: - ids.append(f"qq:{chat_id}:private") - - return ids - - def _match_id_list(self, id_list: List[str], context_id: str) -> bool: - """检查 ID 是否在列表中""" - for rule_id in id_list: - if fnmatch.fnmatch(context_id, rule_id): - return True - return False - - def get_rules_for_tool(self, tool_name: str) -> List[Dict]: - """获取特定工具的权限规则""" - return [r for r in self._rules if self._match_tool(r.get("tool", ""), tool_name)] - - -# 全局权限检查器实例 -permission_checker = PermissionChecker() - - -# ============================================================================ -# 工具类型转换 -# ============================================================================ - - -def convert_json_type_to_tool_param_type(json_type: str) -> ToolParamType: - """将 JSON Schema 类型转换为 MaiBot 的 ToolParamType""" - type_mapping = { - "string": ToolParamType.STRING, - "integer": ToolParamType.INTEGER, - "number": ToolParamType.FLOAT, - "boolean": ToolParamType.BOOLEAN, - "array": ToolParamType.STRING, - "object": ToolParamType.STRING, - } - return type_mapping.get(json_type, ToolParamType.STRING) - - -def parse_mcp_parameters( - input_schema: Dict[str, Any], -) -> List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]]: - """解析 MCP 工具的参数 schema,转换为 MaiBot 的参数格式""" - parameters = [] - - if not input_schema: - # 为无参数的工具添加占位参数,避免某些模型报错 - parameters.append(("_placeholder", ToolParamType.STRING, "占位参数,无需填写", False, None)) - return parameters - - properties = input_schema.get("properties", {}) - required = input_schema.get("required", []) - - # 如果没有任何参数,添加占位参数 - if not properties: - parameters.append(("_placeholder", ToolParamType.STRING, "占位参数,无需填写", False, None)) - return parameters - - for param_name, param_info in properties.items(): - json_type = param_info.get("type", "string") - param_type = convert_json_type_to_tool_param_type(json_type) - description = param_info.get("description", f"参数 {param_name}") - - if json_type == "array": - description = f"{description} (JSON 数组格式)" - elif json_type == "object": - description = f"{description} (JSON 对象格式)" - - is_required = param_name in required - enum_values = param_info.get("enum") - - if enum_values is not None: - enum_values = [str(v) for v in enum_values] - - parameters.append((param_name, param_type, description, is_required, enum_values)) - - return parameters - - -# ============================================================================ -# MCP 工具代理 -# ============================================================================ - - -class MCPToolProxy(BaseTool): - """MCP 工具代理基类""" - - name: str = "" - description: str = "" - parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] - available_for_llm: bool = True - - _mcp_tool_key: str = "" - _mcp_original_name: str = "" - _mcp_server_name: str = "" - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - """执行 MCP 工具调用""" - global _plugin_instance - - call_id = str(uuid.uuid4())[:8] - start_time = time.time() - - # 移除 MaiBot 内部标记 - args = {k: v for k, v in function_args.items() if k != "llm_called"} - - # 解析 JSON 字符串参数 - parsed_args = {} - for key, value in args.items(): - if isinstance(value, str): - try: - if value.startswith(("[", "{")): - parsed_args[key] = json.loads(value) - else: - parsed_args[key] = value - except json.JSONDecodeError: - parsed_args[key] = value - else: - parsed_args[key] = value - - # 获取上下文信息 - chat_id, user_id, is_group, user_query = self._get_context_info() - - # v1.4.0: 权限检查 - if not permission_checker.check(self.name, chat_id, user_id, is_group): - logger.warning(f"权限拒绝: 工具 {self.name}, chat={chat_id}, user={user_id}") - return {"name": self.name, "content": f"⛔ 权限不足:工具 {self.name} 在当前场景下不可用"} - - logger.debug(f"调用 MCP 工具: {self._mcp_tool_key}, 参数: {parsed_args}") - - # v1.4.0: 检查缓存 - cache_hit = False - cached_result = tool_call_cache.get(self.name, parsed_args) - - if cached_result is not None: - cache_hit = True - content = cached_result - raw_result = cached_result - success = True - error = "" - logger.debug(f"MCP 工具 {self.name} 命中缓存") - else: - # 调用 MCP - result = await mcp_manager.call_tool(self._mcp_tool_key, parsed_args) - - if result.success: - content = result.content - raw_result = content - success = True - error = "" - - # 存入缓存 - tool_call_cache.set(self.name, parsed_args, content) - else: - content = self._format_error_message(result.error, result.duration_ms) - raw_result = result.error - success = False - error = result.error - logger.warning(f"MCP 工具 {self.name} 调用失败: {result.error}") - - # v1.3.0: 后处理 - post_processed = False - processed_result = content - if success: - processed_content = await self._post_process_result(content) - if processed_content != content: - post_processed = True - processed_result = processed_content - content = processed_content - - duration_ms = (time.time() - start_time) * 1000 - - # v1.4.0: 记录调用追踪 - record = ToolCallRecord( - call_id=call_id, - timestamp=start_time, - tool_name=self.name, - server_name=self._mcp_server_name, - chat_id=chat_id, - user_id=user_id, - user_query=user_query, - arguments=parsed_args, - raw_result=raw_result[:1000] if raw_result else "", - processed_result=processed_result[:1000] if processed_result else "", - duration_ms=duration_ms, - success=success, - error=error, - post_processed=post_processed, - cache_hit=cache_hit, - ) - tool_call_tracer.record(record) - - return {"name": self.name, "content": content} - - def _get_context_info(self) -> Tuple[str, str, bool, str]: - """获取上下文信息""" - chat_id = "" - user_id = "" - is_group = False - user_query = "" - - if self.chat_stream and hasattr(self.chat_stream, "context") and self.chat_stream.context: - try: - ctx = self.chat_stream.context - if hasattr(ctx, "chat_id"): - chat_id = str(ctx.chat_id) if ctx.chat_id else "" - if hasattr(ctx, "user_id"): - user_id = str(ctx.user_id) if ctx.user_id else "" - if hasattr(ctx, "is_group"): - is_group = bool(ctx.is_group) - - last_message = ctx.get_last_message() - if last_message and hasattr(last_message, "processed_plain_text"): - user_query = last_message.processed_plain_text or "" - except Exception as e: - logger.debug(f"获取上下文信息失败: {e}") - - return chat_id, user_id, is_group, user_query - - async def _post_process_result(self, content: str) -> str: - """v1.3.0: 对工具返回结果进行后处理(摘要提炼)""" - global _plugin_instance - - if _plugin_instance is None: - return content - - settings = _plugin_instance.config.get("settings", {}) - - if not settings.get("post_process_enabled", False): - return content - - server_post_config = self._get_server_post_process_config() - - if server_post_config is not None: - if not server_post_config.get("enabled", True): - return content - - threshold = settings.get("post_process_threshold", 500) - if server_post_config and "threshold" in server_post_config: - threshold = server_post_config["threshold"] - - content_length = len(content) if content else 0 - if content_length <= threshold: - return content - - user_query = self._get_context_info()[3] - if not user_query: - return content - - max_tokens = settings.get("post_process_max_tokens", 500) - if server_post_config and "max_tokens" in server_post_config: - max_tokens = server_post_config["max_tokens"] - - prompt_template = settings.get("post_process_prompt", "") - if server_post_config and "prompt" in server_post_config: - prompt_template = server_post_config["prompt"] - - if not prompt_template: - prompt_template = """用户问题:{query} - -工具返回内容: -{result} - -请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:""" - - try: - prompt = prompt_template.format(query=user_query, result=content) - except KeyError as e: - logger.warning(f"后处理 prompt 模板格式错误: {e}") - return content - - try: - processed_content = await self._call_post_process_llm(prompt, max_tokens, settings, server_post_config) - if processed_content: - logger.info(f"MCP 工具 {self.name} 后处理完成: {content_length} -> {len(processed_content)} 字符") - return processed_content - return content - except Exception as e: - logger.error(f"MCP 工具 {self.name} 后处理失败: {e}") - return content - - def _get_server_post_process_config(self) -> Optional[Dict[str, Any]]: - """获取当前服务器的后处理配置""" - global _plugin_instance - - if _plugin_instance is None: - return None - - servers = _plugin_instance._load_mcp_servers_config() - for server_conf in servers: - if server_conf.get("name") == self._mcp_server_name: - return server_conf.get("post_process") - - return None - - async def _call_post_process_llm( - self, prompt: str, max_tokens: int, settings: Dict[str, Any], server_config: Optional[Dict[str, Any]] - ) -> Optional[str]: - """调用 LLM 进行后处理""" - from src.config.config import model_config - from src.config.model_configs import TaskConfig - from src.llm_models.utils_model import LLMRequest - - model_name = settings.get("post_process_model", "") - if server_config and "model" in server_config: - model_name = server_config["model"] - - if model_name: - task_config = TaskConfig( - model_list=[model_name], - max_tokens=max_tokens, - temperature=0.3, - slow_threshold=30.0, - ) - else: - task_config = model_config.model_task_config.utils - - llm_request = LLMRequest(model_set=task_config, request_type="mcp_post_process") - - response, (reasoning, model_used, _) = await llm_request.generate_response_async( - prompt=prompt, - max_tokens=max_tokens, - temperature=0.3, - ) - - return response.strip() if response else None - - def _format_error_message(self, error: str, duration_ms: float) -> str: - """格式化友好的错误消息""" - if not error: - return "工具调用失败(未知错误)" - - error_lower = error.lower() - - if "未连接" in error or "not connected" in error_lower: - return f"⚠️ MCP 服务器 [{self._mcp_server_name}] 未连接,请检查服务器状态或等待自动重连" - - if "超时" in error or "timeout" in error_lower: - return f"⏱️ 工具调用超时(耗时 {duration_ms:.0f}ms),服务器响应过慢,请稍后重试" - - if "connection" in error_lower and ("closed" in error_lower or "reset" in error_lower): - return f"🔌 与 MCP 服务器 [{self._mcp_server_name}] 的连接已断开,正在尝试重连..." - - if "invalid" in error_lower and "argument" in error_lower: - return f"❌ 参数错误: {error}" - - return f"❌ 工具调用失败: {error}" - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - """直接执行(供其他插件调用)""" - return await self.execute(function_args) - - -def create_mcp_tool_class( - tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False -) -> Type[MCPToolProxy]: - """根据 MCP 工具信息动态创建 BaseTool 子类""" - parameters = parse_mcp_parameters(tool_info.input_schema) - - class_name = f"MCPTool_{tool_info.server_name}_{tool_info.name}".replace("-", "_").replace(".", "_") - tool_name = tool_key.replace("-", "_").replace(".", "_") - - description = tool_info.description - if not description.endswith(f"[来自 MCP 服务器: {tool_info.server_name}]"): - description = f"{description} [来自 MCP 服务器: {tool_info.server_name}]" - - tool_class = type( - class_name, - (MCPToolProxy,), - { - "name": tool_name, - "description": description, - "parameters": parameters, - "available_for_llm": not disabled, # v1.4.0: 禁用的工具不可被 LLM 调用 - "_mcp_tool_key": tool_key, - "_mcp_original_name": tool_info.name, - "_mcp_server_name": tool_info.server_name, - }, - ) - - return tool_class - - -class MCPToolRegistry: - """MCP 工具注册表""" - - def __init__(self): - self._tool_classes: Dict[str, Type[MCPToolProxy]] = {} - self._tool_infos: Dict[str, ToolInfo] = {} - - def register_tool( - self, tool_key: str, tool_info: MCPToolInfo, tool_prefix: str, disabled: bool = False - ) -> Tuple[ToolInfo, Type[MCPToolProxy]]: - """注册 MCP 工具""" - tool_class = create_mcp_tool_class(tool_key, tool_info, tool_prefix, disabled) - - self._tool_classes[tool_key] = tool_class - - info = ToolInfo( - name=tool_class.name, - tool_description=tool_class.description, - enabled=True, - tool_parameters=tool_class.parameters, - component_type=ComponentType.TOOL, - ) - self._tool_infos[tool_key] = info - - return info, tool_class - - def unregister_tool(self, tool_key: str) -> bool: - """注销工具""" - if tool_key in self._tool_classes: - del self._tool_classes[tool_key] - del self._tool_infos[tool_key] - return True - return False - - def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: - """获取所有工具组件""" - return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - - def clear(self) -> None: - """清空所有注册""" - self._tool_classes.clear() - self._tool_infos.clear() - - -# 全局工具注册表 -mcp_tool_registry = MCPToolRegistry() - -# 全局插件实例引用 -_plugin_instance: Optional["MCPBridgePlugin"] = None - - -# ============================================================================ -# 内置工具 -# ============================================================================ - - -class MCPReadResourceTool(BaseTool): - """v1.2.0: MCP 资源读取工具""" - - name = "mcp_read_resource" - description = "读取 MCP 服务器提供的资源内容(如文件、数据库记录等)。使用前请先用 mcp_status 查看可用资源。" - parameters = [ - ("uri", ToolParamType.STRING, "资源 URI(如 file:///path/to/file 或自定义 URI)", True, None), - ("server_name", ToolParamType.STRING, "指定服务器名称(可选,不指定则自动查找)", False, None), - ] - available_for_llm = True - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - uri = function_args.get("uri", "") - server_name = function_args.get("server_name") - - if not uri: - return {"name": self.name, "content": "❌ 请提供资源 URI"} - - result = await mcp_manager.read_resource(uri, server_name) - - if result.success: - return {"name": self.name, "content": result.content} - else: - return {"name": self.name, "content": f"❌ 读取资源失败: {result.error}"} - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -class MCPGetPromptTool(BaseTool): - """v1.2.0: MCP 提示模板工具""" - - name = "mcp_get_prompt" - description = "获取 MCP 服务器提供的提示模板内容。使用前请先用 mcp_status 查看可用模板。" - parameters = [ - ("name", ToolParamType.STRING, "提示模板名称", True, None), - ("arguments", ToolParamType.STRING, "模板参数(JSON 对象格式)", False, None), - ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), - ] - available_for_llm = True - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - prompt_name = function_args.get("name", "") - arguments_str = function_args.get("arguments", "") - server_name = function_args.get("server_name") - - if not prompt_name: - return {"name": self.name, "content": "❌ 请提供提示模板名称"} - - arguments = None - if arguments_str: - try: - arguments = json.loads(arguments_str) - except json.JSONDecodeError: - return {"name": self.name, "content": "❌ 参数格式错误,请使用 JSON 对象格式"} - - result = await mcp_manager.get_prompt(prompt_name, arguments, server_name) - - if result.success: - return {"name": self.name, "content": result.content} - else: - return {"name": self.name, "content": f"❌ 获取提示模板失败: {result.error}"} - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -# ============================================================================ -# v1.8.0: 工具链代理工具 -# ============================================================================ - - -class ToolChainProxyBase(BaseTool): - """工具链代理基类""" - - name: str = "" - description: str = "" - parameters: List[Tuple[str, ToolParamType, str, bool, Optional[List[str]]]] = [] - available_for_llm: bool = True - - _chain_name: str = "" - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - """执行工具链""" - # 移除内部标记 - args = {k: v for k, v in function_args.items() if k != "llm_called"} - - logger.debug(f"执行工具链 {self._chain_name},参数: {args}") - - result = await tool_chain_manager.execute_chain(self._chain_name, args) - - if result.success: - # 构建输出 - output_parts = [] - output_parts.append(result.final_output) - - # 可选:添加执行摘要 - # output_parts.append(f"\n\n---\n执行摘要:\n{result.to_summary()}") - - return {"name": self.name, "content": "\n".join(output_parts)} - else: - error_msg = f"⚠️ 工具链执行失败: {result.error}" - if result.step_results: - error_msg += f"\n\n执行详情:\n{result.to_summary()}" - return {"name": self.name, "content": error_msg} - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -def create_chain_tool_class(chain: ToolChainDefinition) -> Type[ToolChainProxyBase]: - """根据工具链定义动态创建工具类""" - # 构建参数列表 - parameters = [] - for param_name, param_desc in chain.input_params.items(): - parameters.append((param_name, ToolParamType.STRING, param_desc, True, None)) - - # 生成类名和工具名 - class_name = f"ToolChain_{chain.name}".replace("-", "_").replace(".", "_") - tool_name = f"chain_{chain.name}".replace("-", "_").replace(".", "_") - - # 构建描述 - description = chain.description - if chain.steps: - step_names = [s.tool_name.split("_")[-1] for s in chain.steps[:3]] - description += f" (执行流程: {' → '.join(step_names)}{'...' if len(chain.steps) > 3 else ''})" - - tool_class = type( - class_name, - (ToolChainProxyBase,), - { - "name": tool_name, - "description": description, - "parameters": parameters, - "available_for_llm": True, - "_chain_name": chain.name, - }, - ) - - return tool_class - - -class ToolChainRegistry: - """工具链注册表""" - - def __init__(self): - self._tool_classes: Dict[str, Type[ToolChainProxyBase]] = {} - self._tool_infos: Dict[str, ToolInfo] = {} - - def register_chain(self, chain: ToolChainDefinition) -> Tuple[ToolInfo, Type[ToolChainProxyBase]]: - """注册工具链为组合工具""" - tool_class = create_chain_tool_class(chain) - - self._tool_classes[chain.name] = tool_class - - info = ToolInfo( - name=tool_class.name, - tool_description=tool_class.description, - enabled=True, - tool_parameters=tool_class.parameters, - component_type=ComponentType.TOOL, - ) - self._tool_infos[chain.name] = info - - return info, tool_class - - def unregister_chain(self, chain_name: str) -> bool: - """注销工具链""" - if chain_name in self._tool_classes: - del self._tool_classes[chain_name] - del self._tool_infos[chain_name] - return True - return False - - def get_all_components(self) -> List[Tuple[ComponentInfo, Type]]: - """获取所有工具链组件""" - return [(self._tool_infos[key], self._tool_classes[key]) for key in self._tool_classes.keys()] - - def clear(self) -> None: - """清空所有注册""" - self._tool_classes.clear() - self._tool_infos.clear() - - -# 全局工具链注册表 -tool_chain_registry = ToolChainRegistry() - - -class MCPStatusTool(BaseTool): - """MCP 状态查询工具""" - - name = "mcp_status" - description = "查询 MCP 桥接插件的状态,包括服务器连接状态、可用工具列表、工具链列表、资源列表、提示模板列表、调用统计、追踪记录等信息" - parameters = [ - ( - "query_type", - ToolParamType.STRING, - "查询类型", - False, - ["status", "tools", "chains", "resources", "prompts", "stats", "trace", "cache", "all"], - ), - ("server_name", ToolParamType.STRING, "指定服务器名称(可选)", False, None), - ] - available_for_llm = True - - async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: - query_type = function_args.get("query_type", "status") - server_name = function_args.get("server_name") - - result_parts = [] - - if query_type in ("status", "all"): - result_parts.append(self._format_status(server_name)) - - if query_type in ("tools", "all"): - result_parts.append(self._format_tools(server_name)) - - if query_type in ("chains", "all"): - result_parts.append(self._format_chains()) - - if query_type in ("resources", "all"): - result_parts.append(self._format_resources(server_name)) - - if query_type in ("prompts", "all"): - result_parts.append(self._format_prompts(server_name)) - - if query_type in ("stats", "all"): - result_parts.append(self._format_stats(server_name)) - - # v1.4.0: 追踪记录 - if query_type in ("trace",): - result_parts.append(self._format_trace()) - - # v1.4.0: 缓存状态 - if query_type in ("cache",): - result_parts.append(self._format_cache()) - - return {"name": self.name, "content": "\n\n".join(result_parts) if result_parts else "未知的查询类型"} - - def _format_status(self, server_name: Optional[str] = None) -> str: - status = mcp_manager.get_status() - lines = ["📊 MCP 桥接插件状态"] - lines.append(f" 总服务器数: {status['total_servers']}") - lines.append(f" 已连接: {status['connected_servers']}") - lines.append(f" 已断开: {status['disconnected_servers']}") - lines.append(f" 可用工具数: {status['total_tools']}") - lines.append(f" 心跳检测: {'运行中' if status['heartbeat_running'] else '已停止'}") - - lines.append("\n🔌 服务器详情:") - for name, info in status["servers"].items(): - if server_name and name != server_name: - continue - status_icon = "✅" if info["connected"] else "❌" - enabled_text = "" if info["enabled"] else " (已禁用)" - lines.append(f" {status_icon} {name}{enabled_text}") - lines.append(f" 传输: {info['transport']}, 工具数: {info['tools_count']}") - if info["consecutive_failures"] > 0: - lines.append(f" ⚠️ 连续失败: {info['consecutive_failures']} 次") - - return "\n".join(lines) - - def _format_tools(self, server_name: Optional[str] = None) -> str: - tools = mcp_manager.all_tools - lines = ["🔧 可用 MCP 工具"] - - by_server: Dict[str, List[str]] = {} - for tool_key, (tool_info, _) in tools.items(): - if server_name and tool_info.server_name != server_name: - continue - if tool_info.server_name not in by_server: - by_server[tool_info.server_name] = [] - by_server[tool_info.server_name].append(f" • {tool_key}: {tool_info.description[:50]}...") - - for srv_name, tool_list in by_server.items(): - lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个工具):") - lines.extend(tool_list) - - if not by_server: - lines.append(" (无可用工具)") - - return "\n".join(lines) - - def _format_stats(self, server_name: Optional[str] = None) -> str: - stats = mcp_manager.get_all_stats() - lines = ["📈 调用统计"] - - g = stats["global"] - lines.append(f" 总调用次数: {g['total_tool_calls']}") - lines.append(f" 成功: {g['successful_calls']}, 失败: {g['failed_calls']}") - if g["total_tool_calls"] > 0: - success_rate = (g["successful_calls"] / g["total_tool_calls"]) * 100 - lines.append(f" 成功率: {success_rate:.1f}%") - lines.append(f" 运行时间: {g['uptime_seconds']:.0f} 秒") - - return "\n".join(lines) - - def _format_resources(self, server_name: Optional[str] = None) -> str: - resources = mcp_manager.all_resources - if not resources: - return "📦 当前没有可用的 MCP 资源" - - lines = ["📦 可用 MCP 资源"] - by_server: Dict[str, List[MCPResourceInfo]] = {} - for _key, (resource_info, _) in resources.items(): - if server_name and resource_info.server_name != server_name: - continue - if resource_info.server_name not in by_server: - by_server[resource_info.server_name] = [] - by_server[resource_info.server_name].append(resource_info) - - for srv_name, resource_list in by_server.items(): - lines.append(f"\n🔌 {srv_name} ({len(resource_list)} 个资源):") - for res in resource_list: - lines.append(f" • {res.name}: {res.uri}") - - return "\n".join(lines) - - def _format_prompts(self, server_name: Optional[str] = None) -> str: - prompts = mcp_manager.all_prompts - if not prompts: - return "📝 当前没有可用的 MCP 提示模板" - - lines = ["📝 可用 MCP 提示模板"] - by_server: Dict[str, List[MCPPromptInfo]] = {} - for _key, (prompt_info, _) in prompts.items(): - if server_name and prompt_info.server_name != server_name: - continue - if prompt_info.server_name not in by_server: - by_server[prompt_info.server_name] = [] - by_server[prompt_info.server_name].append(prompt_info) - - for srv_name, prompt_list in by_server.items(): - lines.append(f"\n🔌 {srv_name} ({len(prompt_list)} 个模板):") - for prompt in prompt_list: - lines.append(f" • {prompt.name}") - - return "\n".join(lines) - - def _format_trace(self) -> str: - """v1.4.0: 格式化追踪记录""" - records = tool_call_tracer.get_recent(10) - if not records: - return "🔍 暂无调用追踪记录" - - lines = ["🔍 最近调用追踪记录"] - for r in reversed(records): - status = "✅" if r.success else "❌" - cache = "📦" if r.cache_hit else "" - post = "🔄" if r.post_processed else "" - lines.append(f" {status}{cache}{post} {r.tool_name} ({r.duration_ms:.0f}ms)") - if r.error: - lines.append(f" 错误: {r.error[:50]}") - - return "\n".join(lines) - - def _format_cache(self) -> str: - """v1.4.0: 格式化缓存状态""" - stats = tool_call_cache.get_stats() - lines = ["🗄️ 缓存状态"] - lines.append(f" 启用: {'是' if stats['enabled'] else '否'}") - lines.append(f" 条目数: {stats['entries']}/{stats['max_entries']}") - lines.append(f" TTL: {stats['ttl']}秒") - lines.append(f" 命中: {stats['hits']}, 未命中: {stats['misses']}") - lines.append(f" 命中率: {stats['hit_rate']}") - return "\n".join(lines) - - def _format_chains(self) -> str: - """v1.8.0: 格式化工具链列表""" - chains = tool_chain_manager.get_all_chains() - if not chains: - return "🔗 当前没有配置工具链" - - lines = ["🔗 工具链列表"] - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - lines.append(f"\n{status} {name}") - lines.append(f" 描述: {chain.description[:50]}...") - lines.append(f" 步骤: {len(chain.steps)} 个") - for i, step in enumerate(chain.steps[:3]): - lines.append(f" {i + 1}. {step.tool_name}") - if len(chain.steps) > 3: - lines.append(f" ... 还有 {len(chain.steps) - 3} 个步骤") - if chain.input_params: - params = ", ".join(chain.input_params.keys()) - lines.append(f" 参数: {params}") - - return "\n".join(lines) - - async def direct_execute(self, **function_args) -> Dict[str, Any]: - return await self.execute(function_args) - - -# ============================================================================ -# 命令处理 -# ============================================================================ - - -class MCPStatusCommand(BaseCommand): - """MCP 状态查询命令 - 通过 /mcp 命令查看服务器状态""" - - command_name = "mcp_status_command" - command_description = "查看 MCP 服务器连接状态和统计信息" - command_pattern = r"^[//]mcp(?:\s+(?Pstatus|tools|stats|reconnect|trace|cache|perm|export|search|chain))?(?:\s+(?P.+))?$" - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - """执行命令""" - subcommand = self.matched_groups.get("subcommand", "status") or "status" - arg = self.matched_groups.get("arg") - - if subcommand == "reconnect": - return await self._handle_reconnect(arg) - - # v1.4.0: 追踪命令 - if subcommand == "trace": - return await self._handle_trace(arg) - - # v1.4.0: 缓存命令 - if subcommand == "cache": - return await self._handle_cache(arg) - - # v1.4.0: 权限命令 - if subcommand == "perm": - return await self._handle_perm(arg) - - # v1.6.0: 导出命令 - if subcommand == "export": - return await self._handle_export(arg) - - # v1.7.0: 工具搜索命令 - if subcommand == "search": - return await self._handle_search(arg) - - # v1.8.0: 工具链命令 - if subcommand == "chain": - return await self._handle_chain(arg) - - result = self._format_output(subcommand, arg) - await self.send_text(result) - return (True, None, True) - - def _find_similar_servers(self, name: str, max_results: int = 3) -> List[str]: - """查找相似的服务器名称""" - name_lower = name.lower() - all_servers = list(mcp_manager._clients.keys()) - - # 简单的相似度匹配:包含关系或前缀匹配 - similar = [] - for srv in all_servers: - srv_lower = srv.lower() - if name_lower in srv_lower or srv_lower in name_lower: - similar.append(srv) - elif srv_lower.startswith(name_lower[:3]) if len(name_lower) >= 3 else False: - similar.append(srv) - - return similar[:max_results] - - async def _handle_reconnect(self, server_name: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """处理重连请求""" - if server_name: - if server_name not in mcp_manager._clients: - # 提示相似的服务器名 - similar = self._find_similar_servers(server_name) - msg = f"❌ 服务器 '{server_name}' 不存在" - if similar: - msg += f"\n💡 你是不是想找: {', '.join(similar)}" - await self.send_text(msg) - return (True, None, True) - - await self.send_text(f"🔄 正在重连服务器 {server_name}...") - success = await mcp_manager.reconnect_server(server_name) - if success: - await self.send_text(f"✅ 服务器 {server_name} 重连成功") - else: - await self.send_text(f"❌ 服务器 {server_name} 重连失败") - else: - disconnected = mcp_manager.disconnected_servers - if not disconnected: - await self.send_text("✅ 所有服务器都已连接") - return (True, None, True) - - await self.send_text(f"🔄 正在重连 {len(disconnected)} 个断开的服务器...") - for srv in disconnected: - success = await mcp_manager.reconnect_server(srv) - status = "✅" if success else "❌" - await self.send_text(f"{status} {srv}") - - return (True, None, True) - - async def _handle_trace(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.4.0: 处理追踪命令""" - if arg and arg.isdigit(): - # /mcp trace 20 - 最近 N 条 - n = int(arg) - records = tool_call_tracer.get_recent(n) - elif arg: - # /mcp trace - 特定工具 - records = tool_call_tracer.get_by_tool(arg) - else: - # /mcp trace - 最近 10 条 - records = tool_call_tracer.get_recent(10) - - if not records: - await self.send_text("🔍 暂无调用追踪记录\n\n用法: /mcp trace [数量|工具名]") - return (True, None, True) - - lines = [f"🔍 调用追踪记录 ({len(records)} 条)"] - lines.append("-" * 30) - for i, r in enumerate(reversed(records)): - status_icon = "✅" if r.success else "❌" - cache_tag = " [缓存]" if r.cache_hit else "" - post_tag = " [后处理]" if r.post_processed else "" - ts = time.strftime("%H:%M:%S", time.localtime(r.timestamp)) - lines.append(f"{status_icon} [{ts}] {r.tool_name}") - lines.append(f" {r.duration_ms:.0f}ms | {r.server_name}{cache_tag}{post_tag}") - if r.error: - lines.append(f" 错误: {r.error[:50]}") - if i < len(records) - 1: - lines.append("") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - async def _handle_cache(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.4.0: 处理缓存命令""" - if arg == "clear": - tool_call_cache.clear() - await self.send_text("✅ 缓存已清空") - return (True, None, True) - - stats = tool_call_cache.get_stats() - lines = ["🗄️ 缓存状态"] - lines.append(f"├ 启用: {'是' if stats['enabled'] else '否'}") - lines.append(f"├ 条目: {stats['entries']}/{stats['max_entries']}") - lines.append(f"├ TTL: {stats['ttl']}秒") - lines.append(f"├ 命中: {stats['hits']}") - lines.append(f"├ 未命中: {stats['misses']}") - lines.append(f"└ 命中率: {stats['hit_rate']}") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - async def _handle_perm(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.4.0: 处理权限命令""" - global _plugin_instance - - if _plugin_instance is None: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - perm_config = _plugin_instance.config.get("permissions", {}) - enabled = perm_config.get("perm_enabled", False) - default_mode = perm_config.get("perm_default_mode", "allow_all") - - if arg: - # 查看特定工具的权限 - rules = permission_checker.get_rules_for_tool(arg) - if not rules: - await self.send_text(f"🔐 工具 {arg} 无特定权限规则\n默认模式: {default_mode}") - else: - lines = [f"🔐 工具 {arg} 的权限规则:"] - for r in rules: - lines.append(f" • 模式: {r.get('mode', 'default')}") - if r.get("allowed"): - lines.append(f" 允许: {', '.join(r['allowed'][:3])}...") - if r.get("denied"): - lines.append(f" 拒绝: {', '.join(r['denied'][:3])}...") - await self.send_text("\n".join(lines)) - else: - # 查看权限配置概览 - lines = ["🔐 权限控制配置"] - lines.append(f"├ 启用: {'是' if enabled else '否'}") - lines.append(f"├ 默认模式: {default_mode}") - # 快捷配置 - deny_count = len(permission_checker._quick_deny_groups) - allow_count = len(permission_checker._quick_allow_users) - if deny_count > 0: - lines.append(f"├ 禁用群: {deny_count} 个") - if allow_count > 0: - lines.append(f"├ 管理员白名单: {allow_count} 人") - lines.append(f"└ 高级规则: {len(permission_checker._rules)} 条") - await self.send_text("\n".join(lines)) - - return (True, None, True) - - async def _handle_export(self, format_type: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.6.0: 处理导出命令""" - global _plugin_instance - - if _plugin_instance is None: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - servers_section = _plugin_instance.config.get("servers", {}) - if not isinstance(servers_section, dict): - servers_section = {} - - claude_json = str(servers_section.get("claude_config_json", "") or "") - if not claude_json.strip(): - legacy_list = str(servers_section.get("list", "") or "") - claude_json = legacy_servers_list_to_claude_config(legacy_list) or "" - - if not claude_json.strip(): - await self.send_text("📤 当前没有配置任何服务器") - return (True, None, True) - - try: - pretty = json.dumps(json.loads(claude_json), ensure_ascii=False, indent=2) - except Exception: - pretty = claude_json - - lines = ["📤 导出为 Claude Desktop 格式(mcpServers):"] - if format_type and format_type.strip() and format_type.strip().lower() != "claude": - lines.append("(v2.0 已精简为仅 Claude 格式,忽略其他格式参数)") - lines.append("") - lines.append(pretty) - await self.send_text("\n".join(lines)) - - return (True, None, True) - - async def _handle_search(self, query: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.7.0: 处理工具搜索命令""" - if not query or not query.strip(): - # 显示使用帮助 - help_text = """🔍 工具搜索 - -用法: /mcp search <关键词> - -示例: - /mcp search time 搜索包含 time 的工具 - /mcp search fetch 搜索包含 fetch 的工具 - /mcp search * 列出所有工具 - -支持模糊匹配工具名称和描述""" - await self.send_text(help_text) - return (True, None, True) - - query = query.strip().lower() - tools = mcp_manager.all_tools - - if not tools: - await self.send_text("🔍 当前没有可用的 MCP 工具") - return (True, None, True) - - # 搜索匹配的工具 - matched = [] - for tool_key, (tool_info, client) in tools.items(): - tool_name = tool_key.lower() - tool_desc = (tool_info.description or "").lower() - - # * 表示列出所有 - if query == "*": - matched.append((tool_key, tool_info, client)) - elif query in tool_name or query in tool_desc: - matched.append((tool_key, tool_info, client)) - - if not matched: - await self.send_text(f"🔍 未找到匹配 '{query}' 的工具") - return (True, None, True) - - # 按服务器分组显示 - by_server: Dict[str, List[Tuple[str, Any]]] = {} - for tool_key, tool_info, _client in matched: - server_name = tool_info.server_name - if server_name not in by_server: - by_server[server_name] = [] - by_server[server_name].append((tool_key, tool_info)) - - # 如果只有一个服务器或结果较少,显示全部;否则折叠 - single_server = len(by_server) == 1 - lines = [f"🔍 搜索结果: {len(matched)} 个工具匹配 '{query}'"] - - for srv_name, tool_list in by_server.items(): - lines.append(f"\n📦 {srv_name} ({len(tool_list)} 个):") - - # 单服务器或结果少于 15 个时显示全部 - show_all = single_server or len(matched) <= 15 - display_limit = len(tool_list) if show_all else 5 - - for tool_key, tool_info in tool_list[:display_limit]: - desc = tool_info.description[:40] + "..." if len(tool_info.description) > 40 else tool_info.description - lines.append(f" • {tool_key}") - lines.append(f" {desc}") - if len(tool_list) > display_limit: - lines.append(f" ... 还有 {len(tool_list) - display_limit} 个,用 /mcp search {query} {srv_name} 筛选") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - async def _handle_chain(self, arg: Optional[str] = None) -> Tuple[bool, Optional[str], bool]: - """v1.8.0: 处理工具链命令""" - if not arg or not arg.strip(): - # 显示工具链列表和帮助 - chains = tool_chain_manager.get_all_chains() - - lines = ["🔗 工具链管理"] - lines.append("") - - if chains: - lines.append(f"已配置 {len(chains)} 个工具链:") - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - steps_count = len(chain.steps) - lines.append(f" {status} {name} ({steps_count} 步)") - else: - lines.append("当前没有配置工具链") - - lines.append("") - lines.append("命令:") - lines.append(" /mcp chain list 查看所有工具链") - lines.append(" /mcp chain <名称> 查看工具链详情") - lines.append(" /mcp chain test <名称> <参数JSON> 测试执行") - lines.append(" /mcp chain reload 重新加载配置") - lines.append("") - lines.append("💡 在 WebUI「工具链」配置区编辑工具链") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - parts = arg.strip().split(maxsplit=2) - sub_action = parts[0].lower() - - if sub_action == "list": - # 列出所有工具链 - chains = tool_chain_manager.get_all_chains() - if not chains: - await self.send_text("🔗 当前没有配置工具链") - return (True, None, True) - - lines = [f"🔗 工具链列表 ({len(chains)} 个)"] - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - lines.append(f"\n{status} {name}") - lines.append(f" {chain.description[:60]}...") - lines.append(f" 步骤: {' → '.join([s.tool_name.split('_')[-1] for s in chain.steps[:4]])}") - if chain.input_params: - lines.append(f" 参数: {', '.join(chain.input_params.keys())}") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - elif sub_action == "reload": - # 重新加载工具链配置 - global _plugin_instance - if _plugin_instance: - _plugin_instance._load_tool_chains() - chains = tool_chain_manager.get_all_chains() - from src.plugin_system.core.component_registry import component_registry - - registered = 0 - for name, _chain in tool_chain_manager.get_enabled_chains().items(): - tool_name = f"chain_{name}".replace("-", "_").replace(".", "_") - if component_registry.get_component_info(tool_name, ComponentType.TOOL): - registered += 1 - lines = ["✅ 已重新加载工具链配置"] - lines.append(f"📋 配置数: {len(chains)} 个") - lines.append(f"🔧 已注册: {registered} 个(可被 LLM 调用)") - if chains: - lines.append("") - lines.append("工具链列表:") - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - lines.append(f" {status} chain_{name}") - await self.send_text("\n".join(lines)) - else: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - elif sub_action == "test" and len(parts) >= 2: - # 测试执行工具链 - chain_name = parts[1] - args_json = parts[2] if len(parts) > 2 else "{}" - - chain = tool_chain_manager.get_chain(chain_name) - if not chain: - await self.send_text(f"❌ 工具链 '{chain_name}' 不存在") - return (True, None, True) - - try: - input_args = json.loads(args_json) - except json.JSONDecodeError: - await self.send_text("❌ 参数 JSON 格式错误") - return (True, None, True) - - await self.send_text(f"🔄 正在执行工具链 {chain_name}...") - - result = await tool_chain_manager.execute_chain(chain_name, input_args) - - lines = [] - if result.success: - lines.append(f"✅ 工具链执行成功 ({result.total_duration_ms:.0f}ms)") - lines.append("") - lines.append("执行详情:") - lines.append(result.to_summary()) - lines.append("") - lines.append("最终输出:") - output_preview = result.final_output[:500] - if len(result.final_output) > 500: - output_preview += "..." - lines.append(output_preview) - else: - lines.append("❌ 工具链执行失败") - lines.append(f"错误: {result.error}") - if result.step_results: - lines.append("") - lines.append("执行详情:") - lines.append(result.to_summary()) - - await self.send_text("\n".join(lines)) - return (True, None, True) - - else: - # 查看特定工具链详情 - chain_name = sub_action - chain = tool_chain_manager.get_chain(chain_name) - - if not chain: - # 尝试模糊匹配 - all_chains = tool_chain_manager.get_all_chains() - similar = [n for n in all_chains.keys() if chain_name.lower() in n.lower()] - msg = f"❌ 工具链 '{chain_name}' 不存在" - if similar: - msg += f"\n💡 你是不是想找: {', '.join(similar[:3])}" - await self.send_text(msg) - return (True, None, True) - - lines = [f"🔗 工具链: {chain.name}"] - lines.append(f"状态: {'✅ 启用' if chain.enabled else '❌ 禁用'}") - lines.append(f"描述: {chain.description}") - lines.append("") - - if chain.input_params: - lines.append("📥 输入参数:") - for param, desc in chain.input_params.items(): - lines.append(f" • {param}: {desc}") - lines.append("") - - lines.append(f"📋 执行步骤 ({len(chain.steps)} 个):") - for i, step in enumerate(chain.steps): - optional_tag = " (可选)" if step.optional else "" - lines.append(f" {i + 1}. {step.tool_name}{optional_tag}") - if step.description: - lines.append(f" {step.description}") - if step.output_key: - lines.append(f" 输出键: {step.output_key}") - if step.args_template: - args_preview = json.dumps(step.args_template, ensure_ascii=False)[:60] - lines.append(f" 参数: {args_preview}...") - - lines.append("") - lines.append(f"💡 测试: /mcp chain test {chain.name} " + '{"参数": "值"}') - - await self.send_text("\n".join(lines)) - return (True, None, True) - - def _format_output(self, subcommand: str, server_name: str = None) -> str: - """格式化输出""" - status = mcp_manager.get_status() - stats = mcp_manager.get_all_stats() - lines = [] - - if subcommand in ("status", "all"): - lines.append("📊 MCP 桥接插件状态") - lines.append(f"├ 服务器: {status['connected_servers']}/{status['total_servers']} 已连接") - lines.append(f"├ 工具数: {status['total_tools']}") - lines.append(f"└ 心跳: {'运行中' if status['heartbeat_running'] else '已停止'}") - - if status["servers"]: - lines.append("\n🔌 服务器列表:") - for name, info in status["servers"].items(): - if server_name and name != server_name: - continue - icon = "✅" if info["connected"] else "❌" - enabled = "" if info["enabled"] else " (禁用)" - lines.append(f" {icon} {name}{enabled}") - lines.append(f" {info['transport']} | {info['tools_count']} 工具") - # 显示断路器状态 - cb = info.get("circuit_breaker", {}) - cb_state = cb.get("state", "closed") - if cb_state == "open": - lines.append(" ⚡ 断路器熔断中") - elif cb_state == "half_open": - lines.append(" ⚡ 断路器试探中") - if info["consecutive_failures"] > 0: - lines.append(f" ⚠️ 连续失败 {info['consecutive_failures']} 次") - - if subcommand in ("tools", "all"): - tools = mcp_manager.all_tools - if tools: - lines.append("\n🔧 可用工具:") - by_server = {} - for _key, (info, _) in tools.items(): - if server_name and info.server_name != server_name: - continue - by_server.setdefault(info.server_name, []).append(info.name) - - # 如果指定了服务器名,显示全部工具;否则折叠显示 - show_all = server_name is not None - - for srv, tool_list in by_server.items(): - lines.append(f" 📦 {srv} ({len(tool_list)})") - if show_all: - # 指定服务器时显示全部 - for t in tool_list: - lines.append(f" • {t}") - else: - # 未指定时折叠显示 - for t in tool_list[:5]: - lines.append(f" • {t}") - if len(tool_list) > 5: - lines.append(f" ... 还有 {len(tool_list) - 5} 个,用 /mcp tools {srv} 查看全部") - - if subcommand in ("stats", "all"): - g = stats["global"] - lines.append("\n📈 调用统计:") - lines.append(f" 总调用: {g['total_tool_calls']}") - if g["total_tool_calls"] > 0: - rate = (g["successful_calls"] / g["total_tool_calls"]) * 100 - lines.append(f" 成功率: {rate:.1f}%") - lines.append(f" 运行: {g['uptime_seconds']:.0f}秒") - - if not lines: - lines.append("📖 MCP 桥接插件命令帮助") - lines.append("") - lines.append("状态查询:") - lines.append(" /mcp 查看连接状态") - lines.append(" /mcp tools 查看所有工具") - lines.append(" /mcp tools <服务器> 查看指定服务器工具") - lines.append(" /mcp stats 查看调用统计") - lines.append("") - lines.append("工具搜索:") - lines.append(" /mcp search <关键词> 搜索工具") - lines.append(" /mcp search * 列出所有工具") - lines.append("") - lines.append("服务器管理:") - lines.append(" /mcp reconnect 重连断开的服务器") - lines.append(" /mcp reconnect <名称> 重连指定服务器") - lines.append("") - lines.append("服务器配置(Claude):") - lines.append(" /mcp import 合并 Claude mcpServers 配置") - lines.append(" /mcp export 导出当前 mcpServers 配置") - lines.append("") - lines.append("工具链:") - lines.append(" /mcp chain 查看工具链列表") - lines.append(" /mcp chain <名称> 查看工具链详情") - lines.append(" /mcp chain test <名称> <参数> 测试执行") - lines.append("") - lines.append("其他:") - lines.append(" /mcp trace 查看调用追踪") - lines.append(" /mcp cache 查看缓存状态") - lines.append(" /mcp perm 查看权限配置") - - return "\n".join(lines) - - -class MCPImportCommand(BaseCommand): - """v1.6.0: MCP 配置导入命令 - 支持从 Claude Desktop 格式导入""" - - command_name = "mcp_import_command" - command_description = "从 Claude Desktop 或其他格式导入 MCP 服务器配置" - # 匹配 /mcp import 后面的所有内容(包括多行 JSON) - command_pattern = r"^[//]mcp\s+import(?:\s+(?P.+))?$" - - async def execute(self) -> Tuple[bool, Optional[str], bool]: - """执行导入命令""" - global _plugin_instance - - if _plugin_instance is None: - await self.send_text("❌ 插件未初始化") - return (True, None, True) - - content = self.matched_groups.get("content", "") - - if not content or not content.strip(): - # 显示使用帮助 - help_text = """📥 MCP 配置导入 - -用法: /mcp import - -支持的格式: -• Claude Desktop 格式 (mcpServers 对象) -• 兼容旧版:MaiBot servers 列表数组(将自动迁移为 mcpServers) - -示例: -/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]}}} - -/mcp import {"mcpServers":{"api":{"url":"https://example.com/mcp","transport":"sse"}}}""" - await self.send_text(help_text) - return (True, None, True) - - raw_text = content.strip() - - # 解析输入:支持 Claude mcpServers 或旧版 servers 列表数组 - try: - data = json.loads(raw_text) - except json.JSONDecodeError as e: - await self.send_text(f"❌ JSON 解析失败: {e}") - return (True, None, True) - - if isinstance(data, list): - migrated = legacy_servers_list_to_claude_config(raw_text) - if not migrated: - await self.send_text("❌ 旧版 servers 列表解析失败,无法迁移") - return (True, None, True) - data = json.loads(migrated) - - if not isinstance(data, dict): - await self.send_text("❌ 配置必须是 JSON 对象(包含 mcpServers)") - return (True, None, True) - - incoming_mapping = data.get("mcpServers", data) - if not isinstance(incoming_mapping, dict): - await self.send_text("❌ mcpServers 必须是 JSON 对象") - return (True, None, True) - - # 校验输入配置 - try: - parse_claude_mcp_config(json.dumps({"mcpServers": incoming_mapping}, ensure_ascii=False)) - except ClaudeConfigError as e: - await self.send_text(f"❌ 配置校验失败: {e}") - return (True, None, True) - - servers_section = _plugin_instance.config.get("servers", {}) - if not isinstance(servers_section, dict): - servers_section = {} - - existing_json = str(servers_section.get("claude_config_json", "") or "") - if not existing_json.strip(): - legacy_list = str(servers_section.get("list", "") or "") - existing_json = legacy_servers_list_to_claude_config(legacy_list) or "" - - existing_mapping: Dict[str, Any] = {} - if existing_json.strip(): - try: - parsed = json.loads(existing_json) - mapping = parsed.get("mcpServers", parsed) - if isinstance(mapping, dict): - existing_mapping = mapping - except Exception: - existing_mapping = {} - - added: List[str] = [] - skipped: List[str] = [] - - for name, conf in incoming_mapping.items(): - if name in existing_mapping: - skipped.append(str(name)) - continue - existing_mapping[str(name)] = conf - added.append(str(name)) - - if "servers" not in _plugin_instance.config: - _plugin_instance.config["servers"] = {} - - _plugin_instance.config["servers"]["claude_config_json"] = json.dumps( - {"mcpServers": existing_mapping}, ensure_ascii=False, indent=2 - ) - - # 持久化到配置文件(使用插件基类的写入逻辑) - try: - config_path = Path(_plugin_instance.plugin_dir) / _plugin_instance.config_file_name - _plugin_instance._save_config_to_file(_plugin_instance.config, str(config_path)) - except Exception as e: - logger.warning(f"保存配置文件失败: {e}") - - lines = [] - if added: - lines.append(f"✅ 成功导入 {len(added)} 个服务器:") - for n in added[:20]: - lines.append(f" • {n}") - if len(added) > 20: - lines.append(f" ... 还有 {len(added) - 20} 个") - else: - lines.append("⚠️ 没有新服务器可导入") - - if skipped: - lines.append(f"\n⏭️ 跳过 {len(skipped)} 个已存在的服务器") - - lines.append("\n💡 发送 /mcp reconnect 使配置生效") - - await self.send_text("\n".join(lines)) - return (True, None, True) - - -# ============================================================================ -# 事件处理器 -# ============================================================================ - - -class MCPStartupHandler(BaseEventHandler): - """MCP 启动事件处理器""" - - event_type = EventType.ON_START - handler_name = "mcp_startup_handler" - handler_description = "MCP 桥接插件启动处理器" - weight = 0 - intercept_message = False - - async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: - """处理启动事件""" - global _plugin_instance - - if _plugin_instance is None: - logger.warning("MCP 桥接插件实例未初始化") - return (False, True, None, None, None) - - logger.info("MCP 桥接插件收到 ON_START 事件,开始连接 MCP 服务器...") - await _plugin_instance._async_connect_servers() - - await mcp_manager.start_heartbeat() - - return (True, True, None, None, None) - - -class MCPStopHandler(BaseEventHandler): - """MCP 停止事件处理器""" - - event_type = EventType.ON_STOP - handler_name = "mcp_stop_handler" - handler_description = "MCP 桥接插件停止处理器" - weight = 0 - intercept_message = False - - async def execute(self, message: Optional[Any]) -> Tuple[bool, bool, Optional[str], None, None]: - """处理停止事件""" - global _plugin_instance - - logger.info("MCP 桥接插件收到 ON_STOP 事件,正在关闭...") - - if _plugin_instance is not None: - await _plugin_instance._stop_status_refresher() - - await mcp_manager.shutdown() - mcp_tool_registry.clear() - - logger.info("MCP 桥接插件已关闭所有连接") - return (True, True, None, None, None) - - -# ============================================================================ -# 主插件类 -# ============================================================================ - - -@register_plugin -class MCPBridgePlugin(BasePlugin): - """MCP 桥接插件 v2.0.0 - 将 MCP 服务器的工具桥接到 MaiBot""" - - plugin_name: str = "mcp_bridge_plugin" - enable_plugin: bool = False # 默认禁用,用户需在 WebUI 手动启用 - dependencies: List[str] = [] - python_dependencies: List[str] = ["mcp"] - config_file_name: str = "config.toml" - - config_section_descriptions = { - "guide": section_meta("📖 快速入门", order=1), - "plugin": section_meta("🔘 插件开关", order=2), - "servers": section_meta("🔌 MCP Servers(Claude)", order=3), - "tool_chains": section_meta("🔗 Workflow(硬流程/工具链)", order=4), - "react": section_meta("🔄 ReAct(软流程)", collapsed=True, order=5), - "status": section_meta("📊 运行状态", order=10), - "tools": section_meta("🔧 工具管理", collapsed=True, order=20), - "permissions": section_meta("🔐 权限控制", collapsed=True, order=21), - "settings": section_meta("⚙️ 高级设置", collapsed=True, order=30), - } - - config_schema: dict = { - # 新手引导区(只读) - "guide": { - "quick_start": ConfigField( - type=str, - default="1. 获取 MCP 服务器 2. 在「MCP Servers(Claude)」粘贴 mcpServers 配置 3. 保存后发送 /mcp reconnect 4. (可选)在「Workflow/ ReAct」配置流程", - description="三步开始使用", - label="🚀 快速入门", - disabled=True, - order=1, - ), - "mcp_sources": ConfigField( - type=str, - default="https://modelscope.cn/mcp (魔搭·推荐) | https://smithery.ai | https://glama.ai | https://mcp.so", - description="复制链接到浏览器打开,获取免费 MCP 服务器", - label="🌐 获取 MCP 服务器", - disabled=True, - hint="魔搭 ModelScope 国内免费推荐,将 mcpServers 配置粘贴到「MCP Servers(Claude)」即可", - order=2, - ), - "example_config": ConfigField( - type=str, - default='{"mcpServers":{"time":{"url":"https://mcp.api-inference.modelscope.cn/server/mcp-server-time"}}}', - description="复制到 MCP Servers(Claude)可直接使用(免费时间服务器)", - label="📝 配置示例", - disabled=True, - order=3, - ), - }, - "plugin": { - "enabled": ConfigField( - type=bool, - default=False, - description="是否启用插件(默认关闭)", - label="启用插件", - ), - }, - "settings": { - "tool_prefix": ConfigField( - type=str, - default="mcp", - description="🏷️ 工具前缀 - 生成的工具名格式: {前缀}_{服务器名}_{工具名}", - label="🏷️ 工具前缀", - placeholder="mcp", - order=1, - ), - "connect_timeout": ConfigField( - type=float, - default=30.0, - description="⏱️ 连接超时(秒)", - label="⏱️ 连接超时(秒)", - min=5.0, - max=120.0, - step=5.0, - order=2, - ), - "call_timeout": ConfigField( - type=float, - default=60.0, - description="⏱️ 调用超时(秒)", - label="⏱️ 调用超时(秒)", - min=10.0, - max=300.0, - step=10.0, - order=3, - ), - "auto_connect": ConfigField( - type=bool, - default=True, - description="🔄 启动时自动连接所有已启用的服务器", - label="🔄 自动连接", - order=4, - ), - "retry_attempts": ConfigField( - type=int, - default=3, - description="🔁 连接失败时的重试次数", - label="🔁 重试次数", - min=0, - max=10, - order=5, - ), - "retry_interval": ConfigField( - type=float, - default=5.0, - description="⏳ 重试间隔(秒)", - label="⏳ 重试间隔(秒)", - min=1.0, - max=60.0, - step=1.0, - order=6, - ), - "heartbeat_enabled": ConfigField( - type=bool, - default=True, - description="💓 定期检测服务器连接状态", - label="💓 启用心跳检测", - order=7, - ), - "heartbeat_interval": ConfigField( - type=float, - default=60.0, - description="💓 基准心跳间隔(秒)", - label="💓 心跳间隔(秒)", - min=10.0, - max=300.0, - step=10.0, - hint="智能心跳会根据服务器稳定性自动调整", - order=8, - ), - "heartbeat_adaptive": ConfigField( - type=bool, - default=True, - description="🧠 根据服务器稳定性自动调整心跳间隔", - label="🧠 智能心跳", - hint="稳定服务器逐渐增加间隔,断开的服务器缩短间隔", - order=9, - ), - "heartbeat_max_multiplier": ConfigField( - type=float, - default=3.0, - description="稳定服务器的最大间隔倍数", - label="📈 最大间隔倍数", - min=1.5, - max=5.0, - step=0.5, - hint="稳定服务器心跳间隔最高可达 基准间隔 × 此值", - order=10, - ), - "auto_reconnect": ConfigField( - type=bool, - default=True, - description="🔄 检测到断开时自动尝试重连", - label="🔄 自动重连", - order=11, - ), - "max_reconnect_attempts": ConfigField( - type=int, - default=3, - description="🔄 连续重连失败后暂停重连", - label="🔄 最大重连次数", - min=1, - max=10, - order=12, - ), - # v1.7.0: 状态刷新配置 - "status_refresh_enabled": ConfigField( - type=bool, - default=True, - description="📊 定期更新 WebUI 状态显示", - label="📊 启用状态实时刷新", - hint="关闭后 WebUI 状态仅在启动时更新", - order=13, - ), - "status_refresh_interval": ConfigField( - type=float, - default=10.0, - description="📊 状态刷新间隔(秒)", - label="📊 状态刷新间隔(秒)", - min=5.0, - max=60.0, - step=5.0, - hint="值越小刷新越频繁,但会增加少量 CPU 消耗", - order=14, - ), - "enable_resources": ConfigField( - type=bool, - default=False, - description="📦 允许读取 MCP 服务器提供的资源", - label="📦 启用 Resources(实验性)", - order=11, - ), - "enable_prompts": ConfigField( - type=bool, - default=False, - description="📝 允许使用 MCP 服务器提供的提示模板", - label="📝 启用 Prompts(实验性)", - order=12, - ), - # v1.3.0 后处理配置 - "post_process_enabled": ConfigField( - type=bool, - default=False, - description="🔄 使用 LLM 对长结果进行摘要提炼", - label="🔄 启用结果后处理", - order=20, - ), - "post_process_threshold": ConfigField( - type=int, - default=500, - description="📏 结果长度超过此值才触发后处理", - label="📏 后处理阈值(字符)", - min=100, - max=5000, - step=100, - order=21, - ), - "post_process_max_tokens": ConfigField( - type=int, - default=500, - description="📝 LLM 摘要输出的最大 token 数", - label="📝 后处理最大输出 token", - min=100, - max=2000, - step=50, - order=22, - ), - "post_process_model": ConfigField( - type=str, - default="", - description="🤖 指定用于后处理的模型名称", - label="🤖 后处理模型(可选)", - placeholder="留空则使用 Utils 模型组", - order=23, - ), - "post_process_prompt": ConfigField( - type=str, - default="用户问题:{query}\\n\\n工具返回内容:\\n{result}\\n\\n请从上述内容中提取与用户问题最相关的关键信息,简洁准确地输出:", - description="📋 后处理提示词模板", - label="📋 后处理提示词模板", - input_type="textarea", - rows=8, - order=24, - ), - # v1.4.0 追踪配置 - "trace_enabled": ConfigField( - type=bool, - default=True, - description="🔍 记录工具调用详情", - label="🔍 启用调用追踪", - order=30, - ), - "trace_max_records": ConfigField( - type=int, - default=100, - description="内存中保留的最大记录数", - label="📊 追踪记录上限", - min=10, - max=1000, - order=31, - ), - "trace_log_enabled": ConfigField( - type=bool, - default=False, - description="是否将追踪记录写入日志文件", - label="📝 追踪日志文件", - hint="启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl", - order=32, - ), - # v1.4.0 缓存配置 - "cache_enabled": ConfigField( - type=bool, - default=False, - description="🗄️ 缓存相同参数的调用结果", - label="🗄️ 启用调用缓存", - hint="相同参数的调用会返回缓存结果,减少重复请求", - order=40, - ), - "cache_ttl": ConfigField( - type=int, - default=300, - description="缓存有效期(秒)", - label="⏱️ 缓存有效期(秒)", - min=60, - max=3600, - order=41, - ), - "cache_max_entries": ConfigField( - type=int, - default=200, - description="最大缓存条目数(超出后 LRU 淘汰)", - label="📦 最大缓存条目", - min=50, - max=1000, - order=42, - ), - "cache_exclude_tools": ConfigField( - type=str, - default="", - description="不缓存的工具(每行一个,支持通配符 *)", - label="🚫 缓存排除列表", - input_type="textarea", - rows=4, - hint="时间类、随机类工具建议排除,如 mcp_time_*", - order=43, - ), - }, - # v1.4.0 工具管理 - "tools": { - "tool_list": ConfigField( - type=str, - default="(启动后自动生成)", - description="当前已注册的 MCP 工具列表(只读)", - label="📋 工具清单", - input_type="textarea", - disabled=True, - rows=12, - hint="从此处复制工具名到下方禁用列表或工具链配置", - order=1, - ), - "disabled_tools": ConfigField( - type=str, - default="", - description="要禁用的工具名(每行一个)", - label="🚫 禁用工具列表", - input_type="textarea", - rows=6, - hint="从上方工具清单复制工具名,每行一个。禁用后该工具不会被 LLM 调用", - order=2, - ), - }, - # v1.8.0 工具链配置 - "tool_chains": { - "chains_enabled": ConfigField( - type=bool, - default=True, - description="🔗 启用工具链功能", - label="🔗 启用工具链", - hint="工具链可将多个工具按顺序执行,后续工具可使用前序工具的输出", - order=1, - ), - # 工具链使用指南 - "chains_guide": ConfigField( - type=str, - default="""工具链将多个 MCP 工具串联执行,后续步骤可使用前序步骤的输出 - -📌 变量语法: - ${input.参数名} - 用户输入的参数 - ${step.输出键} - 某步骤的输出(需设置 output_key) - ${prev} - 上一步的输出 - ${prev.字段} - 上一步输出(JSON)的某字段 - ${step.输出键.0.字段} / ${step.输出键[0].字段} - 访问数组下标 - ${step.输出键['return'][0]['location']} - 支持 bracket 写法 - -📌 测试命令: - /mcp chain list - 查看所有工具链 - /mcp chain 链名 {"参数":"值"} - 测试执行""", - description="工具链使用说明", - label="📖 使用指南", - input_type="textarea", - disabled=True, - rows=10, - order=2, - ), - # 快速添加工具链(表单式) - "quick_chain_name": ConfigField( - type=str, - default="", - description="工具链名称(英文,如 search_and_summarize)", - label="➕ 快速添加 - 名称", - placeholder="my_tool_chain", - hint="必填,将作为 LLM 可调用的工具名", - order=10, - ), - "quick_chain_desc": ConfigField( - type=str, - default="", - description="工具链描述(供 LLM 理解何时使用)", - label="➕ 快速添加 - 描述", - placeholder="先搜索内容,再获取详情并总结", - hint="必填,清晰描述工具链的用途", - order=11, - ), - "quick_chain_params": ConfigField( - type=str, - default="", - description="输入参数(每行一个,格式: 参数名=描述)", - label="➕ 快速添加 - 输入参数", - input_type="textarea", - rows=3, - placeholder="query=搜索关键词\nmax_results=最大结果数", - hint="定义用户需要提供的参数", - order=12, - ), - "quick_chain_steps": ConfigField( - type=str, - default="", - description="执行步骤(每行一个,格式: 工具名|参数JSON|输出键)", - label="➕ 快速添加 - 执行步骤", - input_type="textarea", - rows=5, - placeholder='mcp_server_search|{"keyword":"${input.query}"}|search_result\nmcp_server_detail|{"id":"${prev}"}|\n# 访问数组示例:\n# mcp_geo|{"q":"${input.query}"}|geo\n# mcp_next|{"location":"${step.geo.return.0.location}"}|', - hint="格式: 工具名|参数模板|输出键(输出键可选,用于后续步骤引用 ${step.xxx})", - order=13, - ), - "quick_chain_add": ConfigField( - type=str, - default="", - description="填写上方信息后,在此输入 ADD 并保存即可添加", - label="➕ 确认添加", - placeholder="输入 ADD 并保存", - hint="添加后会自动合并到下方工具链列表", - order=14, - ), - # 工具链模板 - "chains_templates": ConfigField( - type=str, - default="""📋 常用工具链模板(复制到下方列表使用): - -1️⃣ 搜索+详情模板: -{ - "name": "search_and_detail", - "description": "搜索内容并获取详情", - "input_params": {"query": "搜索关键词"}, - "steps": [ - {"tool_name": "搜索工具名", "args_template": {"keyword": "${input.query}"}, "output_key": "results"}, - {"tool_name": "详情工具名", "args_template": {"id": "${prev}"}} - ] -} - -2️⃣ 获取+处理模板: -{ - "name": "fetch_and_process", - "description": "获取数据并处理", - "input_params": {"url": "目标URL"}, - "steps": [ - {"tool_name": "获取工具名", "args_template": {"url": "${input.url}"}, "output_key": "data"}, - {"tool_name": "处理工具名", "args_template": {"content": "${step.data}"}} - ] -} - -3️⃣ 多步骤可选模板: -{ - "name": "multi_step_chain", - "description": "多步骤处理,部分可选", - "input_params": {"input": "输入内容"}, - "steps": [ - {"tool_name": "步骤1工具", "args_template": {"data": "${input.input}"}, "output_key": "step1"}, - {"tool_name": "步骤2工具", "args_template": {"data": "${prev}"}, "output_key": "step2", "optional": true}, - {"tool_name": "步骤3工具", "args_template": {"data": "${step.step1}"}} - ] -}""", - description="工具链配置模板参考", - label="📝 配置模板", - input_type="textarea", - disabled=True, - rows=15, - order=20, - ), - "chains_list": ConfigField( - type=str, - default="[]", - description="工具链配置(JSON 数组格式)", - label="📋 工具链列表", - input_type="textarea", - rows=20, - placeholder="""[ - { - "name": "search_and_detail", - "description": "先搜索再获取详情", - "input_params": {"query": "搜索关键词"}, - "steps": [ - {"tool_name": "mcp_server_search", "args_template": {"keyword": "${input.query}"}, "output_key": "search_result"}, - {"tool_name": "mcp_server_get_detail", "args_template": {"id": "${step.search_result}"}} - ] - } -]""", - hint="每个工具链包含 name、description、input_params、steps", - order=30, - ), - "chains_status": ConfigField( - type=str, - default="(启动后自动生成)", - description="当前已注册的工具链状态(只读)", - label="📊 工具链状态", - input_type="textarea", - disabled=True, - rows=8, - order=40, - ), - }, - # v1.9.0 ReAct 软流程配置 - "react": { - "react_enabled": ConfigField( - type=bool, - default=False, - description="🔄 将 MCP 工具注册到记忆检索 ReAct 系统", - label="🔄 启用 ReAct 集成", - hint="启用后,MaiBot 的 ReAct Agent 可在记忆检索时调用 MCP 工具", - order=1, - ), - "react_guide": ConfigField( - type=str, - default="""ReAct 软流程说明: - -📌 什么是 ReAct? -ReAct (Reasoning + Acting) 是 LLM 自主决策的多轮工具调用模式。 -与 Workflow 硬流程不同,ReAct 由 LLM 动态决定调用哪些工具。 - -📌 工作原理: -1. 用户提问 → LLM 分析需要什么信息 -2. LLM 选择调用工具 → 获取结果 -3. LLM 观察结果 → 决定是否需要更多信息 -4. 重复 2-3 直到信息足够 → 生成最终回答 - -📌 与 Workflow 的区别: -- ReAct (软流程): LLM 自主决策,灵活但不可预测 -- Workflow (硬流程): 用户预定义,固定流程,可靠可控 - -📌 使用场景: -- 复杂问题需要多步推理 -- 不确定需要调用哪些工具 -- 需要根据中间结果动态调整""", - description="ReAct 软流程使用说明", - label="📖 使用指南", - input_type="textarea", - disabled=True, - rows=15, - order=2, - ), - "filter_mode": ConfigField( - type=str, - default="whitelist", - description="过滤模式", - label="📋 过滤模式", - choices=["whitelist", "blacklist"], - hint="whitelist: 只注册列出的工具;blacklist: 排除列出的工具", - order=3, - ), - "tool_filter": ConfigField( - type=str, - default="", - description="工具过滤列表(每行一个,支持通配符 * 和精确匹配)", - label="🔍 工具过滤列表", - input_type="textarea", - rows=6, - placeholder="""# 精确匹配示例: -mcp_bing_web_search_bing_search -mcp_mcmod_search_mod - -# 通配符示例: -mcp_*_search_* -mcp_bing_*""", - hint="白名单模式: 只注册列出的工具;黑名单模式: 排除列出的工具。支持 # 注释", - order=4, - ), - "react_status": ConfigField( - type=str, - default="(启动后自动生成)", - description="当前已注册到 ReAct 的工具状态(只读)", - label="📊 ReAct 工具状态", - input_type="textarea", - disabled=True, - rows=6, - order=10, - ), - }, - # v1.4.0 权限控制 - "permissions": { - "perm_enabled": ConfigField( - type=bool, - default=False, - description="🔐 按群/用户限制工具使用", - label="🔐 启用权限控制", - order=1, - ), - "perm_default_mode": ConfigField( - type=str, - default="allow_all", - description="默认模式:allow_all(默认允许)或 deny_all(默认禁止)", - label="📋 默认模式", - placeholder="allow_all", - hint="allow_all: 未配置的默认允许;deny_all: 未配置的默认禁止", - order=2, - ), - # 快捷配置(简化版) - "quick_deny_groups": ConfigField( - type=str, - default="", - description="禁止使用所有 MCP 工具的群号(每行一个)", - label="🚫 禁用群列表(快捷)", - input_type="textarea", - rows=4, - hint="填入群号,该群将无法使用任何 MCP 工具", - order=3, - ), - "quick_allow_users": ConfigField( - type=str, - default="", - description="始终允许使用所有工具的用户 QQ 号(管理员白名单,每行一个)", - label="✅ 管理员白名单(快捷)", - input_type="textarea", - rows=3, - hint="填入 QQ 号,该用户在任何场景都可使用 MCP 工具", - order=4, - ), - # 高级配置 - "perm_rules": ConfigField( - type=str, - default="[]", - description="高级权限规则(JSON 格式,可针对特定工具配置)", - label="📜 高级权限规则(可选)", - input_type="textarea", - rows=10, - placeholder="""[ - {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]} -]""", - hint="格式: qq:ID:group/private/user,工具名支持通配符 *", - order=10, - ), - }, - # v2.0: 服务器配置统一为 Claude Desktop mcpServers 规范(JSON) - "servers": { - "claude_config_json": ConfigField( - type=str, - default='{"mcpServers":{}}', - description="Claude Desktop 规范的 MCP 配置(JSON)", - label="🔌 MCP Servers(Claude 规范)", - input_type="textarea", - rows=18, - hint="仅支持 Claude Desktop 的 mcpServers JSON。每个服务器需包含 command(stdio) 或 url(remote)。", - order=1, - ), - "claude_config_guide": ConfigField( - type=str, - default="""示例: -{ - "mcpServers": { - "fetch": { "command": "uvx", "args": ["mcp-server-fetch"] }, - "time": { "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time" } - } -} - -可选字段: -- enabled: true/false -- headers: {"Authorization":"Bearer ..."} -- env: {"KEY":"VALUE"} -- transport/type: "streamable_http" | "http" | "sse"(remote 可选,默认 streamable_http) -""", - description="配置说明(只读)", - label="📖 配置说明", - input_type="textarea", - disabled=True, - rows=12, - order=2, - ), - }, - "status": { - "connection_status": ConfigField( - type=str, - default="未初始化", - description="当前 MCP 服务器连接状态和工具列表", - label="📊 连接状态", - input_type="textarea", - disabled=True, - rows=15, - hint="此状态仅在插件启动时更新。查询实时状态请发送 /mcp 命令", - order=1, - ), - }, - } - - @staticmethod - def _fix_config_multiline_strings(config_path: Path) -> bool: - """修复配置文件中的多行字符串格式问题 - - 处理两种情况: - 1. 带转义 \\n 的单行字符串(json.dumps 生成) - 2. 跨越多行但使用普通双引号的字符串(控制字符错误) - - Returns: - bool: 是否进行了修复 - """ - if not config_path.exists(): - return False - - try: - content = config_path.read_text(encoding="utf-8") - - # 情况1: 修复带转义 \n 的单行字符串 - # 匹配: key = "内容包含\n的字符串" - pattern1 = r'^(\s*\w+\s*=\s*)"((?:[^"\\]|\\.)*\\n(?:[^"\\]|\\.)*)"(\s*)$' - - # 情况2: 修复跨越多行的普通双引号字符串 - # 匹配: key = "第一行 - # 第二行 - # 第三行" - pattern2_start = r'^(\s*\w+\s*=\s*)"([^"]*?)$' # 开始行 - pattern2_end = r'^([^"]*)"(\s*)$' # 结束行 - - lines = content.split("\n") - fixed_lines = [] - modified = False - - i = 0 - while i < len(lines): - line = lines[i] - - # 情况1: 单行带转义换行符 - match1 = re.match(pattern1, line) - if match1: - prefix = match1.group(1) - value = match1.group(2) - suffix = match1.group(3) - # 将转义的换行符还原为实际换行符 - unescaped = ( - value.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"').replace("\\\\", "\\") - ) - fixed_line = f'{prefix}"""{unescaped}"""{suffix}' - fixed_lines.append(fixed_line) - modified = True - i += 1 - continue - - # 情况2: 跨越多行的字符串 - match2_start = re.match(pattern2_start, line) - if match2_start: - prefix = match2_start.group(1) - first_part = match2_start.group(2) - - # 收集后续行直到找到结束引号 - multiline_parts = [first_part] - j = i + 1 - found_end = False - - while j < len(lines): - next_line = lines[j] - match2_end = re.match(pattern2_end, next_line) - if match2_end: - multiline_parts.append(match2_end.group(1)) - suffix = match2_end.group(2) - found_end = True - j += 1 - break - else: - multiline_parts.append(next_line) - j += 1 - - if found_end and len(multiline_parts) > 1: - # 合并为三引号字符串 - full_value = "\n".join(multiline_parts) - fixed_line = f'{prefix}"""{full_value}"""{suffix}' - fixed_lines.append(fixed_line) - modified = True - i = j - continue - - fixed_lines.append(line) - i += 1 - - if modified: - config_path.write_text("\n".join(fixed_lines), encoding="utf-8") - logger.info("已自动修复配置文件中的多行字符串格式") - return True - - return False - except Exception as e: - logger.warning(f"修复配置文件格式失败: {e}") - return False - - def __init__(self, *args, **kwargs): - global _plugin_instance - - # 在父类初始化前尝试修复配置文件格式 - config_path = Path(__file__).parent / "config.toml" - self._fix_config_multiline_strings(config_path) - - super().__init__(*args, **kwargs) - self._initialized = False - self._status_refresh_running = False - self._status_refresh_task: Optional[asyncio.Task] = None - self._last_persisted_display_hash: str = "" - self._last_servers_config_error: str = "" - _plugin_instance = self - - # 配置 MCP 管理器 - settings = self.config.get("settings", {}) - mcp_manager.configure(settings) - - # v1.4.0: 配置追踪器 - trace_log_path = Path(__file__).parent / "logs" / "trace.jsonl" - tool_call_tracer.configure( - enabled=settings.get("trace_enabled", True), - max_records=settings.get("trace_max_records", 100), - log_enabled=settings.get("trace_log_enabled", False), - log_path=trace_log_path, - ) - - # v1.4.0: 配置缓存 - tool_call_cache.configure( - enabled=settings.get("cache_enabled", False), - ttl=settings.get("cache_ttl", 300), - max_entries=settings.get("cache_max_entries", 200), - exclude_tools=settings.get("cache_exclude_tools", ""), - ) - - # v1.4.0: 配置权限检查器 - perm_config = self.config.get("permissions", {}) - permission_checker.configure( - enabled=perm_config.get("perm_enabled", False), - default_mode=perm_config.get("perm_default_mode", "allow_all"), - rules_json=perm_config.get("perm_rules", "[]"), - quick_deny_groups=perm_config.get("quick_deny_groups", ""), - quick_allow_users=perm_config.get("quick_allow_users", ""), - ) - - # 注册状态变化回调 - mcp_manager.set_status_change_callback(self._update_status_display) - - # v2.0: 服务器配置统一由 servers.claude_config_json 提供(不再通过 WebUI 导入/快速添加写入旧 servers.list) - - # v1.8.0: 初始化工具链管理器 - tool_chain_manager.set_executor(mcp_manager) - self._load_tool_chains() - - def _persist_runtime_displays(self) -> None: - """将 WebUI 只读展示字段写回配置文件,使 WebUI 能正确显示运行状态。""" - try: - config_path = Path(self.plugin_dir) / self.config_file_name - - payload = { - "status.connection_status": str(self.config.get("status", {}).get("connection_status", "") or ""), - "tools.tool_list": str(self.config.get("tools", {}).get("tool_list", "") or ""), - "tool_chains.chains_status": str(self.config.get("tool_chains", {}).get("chains_status", "") or ""), - "react.react_status": str(self.config.get("react", {}).get("react_status", "") or ""), - } - digest = hashlib.sha256(json.dumps(payload, ensure_ascii=False).encode("utf-8")).hexdigest() - if digest == self._last_persisted_display_hash: - return - - self._save_config_to_file(self.config, str(config_path)) - self._last_persisted_display_hash = digest - except Exception as e: - logger.debug(f"写回运行状态到配置文件失败: {e}") - - def _process_quick_add_chain(self) -> None: - """v1.8.0: 处理快速添加工具链表单""" - chains_config = self.config.get("tool_chains", {}) - - # 检查是否触发添加 - add_trigger = chains_config.get("quick_chain_add", "").strip().upper() - if add_trigger != "ADD": - return - - # 获取表单数据 - chain_name = chains_config.get("quick_chain_name", "").strip() - chain_desc = chains_config.get("quick_chain_desc", "").strip() - params_str = chains_config.get("quick_chain_params", "").strip() - steps_str = chains_config.get("quick_chain_steps", "").strip() - - # 验证必填字段 - if not chain_name: - logger.warning("快速添加工具链: 名称不能为空") - self._clear_quick_chain_fields() - return - - if not chain_desc: - logger.warning("快速添加工具链: 描述不能为空") - self._clear_quick_chain_fields() - return - - if not steps_str: - logger.warning("快速添加工具链: 步骤不能为空") - self._clear_quick_chain_fields() - return - - # 解析输入参数 - input_params = {} - if params_str: - for line in params_str.split("\n"): - line = line.strip() - if not line or "=" not in line: - continue - parts = line.split("=", 1) - param_name = parts[0].strip() - param_desc = parts[1].strip() if len(parts) > 1 else param_name - input_params[param_name] = param_desc - - # 解析步骤 - steps = [] - for line in steps_str.split("\n"): - line = line.strip() - if not line: - continue - - parts = line.split("|") - if len(parts) < 2: - logger.warning(f"快速添加工具链: 步骤格式错误: {line}") - continue - - tool_name = parts[0].strip() - args_str = parts[1].strip() if len(parts) > 1 else "{}" - output_key = parts[2].strip() if len(parts) > 2 else "" - - # 解析参数 JSON - try: - args_template = json.loads(args_str) if args_str else {} - except json.JSONDecodeError: - logger.warning(f"快速添加工具链: 参数 JSON 格式错误: {args_str}") - args_template = {} - - steps.append( - { - "tool_name": tool_name, - "args_template": args_template, - "output_key": output_key, - } - ) - - if not steps: - logger.warning("快速添加工具链: 没有有效的步骤") - self._clear_quick_chain_fields() - return - - # 构建新工具链 - new_chain = { - "name": chain_name, - "description": chain_desc, - "input_params": input_params, - "steps": steps, - "enabled": True, - } - - # 获取现有工具链列表 - chains_json = chains_config.get("chains_list", "[]") - try: - chains_list = json.loads(chains_json) if chains_json.strip() else [] - except json.JSONDecodeError: - chains_list = [] - - # 检查是否已存在同名工具链 - for existing in chains_list: - if existing.get("name") == chain_name: - logger.info(f"快速添加: 工具链 {chain_name} 已存在,将更新") - chains_list.remove(existing) - break - - # 添加新工具链 - chains_list.append(new_chain) - new_chains_json = json.dumps(chains_list, ensure_ascii=False, indent=2) - - # 更新配置 - self.config["tool_chains"]["chains_list"] = new_chains_json - - # 清空表单字段 - self._clear_quick_chain_fields() - - # 保存到配置文件 - self._save_chains_list(new_chains_json) - - logger.info(f"快速添加: 已添加工具链 {chain_name} ({len(steps)} 个步骤)") - - def _clear_quick_chain_fields(self) -> None: - """清空快速添加工具链表单字段""" - if "tool_chains" not in self.config: - self.config["tool_chains"] = {} - self.config["tool_chains"]["quick_chain_name"] = "" - self.config["tool_chains"]["quick_chain_desc"] = "" - self.config["tool_chains"]["quick_chain_params"] = "" - self.config["tool_chains"]["quick_chain_steps"] = "" - self.config["tool_chains"]["quick_chain_add"] = "" - - def _save_chains_list(self, chains_json: str) -> None: - """保存工具链列表到配置文件""" - try: - config_path = Path(self.plugin_dir) / self.config_file_name - self._save_config_to_file(self.config, str(config_path)) - logger.info("工具链列表已保存到配置文件") - except Exception as e: - logger.warning(f"保存工具链列表失败: {e}") - - def _load_tool_chains(self) -> None: - """v1.8.0: 加载工具链配置""" - # 先处理快速添加 - self._process_quick_add_chain() - - chains_config = self.config.get("tool_chains", {}) - if not isinstance(chains_config, dict): - chains_config = {} - - # 兼容旧版本:部分版本可能使用 tool_chain 或其他字段名 - if not chains_config: - legacy_section = self.config.get("tool_chain") - if isinstance(legacy_section, dict): - chains_config = legacy_section - self.config["tool_chains"] = legacy_section - - # 兼容旧版本:chains_list 字段名变化 - chains_json = str(chains_config.get("chains_list", "") or "") - if not chains_json.strip(): - for legacy_key in ("list", "chains", "workflow_list", "workflows", "toolchains"): - legacy_val = chains_config.get(legacy_key) - if legacy_val is None: - continue - - if isinstance(legacy_val, str) and legacy_val.strip(): - chains_json = legacy_val - break - - if isinstance(legacy_val, list): - chains_json = json.dumps(legacy_val, ensure_ascii=False, indent=2) - break - - if isinstance(legacy_val, dict): - chains_json = json.dumps([legacy_val], ensure_ascii=False, indent=2) - break - - if chains_json.strip(): - if "tool_chains" not in self.config or not isinstance(self.config.get("tool_chains"), dict): - self.config["tool_chains"] = {} - self.config["tool_chains"]["chains_list"] = chains_json - logger.info( - "检测到旧版 Workflow 配置字段,已自动迁移为 tool_chains.chains_list(请在 WebUI 保存一次以固化)" - ) - - chains_config = self.config.get("tool_chains", {}) - if not isinstance(chains_config, dict): - chains_config = {} - - if not chains_config.get("chains_enabled", True): - logger.info("工具链功能已禁用") - return - - chains_json = str(chains_config.get("chains_list", "[]") or "") - if not chains_json or not chains_json.strip(): - return - - # 清空现有工具链 - tool_chain_manager.clear() - tool_chain_registry.clear() - - # 加载新配置 - loaded, errors = tool_chain_manager.load_from_json(chains_json) - - if errors: - for err in errors: - logger.warning(f"工具链配置错误: {err}") - - if loaded > 0: - logger.info(f"已加载 {loaded} 个工具链") - # 注册工具链到组件系统 - self._register_tool_chains() - self._update_chains_status_display() - - def _register_tool_chains(self) -> None: - """v1.8.1: 将工具链注册到 MaiBot 组件系统,使 LLM 可调用""" - from src.plugin_system.core.component_registry import component_registry - - chain_count = 0 - for chain_name, chain in tool_chain_manager.get_enabled_chains().items(): - try: - expected_tool_name = f"chain_{chain.name}".replace("-", "_").replace(".", "_") - if component_registry.get_component_info(expected_tool_name, ComponentType.TOOL): - chain_count += 1 - logger.debug(f"🔗 工具链已存在,跳过重复注册: {expected_tool_name}") - continue - - info, tool_class = tool_chain_registry.register_chain(chain) - info.plugin_name = self.plugin_name - - if component_registry.register_component(info, tool_class): - chain_count += 1 - logger.info(f"🔗 注册工具链: {tool_class.name}") - else: - logger.warning(f"⚠️ 工具链注册被跳过(可能已存在): {tool_class.name}") - except Exception as e: - logger.error(f"注册工具链 {chain_name} 失败: {e}") - - if chain_count > 0: - logger.info(f"已注册 {chain_count} 个工具链到组件系统") - - def _register_tools_to_react(self) -> int: - """v1.9.0: 将 MCP 工具注册到记忆检索 ReAct 系统(软流程) - - 这样 MaiBot 的 ReAct Agent 在检索记忆时可以调用 MCP 工具, - 实现 LLM 自主决策的多轮工具调用。 - - Returns: - int: 成功注册的工具数量 - """ - try: - from src.memory_system.retrieval_tools import register_memory_retrieval_tool - except ImportError: - logger.warning("无法导入记忆检索工具注册模块,跳过 ReAct 工具注册") - return 0 - - react_config = self.config.get("react", {}) - filter_mode = react_config.get("filter_mode", "whitelist") - tool_filter = react_config.get("tool_filter", "").strip() - - # 解析过滤列表(支持 # 注释) - filter_patterns = [] - for line in tool_filter.split("\n"): - line = line.strip() - if line and not line.startswith("#"): - filter_patterns.append(line) - - registered_count = 0 - disabled_tools = self._get_disabled_tools() - registered_tools = [] # 记录已注册的工具名 - - for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): - tool_name = tool_key.replace("-", "_").replace(".", "_") - - # 跳过禁用的工具 - if tool_name in disabled_tools: - continue - - # 应用过滤器 - if filter_patterns: - matched = any(fnmatch.fnmatch(tool_name, p) or tool_name == p for p in filter_patterns) - - if filter_mode == "whitelist": - # 白名单模式:只注册匹配的 - if not matched: - continue - else: - # 黑名单模式:排除匹配的 - if matched: - continue - - try: - # 转换参数格式 - parameters = self._convert_mcp_params_to_react_format(tool_info.input_schema) - - # 创建异步执行函数(使用闭包捕获 tool_key) - def make_execute_func(tk: str): - async def _execute_func(**kwargs) -> str: - result = await mcp_manager.call_tool(tk, kwargs) - if result.success: - return result.content or "(无返回内容)" - else: - return f"工具调用失败: {result.error}" - - return _execute_func - - execute_func = make_execute_func(tool_key) - - # 注册到 ReAct 系统 - register_memory_retrieval_tool( - name=f"mcp_{tool_name}", - description=f"{tool_info.description} [MCP: {tool_info.server_name}]", - parameters=parameters, - execute_func=execute_func, - ) - - registered_count += 1 - registered_tools.append(f"mcp_{tool_name}") - logger.debug(f"🔄 注册 ReAct 工具: mcp_{tool_name}") - - except Exception as e: - logger.warning(f"注册 ReAct 工具 {tool_name} 失败: {e}") - - if registered_count > 0: - mode_str = "白名单" if filter_mode == "whitelist" else "黑名单" - logger.info(f"已注册 {registered_count} 个 MCP 工具到 ReAct 系统 (过滤模式: {mode_str})") - - # 更新状态显示 - self._update_react_status_display(registered_tools, filter_mode, filter_patterns) - - return registered_count - - def _update_react_status_display( - self, registered_tools: List[str], filter_mode: str, filter_patterns: List[str] - ) -> None: - """更新 ReAct 工具状态显示""" - if not registered_tools: - status_text = "(未注册任何工具)" - else: - mode_str = "白名单" if filter_mode == "whitelist" else "黑名单" - lines = [f"📊 已注册 {len(registered_tools)} 个工具 (模式: {mode_str})"] - if filter_patterns: - lines.append(f"过滤规则: {len(filter_patterns)} 条") - lines.append("") - for tool in registered_tools[:20]: - lines.append(f" • {tool}") - if len(registered_tools) > 20: - lines.append(f" ... 还有 {len(registered_tools) - 20} 个") - status_text = "\n".join(lines) - - # 更新内存配置 - if "react" not in self.config: - self.config["react"] = {} - self.config["react"]["react_status"] = status_text - - def _convert_mcp_params_to_react_format(self, input_schema: Dict) -> List[Dict[str, Any]]: - """将 MCP 工具参数转换为 ReAct 工具参数格式""" - parameters = [] - - if not input_schema: - return parameters - - properties = input_schema.get("properties", {}) - required = input_schema.get("required", []) - - for param_name, param_info in properties.items(): - param_type = param_info.get("type", "string") - description = param_info.get("description", f"参数 {param_name}") - is_required = param_name in required - - parameters.append( - { - "name": param_name, - "type": param_type, - "description": description, - "required": is_required, - } - ) - - return parameters - - def _update_chains_status_display(self) -> None: - """v1.8.0: 更新工具链状态显示""" - chains = tool_chain_manager.get_all_chains() - - if not chains: - status_text = "(无工具链配置)" - else: - lines = [f"📊 已配置 {len(chains)} 个工具链:\n"] - for name, chain in chains.items(): - status = "✅" if chain.enabled else "❌" - # 显示工具链基本信息 - lines.append(f"{status} chain_{name}") - lines.append(f" 描述: {chain.description[:40]}{'...' if len(chain.description) > 40 else ''}") - - # 显示输入参数 - if chain.input_params: - params = ", ".join(chain.input_params.keys()) - lines.append(f" 参数: {params}") - - # 显示步骤 - lines.append(f" 步骤: {len(chain.steps)} 个") - for i, step in enumerate(chain.steps): - opt = " (可选)" if step.optional else "" - out = f" → {step.output_key}" if step.output_key else "" - lines.append(f" {i + 1}. {step.tool_name}{out}{opt}") - lines.append("") - - status_text = "\n".join(lines) - - # 更新内存配置 - if "tool_chains" not in self.config: - self.config["tool_chains"] = {} - self.config["tool_chains"]["chains_status"] = status_text - - def _get_disabled_tools(self) -> set: - """v1.4.0: 获取禁用的工具列表""" - tools_config = self.config.get("tools", {}) - disabled_str = tools_config.get("disabled_tools", "") - return {t.strip() for t in disabled_str.strip().split("\n") if t.strip()} - - async def _async_connect_servers(self) -> None: - """异步连接所有配置的 MCP 服务器(v1.5.0: 并行连接优化)""" - import asyncio - - settings = self.config.get("settings", {}) - - servers_config = self._load_mcp_servers_config() - - if not servers_config: - logger.warning("未配置任何 MCP 服务器") - self._initialized = True - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - return - - auto_connect = settings.get("auto_connect", True) - if not auto_connect: - logger.info("auto_connect 已禁用,跳过自动连接") - self._initialized = True - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - return - - tool_prefix = settings.get("tool_prefix", "mcp") - disabled_tools = self._get_disabled_tools() - enable_resources = settings.get("enable_resources", False) - enable_prompts = settings.get("enable_prompts", False) - - # 解析所有服务器配置 - enabled_configs: List[MCPServerConfig] = [] - for idx, server_conf in enumerate(servers_config): - server_name = server_conf.get("name", f"unknown_{idx}") - - if not server_conf.get("enabled", True): - logger.info(f"服务器 {server_name} 已禁用,跳过") - continue - - try: - config = self._parse_server_config(server_conf) - enabled_configs.append(config) - except Exception as e: - logger.error(f"解析服务器 {server_name} 配置失败: {e}") - - if not enabled_configs: - logger.warning("没有已启用的 MCP 服务器") - self._initialized = True - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - return - - logger.info(f"准备并行连接 {len(enabled_configs)} 个 MCP 服务器") - - # v1.5.0: 并行连接所有服务器 - async def connect_single_server(config: MCPServerConfig) -> Tuple[MCPServerConfig, bool]: - """连接单个服务器""" - logger.info(f"正在连接服务器: {config.name} ({config.transport.value})") - try: - success = await mcp_manager.add_server(config) - if success: - logger.info(f"✅ 服务器 {config.name} 连接成功") - # 获取资源和提示模板 - if enable_resources: - try: - await mcp_manager.fetch_resources_for_server(config.name) - except Exception as e: - logger.warning(f"服务器 {config.name} 获取资源列表失败: {e}") - if enable_prompts: - try: - await mcp_manager.fetch_prompts_for_server(config.name) - except Exception as e: - logger.warning(f"服务器 {config.name} 获取提示模板列表失败: {e}") - else: - logger.warning(f"❌ 服务器 {config.name} 连接失败") - return config, success - except Exception as e: - logger.error(f"❌ 服务器 {config.name} 连接异常: {e}") - return config, False - - # 并行执行所有连接 - start_time = time.time() - results = await asyncio.gather(*[connect_single_server(cfg) for cfg in enabled_configs], return_exceptions=True) - connect_duration = time.time() - start_time - - # 统计连接结果 - success_count = 0 - failed_count = 0 - for result in results: - if isinstance(result, Exception): - failed_count += 1 - logger.error(f"连接任务异常: {result}") - elif isinstance(result, tuple): - _, success = result - if success: - success_count += 1 - else: - failed_count += 1 - - logger.info(f"并行连接完成: {success_count} 成功, {failed_count} 失败, 耗时 {connect_duration:.2f}s") - - # 注册所有工具 - from src.plugin_system.core.component_registry import component_registry - - registered_count = 0 - - for tool_key, (tool_info, _) in mcp_manager.all_tools.items(): - tool_name = tool_key.replace("-", "_").replace(".", "_") - is_disabled = tool_name in disabled_tools - - info, tool_class = mcp_tool_registry.register_tool(tool_key, tool_info, tool_prefix, disabled=is_disabled) - info.plugin_name = self.plugin_name - - if component_registry.register_component(info, tool_class): - registered_count += 1 - status = "🚫" if is_disabled else "✅" - logger.info(f"{status} 注册 MCP 工具: {tool_class.name}") - else: - logger.warning(f"❌ 注册 MCP 工具失败: {tool_class.name}") - - chains_config = self.config.get("tool_chains", {}) - chains_enabled = bool(chains_config.get("chains_enabled", True)) if isinstance(chains_config, dict) else True - chain_count = len(tool_chain_manager.get_enabled_chains()) if chains_enabled else 0 - - # v1.9.0: 注册 MCP 工具到记忆检索 ReAct 系统(软流程) - react_count = 0 - react_config = self.config.get("react", {}) - if react_config.get("react_enabled", False): - react_count = self._register_tools_to_react() - - self._initialized = True - logger.info( - f"MCP 桥接插件初始化完成,已注册 {registered_count} 个工具,{chain_count} 个工具链,{react_count} 个 ReAct 工具" - ) - - # 更新状态显示 - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._start_status_refresher() - self._persist_runtime_displays() - - def _start_status_refresher(self) -> None: - """启动 WebUI 状态刷新任务(不写入磁盘)""" - task = getattr(self, "_status_refresh_task", None) - if task and not task.done(): - return - - self._status_refresh_running = True - self._status_refresh_task = asyncio.create_task(self._status_refresh_loop()) - - async def _stop_status_refresher(self) -> None: - """停止 WebUI 状态刷新任务""" - self._status_refresh_running = False - task = getattr(self, "_status_refresh_task", None) - if task: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - self._status_refresh_task = None - - async def _status_refresh_loop(self) -> None: - """定期刷新 WebUI 展示字段(状态/工具列表/工具链状态)""" - while getattr(self, "_status_refresh_running", False): - try: - settings = self.config.get("settings", {}) - enabled = bool(settings.get("status_refresh_enabled", True)) - interval = float(settings.get("status_refresh_interval", 10.0) or 10.0) - interval = max(5.0, min(interval, 60.0)) - - if enabled and self._initialized: - self._update_status_display() - self._update_tool_list_display() - self._update_chains_status_display() - self._persist_runtime_displays() - - await asyncio.sleep(interval if enabled else 5.0) - except asyncio.CancelledError: - break - except Exception as e: - logger.debug(f"状态刷新任务异常: {e}") - await asyncio.sleep(5.0) - - def _load_mcp_servers_config(self) -> List[Dict[str, Any]]: - """v2.0: 从 Claude mcpServers JSON 加载服务器配置。 - - - 唯一主入口:config.servers.claude_config_json - - 兼容:若旧版 servers.list 存在且 claude_config_json 为空,会自动迁移并写回内存配置 - """ - servers_section = self.config.get("servers", {}) - if not isinstance(servers_section, dict): - servers_section = {} - - claude_json = str(servers_section.get("claude_config_json", "") or "") - - if not claude_json.strip(): - legacy_list = str(servers_section.get("list", "") or "") - migrated = legacy_servers_list_to_claude_config(legacy_list) - if migrated: - claude_json = migrated - if "servers" not in self.config: - self.config["servers"] = {} - self.config["servers"]["claude_config_json"] = migrated - logger.info("检测到旧版 servers.list,已自动迁移为 Claude mcpServers(请在 WebUI 保存一次以固化)") - - if not claude_json.strip(): - self._last_servers_config_error = ( - "未配置任何 MCP 服务器(请在 WebUI 的「MCP Servers(Claude)」粘贴 mcpServers JSON)" - ) - return [] - - try: - servers = parse_claude_mcp_config(claude_json) - except ClaudeConfigError as e: - self._last_servers_config_error = str(e) - logger.error(f"Claude mcpServers 配置解析失败: {e}") - return [] - except Exception as e: - self._last_servers_config_error = str(e) - logger.error(f"Claude mcpServers 配置解析异常: {e}") - return [] - - self._last_servers_config_error = "" - - # 保留未知字段(如 post_process)供旧功能使用 - raw_mapping: Dict[str, Any] = {} - try: - parsed = json.loads(claude_json) - mapping = parsed.get("mcpServers", parsed) - if isinstance(mapping, dict): - raw_mapping = mapping - except Exception: - raw_mapping = {} - - configs: List[Dict[str, Any]] = [] - for srv in servers: - raw = raw_mapping.get(srv.name, {}) - cfg: Dict[str, Any] = raw.copy() if isinstance(raw, dict) else {} - cfg.update( - { - "name": srv.name, - "enabled": srv.enabled, - "transport": srv.transport, - "command": srv.command, - "args": srv.args, - "env": srv.env, - "url": srv.url, - "headers": srv.headers, - } - ) - configs.append(cfg) - - return configs - - def _parse_server_config(self, conf: Dict) -> MCPServerConfig: - """解析服务器配置字典""" - transport_str = conf.get("transport", "stdio").lower() - - transport_map = { - "stdio": TransportType.STDIO, - "sse": TransportType.SSE, - "http": TransportType.HTTP, - "streamable_http": TransportType.STREAMABLE_HTTP, - } - transport = transport_map.get(transport_str, TransportType.STDIO) - - return MCPServerConfig( - name=conf.get("name", "unnamed"), - enabled=conf.get("enabled", True), - transport=transport, - command=conf.get("command", ""), - args=conf.get("args", []), - env=conf.get("env", {}), - url=conf.get("url", ""), - headers=conf.get("headers", {}), # v1.4.2: 鉴权头支持 - ) - - def _update_tool_list_display(self) -> None: - """v1.4.0: 更新工具列表显示""" - tools = mcp_manager.all_tools - disabled_tools = self._get_disabled_tools() - - lines = [] - by_server: Dict[str, List[str]] = {} - - for tool_key, (tool_info, _) in tools.items(): - tool_name = tool_key.replace("-", "_").replace(".", "_") - if tool_info.server_name not in by_server: - by_server[tool_info.server_name] = [] - - is_disabled = tool_name in disabled_tools - status = " ❌" if is_disabled else "" - by_server[tool_info.server_name].append(f" • {tool_name}{status}") - - for srv_name, tool_list in by_server.items(): - lines.append(f"📦 {srv_name} ({len(tool_list)}个工具):") - lines.extend(tool_list) - lines.append("") - - if not by_server: - lines.append("(无已注册工具)") - - tool_list_text = "\n".join(lines) - - # 更新内存配置 - if "tools" not in self.config: - self.config["tools"] = {} - self.config["tools"]["tool_list"] = tool_list_text - - def _update_status_display(self) -> None: - """更新配置文件中的状态显示字段""" - status = mcp_manager.get_status() - settings = self.config.get("settings", {}) - lines = [] - - cfg_err = str(getattr(self, "_last_servers_config_error", "") or "").strip() - if cfg_err: - lines.append(f"⚠️ 配置: {cfg_err}") - lines.append("") - - lines.append(f"服务器: {status['connected_servers']}/{status['total_servers']} 已连接") - lines.append(f"工具数: {status['total_tools']}") - if settings.get("enable_resources", False): - lines.append(f"资源数: {status.get('total_resources', 0)}") - if settings.get("enable_prompts", False): - lines.append(f"模板数: {status.get('total_prompts', 0)}") - lines.append(f"心跳: {'运行中' if status['heartbeat_running'] else '已停止'}") - lines.append("") - - tools = mcp_manager.all_tools - - for name, info in status.get("servers", {}).items(): - icon = "✅" if info["connected"] else "❌" - lines.append(f"{icon} {name} ({info['transport']})") - - # v1.7.0: 显示断路器状态 - cb_status = info.get("circuit_breaker", {}) - cb_state = cb_status.get("state", "closed") - if cb_state == "open": - lines.append(" ⚡ 断路器: 熔断中") - elif cb_state == "half_open": - lines.append(" ⚡ 断路器: 试探中") - - server_tools = [t.name for key, (t, _) in tools.items() if t.server_name == name] - if server_tools: - for tool_name in server_tools: - lines.append(f" • {tool_name}") - else: - lines.append(" (无工具)") - - if not status.get("servers"): - lines.append("(无服务器)") - - status_text = "\n".join(lines) - - if "status" not in self.config: - self.config["status"] = {} - self.config["status"]["connection_status"] = status_text - - def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: - """返回插件的所有组件""" - components: List[Tuple[ComponentInfo, Type]] = [] - - # 事件处理器 - components.append((MCPStartupHandler.get_handler_info(), MCPStartupHandler)) - components.append((MCPStopHandler.get_handler_info(), MCPStopHandler)) - - # 命令 - components.append((MCPStatusCommand.get_command_info(), MCPStatusCommand)) - components.append((MCPImportCommand.get_command_info(), MCPImportCommand)) - - # 内置工具 - status_tool_info = ToolInfo( - name=MCPStatusTool.name, - tool_description=MCPStatusTool.description, - enabled=True, - tool_parameters=MCPStatusTool.parameters, - component_type=ComponentType.TOOL, - ) - components.append((status_tool_info, MCPStatusTool)) - - settings = self.config.get("settings", {}) - - if settings.get("enable_resources", False): - read_resource_info = ToolInfo( - name=MCPReadResourceTool.name, - tool_description=MCPReadResourceTool.description, - enabled=True, - tool_parameters=MCPReadResourceTool.parameters, - component_type=ComponentType.TOOL, - ) - components.append((read_resource_info, MCPReadResourceTool)) - - if settings.get("enable_prompts", False): - get_prompt_info = ToolInfo( - name=MCPGetPromptTool.name, - tool_description=MCPGetPromptTool.description, - enabled=True, - tool_parameters=MCPGetPromptTool.parameters, - component_type=ComponentType.TOOL, - ) - components.append((get_prompt_info, MCPGetPromptTool)) - - return components - - def get_status(self) -> Dict[str, Any]: - """获取插件状态""" - return { - "initialized": self._initialized, - "mcp_manager": mcp_manager.get_status(), - "registered_tools": len(mcp_tool_registry._tool_classes), - "trace_records": tool_call_tracer.total_records, - "cache_stats": tool_call_cache.get_stats(), - } - - def get_stats(self) -> Dict[str, Any]: - """获取详细统计信息""" - return mcp_manager.get_all_stats() diff --git a/plugins/MaiBot_MCPBridgePlugin/requirements.txt b/plugins/MaiBot_MCPBridgePlugin/requirements.txt deleted file mode 100644 index 7580f09e..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -# MCP 桥接插件依赖 -mcp>=1.0.0 diff --git a/plugins/MaiBot_MCPBridgePlugin/tool_chain.py b/plugins/MaiBot_MCPBridgePlugin/tool_chain.py deleted file mode 100644 index 6a1530cc..00000000 --- a/plugins/MaiBot_MCPBridgePlugin/tool_chain.py +++ /dev/null @@ -1,584 +0,0 @@ -""" -MCP Workflow 模块 v1.9.0 -支持用户自定义工作流(硬流程),将多个 MCP 工具按顺序执行 - -双轨制架构: -- 软流程 (ReAct): LLM 自主决策,动态多轮调用工具,灵活但不可预测 -- 硬流程 (Workflow): 用户预定义的工作流,固定流程,可靠可控 - -功能: -- Workflow 定义和管理 -- 顺序执行多个工具(硬流程) -- 支持变量替换(使用前序工具的输出) -- 自动注册为组合工具供 LLM 调用 -- 与 ReAct 软流程互补,用户可选择合适的执行方式 -""" - -import json -import re -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -try: - from src.common.logger import get_logger - - logger = get_logger("mcp_tool_chain") -except ImportError: - import logging - - logger = logging.getLogger("mcp_tool_chain") - - -@dataclass -class ToolChainStep: - """工具链步骤""" - - tool_name: str # 要调用的工具名(如 mcp_server_tool) - args_template: Dict[str, Any] = field(default_factory=dict) # 参数模板,支持变量替换 - output_key: str = "" # 输出存储的键名,供后续步骤引用 - description: str = "" # 步骤描述 - optional: bool = False # 是否可选(失败时继续执行) - - def to_dict(self) -> Dict[str, Any]: - return { - "tool_name": self.tool_name, - "args_template": self.args_template, - "output_key": self.output_key, - "description": self.description, - "optional": self.optional, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ToolChainStep": - return cls( - tool_name=data.get("tool_name", ""), - args_template=data.get("args_template", {}), - output_key=data.get("output_key", ""), - description=data.get("description", ""), - optional=data.get("optional", False), - ) - - -@dataclass -class ToolChainDefinition: - """工具链定义""" - - name: str # 工具链名称(将作为组合工具的名称) - description: str # 工具链描述(供 LLM 理解) - steps: List[ToolChainStep] = field(default_factory=list) # 执行步骤 - input_params: Dict[str, str] = field(default_factory=dict) # 输入参数定义 {参数名: 描述} - enabled: bool = True # 是否启用 - - def to_dict(self) -> Dict[str, Any]: - return { - "name": self.name, - "description": self.description, - "steps": [step.to_dict() for step in self.steps], - "input_params": self.input_params, - "enabled": self.enabled, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ToolChainDefinition": - steps = [ToolChainStep.from_dict(s) for s in data.get("steps", [])] - return cls( - name=data.get("name", ""), - description=data.get("description", ""), - steps=steps, - input_params=data.get("input_params", {}), - enabled=data.get("enabled", True), - ) - - -@dataclass -class ChainExecutionResult: - """工具链执行结果""" - - success: bool - final_output: str # 最终输出(最后一个步骤的结果) - step_results: List[Dict[str, Any]] = field(default_factory=list) # 每个步骤的结果 - error: str = "" - total_duration_ms: float = 0.0 - - def to_summary(self) -> str: - """生成执行摘要""" - lines = [] - for i, step in enumerate(self.step_results): - status = "✅" if step.get("success") else "❌" - tool = step.get("tool_name", "unknown") - duration = step.get("duration_ms", 0) - lines.append(f"{status} 步骤{i + 1}: {tool} ({duration:.0f}ms)") - if not step.get("success") and step.get("error"): - lines.append(f" 错误: {step['error'][:50]}") - return "\n".join(lines) - - -class ToolChainExecutor: - """工具链执行器""" - - # 变量替换模式: ${step.output_key} 或 ${input.param_name} 或 ${prev} - VAR_PATTERN = re.compile(r"\$\{([^}]+)\}") - - def __init__(self, mcp_manager): - self._mcp_manager = mcp_manager - - def _resolve_tool_key(self, tool_name: str) -> Optional[str]: - """解析工具名,返回有效的 tool_key - - 支持: - - 直接使用 tool_key(如 mcp_server_tool) - - 使用注册后的工具名(会自动转换 - 和 . 为 _) - """ - all_tools = self._mcp_manager.all_tools - - # 直接匹配 - if tool_name in all_tools: - return tool_name - - # 尝试转换后匹配(用户可能使用了注册后的名称) - normalized = tool_name.replace("-", "_").replace(".", "_") - if normalized in all_tools: - return normalized - - # 尝试查找包含该名称的工具 - for key in all_tools.keys(): - if key.endswith(f"_{tool_name}") or key.endswith(f"_{normalized}"): - return key - - return None - - async def execute( - self, - chain: ToolChainDefinition, - input_args: Dict[str, Any], - ) -> ChainExecutionResult: - """执行工具链 - - Args: - chain: 工具链定义 - input_args: 用户输入的参数 - - Returns: - ChainExecutionResult: 执行结果 - """ - start_time = time.time() - step_results = [] - context = { - "input": input_args or {}, # 用户输入,确保不为 None - "step": {}, # 各步骤输出,按 output_key 存储 - "prev": "", # 上一步的输出 - } - - final_output = "" - - # 验证必需的输入参数 - missing_params = [] - for param_name in chain.input_params.keys(): - if param_name not in context["input"]: - missing_params.append(param_name) - - if missing_params: - return ChainExecutionResult( - success=False, - final_output="", - error=f"缺少必需参数: {', '.join(missing_params)}", - total_duration_ms=(time.time() - start_time) * 1000, - ) - - for i, step in enumerate(chain.steps): - step_start = time.time() - step_result = { - "step_index": i, - "tool_name": step.tool_name, - "success": False, - "output": "", - "error": "", - "duration_ms": 0, - } - - try: - # 替换参数中的变量 - resolved_args = self._resolve_args(step.args_template, context) - step_result["resolved_args"] = resolved_args - - # 解析工具名 - tool_key = self._resolve_tool_key(step.tool_name) - if not tool_key: - step_result["error"] = f"工具 {step.tool_name} 不存在" - logger.warning(f"工具链步骤 {i + 1}: 工具 {step.tool_name} 不存在") - - if not step.optional: - step_results.append(step_result) - return ChainExecutionResult( - success=False, - final_output="", - step_results=step_results, - error=f"步骤 {i + 1}: 工具 {step.tool_name} 不存在", - total_duration_ms=(time.time() - start_time) * 1000, - ) - step_results.append(step_result) - continue - - logger.debug(f"工具链步骤 {i + 1}: 调用 {tool_key},参数: {resolved_args}") - - # 调用工具 - result = await self._mcp_manager.call_tool(tool_key, resolved_args) - - step_duration = (time.time() - step_start) * 1000 - step_result["duration_ms"] = step_duration - - if result.success: - step_result["success"] = True - # 确保 content 不为 None - content = result.content if result.content is not None else "" - step_result["output"] = content - - # 更新上下文 - context["prev"] = content - if step.output_key: - context["step"][step.output_key] = content - - final_output = content - content_preview = content[:100] if content else "(空)" - logger.debug(f"工具链步骤 {i + 1} 成功: {content_preview}...") - else: - step_result["error"] = result.error or "未知错误" - logger.warning(f"工具链步骤 {i + 1} 失败: {result.error}") - - if not step.optional: - step_results.append(step_result) - return ChainExecutionResult( - success=False, - final_output="", - step_results=step_results, - error=f"步骤 {i + 1} ({step.tool_name}) 失败: {result.error}", - total_duration_ms=(time.time() - start_time) * 1000, - ) - - except Exception as e: - step_duration = (time.time() - step_start) * 1000 - step_result["duration_ms"] = step_duration - step_result["error"] = str(e) - logger.error(f"工具链步骤 {i + 1} 异常: {e}") - - if not step.optional: - step_results.append(step_result) - return ChainExecutionResult( - success=False, - final_output="", - step_results=step_results, - error=f"步骤 {i + 1} ({step.tool_name}) 异常: {e}", - total_duration_ms=(time.time() - start_time) * 1000, - ) - - step_results.append(step_result) - - total_duration = (time.time() - start_time) * 1000 - - return ChainExecutionResult( - success=True, - final_output=final_output, - step_results=step_results, - total_duration_ms=total_duration, - ) - - def _resolve_args(self, args_template: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: - """解析参数模板,替换变量 - - 支持的变量格式: - - ${input.param_name}: 用户输入的参数 - - ${step.output_key}: 某个步骤的输出 - - ${prev}: 上一步的输出 - - ${prev.field}: 上一步输出(JSON)的某个字段 - """ - resolved = {} - - for key, value in args_template.items(): - if isinstance(value, str): - resolved[key] = self._substitute_vars(value, context) - elif isinstance(value, dict): - resolved[key] = self._resolve_args(value, context) - elif isinstance(value, list): - resolved[key] = [self._substitute_vars(v, context) if isinstance(v, str) else v for v in value] - else: - resolved[key] = value - - return resolved - - def _substitute_vars(self, template: str, context: Dict[str, Any]) -> str: - """替换字符串中的变量""" - - def replacer(match): - var_path = match.group(1) - return self._get_var_value(var_path, context) - - return self.VAR_PATTERN.sub(replacer, template) - - def _get_var_value(self, var_path: str, context: Dict[str, Any]) -> str: - """获取变量值 - - Args: - var_path: 变量路径,如 "input.query", "step.search_result", "prev", "prev.id" - context: 上下文 - """ - parts = self._parse_var_path(var_path) - - if not parts: - return "" - - # 获取根对象 - root = parts[0] - if root not in context: - logger.warning(f"变量 {var_path} 的根 '{root}' 不存在") - return "" - - value = context[root] - - # 遍历路径 - for part in parts[1:]: - if isinstance(value, str): - parsed = self._try_parse_json(value) - if parsed is not None: - value = parsed - - if isinstance(value, dict): - value = value.get(part, "") - elif isinstance(value, list): - if part.isdigit(): - idx = int(part) - value = value[idx] if 0 <= idx < len(value) else "" - else: - value = "" - else: - value = "" - - # 确保返回字符串 - if isinstance(value, (dict, list)): - return json.dumps(value, ensure_ascii=False) - if value is None: - return "" - if value == "": - return "" - return str(value) - - def _try_parse_json(self, value: str) -> Optional[Any]: - """尝试将字符串解析为 JSON 对象,失败则返回 None。""" - if not value: - return None - try: - return json.loads(value) - except json.JSONDecodeError: - return None - - def _parse_var_path(self, var_path: str) -> List[str]: - """解析变量路径,支持点号与下标写法。 - - 支持: - - step.geo.return.0.location - - step.geo.return[0].location - - step.geo['return'][0]['location'] - """ - if not var_path: - return [] - - tokens: List[str] = [] - buf: List[str] = [] - in_bracket = False - in_quote = False - quote_char = "" - - def flush_buf() -> None: - if buf: - token = "".join(buf).strip() - if token: - tokens.append(token) - buf.clear() - - i = 0 - while i < len(var_path): - ch = var_path[i] - - if not in_bracket and ch == ".": - flush_buf() - i += 1 - continue - - if not in_bracket and ch == "[": - flush_buf() - in_bracket = True - in_quote = False - quote_char = "" - i += 1 - continue - - if in_bracket and not in_quote and ch == "]": - flush_buf() - in_bracket = False - i += 1 - continue - - if in_bracket and ch in ("'", '"'): - if not in_quote: - in_quote = True - quote_char = ch - i += 1 - continue - if quote_char == ch: - in_quote = False - quote_char = "" - i += 1 - continue - - if in_bracket and not in_quote: - if ch.isspace(): - i += 1 - continue - if ch == ",": - i += 1 - continue - - buf.append(ch) - i += 1 - - flush_buf() - - if in_bracket or in_quote: - return [p for p in var_path.split(".") if p] - - return tokens - - -class ToolChainManager: - """工具链管理器""" - - _instance: Optional["ToolChainManager"] = None - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - self._initialized = True - self._chains: Dict[str, ToolChainDefinition] = {} - self._executor: Optional[ToolChainExecutor] = None - - def set_executor(self, mcp_manager) -> None: - """设置执行器""" - self._executor = ToolChainExecutor(mcp_manager) - - def add_chain(self, chain: ToolChainDefinition) -> bool: - """添加工具链""" - if not chain.name: - logger.error("工具链名称不能为空") - return False - - if chain.name in self._chains: - logger.warning(f"工具链 {chain.name} 已存在,将被覆盖") - - self._chains[chain.name] = chain - logger.info(f"已添加工具链: {chain.name} ({len(chain.steps)} 个步骤)") - return True - - def remove_chain(self, name: str) -> bool: - """移除工具链""" - if name in self._chains: - del self._chains[name] - logger.info(f"已移除工具链: {name}") - return True - return False - - def get_chain(self, name: str) -> Optional[ToolChainDefinition]: - """获取工具链""" - return self._chains.get(name) - - def get_all_chains(self) -> Dict[str, ToolChainDefinition]: - """获取所有工具链""" - return self._chains.copy() - - def get_enabled_chains(self) -> Dict[str, ToolChainDefinition]: - """获取所有启用的工具链""" - return {name: chain for name, chain in self._chains.items() if chain.enabled} - - async def execute_chain( - self, - chain_name: str, - input_args: Dict[str, Any], - ) -> ChainExecutionResult: - """执行工具链""" - chain = self._chains.get(chain_name) - if not chain: - return ChainExecutionResult( - success=False, - final_output="", - error=f"工具链 {chain_name} 不存在", - ) - - if not chain.enabled: - return ChainExecutionResult( - success=False, - final_output="", - error=f"工具链 {chain_name} 已禁用", - ) - - if not self._executor: - return ChainExecutionResult( - success=False, - final_output="", - error="工具链执行器未初始化", - ) - - return await self._executor.execute(chain, input_args) - - def load_from_json(self, json_str: str) -> Tuple[int, List[str]]: - """从 JSON 字符串加载工具链配置 - - Returns: - (成功加载数量, 错误列表) - """ - errors = [] - loaded = 0 - - try: - data = json.loads(json_str) if json_str.strip() else [] - except json.JSONDecodeError as e: - return 0, [f"JSON 解析失败: {e}"] - - if not isinstance(data, list): - data = [data] - - for i, item in enumerate(data): - try: - chain = ToolChainDefinition.from_dict(item) - if not chain.name: - errors.append(f"第 {i + 1} 个工具链缺少名称") - continue - if not chain.steps: - errors.append(f"工具链 {chain.name} 没有步骤") - continue - - self.add_chain(chain) - loaded += 1 - except Exception as e: - errors.append(f"第 {i + 1} 个工具链解析失败: {e}") - - return loaded, errors - - def export_to_json(self, pretty: bool = True) -> str: - """导出所有工具链为 JSON""" - chains_data = [chain.to_dict() for chain in self._chains.values()] - if pretty: - return json.dumps(chains_data, ensure_ascii=False, indent=2) - return json.dumps(chains_data, ensure_ascii=False) - - def clear(self) -> None: - """清空所有工具链""" - self._chains.clear() - - -# 全局工具链管理器实例 -tool_chain_manager = ToolChainManager() diff --git a/prompts/zh-CN/action.prompt b/prompts/zh-CN/action.prompt deleted file mode 100644 index 91831b2a..00000000 --- a/prompts/zh-CN/action.prompt +++ /dev/null @@ -1,5 +0,0 @@ -{action_name} -动作描述:{action_description} -使用条件{parallel_text}: -{action_require} -{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)"}} \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_group1.prompt b/prompts/zh-CN/chat_target_group1.prompt deleted file mode 100644 index 77e89bcc..00000000 --- a/prompts/zh-CN/chat_target_group1.prompt +++ /dev/null @@ -1 +0,0 @@ -你正在qq群里聊天,下面是群里正在聊的内容: \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_group2.prompt b/prompts/zh-CN/chat_target_group2.prompt deleted file mode 100644 index 5b71bace..00000000 --- a/prompts/zh-CN/chat_target_group2.prompt +++ /dev/null @@ -1 +0,0 @@ -正在群里聊天 \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_private1.prompt b/prompts/zh-CN/chat_target_private1.prompt deleted file mode 100644 index 3e86c71f..00000000 --- a/prompts/zh-CN/chat_target_private1.prompt +++ /dev/null @@ -1 +0,0 @@ -你正在和{sender_name}聊天,这是你们之前聊的内容: \ No newline at end of file diff --git a/prompts/zh-CN/chat_target_private2.prompt b/prompts/zh-CN/chat_target_private2.prompt deleted file mode 100644 index 9225ec82..00000000 --- a/prompts/zh-CN/chat_target_private2.prompt +++ /dev/null @@ -1 +0,0 @@ -和{sender_name}聊天 \ No newline at end of file diff --git a/prompts/zh-CN/lpmm_get_knowledge.prompt b/prompts/zh-CN/lpmm_get_knowledge.prompt deleted file mode 100644 index 2ade0d0f..00000000 --- a/prompts/zh-CN/lpmm_get_knowledge.prompt +++ /dev/null @@ -1,10 +0,0 @@ -你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。 -群里正在进行的聊天内容: -{chat_history} - -现在,{sender}发送了内容:{target_message},你想要回复ta。 -请仔细分析聊天内容,考虑以下几点: -1. 内容中是否包含需要查询信息的问题 -2. 是否有明确的知识获取指令 - -If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed". \ No newline at end of file diff --git a/prompts/zh-CN/maidairy_replyer.prompt b/prompts/zh-CN/maidairy_replyer.prompt index 2884afd9..9e13f45b 100644 --- a/prompts/zh-CN/maidairy_replyer.prompt +++ b/prompts/zh-CN/maidairy_replyer.prompt @@ -1,14 +1,10 @@ -你的任务是根据内部想法生成一条对用户可见的自然回复。 +你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片 +其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分: -【参考信息】 -{bot_name}的人设:{identity} -回复风格要求:{reply_style} -【参考信息结束】 +{time_block} -你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复, -尽量简短一些。 -没必要刻意友好回复,符合你的人格就行。没必要刻意友好回复,符合你的人格就行。没必要刻意友好回复,符合你的人格就行。 -请注意把握聊天内容,不要回复的太有条理。 -你的风格平淡但不失讽刺,不过分兴奋,很简短。可以参考贴吧,知乎和微博的回复风格。很平淡和白话,不浮夸不长篇大论,b站评论风格,但一定注意不要过分修辞和复杂句。 -请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。 -最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。 +{identity} +你正在群里聊天,现在请你读读之前的聊天记录,把握当前的话题,然后给出日常且口语化的回复, +尽量简短一些。最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。请注意把握聊天内容。 +{reply_style} +请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。 \ No newline at end of file diff --git a/prompts/zh-CN/private_replyer_self.prompt b/prompts/zh-CN/private_replyer_self.prompt deleted file mode 100644 index f58136ef..00000000 --- a/prompts/zh-CN/private_replyer_self.prompt +++ /dev/null @@ -1,14 +0,0 @@ -{knowledge_prompt}{tool_info_block}{extra_info_block} -{expression_habits_block}{memory_retrieval}{jargon_explanation} - -你正在和{sender_name}聊天,这是你们之前聊的内容: -{time_block} -{dialogue_prompt} - -你现在想补充说明你刚刚自己的发言内容:{target},原因是{reason} -请你根据聊天内容,组织一条新回复。注意,{target} 是刚刚你自己的发言,你要在这基础上进一步发言,请按照你自己的角度来继续进行回复。注意保持上下文的连贯性。 -{identity} -{chat_prompt}尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。 -{reply_style} -请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 -{moderation_prompt}不要输出多余内容(包括冒号和引号,括号,表情包,at或 @等 )。 \ No newline at end of file diff --git a/prompts/zh-CN/replyer_light.prompt b/prompts/zh-CN/replyer_light.prompt deleted file mode 100644 index 8e3a425a..00000000 --- a/prompts/zh-CN/replyer_light.prompt +++ /dev/null @@ -1,18 +0,0 @@ -{knowledge_prompt}{tool_info_block}{extra_info_block} -{expression_habits_block}{memory_retrieval}{jargon_explanation} - -你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片 -其中标注 {bot_name}(你) 的发言是你自己的发言,请注意区分: -{time_block} -{dialogue_prompt} - -{reply_target_block}。 -{planner_reasoning} -{identity} -{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录,然后给出日常且口语化的回复, -尽量简短一些。{keywords_reaction_prompt} -请注意把握聊天内容,不要回复的太有条理。 -{reply_style} -请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。 -最好一次对一个话题进行回复,免得啰嗦或者回复内容太乱。 -现在,你说: \ No newline at end of file diff --git a/prompts/zh-CN/tool_executor.prompt b/prompts/zh-CN/tool_executor.prompt deleted file mode 100644 index 23f2b043..00000000 --- a/prompts/zh-CN/tool_executor.prompt +++ /dev/null @@ -1,11 +0,0 @@ -你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。 -群里正在进行的聊天内容: -{chat_history} - -现在,{sender}发送了内容:{target_message},你想要回复ta。 -请仔细分析聊天内容,考虑以下几点: -1. 内容中是否包含需要查询信息的问题 -2. 是否有明确的工具使用指令 -你可以选择多个动作 - -If you need to use tools, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed". \ No newline at end of file diff --git a/src/chat/replyer/maisaka_generator.py b/src/chat/replyer/maisaka_generator.py index 8b35b00b..ed0b5fc2 100644 --- a/src/chat/replyer/maisaka_generator.py +++ b/src/chat/replyer/maisaka_generator.py @@ -8,7 +8,6 @@ import time from sqlmodel import select from src.chat.message_receive.chat_manager import BotChatSession -from src.chat.message_receive.message import SessionMessage from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.common.data_models.reply_generation_data_models import ( @@ -22,15 +21,11 @@ from src.config.config import global_config from src.core.types import ActionInfo from src.services.llm_service import LLMServiceClient -from src.maisaka.message_adapter import ( - get_message_kind, - get_message_role, - get_message_source, - get_message_text, - parse_speaker_content, -) +from src.chat.message_receive.message import SessionMessage +from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage +from src.maisaka.message_adapter import parse_speaker_content -logger = get_logger("maisaka_replyer") +logger = get_logger("replyer") @dataclass @@ -96,16 +91,16 @@ class MaisakaReplyGenerator: return normalized @staticmethod - def _format_message_time(message: SessionMessage) -> str: + def _format_message_time(message: LLMContextMessage) -> str: return message.timestamp.strftime("%H:%M:%S") @staticmethod - def _extract_visible_assistant_reply(message: SessionMessage) -> str: + def _extract_visible_assistant_reply(message: AssistantMessage) -> str: del message return "" - def _extract_guided_bot_reply(self, message: SessionMessage) -> str: - speaker_name, body = parse_speaker_content(get_message_text(message).strip()) + def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str: + speaker_name, body = parse_speaker_content(message.processed_plain_text.strip()) bot_nickname = global_config.bot.nickname.strip() or "Bot" if speaker_name == bot_nickname: return self._normalize_content(body.strip()) @@ -134,25 +129,24 @@ class MaisakaReplyGenerator: return segments - def _format_chat_history(self, messages: List[SessionMessage]) -> str: + def _format_chat_history(self, messages: List[LLMContextMessage]) -> str: """格式化 replyer 使用的可见聊天记录。""" bot_nickname = global_config.bot.nickname.strip() or "Bot" parts: List[str] = [] for message in messages: - role = get_message_role(message) timestamp = self._format_message_time(message) - if get_message_source(message) == "user_reference": + if isinstance(message, (ReferenceMessage, ToolResultMessage)): continue - if role == "user": + if isinstance(message, SessionBackedMessage): guided_reply = self._extract_guided_bot_reply(message) if guided_reply: parts.append(f"{timestamp} {bot_nickname}(you): {guided_reply}") continue - raw_content = get_message_text(message) + raw_content = message.processed_plain_text for speaker_name, content_body in self._split_user_message_segments(raw_content): content = self._normalize_content(content_body) if not content: @@ -161,7 +155,7 @@ class MaisakaReplyGenerator: parts.append(f"{timestamp} {visible_speaker}: {content}") continue - if role == "assistant": + if isinstance(message, AssistantMessage): visible_reply = self._extract_visible_assistant_reply(message) if visible_reply: parts.append(f"{timestamp} {bot_nickname}(you): {visible_reply}") @@ -170,7 +164,7 @@ class MaisakaReplyGenerator: def _build_prompt( self, - chat_history: List[SessionMessage], + chat_history: List[LLMContextMessage], reply_reason: str, expression_habits: str = "", ) -> str: @@ -182,6 +176,7 @@ class MaisakaReplyGenerator: system_prompt = load_prompt( "maidairy_replyer", bot_name=global_config.bot.nickname, + time_block=f"当前时间:{current_time}", identity=self._personality_prompt, reply_style=global_config.personality.reply_style, ) @@ -214,7 +209,7 @@ class MaisakaReplyGenerator: async def _build_reply_context( self, - chat_history: List[SessionMessage], + chat_history: List[LLMContextMessage], reply_message: Optional[SessionMessage], reply_reason: str, stream_id: Optional[str], @@ -239,7 +234,7 @@ class MaisakaReplyGenerator: def _build_expression_habits( self, session_id: str, - chat_history: List[SessionMessage], + chat_history: List[LLMContextMessage], reply_message: Optional[SessionMessage], reply_reason: str, ) -> tuple[str, List[int]]: @@ -301,7 +296,7 @@ class MaisakaReplyGenerator: think_level: int = 1, unknown_words: Optional[List[str]] = None, log_reply: bool = True, - chat_history: Optional[List[SessionMessage]] = None, + chat_history: Optional[List[LLMContextMessage]] = None, expression_habits: str = "", selected_expression_ids: Optional[List[int]] = None, ) -> Tuple[bool, ReplyGenerationResult]: @@ -330,9 +325,7 @@ class MaisakaReplyGenerator: filtered_history = [ message for message in chat_history - if get_message_role(message) != "system" - and get_message_kind(message) != "perception" - and get_message_source(message) != "user_reference" + if not isinstance(message, (ReferenceMessage, ToolResultMessage)) ] logger.debug(f"Maisaka replyer: filtered_history size={len(filtered_history)}") diff --git a/src/cli/maisaka_cli.py b/src/cli/maisaka_cli.py index ad4d5c9a..f7c2d792 100644 --- a/src/cli/maisaka_cli.py +++ b/src/cli/maisaka_cli.py @@ -23,7 +23,13 @@ from src.config.config import config_manager, global_config from src.mcp_module import MCPManager from src.maisaka.chat_loop_service import MaisakaChatLoopService -from src.maisaka.message_adapter import build_message, format_speaker_content, remove_last_perception +from src.maisaka.context_messages import ( + AssistantMessage, + LLMContextMessage, + SessionBackedMessage, + ToolResultMessage, +) +from src.maisaka.message_adapter import format_speaker_content from src.maisaka.tool_handlers import ( ToolHandlerContext, handle_mcp_tool, @@ -43,7 +49,7 @@ class BufferCLI: self._chat_loop_service: Optional[MaisakaChatLoopService] = None self._reply_generator = MaisakaReplyGenerator() self._reader = InputReader() - self._chat_history: Optional[list[SessionMessage]] = None + self._chat_history: Optional[list[LLMContextMessage]] = None self._knowledge_store = get_knowledge_store() self._knowledge_learner = KnowledgeLearner("maisaka_cli") self._knowledge_min_messages_for_extraction = 10 @@ -118,22 +124,78 @@ class BufferCLI: self._chat_start_time = now self._last_assistant_response_time = None self._chat_history = self._chat_loop_service.build_chat_context(user_text) - self._trigger_knowledge_learning([self._chat_history[-1]]) + self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)]) else: self._chat_history.append( - build_message( - role="user", - content=format_speaker_content( - global_config.maisaka.user_name.strip() or "User", - user_text, - now, - ), + self._build_cli_context_message( + user_text=user_text, + timestamp=now, + source_kind="user", ) ) - self._trigger_knowledge_learning([self._chat_history[-1]]) + self._trigger_knowledge_learning([self._build_cli_session_message(user_text, now)]) await self._run_llm_loop(self._chat_history) + @staticmethod + def _build_cli_context_message( + user_text: str, + timestamp: datetime, + source_kind: str = "user", + speaker_name: Optional[str] = None, + ) -> SessionBackedMessage: + """为 CLI 构造新的上下文消息。""" + resolved_speaker_name = speaker_name or global_config.maisaka.user_name.strip() or "User" + visible_text = format_speaker_content( + resolved_speaker_name, + user_text, + timestamp, + ) + planner_prefix = ( + f"[时间]{timestamp.strftime('%H:%M:%S')}\n" + f"[用户]{resolved_speaker_name}\n" + "[用户群昵称]\n" + "[msg_id]\n" + "[发言内容]" + ) + from src.common.data_models.message_component_data_model import MessageSequence, TextComponent + + return SessionBackedMessage( + raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]), + visible_text=visible_text, + timestamp=timestamp, + source_kind=source_kind, + ) + + @staticmethod + def _build_cli_session_message(user_text: str, timestamp: datetime) -> SessionMessage: + """为 CLI 的知识学习构造兼容 SessionMessage。""" + from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo + from src.common.data_models.message_component_data_model import MessageSequence + + message = SessionMessage(message_id=f"maisaka_cli_{int(timestamp.timestamp() * 1000)}", timestamp=timestamp, platform="maisaka") + message.message_info = MessageInfo( + user_info=UserInfo( + user_id="maisaka_user", + user_nickname=global_config.maisaka.user_name.strip() or "User", + user_cardname=None, + ), + group_info=None, + additional_config={}, + ) + message.session_id = "maisaka_cli" + message.raw_message = MessageSequence([]) + visible_text = format_speaker_content( + global_config.maisaka.user_name.strip() or "User", + user_text, + timestamp, + ) + message.raw_message.text(visible_text) + message.processed_plain_text = visible_text + message.display_message = visible_text + message.initialized = True + return message + def _trigger_knowledge_learning(self, messages: list[SessionMessage]) -> None: """在 CLI 会话中按批次触发 knowledge 学习。""" if not global_config.maisaka.enable_knowledge_module: @@ -161,7 +223,7 @@ class BufferCLI: except Exception as exc: console.print(f"[warning]Knowledge learning failed: {exc}[/warning]") - async def _run_llm_loop(self, chat_history: list[SessionMessage]) -> None: + async def _run_llm_loop(self, chat_history: list[LLMContextMessage]) -> None: """ Main inner loop for the Maisaka planner. @@ -210,7 +272,8 @@ class BufferCLI: ) ) - remove_last_perception(chat_history) + if chat_history and isinstance(chat_history[-1], AssistantMessage) and chat_history[-1].source == "perception": + chat_history.pop() perception_parts = [] if knowledge_analysis: @@ -218,11 +281,10 @@ class BufferCLI: if perception_parts: chat_history.append( - build_message( - role="assistant", + AssistantMessage( content="\n\n".join(perception_parts), - message_kind="perception", - source="assistant", + timestamp=datetime.now(), + source_kind="perception", ) ) elif global_config.maisaka.show_thinking: @@ -273,22 +335,19 @@ class BufferCLI: elif tool_call.func_name == "reply": reply = await self._generate_visible_reply(chat_history, response.content) chat_history.append( - build_message( - role="tool", + ToolResultMessage( content="Visible reply generated and recorded.", - source="tool", + timestamp=datetime.now(), tool_call_id=tool_call.call_id, + tool_name=tool_call.func_name, ) ) chat_history.append( - build_message( - role="user", - content=format_speaker_content( - global_config.bot.nickname.strip() or "MaiSaka", - reply, - datetime.now(), - ), - source="guided_reply", + self._build_cli_context_message( + user_text=reply, + timestamp=datetime.now(), + source_kind="guided_reply", + speaker_name=global_config.bot.nickname.strip() or "MaiSaka", ) ) @@ -296,11 +355,11 @@ class BufferCLI: if global_config.maisaka.show_thinking: console.print("[muted]No visible reply this round.[/muted]") chat_history.append( - build_message( - role="tool", + ToolResultMessage( content="No visible reply was sent for this round.", - source="tool", + timestamp=datetime.now(), tool_call_id=tool_call.call_id, + tool_name=tool_call.func_name, ) ) @@ -342,7 +401,7 @@ class BufferCLI: ) ) - async def _generate_visible_reply(self, chat_history: list[SessionMessage], latest_thought: str) -> str: + async def _generate_visible_reply(self, chat_history: list[LLMContextMessage], latest_thought: str) -> str: """根据最新思考生成并输出可见回复。""" if not latest_thought: return "" diff --git a/src/know_u/knowledge.py b/src/know_u/knowledge.py index e815e96b..7fbc0948 100644 --- a/src/know_u/knowledge.py +++ b/src/know_u/knowledge.py @@ -11,10 +11,11 @@ from src.chat.message_receive.message import SessionMessage from src.chat.utils.utils import is_bot_self from src.common.data_models.llm_service_data_models import LLMGenerationOptions from src.common.logger import get_logger +from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage, ToolResultMessage from src.services.llm_service import LLMServiceClient from src.know_u.knowledge_store import KNOWLEDGE_CATEGORIES, get_knowledge_store -from src.maisaka.message_adapter import get_message_role, get_message_text, parse_speaker_content +from src.maisaka.message_adapter import parse_speaker_content logger = get_logger("maisaka_knowledge") @@ -53,7 +54,7 @@ def extract_category_ids_from_result(result: str) -> List[str]: async def retrieve_relevant_knowledge( knowledge_analyzer: Any, - chat_history: List[SessionMessage], + chat_history: List[LLMContextMessage], ) -> str: """Retrieve formatted knowledge snippets relevant to the current chat history.""" store = get_knowledge_store() @@ -156,14 +157,26 @@ class KnowledgeLearner: """ lines: List[str] = [] for message in self._messages_cache[-30:]: - if get_message_role(message) == "assistant": - continue - if get_message_role(message) == "tool": - continue - if is_bot_self(message.platform, message.message_info.user_info.user_id): + if isinstance(message, (AssistantMessage, ToolResultMessage)): continue + if isinstance(message, SessionBackedMessage): + if message.original_message and is_bot_self( + message.original_message.platform, + message.original_message.message_info.user_info.user_id, + ): + continue + raw_text = message.processed_plain_text.strip() + fallback_speaker = ( + message.original_message.message_info.user_info.user_nickname + if message.original_message is not None + else "用户" + ) + else: + if is_bot_self(message.platform, message.message_info.user_info.user_id): + continue + raw_text = message.processed_plain_text.strip() + fallback_speaker = message.message_info.user_info.user_nickname or "用户" - raw_text = get_message_text(message).strip() if not raw_text: continue @@ -172,7 +185,7 @@ class KnowledgeLearner: if not visible_text: continue - speaker = speaker_name or message.message_info.user_info.user_nickname or "用户" + speaker = speaker_name or fallback_speaker lines.append(f"{speaker}: {visible_text}") return "\n".join(lines) diff --git a/src/llm_models/model_client/adapter_base.py b/src/llm_models/model_client/adapter_base.py index d631870c..660a286d 100644 --- a/src/llm_models/model_client/adapter_base.py +++ b/src/llm_models/model_client/adapter_base.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast import asyncio +from src.common.logger import get_logger from src.config.model_configs import ModelInfo from .base_client import ( @@ -33,12 +34,14 @@ ProviderStreamResponseHandler = Callable[ ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]] """Provider 专用非流式响应解析函数类型。""" +logger = get_logger("llm_adapter_base") + async def await_task_with_interrupt( task: asyncio.Task[TaskResultT], interrupt_flag: asyncio.Event | None, *, - interval_seconds: float = 0.1, + interval_seconds: float = 0.02, ) -> TaskResultT: """在支持外部中断的前提下等待异步任务完成。 @@ -55,8 +58,11 @@ async def await_task_with_interrupt( """ from src.llm_models.exceptions import ReqAbortException + started_at = asyncio.get_running_loop().time() while not task.done(): if interrupt_flag and interrupt_flag.is_set(): + elapsed = asyncio.get_running_loop().time() - started_at + logger.info(f"LLM 请求检测到中断信号,准备取消底层任务,elapsed={elapsed:.3f}s") task.cancel() raise ReqAbortException("请求被外部信号中断") await asyncio.sleep(interval_seconds) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 43fb5189..775fa663 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -22,6 +22,7 @@ from src.llm_models.exceptions import ( EmptyResponseException, ModelAttemptFailed, NetworkConnectionError, + ReqAbortException, RespNotOkException, RespParseException, ) @@ -326,16 +327,7 @@ class LLMOrchestrator: del raise_when_empty self._refresh_task_config() start_time = time.time() - if self.request_type.startswith("maisaka_"): - logger.info( - f"LLMOrchestrator[{self.request_type}] 开始执行 generate_response_with_message_async " - f"(temperature={temperature}, max_tokens={max_tokens}, tools={len(tools or [])})" - ) - if self.request_type.startswith("maisaka_"): - logger.info( - f"LLMOrchestrator[{self.request_type}] 正在根据 {len(tools or [])} 个工具构建内部工具选项" - ) tool_built = self._build_tool_options(tools) if self.request_type.startswith("maisaka_"): logger.info(f"LLMOrchestrator[{self.request_type}] 已构建 {len(tool_built or [])} 个内部工具选项") @@ -777,6 +769,9 @@ class LLMOrchestrator: ) await asyncio.sleep(api_provider.retry_interval) + except ReqAbortException: + raise + except Exception as e: logger.error(traceback.format_exc()) @@ -881,6 +876,15 @@ class LLMOrchestrator: self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) return LLMExecutionResult(api_response=response, model_info=model_info) + except ReqAbortException as e: + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty - 1) + if self.request_type.startswith("maisaka_"): + logger.info( + f"LLMOrchestrator[{self.request_type}] 模型 model={model_info.name} 的请求已被外部信号中断" + ) + raise e + except ModelAttemptFailed as e: last_exception = e.original_exception or e logger.warning(f"模型 '{model_info.name}' 尝试失败,切换到下一个模型。原因: {e}") diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index 2dc8c03a..00118ac5 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -14,9 +14,9 @@ from rich.panel import Panel from rich.pretty import Pretty from rich.text import Text -from src.chat.message_receive.message import SessionMessage from src.cli.console import console from src.common.data_models.llm_service_data_models import LLMGenerationOptions +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.common.logger import get_logger from src.common.prompt_i18n import load_prompt from src.config.config import global_config @@ -27,12 +27,8 @@ from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionI from src.services.llm_service import LLMServiceClient from .builtin_tools import get_builtin_tools -from .message_adapter import ( - build_message, - format_speaker_content, - get_message_role, - to_llm_message, -) +from .context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage +from .message_adapter import format_speaker_content @dataclass(slots=True) @@ -41,7 +37,7 @@ class ChatResponse: content: Optional[str] tool_calls: List[ToolCall] - raw_message: SessionMessage + raw_message: AssistantMessage logger = get_logger("maisaka_chat_loop") @@ -59,6 +55,7 @@ class MaisakaChatLoopService: self._temperature = temperature self._max_tokens = max_tokens self._extra_tools: List[ToolOption] = [] + self._interrupt_flag: asyncio.Event | None = None self._prompts_loaded = False self._prompt_load_lock = asyncio.Lock() self._personality_prompt = self._build_personality_prompt() @@ -117,18 +114,21 @@ class MaisakaChatLoopService: def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None: self._extra_tools = normalize_tool_options(tools) or [] + def set_interrupt_flag(self, interrupt_flag: asyncio.Event | None) -> None: + """设置当前 planner 请求使用的中断标记。""" + self._interrupt_flag = interrupt_flag + async def analyze_knowledge_need( self, - chat_history: List[SessionMessage], + chat_history: List[LLMContextMessage], categories_summary: str, ) -> List[str]: """分析当前对话是否需要检索知识库分类。""" visible_history: List[str] = [] for message in chat_history[-8:]: - if not message.content: + if not message.processed_plain_text: continue - role = getattr(message, "role", "") - visible_history.append(f"{role}: {message.content}") + visible_history.append(f"{message.role}: {message.processed_plain_text}") if not visible_history or not categories_summary.strip(): return [] @@ -302,7 +302,7 @@ class MaisakaChatLoopService: padding=(0, 1), ) - async def chat_loop_step(self, chat_history: List[SessionMessage]) -> ChatResponse: + async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse: await self.ensure_chat_prompt_loaded() selected_history, selection_reason = self._select_llm_context_messages(chat_history) @@ -313,7 +313,7 @@ class MaisakaChatLoopService: messages.append(system_msg.build()) for msg in selected_history: - llm_message = to_llm_message(msg) + llm_message = msg.to_llm_message() if llm_message is not None: messages.append(llm_message) @@ -342,15 +342,24 @@ class MaisakaChatLoopService: ) request_started_at = perf_counter() + logger.info( + "planner 请求开始: " + f"selected_history={len(selected_history)} " + f"llm_messages={len(built_messages)} " + f"tool_count={len(all_tools)} " + f"interrupt_enabled={self._interrupt_flag is not None}" + ) generation_result = await self._llm_chat.generate_response_with_messages( message_factory=message_factory, options=LLMGenerationOptions( tool_options=all_tools if all_tools else None, temperature=self._temperature, max_tokens=self._max_tokens, + interrupt_flag=self._interrupt_flag, ), ) - _ = perf_counter() - request_started_at + request_elapsed = perf_counter() - request_started_at + logger.info(f"planner 请求完成,elapsed={request_elapsed:.3f}s") tool_call_summaries = [ { @@ -365,11 +374,10 @@ class MaisakaChatLoopService: f"tool_calls={tool_call_summaries}" ) - raw_message = build_message( - role=RoleType.Assistant.value, + raw_message = AssistantMessage( content=generation_result.response or "", - source="assistant", - tool_calls=generation_result.tool_calls or None, + timestamp=datetime.now(), + tool_calls=generation_result.tool_calls or [], ) return ChatResponse( content=generation_result.response, @@ -378,20 +386,19 @@ class MaisakaChatLoopService: ) @staticmethod - def _select_llm_context_messages(chat_history: List[SessionMessage]) -> tuple[List[SessionMessage], str]: + def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]: """选择真正发送给 LLM 的上下文消息。""" max_context_size = max(1, int(global_config.chat.max_context_size)) - counted_roles = {"user", "assistant"} selected_indices: List[int] = [] counted_message_count = 0 for index in range(len(chat_history) - 1, -1, -1): message = chat_history[index] - if to_llm_message(message) is None: + if message.to_llm_message() is None: continue selected_indices.append(index) - if get_message_role(message) in counted_roles: + if message.count_in_context: counted_message_count += 1 if counted_message_count >= max_context_size: break @@ -410,15 +417,25 @@ class MaisakaChatLoopService: ) @staticmethod - def build_chat_context(user_text: str) -> List[SessionMessage]: + def build_chat_context(user_text: str) -> List[LLMContextMessage]: + timestamp = datetime.now() + visible_text = format_speaker_content( + global_config.maisaka.user_name.strip() or "用户", + user_text, + timestamp, + ) + planner_prefix = ( + f"[时间]{timestamp.strftime('%H:%M:%S')}\n" + f"[用户]{global_config.maisaka.user_name.strip() or '用户'}\n" + "[用户群昵称]\n" + "[msg_id]\n" + "[发言内容]" + ) return [ - build_message( - role=RoleType.User.value, - content=format_speaker_content( - global_config.maisaka.user_name.strip() or "用户", - user_text, - datetime.now(), - ), - source="user", + SessionBackedMessage( + raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]), + visible_text=visible_text, + timestamp=timestamp, + source_kind="user", ) ] diff --git a/src/maisaka/context_messages.py b/src/maisaka/context_messages.py new file mode 100644 index 00000000..8da06a23 --- /dev/null +++ b/src/maisaka/context_messages.py @@ -0,0 +1,275 @@ +"""Maisaka 内部上下文消息抽象。""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from io import BytesIO +from typing import Optional +import base64 + +from PIL import Image as PILImage + +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType +from src.llm_models.payload_content.tool_option import ToolCall + + +def _guess_image_format(image_bytes: bytes) -> Optional[str]: + if not image_bytes: + return None + + try: + with PILImage.open(BytesIO(image_bytes)) as image: + return image.format.lower() if image.format else None + except Exception: + return None + + +def _build_message_from_sequence( + role: RoleType, + message_sequence: MessageSequence, + fallback_text: str, + *, + tool_call_id: Optional[str] = None, + tool_calls: Optional[list[ToolCall]] = None, +) -> Optional[Message]: + """根据消息片段构造统一 LLM 消息。""" + builder = MessageBuilder().set_role(role) + if role == RoleType.Assistant and tool_calls: + builder.set_tool_calls(tool_calls) + if role == RoleType.Tool and tool_call_id: + builder.add_tool_call(tool_call_id) + + has_content = False + for component in message_sequence.components: + if isinstance(component, TextComponent): + if component.text: + builder.add_text_content(component.text) + has_content = True + continue + + if isinstance(component, (EmojiComponent, ImageComponent)): + image_format = _guess_image_format(component.binary_data) + if image_format and component.binary_data: + builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8")) + has_content = True + continue + + if component.content: + builder.add_text_content(component.content) + has_content = True + + if not has_content and fallback_text: + builder.add_text_content(fallback_text) + has_content = True + + if not has_content and not (role == RoleType.Assistant and tool_calls): + return None + return builder.build() + + +class ReferenceMessageType(str, Enum): + """参考消息类型。""" + + CUSTOM = "custom" + JARGON = "jargon" + KNOWLEDGE = "knowledge" + MEMORY = "memory" + TOOL_HINT = "tool_hint" + + +class LLMContextMessage(ABC): + """Maisaka 内部用于组织 LLM 上下文的统一消息抽象。""" + + timestamp: datetime + + @property + @abstractmethod + def role(self) -> str: + """返回 LLM 消息角色。""" + + @property + @abstractmethod + def processed_plain_text(self) -> str: + """返回可读的纯文本内容。""" + + @property + def count_in_context(self) -> bool: + """是否占用普通 user/assistant 上下文窗口。""" + return True + + @property + def source(self) -> str: + """返回消息来源。""" + return self.__class__.__name__ + + @abstractmethod + def to_llm_message(self) -> Optional[Message]: + """转换为统一 LLM 消息。""" + + def consume_once(self) -> bool: + """消费一次生命周期,返回是否继续保留。""" + return True + + +@dataclass(slots=True) +class SessionBackedMessage(LLMContextMessage): + """真实会话上下文消息。""" + + raw_message: MessageSequence + visible_text: str + timestamp: datetime + message_id: Optional[str] = None + original_message: Optional[SessionMessage] = None + source_kind: str = "user" + + @property + def role(self) -> str: + return RoleType.User.value + + @property + def processed_plain_text(self) -> str: + return self.visible_text + + @property + def source(self) -> str: + return self.source_kind + + def to_llm_message(self) -> Optional[Message]: + return _build_message_from_sequence( + RoleType.User, + self.raw_message, + self.processed_plain_text, + ) + + @classmethod + def from_session_message( + cls, + session_message: SessionMessage, + *, + raw_message: MessageSequence, + visible_text: str, + source_kind: str = "user", + ) -> "SessionBackedMessage": + """从真实 SessionMessage 构造上下文消息。""" + return cls( + raw_message=raw_message, + visible_text=visible_text, + timestamp=session_message.timestamp, + message_id=session_message.message_id, + original_message=session_message, + source_kind=source_kind, + ) + + +@dataclass(slots=True) +class ReferenceMessage(LLMContextMessage): + """参考消息。""" + + content: str + timestamp: datetime + reference_type: ReferenceMessageType = ReferenceMessageType.CUSTOM + remaining_uses_value: Optional[int] = 1 + display_prefix: str = "[参考消息]" + + @property + def role(self) -> str: + return RoleType.User.value + + @property + def processed_plain_text(self) -> str: + return f"{self.display_prefix}\n{self.content}".strip() + + @property + def count_in_context(self) -> bool: + return False + + @property + def source(self) -> str: + return self.reference_type.value + + def to_llm_message(self) -> Optional[Message]: + message_sequence = MessageSequence([TextComponent(self.processed_plain_text)]) + return _build_message_from_sequence(RoleType.User, message_sequence, self.processed_plain_text) + + def consume_once(self) -> bool: + if self.remaining_uses_value is None: + return True + + self.remaining_uses_value -= 1 + return self.remaining_uses_value > 0 + + +@dataclass(slots=True) +class AssistantMessage(LLMContextMessage): + """内部 assistant 消息。""" + + content: str + timestamp: datetime + tool_calls: list[ToolCall] = field(default_factory=list) + source_kind: str = "assistant" + + @property + def role(self) -> str: + return RoleType.Assistant.value + + @property + def processed_plain_text(self) -> str: + return self.content + + @property + def count_in_context(self) -> bool: + return self.source_kind != "perception" + + @property + def source(self) -> str: + return self.source_kind + + def to_llm_message(self) -> Optional[Message]: + message_sequence = MessageSequence([]) + if self.content: + message_sequence.text(self.content) + return _build_message_from_sequence( + RoleType.Assistant, + message_sequence, + self.content, + tool_calls=self.tool_calls or None, + ) + + +@dataclass(slots=True) +class ToolResultMessage(LLMContextMessage): + """工具返回结果消息。""" + + content: str + timestamp: datetime + tool_call_id: str + tool_name: str = "" + success: bool = True + + @property + def role(self) -> str: + return RoleType.Tool.value + + @property + def processed_plain_text(self) -> str: + return self.content + + @property + def count_in_context(self) -> bool: + return False + + @property + def source(self) -> str: + return self.tool_name or "tool" + + def to_llm_message(self) -> Optional[Message]: + message_sequence = MessageSequence([TextComponent(self.content)]) + return _build_message_from_sequence( + RoleType.Tool, + message_sequence, + self.content, + tool_call_id=self.tool_call_id, + ) diff --git a/src/maisaka/message_adapter.py b/src/maisaka/message_adapter.py index ca8620eb..b52d1baa 100644 --- a/src/maisaka/message_adapter.py +++ b/src/maisaka/message_adapter.py @@ -1,148 +1,32 @@ -""" -MaiSaka 内部消息适配器。 -""" +"""Maisaka 文本与消息片段适配工具。""" from copy import deepcopy from datetime import datetime -from io import BytesIO from typing import Optional -from uuid import uuid4 -import base64 import re -from PIL import Image as PILImage - -from src.chat.message_receive.message import SessionMessage -from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent -from src.config.config import global_config -from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType -from src.llm_models.payload_content.tool_option import ToolCall -MAISAKA_PLATFORM = "maisaka" -MAISAKA_SESSION_ID = "maisaka_cli" -MESSAGE_KIND_KEY = "maisaka_message_kind" -SOURCE_KEY = "maisaka_source" -LLM_ROLE_KEY = "maisaka_llm_role" -TOOL_CALL_ID_KEY = "maisaka_tool_call_id" -TOOL_CALLS_KEY = "maisaka_tool_calls" SPEAKER_PREFIX_PATTERN = re.compile( r"^(?:(?P\d{2}:\d{2}:\d{2}))?(?:\[msg_id:(?P[^\]]+)\])?\[(?P[^\]]+)\](?P.*)$", re.DOTALL, ) -def _build_user_info_for_role(role: str) -> UserInfo: - if role == RoleType.User.value: - return UserInfo( - user_id="maisaka_user", - user_nickname=global_config.maisaka.user_name.strip() or "用户", - user_cardname=None, - ) - if role == RoleType.Tool.value: - return UserInfo(user_id="maisaka_tool", user_nickname="tool", user_cardname=None) - return UserInfo( - user_id="maisaka_assistant", - user_nickname=global_config.bot.nickname.strip() or "MaiSaka", - user_cardname=None, - ) - - -def _serialize_tool_call(tool_call: ToolCall) -> dict: - return { - "call_id": tool_call.call_id, - "func_name": tool_call.func_name, - "args": tool_call.args or {}, - } - - -def _deserialize_tool_call(data: dict) -> ToolCall: - return ToolCall( - call_id=str(data.get("call_id", "")), - func_name=str(data.get("func_name", "")), - args=data.get("args", {}) or {}, - ) - - -def _ensure_message_id_in_speaker_content(content: str, message_id: str) -> str: - """Ensure speaker-formatted visible text carries a msg_id marker.""" - match = SPEAKER_PREFIX_PATTERN.match(content or "") - if not match: - return content - - existing_message_id = match.group("message_id") - if existing_message_id: - return content - - timestamp_text = match.group("timestamp") - speaker_name = match.group("speaker") - visible_content = match.group("content") - timestamp = datetime.strptime(timestamp_text, "%H:%M:%S") if timestamp_text else None - return format_speaker_content(speaker_name, visible_content, timestamp, message_id) - - -def build_message( - role: str, - content: str = "", - *, - message_kind: str = "normal", - source: Optional[str] = None, - tool_call_id: Optional[str] = None, - tool_calls: Optional[list[ToolCall]] = None, - timestamp: Optional[datetime] = None, - message_id: Optional[str] = None, - platform: str = MAISAKA_PLATFORM, - session_id: str = MAISAKA_SESSION_ID, - user_info: Optional[UserInfo] = None, - group_info: Optional[GroupInfo] = None, - raw_message: Optional[MessageSequence] = None, - display_text: Optional[str] = None, -) -> SessionMessage: - """为 MaiSaka 会话历史构建内部 ``SessionMessage``。""" - resolved_timestamp = timestamp or datetime.now() - resolved_role = role.value if isinstance(role, RoleType) else role - message = SessionMessage( - message_id=message_id or f"maisaka_{uuid4().hex}", - timestamp=resolved_timestamp, - platform=platform, - ) - normalized_content = _ensure_message_id_in_speaker_content(content, message.message_id) if content else content - message.message_info = MessageInfo( - user_info=user_info or _build_user_info_for_role(resolved_role), - group_info=group_info, - additional_config={ - LLM_ROLE_KEY: resolved_role, - MESSAGE_KIND_KEY: message_kind, - SOURCE_KEY: source or resolved_role, - TOOL_CALL_ID_KEY: tool_call_id, - TOOL_CALLS_KEY: [_serialize_tool_call(tool_call) for tool_call in (tool_calls or [])], - }, - ) - message.session_id = session_id - message.raw_message = raw_message if raw_message is not None else MessageSequence([]) - if raw_message is None and normalized_content: - message.raw_message.text(normalized_content) - visible_text = display_text if display_text is not None else normalized_content - message.processed_plain_text = visible_text - message.display_message = visible_text - message.initialized = True - return message - - def format_speaker_content( speaker_name: str, content: str, timestamp: Optional[datetime] = None, message_id: Optional[str] = None, ) -> str: - """Format visible conversation content with an explicit speaker label.""" + """将可见文本格式化为带说话人前缀的样式。""" time_prefix = timestamp.strftime("%H:%M:%S") if timestamp is not None else "" message_id_prefix = f"[msg_id:{message_id}]" if message_id else "" return f"{time_prefix}{message_id_prefix}[{speaker_name}]{content}" def parse_speaker_content(content: str) -> tuple[Optional[str], str]: - """Parse content formatted as [speaker]message.""" + """解析形如 [speaker]message 的可见文本。""" match = SPEAKER_PREFIX_PATTERN.match(content or "") if not match: return None, content or "" @@ -150,12 +34,12 @@ def parse_speaker_content(content: str) -> tuple[Optional[str], str]: def clone_message_sequence(message_sequence: MessageSequence) -> MessageSequence: - """Create a detached copy of a message sequence.""" + """复制消息片段序列。""" return MessageSequence([deepcopy(component) for component in message_sequence.components]) def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str: - """Extract visible text from a message sequence without forcing image descriptions.""" + """从消息片段序列提取可见文本。""" parts: list[str] = [] for component in message_sequence.components: if isinstance(component, TextComponent): @@ -181,112 +65,5 @@ def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str: if isinstance(component, ImageComponent): parts.append("[图片]") + return "".join(parts) - - -def _guess_image_format(image_bytes: bytes) -> Optional[str]: - if not image_bytes: - return None - - try: - with PILImage.open(BytesIO(image_bytes)) as image: - return image.format.lower() if image.format else None - except Exception: - return None - - -def get_message_text(message: SessionMessage) -> str: - if message.processed_plain_text is not None: - return message.processed_plain_text - if message.display_message is not None: - return message.display_message - - parts: list[str] = [] - for component in message.raw_message.components: - text = getattr(component, "text", None) - if isinstance(text, str): - parts.append(text) - return "".join(parts) - - -def get_message_role(message: SessionMessage) -> str: - return str(message.message_info.additional_config.get(LLM_ROLE_KEY, RoleType.User.value)) - - -def get_message_kind(message: SessionMessage) -> str: - return str(message.message_info.additional_config.get(MESSAGE_KIND_KEY, "normal")) - - -def get_message_source(message: SessionMessage) -> str: - return str(message.message_info.additional_config.get(SOURCE_KEY, get_message_role(message))) - - -def is_perception_message(message: SessionMessage) -> bool: - return get_message_kind(message) == "perception" - - -def get_tool_call_id(message: SessionMessage) -> Optional[str]: - value = message.message_info.additional_config.get(TOOL_CALL_ID_KEY) - return str(value) if value else None - - -def get_tool_calls(message: SessionMessage) -> list[ToolCall]: - raw_tool_calls = message.message_info.additional_config.get(TOOL_CALLS_KEY, []) - if not isinstance(raw_tool_calls, list): - return [] - return [_deserialize_tool_call(item) for item in raw_tool_calls if isinstance(item, dict)] - - -def remove_last_perception(messages: list[SessionMessage]) -> None: - for index in range(len(messages) - 1, -1, -1): - if is_perception_message(messages[index]): - messages.pop(index) - break - - -def to_llm_message(message: SessionMessage) -> Optional[Message]: - role = get_message_role(message) - tool_call_id = get_tool_call_id(message) - tool_calls = get_tool_calls(message) - - if role == RoleType.System.value: - role_type = RoleType.System - elif role == RoleType.User.value: - role_type = RoleType.User - elif role == RoleType.Assistant.value: - role_type = RoleType.Assistant - elif role == RoleType.Tool.value: - role_type = RoleType.Tool - else: - return None - - builder = MessageBuilder().set_role(role_type) - if role_type == RoleType.Assistant and tool_calls: - builder.set_tool_calls(tool_calls) - if role_type == RoleType.Tool and tool_call_id: - builder.add_tool_call(tool_call_id) - - has_content = False - for component in message.raw_message.components: - if isinstance(component, TextComponent): - if component.text: - builder.add_text_content(component.text) - has_content = True - continue - - if isinstance(component, (ImageComponent, EmojiComponent)): - image_format = _guess_image_format(component.binary_data) - if image_format and component.binary_data: - builder.add_image_content(image_format, base64.b64encode(component.binary_data).decode("utf-8")) - has_content = True - continue - - if component.content: - builder.add_text_content(component.content) - has_content = True - - if not has_content: - content = get_message_text(message) - if content: - builder.add_text_content(content) - return builder.build() diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index bef58b59..e5cd6dae 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -6,33 +6,32 @@ from typing import TYPE_CHECKING, Optional import asyncio import difflib import json -import re import time +import traceback -from sqlmodel import select from src.chat.heart_flow.heartFC_utils import CycleDetail from src.chat.message_receive.message import SessionMessage from src.chat.replyer.replyer_manager import replyer_manager -from src.chat.utils.utils import get_bot_account, process_llm_response -from src.common.database.database import get_db_session -from src.common.database.database_model import Jargon -from src.common.data_models.mai_message_data_model import UserInfo +from src.chat.utils.utils import process_llm_response from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.common.logger import get_logger from src.config.config import global_config from src.learners.jargon_explainer import search_jargon +from src.llm_models.exceptions import ReqAbortException from src.llm_models.payload_content.tool_option import ToolCall from src.services import database_service as database_api, send_service +from .context_messages import ( + AssistantMessage, + LLMContextMessage, + SessionBackedMessage, + ToolResultMessage, +) from .message_adapter import ( - build_message, build_visible_text_from_sequence, clone_message_sequence, format_speaker_content, - get_message_source, - get_message_text, - get_message_role, ) from .tool_handlers import ( handle_mcp_tool, @@ -51,7 +50,6 @@ class MaisakaReasoningEngine: def __init__(self, runtime: "MaisakaHeartFlowChatting") -> None: self._runtime = runtime self._last_reasoning_content: str = "" - self._shown_jargons: set[str] = set() # 已在参考消息中展示过的 jargon async def run_loop(self) -> None: """独立消费消息批次,并执行对应的内部思考轮次。""" @@ -65,6 +63,7 @@ class MaisakaReasoningEngine: self._runtime._agent_state = self._runtime._STATE_RUNNING if cached_messages: + self._append_wait_interrupted_message_if_needed() await self._ingest_messages(cached_messages) anchor_message = cached_messages[-1] else: @@ -76,26 +75,35 @@ class MaisakaReasoningEngine: self._runtime._internal_turn_queue.task_done() continue logger.info(f"{self._runtime.log_prefix} wait 超时后开始新一轮思考") - self._runtime._chat_history.append(self._build_wait_timeout_message(anchor_message)) + self._runtime._chat_history.append(self._build_wait_timeout_message()) self._trim_chat_history() try: for round_index in range(self._runtime._max_internal_rounds): cycle_detail = self._start_cycle() self._runtime._log_cycle_started(cycle_detail, round_index) try: - # 每次LLM生成前,动态添加参考消息到最新位置 - reference_added = self._append_jargon_reference_message() planner_started_at = time.time() - response = await self._runtime._chat_loop_service.chat_loop_step(self._runtime._chat_history) + logger.info( + f"{self._runtime.log_prefix} planner 开始: " + f"round={round_index + 1} " + f"history_size={len(self._runtime._chat_history)} " + f"started_at={planner_started_at:.3f}" + ) + interrupt_flag = asyncio.Event() + self._runtime._planner_interrupt_flag = interrupt_flag + self._runtime._chat_loop_service.set_interrupt_flag(interrupt_flag) + try: + response = await self._runtime._chat_loop_service.chat_loop_step(self._runtime._chat_history) + finally: + if self._runtime._planner_interrupt_flag is interrupt_flag: + self._runtime._planner_interrupt_flag = None + self._runtime._chat_loop_service.set_interrupt_flag(None) cycle_detail.time_records["planner"] = time.time() - planner_started_at - - # LLM调用后,移除刚才添加的参考消息(一次性使用) - if reference_added and self._runtime._chat_history: - # 从末尾往前查找并移除参考消息 - for i in range(len(self._runtime._chat_history) - 1, -1, -1): - if get_message_source(self._runtime._chat_history[i]) == "user_reference": - self._runtime._chat_history.pop(i) - break + logger.info( + f"{self._runtime.log_prefix} planner 完成: " + f"round={round_index + 1} " + f"elapsed={cycle_detail.time_records['planner']:.3f}s" + ) reasoning_content = response.content or "" if self._should_replace_reasoning(reasoning_content): @@ -104,9 +112,6 @@ class MaisakaReasoningEngine: logger.info(f"{self._runtime.log_prefix} reasoning content replaced due to high similarity") self._last_reasoning_content = reasoning_content - response.raw_message.platform = anchor_message.platform - response.raw_message.session_id = self._runtime.session_id - response.raw_message.message_info.group_info = self._runtime._build_group_info(anchor_message) self._runtime._chat_history.append(response.raw_message) if response.tool_calls: @@ -124,6 +129,16 @@ class MaisakaReasoningEngine: if response.content: continue + break + except ReqAbortException: + interrupted_at = time.time() + logger.info( + f"{self._runtime.log_prefix} planner 打断成功: " + f"round={round_index + 1} " + f"started_at={planner_started_at:.3f} " + f"interrupted_at={interrupted_at:.3f} " + f"elapsed={interrupted_at - planner_started_at:.3f}s" + ) break finally: self._end_cycle(cycle_detail) @@ -136,6 +151,7 @@ class MaisakaReasoningEngine: raise except Exception: logger.exception("%s Maisaka internal loop crashed", self._runtime.log_prefix) + logger.error(traceback.format_exc()) raise def _get_timeout_anchor_message(self) -> Optional[SessionMessage]: @@ -144,16 +160,31 @@ class MaisakaReasoningEngine: return self._runtime.message_cache[-1] return None - def _build_wait_timeout_message(self, anchor_message: SessionMessage) -> SessionMessage: - """构造 wait 超时后的工具结果消息,用于触发下一轮思考。""" - return build_message( - role="tool", + def _build_wait_timeout_message(self) -> ToolResultMessage: + """构造 wait 超时后的工具结果消息。""" + tool_call_id = self._runtime._pending_wait_tool_call_id or "wait_timeout" + self._runtime._pending_wait_tool_call_id = None + return ToolResultMessage( content="wait 已超时,期间没有收到新的用户输入。请基于现有上下文继续下一轮思考。", - source="tool", - platform=anchor_message.platform, - session_id=self._runtime.session_id, - group_info=self._runtime._build_group_info(anchor_message), - user_info=UserInfo(user_id="maisaka_tool", user_nickname="tool", user_cardname=None), + timestamp=datetime.now(), + tool_call_id=tool_call_id, + tool_name="wait", + ) + + def _append_wait_interrupted_message_if_needed(self) -> None: + """如果 wait 被新消息打断,则补一条对应的工具结果消息。""" + tool_call_id = self._runtime._pending_wait_tool_call_id + if not tool_call_id: + return + + self._runtime._pending_wait_tool_call_id = None + self._runtime._chat_history.append( + ToolResultMessage( + content="wait 被新的用户输入打断,已继续处理最新消息。", + timestamp=datetime.now(), + tool_call_id=tool_call_id, + tool_name="wait", + ) ) async def _ingest_messages(self, messages: list[SessionMessage]) -> None: @@ -164,17 +195,11 @@ class MaisakaReasoningEngine: if not user_sequence.components: continue - history_message = build_message( - role="user", - content=visible_text, - source="user", - timestamp=message.timestamp, - platform=message.platform, - session_id=self._runtime.session_id, - group_info=self._runtime._build_group_info(message), - user_info=self._runtime._build_runtime_user_info(), + history_message = SessionBackedMessage.from_session_message( + message, raw_message=user_sequence, - display_text=visible_text, + visible_text=visible_text, + source_kind="user", ) self._insert_chat_history_message(history_message) self._trim_chat_history() @@ -239,141 +264,10 @@ class MaisakaReasoningEngine: speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id return format_speaker_content(speaker_name, content, message.timestamp, message.message_id).strip() - def _insert_chat_history_message(self, message: SessionMessage) -> int: - """按时间顺序将消息插入聊天历史,同时保留 system 消息在最前。""" - if not self._runtime._chat_history: - self._runtime._chat_history.append(message) - return 0 - - insert_at = len(self._runtime._chat_history) - for index, existing_message in enumerate(self._runtime._chat_history): - if get_message_role(existing_message) == "system": - continue - if existing_message.timestamp > message.timestamp: - insert_at = index - break - - self._runtime._chat_history.insert(insert_at, message) - return insert_at - - def _append_jargon_reference_message(self) -> bool: - """每次LLM生成前,如果命中了黑话词条,则添加一条参考信息消息到聊天历史末尾。 - - Returns: - bool: 是否添加了参考消息 - """ - content = self._build_user_history_corpus() - if not content: - return False - - matched_words = self._find_jargon_words_in_text(content) - if not matched_words: - return False - - # 记录已展示的 jargon - for word in matched_words: - self._shown_jargons.add(word.lower()) - - reference_text = ( - "[参考信息]\n" - f"{','.join(matched_words)}可能是jargon,可以使用query_jargon来查看其含义" - ) - reference_sequence = MessageSequence([TextComponent(reference_text)]) - - # 使用当前时间作为时间戳 - reference_message = build_message( - role="user", - content="", - source="user_reference", - timestamp=datetime.now(), - platform=self._runtime.chat_stream.platform, - session_id=self._runtime.session_id, - group_info=self._runtime._build_group_info(), - user_info=self._runtime._build_runtime_user_info(), - raw_message=reference_sequence, - display_text=reference_text, - ) - self._runtime._chat_history.append(reference_message) - return True - - def _build_user_history_corpus(self) -> str: - """拼接当前聊天记录内所有用户消息的正文,用于统一匹配黑话。""" - parts: list[str] = [] - for history_message in self._runtime._chat_history: - if get_message_role(history_message) != "user": - continue - if get_message_source(history_message) != "user": - continue - text = (get_message_text(history_message) or "").strip() - if not text: - continue - parts.append(text) - - return "\n".join(parts) - - def _find_jargon_words_in_text(self, content: str) -> list[str]: - """匹配正文中出现的 jargon 词条。""" - lowered_content = content.lower() - matched_entries: list[tuple[int, int, int, str]] = [] - seen_words: set[str] = set() - - with get_db_session(auto_commit=False) as session: - query = ( - select(Jargon) - .where(Jargon.is_jargon.is_(True)) - .order_by(Jargon.count.desc()) # type: ignore[attr-defined] - ) - jargons = session.exec(query).all() - - for jargon in jargons: - jargon_content = str(jargon.content or "").strip() - if not jargon_content: - continue - # meaning 为空的不匹配 - if not str(jargon.meaning or "").strip(): - continue - normalized_content = jargon_content.lower() - if normalized_content in seen_words: - continue - # 跳过已经展示过的 jargon - if normalized_content in self._shown_jargons: - continue - if not self._is_visible_jargon(jargon): - continue - match_position = self._get_jargon_match_position(jargon_content, lowered_content, content) - if match_position is None: - continue - - seen_words.add(normalized_content) - matched_entries.append((match_position, -len(jargon_content), -int(jargon.count or 0), jargon_content)) - - matched_entries.sort() - return [matched_content for _, _, _, matched_content in matched_entries[:8]] - - def _is_visible_jargon(self, jargon: Jargon) -> bool: - """判断当前会话是否可见该 jargon。""" - if global_config.expression.all_global_jargon or bool(jargon.is_global): - return True - - try: - session_id_dict = json.loads(jargon.session_id_dict or "{}") - except (TypeError, json.JSONDecodeError): - logger.warning(f"Failed to parse jargon.session_id_dict: jargon_id={jargon.id}") - return False - return self._runtime.session_id in session_id_dict - - @staticmethod - def _get_jargon_match_position(jargon_content: str, lowered_content: str, original_content: str) -> Optional[int]: - """返回 jargon 在文本中的首次命中位置,未命中时返回 `None`。""" - if re.search(r"[\u4e00-\u9fff]", jargon_content): - match_index = original_content.lower().find(jargon_content.lower()) - return match_index if match_index >= 0 else None - - pattern = rf"\b{re.escape(jargon_content.lower())}\b" - match = re.search(pattern, lowered_content) - if match is None: - return None - return match.start() + def _insert_chat_history_message(self, message: LLMContextMessage) -> int: + """将消息按处理顺序追加到聊天历史末尾。""" + self._runtime._chat_history.append(message) + return len(self._runtime._chat_history) - 1 def _start_cycle(self) -> CycleDetail: """开始一轮 Maisaka 思考循环。""" @@ -397,10 +291,7 @@ class MaisakaReasoningEngine: def _trim_chat_history(self) -> None: """裁剪聊天历史,保证用户消息数量不超过配置限制。""" - counted_roles = {"user", "assistant"} - conversation_message_count = sum( - 1 for message in self._runtime._chat_history if get_message_role(message) in counted_roles - ) + conversation_message_count = sum(1 for message in self._runtime._chat_history if message.count_in_context) if conversation_message_count <= self._runtime._max_context_size: return @@ -410,7 +301,7 @@ class MaisakaReasoningEngine: while conversation_message_count >= self._runtime._max_context_size and trimmed_history: removed_message = trimmed_history.pop(0) removed_count += 1 - if get_message_role(removed_message) in counted_roles: + if removed_message.count_in_context: conversation_message_count -= 1 self._runtime._chat_history = trimmed_history @@ -441,6 +332,11 @@ class MaisakaReasoningEngine: bool: 是否需要替换 """ if not self._last_reasoning_content or not current_content: + logger.info( + f"{self._runtime.log_prefix} reasoning similarity skipped: " + f"last_empty={not bool(self._last_reasoning_content)} " + f"current_empty={not bool(current_content)} similarity=0.00" + ) return False similarity = self._calculate_similarity(current_content, self._last_reasoning_content) @@ -495,13 +391,7 @@ class MaisakaReasoningEngine: except (TypeError, ValueError): wait_seconds = 30 wait_seconds = max(0, wait_seconds) - self._runtime._chat_history.append( - self._build_tool_message( - tool_call, - f"Waiting for future input for up to {wait_seconds} seconds.", - ) - ) - self._runtime._enter_wait_state(seconds=wait_seconds) + self._runtime._enter_wait_state(seconds=wait_seconds, tool_call_id=tool_call.call_id) return True if tool_call.func_name == "stop": @@ -743,33 +633,27 @@ class MaisakaReasoningEngine: tool_reasoning=latest_thought, ) - target_platform = target_message.platform or anchor_message.platform bot_name = global_config.bot.nickname.strip() or "MaiSaka" - bot_user_info = UserInfo( - user_id=get_bot_account(target_platform) or "maisaka_assistant", - user_nickname=bot_name, - user_cardname=None, + reply_timestamp = datetime.now() + planner_prefix = ( + f"[时间]{reply_timestamp.strftime('%H:%M:%S')}\n" + f"[用户]{bot_name}\n" + "[用户群昵称]\n" + "[msg_id]\n" + "[发言内容]" ) - history_message = build_message( - role="user", - content="", - source="guided_reply", - platform=target_platform, - session_id=self._runtime.session_id, - group_info=self._runtime._build_group_info(target_message), - user_info=bot_user_info, - ) - history_message.raw_message = MessageSequence( - [TextComponent(f"{self._build_planner_user_prefix(history_message)}{combined_reply_text}")] + history_message = SessionBackedMessage( + raw_message=MessageSequence([TextComponent(f"{planner_prefix}{combined_reply_text}")]), + visible_text="", + timestamp=reply_timestamp, + source_kind="guided_reply", ) visible_reply_text = format_speaker_content( bot_name, combined_reply_text, - history_message.timestamp, - history_message.message_id, + reply_timestamp, ) - history_message.display_message = visible_reply_text - history_message.processed_plain_text = visible_reply_text + history_message.visible_text = visible_reply_text self._runtime._chat_history.append(history_message) return True @@ -871,14 +755,10 @@ class MaisakaReasoningEngine: self._build_tool_message(tool_call, "Failed to send emoji.") ) - def _build_tool_message(self, tool_call: ToolCall, content: str) -> SessionMessage: - return build_message( - role="tool", + def _build_tool_message(self, tool_call: ToolCall, content: str) -> ToolResultMessage: + return ToolResultMessage( content=content, - source="tool", + timestamp=datetime.now(), tool_call_id=tool_call.call_id, - platform=self._runtime.chat_stream.platform, - session_id=self._runtime.session_id, - group_info=self._runtime._build_group_info(), - user_info=UserInfo(user_id="maisaka_tool", user_nickname="tool", user_cardname=None), + tool_name=tool_call.func_name, ) diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 90b4b961..9e34ba72 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -19,6 +19,7 @@ from src.learners.jargon_miner import JargonMiner from src.mcp_module import MCPManager from .chat_loop_service import MaisakaChatLoopService +from .context_messages import LLMContextMessage from .reasoning_engine import MaisakaReasoningEngine logger = get_logger("maisaka_runtime") @@ -40,7 +41,7 @@ class MaisakaHeartFlowChatting: session_name = chat_manager.get_session_name(session_id) or session_id self.log_prefix = f"[{session_name}]" self._chat_loop_service = MaisakaChatLoopService() - self._chat_history: list[SessionMessage] = [] + self._chat_history: list[LLMContextMessage] = [] self.history_loop: list[CycleDetail] = [] # Keep all original messages for batching and later learning. @@ -60,6 +61,8 @@ class MaisakaHeartFlowChatting: self._max_context_size = max(1, int(global_config.chat.max_context_size)) self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP self._wait_until: Optional[float] = None + self._pending_wait_tool_call_id: Optional[str] = None + self._planner_interrupt_flag: Optional[asyncio.Event] = None expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id) self._enable_expression_use = expr_use @@ -78,14 +81,14 @@ class MaisakaHeartFlowChatting: async def start(self) -> None: """Start the runtime loop.""" if self._running: + self._ensure_background_tasks_running() return if global_config.maisaka.enable_mcp: await self._init_mcp() self._running = True - self._internal_loop_task = asyncio.create_task(self._reasoning_engine.run_loop()) - self._loop_task = asyncio.create_task(self._main_loop()) + self._ensure_background_tasks_running() logger.info(f"{self.log_prefix} Maisaka runtime started") async def stop(self) -> None: @@ -128,12 +131,48 @@ class MaisakaHeartFlowChatting: async def register_message(self, message: SessionMessage) -> None: """Cache a new message and wake the main loop.""" + if self._running: + self._ensure_background_tasks_running() self.message_cache.append(message) self._source_messages_by_id[message.message_id] = message + if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None: + logger.info( + f"{self.log_prefix} 收到新消息,发起 planner 打断; " + f"msg_id={message.message_id} cache_size={len(self.message_cache)} " + f"timestamp={time.time():.3f}" + ) + self._planner_interrupt_flag.set() if self._agent_state in (self._STATE_WAIT, self._STATE_STOP): self._agent_state = self._STATE_RUNNING self._new_message_event.set() + def _ensure_background_tasks_running(self) -> None: + """确保后台任务仍在运行,若崩溃则自动拉起。""" + if not self._running: + return + + if self._internal_loop_task is None or self._internal_loop_task.done(): + if self._internal_loop_task is not None and not self._internal_loop_task.cancelled(): + try: + exc = self._internal_loop_task.exception() + except Exception: + exc = None + if exc is not None: + logger.error(f"{self.log_prefix} internal loop task exited unexpectedly: {exc}") + self._internal_loop_task = asyncio.create_task(self._reasoning_engine.run_loop()) + logger.warning(f"{self.log_prefix} restarted Maisaka internal loop task") + + if self._loop_task is None or self._loop_task.done(): + if self._loop_task is not None and not self._loop_task.cancelled(): + try: + exc = self._loop_task.exception() + except Exception: + exc = None + if exc is not None: + logger.error(f"{self.log_prefix} main loop task exited unexpectedly: {exc}") + self._loop_task = asyncio.create_task(self._main_loop()) + logger.warning(f"{self.log_prefix} restarted Maisaka main loop task") + async def _main_loop(self) -> None: try: while self._running: @@ -222,15 +261,17 @@ class MaisakaHeartFlowChatting: self._wait_until = None return "timeout" - def _enter_wait_state(self, seconds: Optional[float] = None) -> None: + def _enter_wait_state(self, seconds: Optional[float] = None, tool_call_id: Optional[str] = None) -> None: """Enter wait state.""" self._agent_state = self._STATE_WAIT self._wait_until = None if seconds is None else time.time() + seconds + self._pending_wait_tool_call_id = tool_call_id def _enter_stop_state(self) -> None: """Enter stop state.""" self._agent_state = self._STATE_STOP self._wait_until = None + self._pending_wait_tool_call_id = None async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None: """按同一批消息触发表达方式、黑话和 knowledge 学习。""" diff --git a/src/maisaka/tool_handlers.py b/src/maisaka/tool_handlers.py index 4724a2b5..904046e2 100644 --- a/src/maisaka/tool_handlers.py +++ b/src/maisaka/tool_handlers.py @@ -9,12 +9,11 @@ import json as _json from rich.panel import Panel -from src.chat.message_receive.message import SessionMessage from src.cli.console import console from src.cli.input_reader import InputReader from src.llm_models.payload_content.tool_option import ToolCall -from .message_adapter import build_message +from .context_messages import LLMContextMessage, ToolResultMessage if TYPE_CHECKING: from src.mcp_module import MCPManager @@ -33,22 +32,34 @@ class ToolHandlerContext: self.last_user_input_time: Optional[datetime] = None -async def handle_stop(tc: ToolCall, chat_history: list[SessionMessage]) -> None: +async def handle_stop(tc: ToolCall, chat_history: list[LLMContextMessage]) -> None: """处理 stop 工具。""" console.print("[accent]调用工具: stop()[/accent]") chat_history.append( - build_message(role="tool", content="当前轮次结束后将停止对话循环。", tool_call_id=tc.call_id) + ToolResultMessage( + content="当前轮次结束后将停止对话循环。", + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, + ) ) -async def handle_wait(tc: ToolCall, chat_history: list[SessionMessage], ctx: ToolHandlerContext) -> str: +async def handle_wait(tc: ToolCall, chat_history: list[LLMContextMessage], ctx: ToolHandlerContext) -> str: """处理 wait 工具。""" seconds = (tc.args or {}).get("seconds", 30) seconds = max(5, min(seconds, 300)) console.print(f"[accent]调用工具: wait({seconds})[/accent]") tool_result = await _do_wait(seconds, ctx) - chat_history.append(build_message(role="tool", content=tool_result, tool_call_id=tc.call_id)) + chat_history.append( + ToolResultMessage( + content=tool_result, + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, + ) + ) return tool_result @@ -78,7 +89,7 @@ async def _do_wait(seconds: int, ctx: ToolHandlerContext) -> str: return f"已收到用户输入: {user_input}" -async def handle_mcp_tool(tc: ToolCall, chat_history: list[SessionMessage], mcp_manager: "MCPManager") -> None: +async def handle_mcp_tool(tc: ToolCall, chat_history: list[LLMContextMessage], mcp_manager: "MCPManager") -> None: """处理 MCP 工具调用。""" args_str = _json.dumps(tc.args or {}, ensure_ascii=False) args_preview = args_str if len(args_str) <= 120 else args_str[:120] + "..." @@ -96,10 +107,24 @@ async def handle_mcp_tool(tc: ToolCall, chat_history: list[SessionMessage], mcp_ padding=(0, 1), ) ) - chat_history.append(build_message(role="tool", content=result, tool_call_id=tc.call_id)) + chat_history.append( + ToolResultMessage( + content=result, + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, + ) + ) -async def handle_unknown_tool(tc: ToolCall, chat_history: list[SessionMessage]) -> None: +async def handle_unknown_tool(tc: ToolCall, chat_history: list[LLMContextMessage]) -> None: """处理未知工具调用。""" console.print(f"[accent]调用未知工具: {tc.func_name}({tc.args})[/accent]") - chat_history.append(build_message(role="tool", content=f"未知工具: {tc.func_name}", tool_call_id=tc.call_id)) + chat_history.append( + ToolResultMessage( + content=f"未知工具: {tc.func_name}", + timestamp=datetime.now(), + tool_call_id=tc.call_id, + tool_name=tc.func_name, + ) + )