diff --git a/src/cli/maisaka_cli.py b/src/cli/maisaka_cli.py index f3c88d73..1174ea67 100644 --- a/src/cli/maisaka_cli.py +++ b/src/cli/maisaka_cli.py @@ -20,6 +20,7 @@ from src.chat.message_receive.message import SessionMessage from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator from src.config.config import config_manager, global_config from src.mcp_module import MCPManager +from src.mcp_module.host_llm_bridge import MCPHostLLMBridge from src.maisaka.chat_loop_service import MaisakaChatLoopService from src.maisaka.context_messages import ( @@ -66,6 +67,7 @@ class BufferCLI: self._last_assistant_response_time: Optional[datetime] = None self._user_input_times: list[datetime] = [] self._mcp_manager: Optional[MCPManager] = None + self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None self._init_llm() def _init_llm(self) -> None: @@ -383,17 +385,23 @@ class BufferCLI: async def _init_mcp(self) -> None: """初始化 MCP 服务并注册暴露的工具。""" - self._mcp_manager = await MCPManager.from_app_config(global_config.mcp) + self._mcp_host_bridge = MCPHostLLMBridge( + sampling_task_name=global_config.mcp.client.sampling.task_name, + ) + self._mcp_manager = await MCPManager.from_app_config( + global_config.mcp, + host_callbacks=self._mcp_host_bridge.build_callbacks(), + ) if self._mcp_manager and self._chat_loop_service: mcp_tools = self._mcp_manager.get_openai_tools() if mcp_tools: self._chat_loop_service.set_extra_tools(mcp_tools) - summary = self._mcp_manager.get_tool_summary() + summary = self._mcp_manager.get_feature_summary() console.print( Panel( - f"已加载 {len(mcp_tools)} 个 MCP 工具:\n{summary}", - title="MCP 工具", + f"已加载 {len(mcp_tools)} 个 MCP 工具。\n{summary}", + title="MCP 能力", border_style="green", padding=(0, 1), ) @@ -452,3 +460,4 @@ class BufferCLI: finally: if self._mcp_manager: await self._mcp_manager.close() + self._mcp_host_bridge = None diff --git a/src/config/config.py b/src/config/config.py index 318c987f..c85a170a 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -57,7 +57,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config" BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.2.0" +CONFIG_VERSION: str = "8.3.0" MODEL_CONFIG_VERSION: str = "1.13.1" logger = get_logger("config") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index edb3e2c2..f3099a5c 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,6 +1,7 @@ -import re from typing import Literal, Optional +import re + from .config_base import ConfigBase, Field """ @@ -1569,6 +1570,225 @@ class MaiSakaConfig(ConfigBase): """Maisaka终端图片预览的字符宽度""" +class MCPAuthorizationConfig(ConfigBase): + """MCP HTTP 认证配置。""" + + mode: Literal["none", "bearer"] = Field( + default="none", + json_schema_extra={ + "x-widget": "select", + "x-icon": "shield", + }, + ) + """认证模式,当前支持无认证和静态 Bearer Token""" + + bearer_token: str = Field( + default="", + json_schema_extra={ + "x-widget": "password", + "x-icon": "key", + }, + ) + """静态 Bearer Token,仅在 `mode=\"bearer\"` 时使用""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证 MCP 认证配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if self.mode == "bearer" and not self.bearer_token.strip(): + raise ValueError("MCP 使用 bearer 认证时必须填写 bearer_token") + return super().model_post_init(context) + + +class MCPRootItemConfig(ConfigBase): + """单个 MCP Root 配置。""" + + enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "power", + }, + ) + """是否启用当前 Root""" + + uri: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "folder", + }, + ) + """Root URI,通常为 `file://` 路径 URI""" + + name: str = Field( + default="", + json_schema_extra={ + "x-widget": "input", + "x-icon": "tag", + }, + ) + """Root 的显示名称""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证单个 Root 配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if self.enabled and not self.uri.strip(): + raise ValueError("启用的 MCP Root 必须填写 uri") + return super().model_post_init(context) + + +class MCPRootsConfig(ConfigBase): + """MCP Roots 能力配置。""" + + enable: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "folder-tree", + }, + ) + """是否向 MCP 服务器暴露 Roots 能力""" + + items: list[MCPRootItemConfig] = Field( + default_factory=lambda: [], + json_schema_extra={ + "x-widget": "custom", + "x-icon": "folder", + }, + ) + """Roots 列表""" + + +class MCPSamplingConfig(ConfigBase): + """MCP Sampling 能力配置。""" + + enable: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "brain", + }, + ) + """是否启用 Sampling 能力声明""" + + task_name: str = Field( + default="planner", + json_schema_extra={ + "x-widget": "input", + "x-icon": "sparkles", + }, + ) + """执行 Sampling 请求时使用的主程序模型任务名""" + + include_context_support: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "layers", + }, + ) + """是否声明支持 `includeContext` 非 `none` 语义""" + + tool_support: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "wrench", + }, + ) + """是否声明支持在 Sampling 中继续使用工具""" + + +class MCPElicitationConfig(ConfigBase): + """MCP Elicitation 能力配置。""" + + enable: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "message-circle-question", + }, + ) + """是否启用 Elicitation 能力声明""" + + allow_form: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "form-input", + }, + ) + """是否允许表单模式 Elicitation""" + + allow_url: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "link", + }, + ) + """是否允许 URL 模式 Elicitation""" + + def model_post_init(self, context: Optional[dict] = None) -> None: + """验证 Elicitation 配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ + + if self.enable and not (self.allow_form or self.allow_url): + raise ValueError("启用 MCP Elicitation 时至少需要允许一种模式") + return super().model_post_init(context) + + +class MCPClientConfig(ConfigBase): + """MCP 客户端宿主能力配置。""" + + client_name: str = Field( + default="MaiBot", + json_schema_extra={ + "x-widget": "input", + "x-icon": "bot", + }, + ) + """MCP 客户端实现名称""" + + client_version: str = Field( + default="1.0.0", + json_schema_extra={ + "x-widget": "input", + "x-icon": "info", + }, + ) + """MCP 客户端实现版本""" + + roots: MCPRootsConfig = Field(default_factory=MCPRootsConfig) + """Roots 能力配置""" + + sampling: MCPSamplingConfig = Field(default_factory=MCPSamplingConfig) + """Sampling 能力配置""" + + elicitation: MCPElicitationConfig = Field(default_factory=MCPElicitationConfig) + """Elicitation 能力配置""" + + class MCPServerItemConfig(ConfigBase): """单个 MCP 服务器配置。""" @@ -1590,14 +1810,14 @@ class MCPServerItemConfig(ConfigBase): ) """是否启用当前 MCP 服务器""" - transport: Literal["stdio", "sse"] = Field( + transport: Literal["stdio", "streamable_http"] = Field( default="stdio", json_schema_extra={ "x-widget": "select", "x-icon": "shuffle", }, ) - """传输方式,可选 stdio 或 sse""" + """传输方式,可选 `stdio` 或 `streamable_http`""" command: str = Field( default="", @@ -1633,7 +1853,7 @@ class MCPServerItemConfig(ConfigBase): "x-icon": "link", }, ) - """sse 模式下的服务地址""" + """`streamable_http` 模式下的 MCP 端点地址""" headers: dict[str, str] = Field( default_factory=lambda: {}, @@ -1642,10 +1862,40 @@ class MCPServerItemConfig(ConfigBase): "x-icon": "file-json", }, ) - """sse 模式下附加的请求头""" + """HTTP 模式下附加的请求头""" + + http_timeout_seconds: float = Field( + default=30.0, + gt=0, + json_schema_extra={ + "x-widget": "number", + "x-icon": "clock-3", + }, + ) + """HTTP 请求超时时间,单位秒""" + + read_timeout_seconds: float = Field( + default=300.0, + gt=0, + json_schema_extra={ + "x-widget": "number", + "x-icon": "timer", + }, + ) + """会话读取超时时间,单位秒""" + + authorization: MCPAuthorizationConfig = Field(default_factory=MCPAuthorizationConfig) + """HTTP 认证配置""" def model_post_init(self, context: Optional[dict] = None) -> None: - """验证 MCP 服务器配置。""" + """验证 MCP 服务器配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ if not self.name.strip(): raise ValueError("MCPServerItemConfig.name 不能为空") @@ -1653,8 +1903,8 @@ class MCPServerItemConfig(ConfigBase): if self.transport == "stdio" and not self.command.strip(): raise ValueError(f"MCP 服务器 {self.name} 使用 stdio 时必须填写 command") - if self.transport == "sse" and not self.url.strip(): - raise ValueError(f"MCP 服务器 {self.name} 使用 sse 时必须填写 url") + if self.transport == "streamable_http" and not self.url.strip(): + raise ValueError(f"MCP 服务器 {self.name} 使用 streamable_http 时必须填写 url") return super().model_post_init(context) @@ -1673,6 +1923,9 @@ class MCPConfig(ConfigBase): ) """是否启用 MCP(Model Context Protocol)""" + client: MCPClientConfig = Field(default_factory=MCPClientConfig) + """MCP 客户端宿主能力配置""" + servers: list[MCPServerItemConfig] = Field( default_factory=lambda: [], json_schema_extra={ @@ -1683,7 +1936,14 @@ class MCPConfig(ConfigBase): """_wrap_MCP 服务器配置列表""" def model_post_init(self, context: Optional[dict] = None) -> None: - """验证 MCP 总配置。""" + """验证 MCP 总配置。 + + Args: + context: Pydantic 传入的上下文对象。 + + Returns: + None + """ server_names = [server.name.strip() for server in self.servers if server.name.strip()] if len(server_names) != len(set(server_names)): diff --git a/src/core/tooling.py b/src/core/tooling.py index 38f3486e..f9c6ec62 100644 --- a/src/core/tooling.py +++ b/src/core/tooling.py @@ -8,8 +8,8 @@ from __future__ import annotations from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Protocol, runtime_checkable import json +from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable from src.common.logger import get_logger from src.llm_models.payload_content.tool_option import ToolDefinitionInput @@ -99,6 +99,64 @@ def build_tool_detailed_description( return "\n".join(lines).strip() +@dataclass(slots=True) +class ToolIcon: + """统一工具图标信息。""" + + src: str + mime_type: str = "" + sizes: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class ToolAnnotation: + """统一工具注解信息。""" + + audience: list[str] = field(default_factory=list) + priority: float | None = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class ToolContentItem: + """统一工具内容项。""" + + content_type: Literal["text", "image", "audio", "resource_link", "resource", "binary", "unknown"] + text: str = "" + data: str = "" + mime_type: str = "" + uri: str = "" + name: str = "" + description: str = "" + annotation: ToolAnnotation | None = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def build_history_text(self) -> str: + """生成适合写入历史消息的文本摘要。 + + Returns: + str: 当前内容项对应的历史摘要文本。 + """ + + if self.content_type == "text" and self.text.strip(): + return self.text.strip() + if self.content_type == "image": + return f"[图片内容 {self.mime_type or 'unknown'}]" + if self.content_type == "audio": + return f"[音频内容 {self.mime_type or 'unknown'}]" + if self.content_type == "resource_link": + label = self.name or self.uri or "资源链接" + return f"[资源链接] {label}" + if self.content_type == "resource": + if self.text.strip(): + return self.text.strip() + label = self.name or self.uri or "嵌入资源" + return f"[嵌入资源] {label}" + if self.content_type == "binary": + return f"[二进制内容 {self.mime_type or 'unknown'}]" + return f"[{self.content_type} 内容]" + + @dataclass(slots=True) class ToolSpec: """统一工具声明。""" @@ -106,10 +164,14 @@ class ToolSpec: name: str brief_description: str detailed_description: str = "" + title: str = "" parameters_schema: Dict[str, Any] | None = None + output_schema: Dict[str, Any] | None = None provider_name: str = "" provider_type: str = "" enabled: bool = True + icons: list[ToolIcon] = field(default_factory=list) + annotation: ToolAnnotation | None = None metadata: Dict[str, Any] = field(default_factory=dict) def build_llm_description(self) -> str: @@ -172,6 +234,7 @@ class ToolExecutionResult: content: str = "" error_message: str = "" structured_content: Any = None + content_items: list[ToolContentItem] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) def get_history_content(self) -> str: @@ -183,6 +246,10 @@ class ToolExecutionResult: if self.content.strip(): return self.content.strip() + if self.content_items: + parts = [item.build_history_text() for item in self.content_items if item.build_history_text().strip()] + if parts: + return "\n".join(parts).strip() if self.structured_content is not None: if isinstance(self.structured_content, str): return self.structured_content.strip() @@ -221,6 +288,8 @@ class ToolRegistry: """统一工具注册表。""" def __init__(self) -> None: + """初始化统一工具注册表。""" + self._providers: list[ToolProvider] = [] def register_provider(self, provider: ToolProvider) -> None: diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index a05c1e37..21c03a06 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -17,6 +17,7 @@ from src.know_u.knowledge import KnowledgeLearner from src.learners.expression_learner import ExpressionLearner from src.learners.jargon_miner import JargonMiner from src.mcp_module import MCPManager +from src.mcp_module.host_llm_bridge import MCPHostLLMBridge from src.mcp_module.provider import MCPToolProvider from src.plugin_runtime.tool_provider import PluginToolProvider @@ -54,6 +55,7 @@ class MaisakaHeartFlowChatting: self._internal_turn_queue: asyncio.Queue[Optional[list[SessionMessage]]] = asyncio.Queue() self._mcp_manager: Optional[MCPManager] = None + self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None self._current_cycle_detail: Optional[CycleDetail] = None self._source_messages_by_id: dict[str, SessionMessage] = {} self._running = False @@ -127,6 +129,7 @@ class MaisakaHeartFlowChatting: await self._tool_registry.close() self._mcp_manager = None + self._mcp_host_bridge = None logger.info(f"{self.log_prefix} Maisaka 运行时已停止") @@ -385,7 +388,13 @@ class MaisakaHeartFlowChatting: async def _init_mcp(self) -> None: """初始化 MCP 工具并注册到统一工具层。""" - self._mcp_manager = await MCPManager.from_app_config(global_config.mcp) + self._mcp_host_bridge = MCPHostLLMBridge( + sampling_task_name=global_config.mcp.client.sampling.task_name, + ) + self._mcp_manager = await MCPManager.from_app_config( + global_config.mcp, + host_callbacks=self._mcp_host_bridge.build_callbacks(), + ) if self._mcp_manager is None: logger.info(f"{self.log_prefix} MCP 管理器不可用") return @@ -397,8 +406,8 @@ class MaisakaHeartFlowChatting: self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager)) logger.info( - f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tool_specs)} 个 MCP 工具:\n" - f"{self._mcp_manager.get_tool_summary()}" + f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tool_specs)} 个 MCP 工具。\n" + f"{self._mcp_manager.get_feature_summary()}" ) def _build_runtime_user_info(self) -> UserInfo: diff --git a/src/mcp_module/config.py b/src/mcp_module/config.py index a9e4a432..4d4d73af 100644 --- a/src/mcp_module/config.py +++ b/src/mcp_module/config.py @@ -6,37 +6,120 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal if TYPE_CHECKING: from src.config.official_configs import MCPConfig +@dataclass(slots=True) +class MCPAuthorizationRuntimeConfig: + """MCP HTTP 认证运行时配置。""" + + mode: Literal["none", "bearer"] = "none" + bearer_token: str = "" + + +@dataclass(slots=True) +class MCPRootRuntimeConfig: + """MCP Root 运行时配置。""" + + uri: str + name: str = "" + + +@dataclass(slots=True) +class MCPClientRuntimeConfig: + """MCP 客户端宿主能力运行时配置。""" + + client_name: str = "MaiBot" + client_version: str = "1.0.0" + enable_roots: bool = False + roots: list[MCPRootRuntimeConfig] = field(default_factory=list) + enable_sampling: bool = False + sampling_task_name: str = "planner" + sampling_include_context_support: bool = False + sampling_tool_support: bool = False + enable_elicitation: bool = False + elicitation_allow_form: bool = True + elicitation_allow_url: bool = False + + @dataclass(slots=True) class MCPServerRuntimeConfig: """单个 MCP 服务器的运行时配置。""" name: str + transport: Literal["stdio", "streamable_http"] = "stdio" command: str = "" args: list[str] = field(default_factory=list) env: dict[str, str] = field(default_factory=dict) url: str = "" headers: dict[str, str] = field(default_factory=dict) + http_timeout_seconds: float = 30.0 + read_timeout_seconds: float = 300.0 + authorization: MCPAuthorizationRuntimeConfig = field(default_factory=MCPAuthorizationRuntimeConfig) @property def transport_type(self) -> str: """返回当前服务器的传输类型。 Returns: - str: ``stdio``、``sse`` 或 ``unknown``。 + str: ``stdio``、``streamable_http`` 或 ``unknown``。 """ - if self.command: + if self.transport == "stdio" and self.command: return "stdio" - if self.url: - return "sse" + if self.transport == "streamable_http" and self.url: + return "streamable_http" return "unknown" + def build_http_headers(self) -> dict[str, str]: + """构建远程 HTTP 连接需要附加的请求头。 + + Returns: + dict[str, str]: 归一化后的请求头集合。 + """ + + headers = {str(key): str(value) for key, value in self.headers.items()} + if self.authorization.mode == "bearer" and self.authorization.bearer_token.strip(): + headers["Authorization"] = f"Bearer {self.authorization.bearer_token.strip()}" + return headers + + +def build_mcp_client_runtime_config(mcp_config: "MCPConfig") -> MCPClientRuntimeConfig: + """将官方 MCP 客户端配置转换为运行时结构。 + + Args: + mcp_config: 主程序中的 MCP 官方配置对象。 + + Returns: + MCPClientRuntimeConfig: MCP 客户端宿主能力运行时配置。 + """ + + roots = [ + MCPRootRuntimeConfig( + uri=root.uri.strip(), + name=root.name.strip(), + ) + for root in mcp_config.client.roots.items + if root.enabled and root.uri.strip() + ] + + return MCPClientRuntimeConfig( + client_name=mcp_config.client.client_name.strip() or "MaiBot", + client_version=mcp_config.client.client_version.strip() or "1.0.0", + enable_roots=mcp_config.client.roots.enable and bool(roots), + roots=roots, + enable_sampling=mcp_config.client.sampling.enable, + sampling_task_name=mcp_config.client.sampling.task_name.strip() or "planner", + sampling_include_context_support=mcp_config.client.sampling.include_context_support, + sampling_tool_support=mcp_config.client.sampling.tool_support, + enable_elicitation=mcp_config.client.elicitation.enable, + elicitation_allow_form=mcp_config.client.elicitation.allow_form, + elicitation_allow_url=mcp_config.client.elicitation.allow_url, + ) + def build_mcp_server_runtime_configs(mcp_config: "MCPConfig") -> list[MCPServerRuntimeConfig]: """将官方 MCP 配置转换为运行时配置列表。 @@ -59,11 +142,18 @@ def build_mcp_server_runtime_configs(mcp_config: "MCPConfig") -> list[MCPServerR runtime_configs.append( MCPServerRuntimeConfig( name=server.name.strip(), + transport=server.transport, command=server.command.strip(), args=[str(argument) for argument in server.args], env={str(key): str(value) for key, value in server.env.items()}, url=server.url.strip(), headers={str(key): str(value) for key, value in server.headers.items()}, + http_timeout_seconds=float(server.http_timeout_seconds), + read_timeout_seconds=float(server.read_timeout_seconds), + authorization=MCPAuthorizationRuntimeConfig( + mode=server.authorization.mode, + bearer_token=server.authorization.bearer_token.strip(), + ), ) ) diff --git a/src/mcp_module/connection.py b/src/mcp_module/connection.py index 0058d518..c598e8bc 100644 --- a/src/mcp_module/connection.py +++ b/src/mcp_module/connection.py @@ -1,16 +1,31 @@ """ MaiSaka - 单个 MCP 服务器连接管理 -封装单个 MCP 服务器的连接生命周期:连接 → 发现工具 → 调用工具 → 断开。 +封装单个 MCP 服务器的连接生命周期:连接 → 发现能力 → 调用工具/读取资源 → 断开。 """ -from contextlib import AsyncExitStack -from typing import Any, Optional +from __future__ import annotations -from src.core.tooling import ToolExecutionResult +from contextlib import AsyncExitStack +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Callable, Optional, cast + +import httpx from src.cli.console import console +from src.core.tooling import ToolExecutionResult -from .config import MCPServerRuntimeConfig +from .config import MCPClientRuntimeConfig, MCPRootRuntimeConfig, MCPServerRuntimeConfig +from .hooks import MCPHostCallbacks +from .models import ( + MCPPromptResult, + MCPResourceReadResult, + build_prompt_result, + build_resource_read_result, + build_tool_content_items, +) + +if TYPE_CHECKING: + from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT # ──────────────────── MCP SDK 可选导入 ──────────────────── # @@ -18,7 +33,7 @@ from .config import MCPServerRuntimeConfig # MCPManager.from_app_config() 会检测到并返回 None,不影响主程序运行。 try: - from mcp import ClientSession + from mcp import ClientSession, types as mcp_types try: from mcp.client.stdio import StdioServerParameters @@ -26,84 +41,114 @@ try: from mcp import StdioServerParameters # type: ignore[attr-defined] from mcp.client.stdio import stdio_client + from mcp.client.streamable_http import streamable_http_client MCP_AVAILABLE = True + STREAMABLE_HTTP_AVAILABLE = True except ImportError: MCP_AVAILABLE = False + STREAMABLE_HTTP_AVAILABLE = False ClientSession = None # type: ignore[assignment,misc] StdioServerParameters = None # type: ignore[assignment,misc] + mcp_types = None # type: ignore[assignment] stdio_client = None # type: ignore[assignment] - -try: - from mcp.client.sse import sse_client - - SSE_AVAILABLE = True -except ImportError: - SSE_AVAILABLE = False - sse_client = None # type: ignore[assignment] + streamable_http_client = None # type: ignore[assignment] class MCPConnection: - """管理单个 MCP 服务器的连接生命周期。 + """管理单个 MCP 服务器的连接生命周期。""" - 支持两种传输方式: - - Stdio: 启动子进程,通过 stdin/stdout 通信 - - SSE: 连接远程 HTTP SSE 端点 - """ - - def __init__(self, config: MCPServerRuntimeConfig) -> None: + def __init__( + self, + config: MCPServerRuntimeConfig, + client_config: MCPClientRuntimeConfig, + host_callbacks: Optional[MCPHostCallbacks] = None, + ) -> None: """初始化单个 MCP 连接。 Args: config: 当前服务器的运行时配置。 + client_config: MCP 客户端宿主能力运行时配置。 + host_callbacks: 宿主侧能力回调集合。 """ self.config = config - self.session: Optional[Any] = None # mcp.ClientSession - self.tools: list = [] # mcp Tool objects + self.client_config = client_config + self.host_callbacks = host_callbacks or MCPHostCallbacks() + + self.session: Optional[Any] = None + self.server_capabilities: Optional[Any] = None + self.tools: list[Any] = [] + self.prompts: list[Any] = [] + self.resources: list[Any] = [] + self.resource_templates: list[Any] = [] + self.protocol_version: str = "" + + self._http_client: Optional[httpx.AsyncClient] = None + self._session_id_getter: Optional[Callable[[], str | None]] = None self._exit_stack = AsyncExitStack() - async def connect(self) -> bool: - """ - 连接到 MCP 服务器并发现可用工具。 + @property + def session_id(self) -> str: + """返回当前连接协商得到的 MCP 会话标识。 Returns: - True 表示连接成功,False 表示失败。 + str: 当前会话 ID;无会话时返回空字符串。 """ + + if self._session_id_getter is None: + return "" + return self._session_id_getter() or "" + + async def connect(self) -> bool: + """连接到 MCP 服务器并发现可用能力。 + + Returns: + bool: `True` 表示连接成功,`False` 表示失败。 + """ + if not MCP_AVAILABLE: console.print("[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]") return False try: await self._exit_stack.__aenter__() + read_stream, write_stream = await self._connect_transport() + session = await self._create_client_session(read_stream, write_stream) + self.session = session + initialize_result = await session.initialize() + self.server_capabilities = getattr(initialize_result, "capabilities", None) + self.protocol_version = str(getattr(initialize_result, "protocolVersion", "") or "") - if self.config.transport_type == "stdio": - read_stream, write_stream = await self._connect_stdio() - elif self.config.transport_type == "sse": - read_stream, write_stream = await self._connect_sse() - else: - console.print(f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]") - return False - - # 创建并初始化 MCP 会话 - if ClientSession is None: - raise RuntimeError("当前环境未安装可用的 MCP ClientSession") - self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream)) - await self.session.initialize() - - # 发现工具 - result = await self.session.list_tools() - self.tools = result.tools if hasattr(result, "tools") else [] - + await self._load_server_features() return True - except Exception as e: - console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]") + except Exception as exc: + console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {exc}[/warning]") await self.close() return False - async def _connect_stdio(self): - """建立 Stdio 传输连接。""" + async def _connect_transport(self) -> tuple[Any, Any]: + """根据配置建立底层传输连接。 + + Returns: + tuple[Any, Any]: 读写流对象。 + """ + + if self.config.transport_type == "stdio": + return await self._connect_stdio() + if self.config.transport_type == "streamable_http": + return await self._connect_streamable_http() + + raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}") + + async def _connect_stdio(self) -> tuple[Any, Any]: + """建立 stdio 传输连接。 + + Returns: + tuple[Any, Any]: 读写流对象。 + """ + if StdioServerParameters is None or stdio_client is None: raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端") if not self.config.command: @@ -116,15 +161,293 @@ class MCPConnection: ) return await self._exit_stack.enter_async_context(stdio_client(params)) - async def _connect_sse(self): - """建立 SSE 传输连接。""" - if not SSE_AVAILABLE: - raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]") - if sse_client is None: - raise RuntimeError("当前环境未安装可用的 MCP SSE 客户端") + async def _connect_streamable_http(self) -> tuple[Any, Any]: + """建立 Streamable HTTP 传输连接。 + + Returns: + tuple[Any, Any]: 读写流对象。 + """ + + if not STREAMABLE_HTTP_AVAILABLE or streamable_http_client is None: + raise ImportError("当前环境未安装可用的 MCP Streamable HTTP 客户端") if not self.config.url: - raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置") - return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers)) + raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 Streamable HTTP url 配置") + + self._http_client = await self._exit_stack.enter_async_context(self._build_http_client()) + read_stream, write_stream, session_id_getter = await self._exit_stack.enter_async_context( + streamable_http_client( + url=self.config.url, + http_client=self._http_client, + terminate_on_close=True, + ) + ) + self._session_id_getter = session_id_getter + return read_stream, write_stream + + def _build_http_client(self) -> httpx.AsyncClient: + """构建 Streamable HTTP 使用的 `httpx` 客户端。 + + Returns: + httpx.AsyncClient: 预配置的异步 HTTP 客户端。 + """ + + return httpx.AsyncClient( + headers=self.config.build_http_headers(), + timeout=httpx.Timeout(self.config.http_timeout_seconds), + ) + + async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any: + """创建并返回 MCP `ClientSession`。 + + Args: + read_stream: 底层读取流。 + write_stream: 底层写入流。 + + Returns: + Any: 已初始化的 MCP `ClientSession` 实例。 + """ + + if ClientSession is None: + raise RuntimeError("当前环境未安装可用的 MCP ClientSession") + + list_roots_callback = self._build_list_roots_callback() + sampling_callback = ( + self.host_callbacks.sampling_callback + if self.client_config.enable_sampling and self.host_callbacks.sampling_callback is not None + else None + ) + elicitation_callback = ( + self.host_callbacks.elicitation_callback + if self.client_config.enable_elicitation and self.host_callbacks.elicitation_callback is not None + else None + ) + logging_callback = cast(Optional["LoggingFnT"], self.host_callbacks.logging_callback) + message_handler = cast(Optional["MessageHandlerFnT"], self.host_callbacks.message_handler) + + if self.client_config.enable_sampling and sampling_callback is None: + console.print( + f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 sampling 配置,但宿主未提供 sampling 回调,当前不会声明该能力[/warning]" + ) + if self.client_config.enable_elicitation and elicitation_callback is None: + console.print( + f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 elicitation 配置,但宿主未提供 elicitation 回调,当前不会声明该能力[/warning]" + ) + + session = await self._exit_stack.enter_async_context( + ClientSession( + read_stream, + write_stream, + read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds), + sampling_callback=cast(Optional["SamplingFnT"], sampling_callback), + elicitation_callback=cast(Optional["ElicitationFnT"], elicitation_callback), + list_roots_callback=cast(Optional["ListRootsFnT"], list_roots_callback), + logging_callback=logging_callback, + message_handler=message_handler, + client_info=self._build_client_info(), + sampling_capabilities=self._build_sampling_capabilities(sampling_callback), + ) + ) + return session + + def _build_client_info(self) -> Any: + """构建 MCP 客户端实现信息。 + + Returns: + Any: MCP SDK 的 `Implementation` 对象。 + """ + + if mcp_types is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + + return mcp_types.Implementation( + name=self.client_config.client_name, + version=self.client_config.client_version, + ) + + def _build_sampling_capabilities(self, sampling_callback: Any) -> Any | None: + """构建 Sampling 能力声明。 + + Args: + sampling_callback: 当前宿主侧的 Sampling 回调。 + + Returns: + Any | None: Sampling 能力对象;未启用时返回 ``None``。 + """ + + if mcp_types is None: + return None + if sampling_callback is None: + return None + + context_capability = ( + mcp_types.SamplingContextCapability() + if self.client_config.sampling_include_context_support + else None + ) + tools_capability = ( + mcp_types.SamplingToolsCapability() + if self.client_config.sampling_tool_support + else None + ) + return mcp_types.SamplingCapability( + context=context_capability, + tools=tools_capability, + ) + + def _build_list_roots_callback(self) -> Any | None: + """构建 Roots 列表回调。 + + Returns: + Any | None: 符合 MCP SDK 要求的回调;未启用时返回 ``None``。 + """ + + if mcp_types is None: + return None + if not self.client_config.enable_roots or not self.client_config.roots: + return None + + async def _list_roots(context: Any) -> Any: + """返回当前客户端声明的 Roots 列表。 + + Args: + context: MCP 请求上下文。 + + Returns: + Any: MCP `ListRootsResult` 对象。 + """ + + del context + types_module = mcp_types + if types_module is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + roots = [ + types_module.Root(uri=cast(Any, root.uri), name=root.name or None) + for root in self.client_config.roots + ] + return types_module.ListRootsResult(roots=roots) + + return _list_roots + + async def _load_server_features(self) -> None: + """根据服务端能力声明加载工具、Prompt 与 Resource。""" + + self.tools = await self._list_tools() if self.supports_tools() else [] + self.prompts = await self._list_prompts() if self.supports_prompts() else [] + self.resources = await self._list_resources() if self.supports_resources() else [] + self.resource_templates = ( + await self._list_resource_templates() if self.supports_resources() else [] + ) + + def supports_tools(self) -> bool: + """判断服务端是否声明支持 Tools。 + + Returns: + bool: 是否支持 Tools。 + """ + + return bool(self.server_capabilities is not None and getattr(self.server_capabilities, "tools", None) is not None) + + def supports_prompts(self) -> bool: + """判断服务端是否声明支持 Prompts。 + + Returns: + bool: 是否支持 Prompts。 + """ + + return bool( + self.server_capabilities is not None and getattr(self.server_capabilities, "prompts", None) is not None + ) + + def supports_resources(self) -> bool: + """判断服务端是否声明支持 Resources。 + + Returns: + bool: 是否支持 Resources。 + """ + + return bool( + self.server_capabilities is not None and getattr(self.server_capabilities, "resources", None) is not None + ) + + async def _list_tools(self) -> list[Any]: + """分页加载服务端暴露的全部工具。 + + Returns: + list[Any]: MCP SDK 的原始工具对象列表。 + """ + + if self.session is None: + return [] + + tools: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_tools(cursor=cursor) + tools.extend(list(getattr(result, "tools", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return tools + + async def _list_prompts(self) -> list[Any]: + """分页加载服务端暴露的全部 Prompt。 + + Returns: + list[Any]: MCP SDK 的原始 Prompt 对象列表。 + """ + + if self.session is None: + return [] + + prompts: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_prompts(cursor=cursor) + prompts.extend(list(getattr(result, "prompts", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return prompts + + async def _list_resources(self) -> list[Any]: + """分页加载服务端暴露的全部 Resource。 + + Returns: + list[Any]: MCP SDK 的原始 Resource 对象列表。 + """ + + if self.session is None: + return [] + + resources: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_resources(cursor=cursor) + resources.extend(list(getattr(result, "resources", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return resources + + async def _list_resource_templates(self) -> list[Any]: + """分页加载服务端暴露的全部 Resource Template。 + + Returns: + list[Any]: MCP SDK 的原始 Resource Template 对象列表。 + """ + + if self.session is None: + return [] + + resource_templates: list[Any] = [] + cursor: Optional[str] = None + while True: + result = await self.session.list_resource_templates(cursor=cursor) + resource_templates.extend(list(getattr(result, "resourceTemplates", []) or [])) + cursor = getattr(result, "nextCursor", None) + if not cursor: + break + return resource_templates async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult: """调用 MCP 工具并返回统一执行结果。 @@ -137,15 +460,20 @@ class MCPConnection: ToolExecutionResult: 统一执行结果。 """ - if not self.session: + if self.session is None: return ToolExecutionResult( tool_name=tool_name, success=False, error_message=f"MCP 服务器 '{self.config.name}' 未连接", + metadata={"server_name": self.config.name}, ) try: - result = await self.session.call_tool(tool_name, arguments=arguments) + result = await self.session.call_tool( + tool_name, + arguments=arguments, + read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds), + ) except Exception as exc: return ToolExecutionResult( tool_name=tool_name, @@ -154,33 +482,78 @@ class MCPConnection: metadata={"server_name": self.config.name}, ) - text_parts: list[str] = [] - binary_parts: list[dict[str, Any]] = [] - for content in result.content: - if hasattr(content, "text"): - text_parts.append(str(content.text)) - elif hasattr(content, "data"): - content_type = getattr(content, "mimeType", "unknown") - binary_parts.append({"mime_type": content_type, "type": "binary"}) - text_parts.append(f"[{content_type} 二进制内容]") - elif hasattr(content, "type"): - text_parts.append(f"[{content.type} 内容]") + content_items = build_tool_content_items(list(getattr(result, "content", []) or [])) + text_parts = [item.text.strip() for item in content_items if item.content_type == "text" and item.text.strip()] + structured_content = getattr(result, "structuredContent", None) + is_error = bool(getattr(result, "isError", False)) + history_content = "\n".join(text_parts).strip() + error_message = history_content if is_error else "" return ToolExecutionResult( tool_name=tool_name, - success=True, - content="\n".join(text_parts) if text_parts else "工具执行成功(无输出)", + success=not is_error, + content=history_content if not is_error else "", + error_message=error_message, + structured_content=structured_content, + content_items=content_items, metadata={ "server_name": self.config.name, - "binary_parts": binary_parts, + "protocol_version": self.protocol_version, + "session_id": self.session_id, }, ) + async def get_prompt( + self, + prompt_name: str, + arguments: Optional[dict[str, str]] = None, + ) -> MCPPromptResult: + """读取指定 MCP Prompt 的内容。 + + Args: + prompt_name: Prompt 名称。 + arguments: Prompt 参数字典。 + + Returns: + MCPPromptResult: 统一 Prompt 结果。 + """ + + if self.session is None: + raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接") + + result = await self.session.get_prompt(prompt_name, arguments=arguments) + return build_prompt_result(result, prompt_name=prompt_name, server_name=self.config.name) + + async def read_resource(self, uri: str) -> MCPResourceReadResult: + """读取指定 MCP Resource 的内容。 + + Args: + uri: 资源 URI。 + + Returns: + MCPResourceReadResult: 统一资源读取结果。 + """ + + if self.session is None: + raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接") + + result = await self.session.read_resource(uri) + return build_resource_read_result(result, uri=uri, server_name=self.config.name) + async def close(self) -> None: """关闭连接并释放资源。""" + try: await self._exit_stack.aclose() except Exception: pass + self.session = None + self.server_capabilities = None self.tools = [] + self.prompts = [] + self.resources = [] + self.resource_templates = [] + self.protocol_version = "" + self._http_client = None + self._session_id_getter = None diff --git a/src/mcp_module/hooks.py b/src/mcp_module/hooks.py new file mode 100644 index 00000000..c1890390 --- /dev/null +++ b/src/mcp_module/hooks.py @@ -0,0 +1,20 @@ +"""MCP 宿主回调声明。""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable + + +@dataclass(slots=True) +class MCPHostCallbacks: + """MCP 宿主回调集合。 + + 该对象用于向 `MCPConnection` 注入宿主侧可选能力, + 例如 Sampling、Elicitation、日志消费和自定义消息处理。 + """ + + sampling_callback: Callable[..., Awaitable[Any]] | None = None + elicitation_callback: Callable[..., Awaitable[Any]] | None = None + logging_callback: Callable[..., Awaitable[None]] | None = None + message_handler: Callable[..., Awaitable[None]] | None = None diff --git a/src/mcp_module/host_llm_bridge.py b/src/mcp_module/host_llm_bridge.py new file mode 100644 index 00000000..1b8bc10d --- /dev/null +++ b/src/mcp_module/host_llm_bridge.py @@ -0,0 +1,597 @@ +"""MCP 宿主侧大模型桥接服务。""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +import json + +from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMResponseResult +from src.common.logger import get_logger +from src.core.tooling import build_tool_detailed_description +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType +from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput +from src.services.llm_service import LLMServiceClient + +from .hooks import MCPHostCallbacks +from .models import build_tool_content_items + +if TYPE_CHECKING: + from src.llm_models.model_client.base_client import BaseClient + +try: + from mcp import types as mcp_types + + MCP_TYPES_AVAILABLE = True +except ImportError: + mcp_types = None # type: ignore[assignment] + MCP_TYPES_AVAILABLE = False + +logger = get_logger("mcp_host_llm_bridge") + + +class MCPHostLLMBridge: + """将 MCP Sampling 请求桥接到主程序大模型调用链。""" + + def __init__(self, sampling_task_name: str = "planner") -> None: + """初始化 MCP 宿主侧大模型桥接服务。 + + Args: + sampling_task_name: 执行 Sampling 请求时使用的模型任务名。 + """ + + self._sampling_task_name = sampling_task_name.strip() or "planner" + self._sampling_client = LLMServiceClient( + task_name=self._sampling_task_name, + request_type="mcp_sampling", + ) + + def build_callbacks(self) -> MCPHostCallbacks: + """构建可注入给 MCP 连接层的宿主回调集合。 + + Returns: + MCPHostCallbacks: 包含 Sampling 回调的宿主回调集合。 + """ + + return MCPHostCallbacks( + sampling_callback=self.handle_sampling_request, + ) + + async def handle_sampling_request(self, context: Any, params: Any) -> Any: + """处理服务端发起的 MCP Sampling 请求。 + + Args: + context: MCP SDK 传入的请求上下文。 + params: `sampling/createMessage` 请求参数。 + + Returns: + Any: MCP `CreateMessageResult`、`CreateMessageResultWithTools` 或 `ErrorData`。 + """ + + del context + if not MCP_TYPES_AVAILABLE or mcp_types is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + + try: + tool_choice_mode = self._get_tool_choice_mode(params) + tool_definitions = self._build_tool_definitions( + raw_tools=getattr(params, "tools", None), + tool_choice_mode=tool_choice_mode, + ) + message_factory = self._build_message_factory( + raw_messages=list(getattr(params, "messages", []) or []), + system_prompt=self._build_system_prompt( + raw_system_prompt=str(getattr(params, "systemPrompt", "") or ""), + tool_choice_mode=tool_choice_mode, + tool_definitions=tool_definitions, + ), + ) + + generation_result = await self._sampling_client.generate_response_with_messages( + message_factory=message_factory, + options=LLMGenerationOptions( + temperature=self._coerce_float(getattr(params, "temperature", None)), + max_tokens=int(getattr(params, "maxTokens", 1024) or 1024), + tool_options=tool_definitions, + ), + ) + + if tool_choice_mode == "required" and tool_definitions and not generation_result.tool_calls: + return mcp_types.ErrorData( + code=mcp_types.INTERNAL_ERROR, + message="Sampling 要求必须调用工具,但模型未返回任何工具调用", + ) + + return self._build_sampling_result( + generation_result=generation_result, + tools_enabled=bool(tool_definitions), + ) + + except Exception as exc: + logger.exception(f"MCP Sampling 调用失败: {exc}") + return mcp_types.ErrorData( + code=mcp_types.INTERNAL_ERROR, + message=f"MCP Sampling 调用失败: {exc}", + ) + + @staticmethod + def _coerce_float(raw_value: Any) -> float | None: + """将任意原始值转换为浮点数。 + + Args: + raw_value: 原始输入值。 + + Returns: + float | None: 转换后的浮点数;无法转换时返回 ``None``。 + """ + + if raw_value is None: + return None + if isinstance(raw_value, int | float): + return float(raw_value) + return None + + @staticmethod + def _get_tool_choice_mode(params: Any) -> str: + """读取 Sampling 请求中的工具选择模式。 + + Args: + params: Sampling 请求参数对象。 + + Returns: + str: `auto`、`required` 或 `none`;缺省时返回 `auto`。 + """ + + tool_choice = getattr(params, "toolChoice", None) + mode = str(getattr(tool_choice, "mode", "") or "").strip().lower() + if mode in {"required", "none"}: + return mode + return "auto" + + def _build_system_prompt( + self, + raw_system_prompt: str, + tool_choice_mode: str, + tool_definitions: list[ToolDefinitionInput] | None, + ) -> str: + """构建发送给主程序大模型的系统提示词。 + + Args: + raw_system_prompt: 服务端请求中的系统提示词。 + tool_choice_mode: 当前工具选择模式。 + tool_definitions: 参与本次 Sampling 的工具定义。 + + Returns: + str: 最终系统提示词。 + """ + + prompt_parts: list[str] = [] + if raw_system_prompt.strip(): + prompt_parts.append(raw_system_prompt.strip()) + if tool_choice_mode == "required" and tool_definitions: + prompt_parts.append("本轮回答必须至少调用一个工具;不要直接结束回答。") + return "\n\n".join(part for part in prompt_parts if part).strip() + + def _build_message_factory( + self, + raw_messages: list[Any], + system_prompt: str, + ) -> Any: + """构建 MCP Sampling 使用的消息工厂。 + + Args: + raw_messages: MCP Sampling 原始消息列表。 + system_prompt: 规范化后的系统提示词。 + + Returns: + Any: 供 `LLMServiceClient` 使用的消息工厂。 + """ + + def _message_factory(client: "BaseClient") -> list[Message]: + """延迟构建内部消息列表。 + + Args: + client: 当前被选中的底层模型客户端。 + + Returns: + list[Message]: 内部统一消息列表。 + """ + + messages: list[Message] = [] + if system_prompt.strip(): + messages.append( + MessageBuilder() + .set_role(RoleType.System) + .add_text_content(system_prompt.strip()) + .build() + ) + + for raw_message in raw_messages: + messages.extend(self._convert_sampling_message(raw_message, client)) + return messages + + return _message_factory + + def _convert_sampling_message(self, raw_message: Any, client: "BaseClient") -> list[Message]: + """将单条 MCP Sampling 消息转换为内部消息列表。 + + Args: + raw_message: MCP Sampling 原始消息对象。 + client: 当前底层模型客户端。 + + Returns: + list[Message]: 转换后的内部消息列表。 + """ + + role = str(getattr(raw_message, "role", "") or "").strip().lower() + content_blocks = self._get_content_blocks(getattr(raw_message, "content", None)) + + if role == "assistant": + assistant_message = self._build_assistant_message(content_blocks, client) + return [assistant_message] if assistant_message is not None else [] + + if role == "user": + return self._build_user_messages(content_blocks, client) + + raise ValueError(f"不支持的 MCP Sampling 消息角色: {role}") + + @staticmethod + def _get_content_blocks(raw_content: Any) -> list[Any]: + """将 MCP Sampling 消息内容统一为列表。 + + Args: + raw_content: 原始内容字段。 + + Returns: + list[Any]: 统一后的内容块列表。 + """ + + if raw_content is None: + return [] + if isinstance(raw_content, list): + return list(raw_content) + return [raw_content] + + def _build_assistant_message(self, content_blocks: list[Any], client: "BaseClient") -> Optional[Message]: + """构建内部 assistant 消息。 + + Args: + content_blocks: MCP assistant 内容块列表。 + client: 当前底层模型客户端。 + + Returns: + Optional[Message]: 转换后的内部 assistant 消息;无有效内容时返回 ``None``。 + """ + + message_builder = MessageBuilder().set_role(RoleType.Assistant) + tool_calls: list[ToolCall] = [] + has_visible_content = False + + for content_block in content_blocks: + content_type = self._get_content_type(content_block) + if content_type == "tool_use": + tool_calls.append( + ToolCall( + call_id=str(getattr(content_block, "id", "") or ""), + func_name=str(getattr(content_block, "name", "") or ""), + args=self._normalize_tool_call_arguments(getattr(content_block, "input", None)), + ) + ) + continue + + has_visible_content = self._append_sampling_content_to_builder( + message_builder=message_builder, + content_block=content_block, + client=client, + ) or has_visible_content + + if tool_calls: + message_builder.set_tool_calls(tool_calls) + + if not has_visible_content and not tool_calls: + return None + return message_builder.build() + + def _build_user_messages(self, content_blocks: list[Any], client: "BaseClient") -> list[Message]: + """构建内部 user/tool 消息序列。 + + Args: + content_blocks: MCP user 内容块列表。 + client: 当前底层模型客户端。 + + Returns: + list[Message]: 转换后的内部消息序列。 + """ + + messages: list[Message] = [] + message_builder = MessageBuilder().set_role(RoleType.User) + has_user_content = False + + def flush_user_message() -> None: + """在当前存在用户可见内容时落盘一条 user 消息。""" + + nonlocal message_builder, has_user_content + if not has_user_content: + return + messages.append(message_builder.build()) + message_builder = MessageBuilder().set_role(RoleType.User) + has_user_content = False + + for content_block in content_blocks: + content_type = self._get_content_type(content_block) + if content_type == "tool_result": + flush_user_message() + messages.append(self._build_tool_result_message(content_block)) + continue + + has_user_content = self._append_sampling_content_to_builder( + message_builder=message_builder, + content_block=content_block, + client=client, + ) or has_user_content + + flush_user_message() + return messages + + @staticmethod + def _get_content_type(content_block: Any) -> str: + """读取 MCP 内容块类型。 + + Args: + content_block: MCP 内容块对象。 + + Returns: + str: 规范化后的内容块类型。 + """ + + return str(getattr(content_block, "type", "text") or "text").strip().lower() + + def _append_sampling_content_to_builder( + self, + message_builder: MessageBuilder, + content_block: Any, + client: "BaseClient", + ) -> bool: + """将 MCP 普通内容块追加到内部消息构建器。 + + Args: + message_builder: 内部消息构建器。 + content_block: MCP 内容块对象。 + client: 当前底层模型客户端。 + + Returns: + bool: 是否成功追加了可见内容。 + """ + + content_type = self._get_content_type(content_block) + if content_type == "text": + text_content = str(getattr(content_block, "text", "") or "") + if text_content.strip(): + message_builder.add_text_content(text_content) + return True + return False + + if content_type == "image": + image_data = str(getattr(content_block, "data", "") or "") + image_mime_type = str(getattr(content_block, "mimeType", "") or "") + image_format = self._normalize_image_format(image_mime_type) + if image_data and image_format: + message_builder.add_image_content( + image_format=image_format, + image_base64=image_data, + support_formats=client.get_support_image_formats(), + ) + return True + + message_builder.add_text_content( + f"[图片内容:mime_type={image_mime_type or 'unknown'},当前客户端无法直接透传]" + ) + return True + + if content_type == "audio": + audio_mime_type = str(getattr(content_block, "mimeType", "") or "") + message_builder.add_text_content(f"[音频内容:mime_type={audio_mime_type or 'unknown'}]") + return True + + return False + + @staticmethod + def _normalize_image_format(mime_type: str) -> str: + """将图片 MIME 类型转换为内部图片格式名称。 + + Args: + mime_type: MCP 图片 MIME 类型。 + + Returns: + str: 内部支持的图片格式名;不支持时返回空字符串。 + """ + + normalized_mime_type = mime_type.strip().lower() + if normalized_mime_type == "image/png": + return "png" + if normalized_mime_type in {"image/jpeg", "image/jpg"}: + return "jpeg" + if normalized_mime_type == "image/webp": + return "webp" + if normalized_mime_type == "image/gif": + return "gif" + return "" + + def _build_tool_result_message(self, content_block: Any) -> Message: + """将 MCP `tool_result` 内容块转换为内部 Tool 消息。 + + Args: + content_block: MCP `tool_result` 内容块对象。 + + Returns: + Message: 转换后的内部 Tool 消息。 + """ + + message_builder = MessageBuilder().set_role(RoleType.Tool) + message_builder.set_tool_call_id(str(getattr(content_block, "toolUseId", "") or "tool_result")) + summary_text = self._summarize_tool_result_content(content_block) + message_builder.add_text_content(summary_text or "工具执行完成。") + return message_builder.build() + + def _summarize_tool_result_content(self, content_block: Any) -> str: + """汇总 MCP `tool_result` 内容块中的结果文本。 + + Args: + content_block: MCP `tool_result` 内容块对象。 + + Returns: + str: 适合发送给主程序模型的工具结果摘要文本。 + """ + + raw_contents = list(getattr(content_block, "content", []) or []) + content_items = build_tool_content_items(raw_contents) + parts = [item.build_history_text().strip() for item in content_items if item.build_history_text().strip()] + + structured_content = getattr(content_block, "structuredContent", None) + if structured_content is not None: + try: + parts.append(json.dumps(structured_content, ensure_ascii=False)) + except (TypeError, ValueError): + parts.append(str(structured_content)) + + summary_text = "\n".join(part for part in parts if part).strip() + if bool(getattr(content_block, "isError", False)) and summary_text: + return f"工具执行失败:\n{summary_text}" + if bool(getattr(content_block, "isError", False)): + return "工具执行失败。" + return summary_text + + @staticmethod + def _normalize_tool_call_arguments(raw_arguments: Any) -> dict[str, Any]: + """将原始工具调用参数规范化为字典。 + + Args: + raw_arguments: 原始工具参数。 + + Returns: + dict[str, Any]: 规范化后的参数字典。 + """ + + if isinstance(raw_arguments, dict): + return dict(raw_arguments) + if raw_arguments is None: + return {} + return {"value": raw_arguments} + + def _build_tool_definitions( + self, + raw_tools: Any, + tool_choice_mode: str, + ) -> list[ToolDefinitionInput] | None: + """将 MCP Sampling 工具定义转换为主程序内部工具定义。 + + Args: + raw_tools: MCP Sampling 请求中的工具列表。 + tool_choice_mode: 当前工具选择模式。 + + Returns: + list[ToolDefinitionInput] | None: 可传给主程序模型层的工具定义列表。 + """ + + if tool_choice_mode == "none": + return None + if not isinstance(raw_tools, list) or not raw_tools: + return None + + tool_definitions: list[ToolDefinitionInput] = [] + for raw_tool in raw_tools: + tool_name = str(getattr(raw_tool, "name", "") or "").strip() + if not tool_name: + continue + + parameters_schema = ( + dict(getattr(raw_tool, "inputSchema", {}) or {}) if getattr(raw_tool, "inputSchema", None) else {} + ) + if "$schema" in parameters_schema: + parameters_schema.pop("$schema") + + title = str(getattr(raw_tool, "title", "") or "").strip() + description = str(getattr(raw_tool, "description", "") or "").strip() + brief_description = description or title or f"工具 {tool_name}" + detailed_description = build_tool_detailed_description( + parameters_schema, + fallback_description=f"工具名称:{tool_name}", + ) + + tool_definitions.append( + { + "name": tool_name, + "description": "\n\n".join( + part for part in [brief_description, detailed_description] if part.strip() + ).strip(), + "parameters_schema": parameters_schema or {"type": "object", "properties": {}}, + } + ) + + return tool_definitions or None + + def _build_sampling_result( + self, + generation_result: LLMResponseResult, + tools_enabled: bool, + ) -> Any: + """将主程序模型响应转换为 MCP Sampling 结果。 + + Args: + generation_result: 主程序统一大模型响应结果。 + tools_enabled: 当前是否允许模型使用工具。 + + Returns: + Any: MCP `CreateMessageResult` 或 `CreateMessageResultWithTools`。 + """ + + if not MCP_TYPES_AVAILABLE or mcp_types is None: + raise RuntimeError("当前环境未安装可用的 MCP types 模块") + + text_content = str(generation_result.response or "") + tool_calls = list(generation_result.tool_calls or []) + model_name = generation_result.model_name or self._sampling_task_name + + if tools_enabled: + content_blocks: list[Any] = [] + if text_content.strip(): + content_blocks.append( + mcp_types.TextContent( + type="text", + text=text_content, + ) + ) + for tool_call in tool_calls: + content_blocks.append( + mcp_types.ToolUseContent( + type="tool_use", + name=tool_call.func_name, + id=tool_call.call_id, + input=dict(tool_call.args or {}), + ) + ) + + if not content_blocks: + content_blocks.append( + mcp_types.TextContent( + type="text", + text="", + ) + ) + + return mcp_types.CreateMessageResultWithTools( + role="assistant", + content=content_blocks[0] if len(content_blocks) == 1 else content_blocks, + model=model_name, + stopReason="toolUse" if tool_calls else "endTurn", + ) + + return mcp_types.CreateMessageResult( + role="assistant", + content=mcp_types.TextContent( + type="text", + text=text_content, + ), + model=model_name, + stopReason="endTurn", + ) diff --git a/src/mcp_module/manager.py b/src/mcp_module/manager.py index 7dbb8c3c..53c4dbc4 100644 --- a/src/mcp_module/manager.py +++ b/src/mcp_module/manager.py @@ -1,8 +1,10 @@ """ MaiSaka - MCP 管理器 -管理所有 MCP 服务器连接,提供统一的工具发现与调用接口。 +管理所有 MCP 服务器连接,提供统一的工具、Prompt 与 Resource 访问入口。 """ +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Optional from src.cli.console import console @@ -13,8 +15,26 @@ from src.core.tooling import ( build_tool_detailed_description, ) -from .config import MCPServerRuntimeConfig, build_mcp_server_runtime_configs +from .config import ( + MCPClientRuntimeConfig, + MCPServerRuntimeConfig, + build_mcp_client_runtime_config, + build_mcp_server_runtime_configs, +) from .connection import MCPConnection, MCP_AVAILABLE +from .hooks import MCPHostCallbacks +from .models import ( + MCPPromptResult, + MCPPromptSpec, + MCPResourceReadResult, + MCPResourceSpec, + MCPResourceTemplateSpec, + build_prompt_spec, + build_resource_spec, + build_resource_template_spec, + build_tool_annotation, + build_tool_icon, +) if TYPE_CHECKING: from src.config.official_configs import MCPConfig @@ -34,37 +54,44 @@ BUILTIN_TOOL_NAMES = frozenset( class MCPManager: - """MCP 服务器连接管理器。 + """MCP 服务器连接管理器。""" - 职责: - - 根据主程序官方配置连接所有 MCP 服务器 - - 将 MCP 工具转换为 OpenAI function calling 格式 - - 路由工具调用到正确的 MCP 服务器 - - 统一管理连接生命周期 - """ + def __init__( + self, + client_config: MCPClientRuntimeConfig, + host_callbacks: Optional[MCPHostCallbacks] = None, + ) -> None: + """初始化 MCP 管理器。 - def __init__(self) -> None: - """初始化 MCP 管理器。""" + Args: + client_config: MCP 客户端宿主能力运行时配置。 + host_callbacks: 宿主侧能力回调集合。 + """ - self._connections: dict[str, MCPConnection] = {} # server_name → connection - self._tool_to_server: dict[str, str] = {} # tool_name → server_name - - # ──────── 工厂方法 ──────── + self._client_config = client_config + self._host_callbacks = host_callbacks or MCPHostCallbacks() + self._connections: dict[str, MCPConnection] = {} + self._tool_to_server: dict[str, str] = {} + self._prompt_to_server: dict[str, str] = {} + self._resource_to_server: dict[str, str] = {} + self._resource_template_to_server: dict[str, str] = {} @classmethod async def from_app_config( cls, mcp_config: "MCPConfig", + host_callbacks: Optional[MCPHostCallbacks] = None, ) -> Optional["MCPManager"]: - """ - 从官方配置创建并初始化 MCPManager。 + """从官方配置创建并初始化 `MCPManager`。 Args: mcp_config: 主程序中的 MCP 配置对象。 + host_callbacks: 宿主侧能力回调集合。 Returns: - 初始化完成的 MCPManager;无可用配置或全部连接失败时返回 None。 + Optional[MCPManager]: 初始化完成的管理器;无可用配置或全部连接失败时返回 ``None``。 """ + configs = build_mcp_server_runtime_configs(mcp_config) if not configs: return None @@ -73,7 +100,10 @@ class MCPManager: console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK,请运行: pip install mcp[/warning]") return None - manager = cls() + manager = cls( + client_config=build_mcp_client_runtime_config(mcp_config), + host_callbacks=host_callbacks, + ) await manager._connect_all(configs) if not manager._connections: @@ -82,48 +112,141 @@ class MCPManager: return manager - # ──────── 连接管理 ──────── - async def _connect_all(self, configs: list[MCPServerRuntimeConfig]) -> None: - """连接所有配置的 MCP 服务器,跳过失败的连接。""" - for cfg in configs: - conn = MCPConnection(cfg) - success = await conn.connect() + """连接全部已配置的 MCP 服务器。 + + Args: + configs: 服务器运行时配置列表。 + + Returns: + None + """ + + for config in configs: + connection = MCPConnection(config, self._client_config, self._host_callbacks) + success = await connection.connect() if not success: continue - self._connections[cfg.name] = conn - - # 注册工具,检查冲突 - registered = 0 - for tool in conn.tools: - tool_name = tool.name - - if tool_name in BUILTIN_TOOL_NAMES: - console.print( - f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]" - ) - continue - - if tool_name in self._tool_to_server: - existing_server = self._tool_to_server[tool_name] - console.print( - f"[warning]⚠️ MCP 工具 '{tool_name}' " - f"(来自 {cfg.name}) 与 {existing_server} 冲突,已跳过[/warning]" - ) - continue - - self._tool_to_server[tool_name] = cfg.name - registered += 1 - + self._connections[config.name] = connection + registered_tool_count = self._register_tools(config.name, connection) + registered_prompt_count = self._register_prompts(config.name, connection) + registered_resource_count = self._register_resources(config.name, connection) + registered_template_count = self._register_resource_templates(config.name, connection) console.print( - f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] [muted]({registered} 个工具已注册)[/muted]" + "[success]✓ MCP 服务器 " + f"'{config.name}' 已连接[/success] " + f"[muted](工具 {registered_tool_count} / Prompt {registered_prompt_count} / " + f"资源 {registered_resource_count} / 模板 {registered_template_count})[/muted]" ) - # ──────── 工具发现 ──────── + def _register_tools(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP 工具。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的工具数量。 + """ + + registered_count = 0 + for tool in connection.tools: + tool_name = str(tool.name) + + if tool_name in BUILTIN_TOOL_NAMES: + console.print( + f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {server_name}) 与内置工具冲突,已跳过[/warning]" + ) + continue + + if tool_name in self._tool_to_server: + existing_server = self._tool_to_server[tool_name] + console.print( + f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + + self._tool_to_server[tool_name] = server_name + registered_count += 1 + return registered_count + + def _register_prompts(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP Prompt。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的 Prompt 数量。 + """ + + registered_count = 0 + for prompt in connection.prompts: + prompt_name = str(prompt.name) + if prompt_name in self._prompt_to_server: + existing_server = self._prompt_to_server[prompt_name] + console.print( + f"[warning]⚠️ MCP Prompt '{prompt_name}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + self._prompt_to_server[prompt_name] = server_name + registered_count += 1 + return registered_count + + def _register_resources(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP Resource。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的 Resource 数量。 + """ + + registered_count = 0 + for resource in connection.resources: + resource_uri = str(resource.uri) + if resource_uri in self._resource_to_server: + existing_server = self._resource_to_server[resource_uri] + console.print( + f"[warning]⚠️ MCP Resource '{resource_uri}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + self._resource_to_server[resource_uri] = server_name + registered_count += 1 + return registered_count + + def _register_resource_templates(self, server_name: str, connection: MCPConnection) -> int: + """注册单个服务器暴露的 MCP Resource Template。 + + Args: + server_name: 服务器名称。 + connection: 对应连接对象。 + + Returns: + int: 成功注册的模板数量。 + """ + + registered_count = 0 + for resource_template in connection.resource_templates: + uri_template = str(resource_template.uriTemplate) + if uri_template in self._resource_template_to_server: + existing_server = self._resource_template_to_server[uri_template] + console.print( + "[warning]⚠️ MCP Resource Template " + f"'{uri_template}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]" + ) + continue + self._resource_template_to_server[uri_template] = server_name + registered_count += 1 + return registered_count def _build_tool_parameters_schema(self, tool: Any) -> dict[str, Any] | None: - """构造单个 MCP 工具的对象级参数 Schema。 + """构造单个 MCP 工具的参数 Schema。 Args: tool: MCP SDK 返回的原始工具对象。 @@ -140,6 +263,21 @@ class MCPManager: parameters_schema.pop("$schema", None) return parameters_schema + def _build_tool_output_schema(self, tool: Any) -> dict[str, Any] | None: + """构造单个 MCP 工具的输出 Schema。 + + Args: + tool: MCP SDK 返回的原始工具对象。 + + Returns: + dict[str, Any] | None: 输出 Schema。 + """ + + output_schema = dict(tool.outputSchema) if hasattr(tool, "outputSchema") and tool.outputSchema else None + if isinstance(output_schema, dict): + output_schema.pop("$schema", None) + return output_schema + def get_tool_specs(self) -> list[ToolSpec]: """获取全部已注册 MCP 工具的统一声明。 @@ -148,31 +286,79 @@ class MCPManager: """ tool_specs: list[ToolSpec] = [] - for server_name, conn in self._connections.items(): - for tool in conn.tools: - if tool.name not in self._tool_to_server: - continue - if self._tool_to_server[tool.name] != server_name: + for server_name, connection in self._connections.items(): + for tool in connection.tools: + if self._tool_to_server.get(tool.name) != server_name: continue parameters_schema = self._build_tool_parameters_schema(tool) + output_schema = self._build_tool_output_schema(tool) brief_description = str(tool.description or f"来自 {server_name} 的 MCP 工具").strip() tool_specs.append( ToolSpec( name=str(tool.name), + title=str(getattr(tool, "title", "") or ""), brief_description=brief_description, detailed_description=build_tool_detailed_description( parameters_schema, fallback_description=f"工具来源:MCP 服务 {server_name}。", ), parameters_schema=parameters_schema, + output_schema=output_schema, provider_name="mcp", provider_type="mcp", - metadata={"server_name": server_name}, + icons=[build_tool_icon(item) for item in getattr(tool, "icons", []) or []], + annotation=build_tool_annotation(getattr(tool, "annotations", None)), + metadata={"server_name": server_name} | getattr(tool, "meta", {}), ) ) return tool_specs + def get_prompt_specs(self) -> list[MCPPromptSpec]: + """获取全部已注册 MCP Prompt 声明。 + + Returns: + list[MCPPromptSpec]: Prompt 声明列表。 + """ + + prompt_specs: list[MCPPromptSpec] = [] + for server_name, connection in self._connections.items(): + for prompt in connection.prompts: + if self._prompt_to_server.get(prompt.name) != server_name: + continue + prompt_specs.append(build_prompt_spec(prompt, server_name)) + return prompt_specs + + def get_resource_specs(self) -> list[MCPResourceSpec]: + """获取全部已注册 MCP Resource 声明。 + + Returns: + list[MCPResourceSpec]: Resource 声明列表。 + """ + + resource_specs: list[MCPResourceSpec] = [] + for server_name, connection in self._connections.items(): + for resource in connection.resources: + if self._resource_to_server.get(resource.uri) != server_name: + continue + resource_specs.append(build_resource_spec(resource, server_name)) + return resource_specs + + def get_resource_template_specs(self) -> list[MCPResourceTemplateSpec]: + """获取全部已注册 MCP Resource Template 声明。 + + Returns: + list[MCPResourceTemplateSpec]: Resource Template 声明列表。 + """ + + resource_template_specs: list[MCPResourceTemplateSpec] = [] + for server_name, connection in self._connections.items(): + for resource_template in connection.resource_templates: + if self._resource_template_to_server.get(resource_template.uriTemplate) != server_name: + continue + resource_template_specs.append(build_resource_template_spec(resource_template, server_name)) + return resource_template_specs + def get_openai_tools(self) -> list[dict[str, Any]]: """获取兼容旧模型层的 MCP 工具定义。 @@ -192,12 +378,42 @@ class MCPManager: for tool_spec in self.get_tool_specs() ] - # ──────── 工具调用 ──────── - def is_mcp_tool(self, tool_name: str) -> bool: - """判断工具名是否为已注册的 MCP 工具。""" + """判断给定名称是否为已注册 MCP 工具。 + + Args: + tool_name: 工具名称。 + + Returns: + bool: 是否存在。 + """ + return tool_name in self._tool_to_server + def is_mcp_prompt(self, prompt_name: str) -> bool: + """判断给定名称是否为已注册 MCP Prompt。 + + Args: + prompt_name: Prompt 名称。 + + Returns: + bool: 是否存在。 + """ + + return prompt_name in self._prompt_to_server + + def is_mcp_resource(self, uri: str) -> bool: + """判断给定 URI 是否为已注册 MCP Resource。 + + Args: + uri: 资源 URI。 + + Returns: + bool: 是否存在。 + """ + + return uri in self._resource_to_server + async def call_tool_invocation(self, invocation: ToolInvocation) -> ToolExecutionResult: """执行统一的 MCP 工具调用。 @@ -217,8 +433,8 @@ class MCPManager: error_message=f"MCP 工具 '{tool_name}' 未找到", ) - conn = self._connections[server_name] - return await conn.call_tool(tool_name, invocation.arguments) + connection = self._connections[server_name] + return await connection.call_tool(tool_name, invocation.arguments) async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> str: """兼容旧接口,返回 MCP 工具的文本结果。 @@ -239,36 +455,137 @@ class MCPManager: ) return result.get_history_content() - # ──────── 信息展示 ──────── + async def get_prompt( + self, + prompt_name: str, + arguments: Optional[dict[str, str]] = None, + ) -> MCPPromptResult: + """读取指定 Prompt 的内容。 + + Args: + prompt_name: Prompt 名称。 + arguments: Prompt 参数字典。 + + Returns: + MCPPromptResult: Prompt 获取结果。 + """ + + server_name = self._prompt_to_server.get(prompt_name) + if not server_name or server_name not in self._connections: + raise KeyError(f"MCP Prompt '{prompt_name}' 未找到") + + connection = self._connections[server_name] + return await connection.get_prompt(prompt_name, arguments=arguments) + + async def read_resource(self, uri: str) -> MCPResourceReadResult: + """读取指定 Resource 的内容。 + + Args: + uri: 资源 URI。 + + Returns: + MCPResourceReadResult: 资源读取结果。 + """ + + server_name = self._resource_to_server.get(uri) + if not server_name or server_name not in self._connections: + raise KeyError(f"MCP Resource '{uri}' 未找到") + + connection = self._connections[server_name] + return await connection.read_resource(uri) def get_tool_summary(self) -> str: - """获取所有已注册 MCP 工具的摘要信息。""" + """获取所有已注册 MCP 工具的摘要信息。 + + Returns: + str: 工具摘要文本。 + """ + parts: list[str] = [] - for server_name, conn in self._connections.items(): + for server_name, connection in self._connections.items(): tool_names = [ - t.name - for t in conn.tools - if t.name in self._tool_to_server and self._tool_to_server[t.name] == server_name + str(tool.name) + for tool in connection.tools + if self._tool_to_server.get(tool.name) == server_name ] if tool_names: parts.append(f" • {server_name}: {', '.join(tool_names)}") return "\n".join(parts) + def get_feature_summary(self) -> str: + """获取所有服务器能力的总体摘要。 + + Returns: + str: 多行摘要文本。 + """ + + parts: list[str] = [] + for server_name, connection in self._connections.items(): + tool_count = sum(1 for tool in connection.tools if self._tool_to_server.get(tool.name) == server_name) + prompt_count = sum( + 1 for prompt in connection.prompts if self._prompt_to_server.get(prompt.name) == server_name + ) + resource_count = sum( + 1 for resource in connection.resources if self._resource_to_server.get(resource.uri) == server_name + ) + template_count = sum( + 1 + for resource_template in connection.resource_templates + if self._resource_template_to_server.get(resource_template.uriTemplate) == server_name + ) + parts.append( + f" • {server_name}: 工具 {tool_count} / Prompt {prompt_count} / " + f"资源 {resource_count} / 模板 {template_count}" + ) + return "\n".join(parts) + @property def server_count(self) -> int: - """已连接的 MCP 服务器数量。""" + """返回已连接 MCP 服务器数量。 + + Returns: + int: 服务器数量。 + """ + return len(self._connections) @property def tool_count(self) -> int: - """已注册的 MCP 工具总数。""" + """返回已注册 MCP 工具总数。 + + Returns: + int: 工具数量。 + """ + return len(self._tool_to_server) - # ──────── 生命周期 ──────── + @property + def prompt_count(self) -> int: + """返回已注册 MCP Prompt 总数。 + + Returns: + int: Prompt 数量。 + """ + + return len(self._prompt_to_server) + + @property + def resource_count(self) -> int: + """返回已注册 MCP Resource 总数。 + + Returns: + int: Resource 数量。 + """ + + return len(self._resource_to_server) async def close(self) -> None: """关闭所有 MCP 服务器连接。""" - for conn in self._connections.values(): - await conn.close() + + for connection in self._connections.values(): + await connection.close() self._connections.clear() self._tool_to_server.clear() + self._prompt_to_server.clear() + self._resource_to_server.clear() + self._resource_template_to_server.clear() diff --git a/src/mcp_module/models.py b/src/mcp_module/models.py new file mode 100644 index 00000000..5550b8df --- /dev/null +++ b/src/mcp_module/models.py @@ -0,0 +1,418 @@ +"""MCP 结构化模型与转换工具。 + +负责在 MCP SDK 原始对象与主程序内部数据模型之间进行转换, +避免连接层和管理器层直接操作大量弱类型字段。 +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from src.core.tooling import ToolAnnotation, ToolContentItem, ToolIcon + + +def _dump_model_metadata(raw_value: Any) -> dict[str, Any]: + """提取任意 MCP 模型对象中的元数据字典。 + + Args: + raw_value: MCP SDK 返回的原始对象。 + + Returns: + dict[str, Any]: 归一化后的元数据字典。 + """ + + metadata = getattr(raw_value, "meta", None) + if isinstance(metadata, dict): + return dict(metadata) + return {} + + +def build_tool_icon(raw_icon: Any) -> ToolIcon: + """将 MCP 图标对象转换为统一图标模型。 + + Args: + raw_icon: MCP SDK 返回的图标对象。 + + Returns: + ToolIcon: 统一图标模型。 + """ + + sizes_value = getattr(raw_icon, "sizes", None) + sizes = [str(item) for item in sizes_value] if isinstance(sizes_value, list) else [] + return ToolIcon( + src=str(getattr(raw_icon, "src", "") or ""), + mime_type=str(getattr(raw_icon, "mimeType", "") or ""), + sizes=sizes, + ) + + +def build_tool_annotation(raw_annotation: Any) -> Optional[ToolAnnotation]: + """将 MCP 注解对象转换为统一注解模型。 + + Args: + raw_annotation: MCP SDK 返回的注解对象。 + + Returns: + Optional[ToolAnnotation]: 统一注解模型;无有效内容时返回 ``None``。 + """ + + if raw_annotation is None: + return None + + audience_value = getattr(raw_annotation, "audience", None) + audience = [str(item) for item in audience_value] if isinstance(audience_value, list) else [] + priority_value = getattr(raw_annotation, "priority", None) + priority = float(priority_value) if isinstance(priority_value, int | float) else None + metadata = _dump_model_metadata(raw_annotation) + + if not audience and priority is None and not metadata: + return None + + return ToolAnnotation( + audience=audience, + priority=priority, + metadata=metadata, + ) + + +def build_tool_content_item(raw_content: Any) -> ToolContentItem: + """将 MCP 内容块转换为统一工具内容项。 + + Args: + raw_content: MCP SDK 返回的内容块对象。 + + Returns: + ToolContentItem: 统一工具内容项。 + """ + + content_type = str(getattr(raw_content, "type", "") or "").strip().lower() + annotation = build_tool_annotation(getattr(raw_content, "annotations", None)) + metadata = _dump_model_metadata(raw_content) + + if content_type == "text" or hasattr(raw_content, "text"): + return ToolContentItem( + content_type="text", + text=str(getattr(raw_content, "text", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "image": + return ToolContentItem( + content_type="image", + data=str(getattr(raw_content, "data", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "audio": + return ToolContentItem( + content_type="audio", + data=str(getattr(raw_content, "data", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "resource_link": + return ToolContentItem( + content_type="resource_link", + uri=str(getattr(raw_content, "uri", "") or ""), + name=str(getattr(raw_content, "name", "") or ""), + description=str(getattr(raw_content, "description", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + if content_type == "resource" or hasattr(raw_content, "resource"): + resource = getattr(raw_content, "resource", None) + resource_metadata = metadata | _dump_model_metadata(resource) + return ToolContentItem( + content_type="resource", + text=str(getattr(resource, "text", "") or ""), + data=str(getattr(resource, "blob", "") or ""), + mime_type=str(getattr(resource, "mimeType", "") or ""), + uri=str(getattr(resource, "uri", "") or ""), + name=str(getattr(resource, "name", "") or ""), + annotation=annotation, + metadata=resource_metadata, + ) + + if hasattr(raw_content, "data"): + return ToolContentItem( + content_type="binary", + data=str(getattr(raw_content, "data", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + annotation=annotation, + metadata=metadata, + ) + + return ToolContentItem( + content_type="unknown", + text=str(raw_content), + annotation=annotation, + metadata=metadata, + ) + + +def build_tool_content_items(raw_contents: list[Any] | None) -> list[ToolContentItem]: + """批量转换 MCP 内容块列表。 + + Args: + raw_contents: MCP SDK 返回的内容块列表。 + + Returns: + list[ToolContentItem]: 转换后的统一内容项列表。 + """ + + if not raw_contents: + return [] + return [build_tool_content_item(item) for item in raw_contents] + + +@dataclass(slots=True) +class MCPPromptArgumentSpec: + """MCP Prompt 参数声明。""" + + name: str + description: str = "" + required: bool = False + + +@dataclass(slots=True) +class MCPPromptSpec: + """MCP Prompt 声明。""" + + name: str + server_name: str + title: str = "" + description: str = "" + arguments: list[MCPPromptArgumentSpec] = field(default_factory=list) + icons: list[ToolIcon] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPPromptMessage: + """MCP Prompt 消息。""" + + role: str + content: ToolContentItem + + +@dataclass(slots=True) +class MCPPromptResult: + """MCP Prompt 获取结果。""" + + prompt_name: str + server_name: str + description: str = "" + messages: list[MCPPromptMessage] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPResourceSpec: + """MCP Resource 声明。""" + + uri: str + server_name: str + name: str + title: str = "" + description: str = "" + mime_type: str = "" + size: int | None = None + icons: list[ToolIcon] = field(default_factory=list) + annotation: ToolAnnotation | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPResourceTemplateSpec: + """MCP Resource Template 声明。""" + + uri_template: str + server_name: str + name: str + title: str = "" + description: str = "" + mime_type: str = "" + icons: list[ToolIcon] = field(default_factory=list) + annotation: ToolAnnotation | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class MCPResourceReadResult: + """MCP Resource 读取结果。""" + + uri: str + server_name: str + contents: list[ToolContentItem] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + +def build_prompt_argument_spec(raw_argument: Any) -> MCPPromptArgumentSpec: + """将 MCP Prompt 参数对象转换为统一结构。 + + Args: + raw_argument: MCP SDK 返回的 Prompt 参数对象。 + + Returns: + MCPPromptArgumentSpec: 统一 Prompt 参数结构。 + """ + + return MCPPromptArgumentSpec( + name=str(getattr(raw_argument, "name", "") or ""), + description=str(getattr(raw_argument, "description", "") or ""), + required=bool(getattr(raw_argument, "required", False)), + ) + + +def build_prompt_spec(raw_prompt: Any, server_name: str) -> MCPPromptSpec: + """将 MCP Prompt 定义转换为统一结构。 + + Args: + raw_prompt: MCP SDK 返回的 Prompt 对象。 + server_name: Prompt 所属的服务器名称。 + + Returns: + MCPPromptSpec: 统一 Prompt 定义。 + """ + + raw_arguments = getattr(raw_prompt, "arguments", None) + raw_icons = getattr(raw_prompt, "icons", None) + return MCPPromptSpec( + name=str(getattr(raw_prompt, "name", "") or ""), + server_name=server_name, + title=str(getattr(raw_prompt, "title", "") or ""), + description=str(getattr(raw_prompt, "description", "") or ""), + arguments=[build_prompt_argument_spec(item) for item in raw_arguments] if isinstance(raw_arguments, list) else [], + icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [], + metadata=_dump_model_metadata(raw_prompt), + ) + + +def build_prompt_result(raw_result: Any, prompt_name: str, server_name: str) -> MCPPromptResult: + """将 MCP Prompt 获取结果转换为统一结构。 + + Args: + raw_result: MCP SDK 返回的 Prompt 结果对象。 + prompt_name: Prompt 名称。 + server_name: Prompt 所属服务器名称。 + + Returns: + MCPPromptResult: 统一 Prompt 获取结果。 + """ + + messages: list[MCPPromptMessage] = [] + raw_messages = getattr(raw_result, "messages", None) + if isinstance(raw_messages, list): + for raw_message in raw_messages: + messages.append( + MCPPromptMessage( + role=str(getattr(raw_message, "role", "") or ""), + content=build_tool_content_item(getattr(raw_message, "content", None)), + ) + ) + + return MCPPromptResult( + prompt_name=prompt_name, + server_name=server_name, + description=str(getattr(raw_result, "description", "") or ""), + messages=messages, + metadata=_dump_model_metadata(raw_result), + ) + + +def build_resource_spec(raw_resource: Any, server_name: str) -> MCPResourceSpec: + """将 MCP Resource 定义转换为统一结构。 + + Args: + raw_resource: MCP SDK 返回的 Resource 对象。 + server_name: Resource 所属服务器名称。 + + Returns: + MCPResourceSpec: 统一 Resource 定义。 + """ + + raw_icons = getattr(raw_resource, "icons", None) + size_value = getattr(raw_resource, "size", None) + size = int(size_value) if isinstance(size_value, int | float) else None + return MCPResourceSpec( + uri=str(getattr(raw_resource, "uri", "") or ""), + server_name=server_name, + name=str(getattr(raw_resource, "name", "") or ""), + title=str(getattr(raw_resource, "title", "") or ""), + description=str(getattr(raw_resource, "description", "") or ""), + mime_type=str(getattr(raw_resource, "mimeType", "") or ""), + size=size, + icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [], + annotation=build_tool_annotation(getattr(raw_resource, "annotations", None)), + metadata=_dump_model_metadata(raw_resource), + ) + + +def build_resource_template_spec(raw_template: Any, server_name: str) -> MCPResourceTemplateSpec: + """将 MCP Resource Template 定义转换为统一结构。 + + Args: + raw_template: MCP SDK 返回的 ResourceTemplate 对象。 + server_name: 模板所属服务器名称。 + + Returns: + MCPResourceTemplateSpec: 统一模板定义。 + """ + + raw_icons = getattr(raw_template, "icons", None) + return MCPResourceTemplateSpec( + uri_template=str(getattr(raw_template, "uriTemplate", "") or ""), + server_name=server_name, + name=str(getattr(raw_template, "name", "") or ""), + title=str(getattr(raw_template, "title", "") or ""), + description=str(getattr(raw_template, "description", "") or ""), + mime_type=str(getattr(raw_template, "mimeType", "") or ""), + icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [], + annotation=build_tool_annotation(getattr(raw_template, "annotations", None)), + metadata=_dump_model_metadata(raw_template), + ) + + +def build_resource_read_result(raw_result: Any, uri: str, server_name: str) -> MCPResourceReadResult: + """将 MCP Resource 读取结果转换为统一结构。 + + Args: + raw_result: MCP SDK 返回的读取结果对象。 + uri: 被读取的资源 URI。 + server_name: 资源所属服务器名称。 + + Returns: + MCPResourceReadResult: 统一资源读取结果。 + """ + + contents: list[ToolContentItem] = [] + raw_contents = getattr(raw_result, "contents", None) + if isinstance(raw_contents, list): + for raw_content in raw_contents: + metadata = _dump_model_metadata(raw_content) + contents.append( + ToolContentItem( + content_type="resource", + text=str(getattr(raw_content, "text", "") or ""), + data=str(getattr(raw_content, "blob", "") or ""), + mime_type=str(getattr(raw_content, "mimeType", "") or ""), + uri=str(getattr(raw_content, "uri", "") or uri), + annotation=None, + metadata=metadata, + ) + ) + + return MCPResourceReadResult( + uri=uri, + server_name=server_name, + contents=contents, + metadata=_dump_model_metadata(raw_result), + )