feat(mcp_module): add hooks, host LLM bridge, and models for MCP integration

- Introduced MCPHostCallbacks for optional host capabilities like sampling and logging.
- Implemented MCPHostLLMBridge to handle MCP Sampling requests and bridge to LLM service.
- Created models for structured data conversion between MCP SDK and internal data models, including tool content items, prompts, and resources.
- Enhanced error handling and logging for better traceability during sampling operations.
This commit is contained in:
DrSmoothl
2026-03-30 23:51:05 +08:00
parent abb1d071b1
commit 42dbd5462a
11 changed files with 2332 additions and 170 deletions

View File

@@ -20,6 +20,7 @@ from src.chat.message_receive.message import SessionMessage
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
from src.config.config import config_manager, global_config from src.config.config import config_manager, global_config
from src.mcp_module import MCPManager from src.mcp_module import MCPManager
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
from src.maisaka.chat_loop_service import MaisakaChatLoopService from src.maisaka.chat_loop_service import MaisakaChatLoopService
from src.maisaka.context_messages import ( from src.maisaka.context_messages import (
@@ -66,6 +67,7 @@ class BufferCLI:
self._last_assistant_response_time: Optional[datetime] = None self._last_assistant_response_time: Optional[datetime] = None
self._user_input_times: list[datetime] = [] self._user_input_times: list[datetime] = []
self._mcp_manager: Optional[MCPManager] = None self._mcp_manager: Optional[MCPManager] = None
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
self._init_llm() self._init_llm()
def _init_llm(self) -> None: def _init_llm(self) -> None:
@@ -383,17 +385,23 @@ class BufferCLI:
async def _init_mcp(self) -> None: async def _init_mcp(self) -> None:
"""初始化 MCP 服务并注册暴露的工具。""" """初始化 MCP 服务并注册暴露的工具。"""
self._mcp_manager = await MCPManager.from_app_config(global_config.mcp) self._mcp_host_bridge = MCPHostLLMBridge(
sampling_task_name=global_config.mcp.client.sampling.task_name,
)
self._mcp_manager = await MCPManager.from_app_config(
global_config.mcp,
host_callbacks=self._mcp_host_bridge.build_callbacks(),
)
if self._mcp_manager and self._chat_loop_service: if self._mcp_manager and self._chat_loop_service:
mcp_tools = self._mcp_manager.get_openai_tools() mcp_tools = self._mcp_manager.get_openai_tools()
if mcp_tools: if mcp_tools:
self._chat_loop_service.set_extra_tools(mcp_tools) self._chat_loop_service.set_extra_tools(mcp_tools)
summary = self._mcp_manager.get_tool_summary() summary = self._mcp_manager.get_feature_summary()
console.print( console.print(
Panel( Panel(
f"已加载 {len(mcp_tools)} 个 MCP 工具\n{summary}", f"已加载 {len(mcp_tools)} 个 MCP 工具\n{summary}",
title="MCP 工具", title="MCP 能力",
border_style="green", border_style="green",
padding=(0, 1), padding=(0, 1),
) )
@@ -452,3 +460,4 @@ class BufferCLI:
finally: finally:
if self._mcp_manager: if self._mcp_manager:
await self._mcp_manager.close() await self._mcp_manager.close()
self._mcp_host_bridge = None

View File

@@ -57,7 +57,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0" MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.2.0" CONFIG_VERSION: str = "8.3.0"
MODEL_CONFIG_VERSION: str = "1.13.1" MODEL_CONFIG_VERSION: str = "1.13.1"
logger = get_logger("config") logger = get_logger("config")

View File

@@ -1,6 +1,7 @@
import re
from typing import Literal, Optional from typing import Literal, Optional
import re
from .config_base import ConfigBase, Field from .config_base import ConfigBase, Field
""" """
@@ -1569,6 +1570,225 @@ class MaiSakaConfig(ConfigBase):
"""Maisaka终端图片预览的字符宽度""" """Maisaka终端图片预览的字符宽度"""
class MCPAuthorizationConfig(ConfigBase):
"""MCP HTTP 认证配置。"""
mode: Literal["none", "bearer"] = Field(
default="none",
json_schema_extra={
"x-widget": "select",
"x-icon": "shield",
},
)
"""认证模式,当前支持无认证和静态 Bearer Token"""
bearer_token: str = Field(
default="",
json_schema_extra={
"x-widget": "password",
"x-icon": "key",
},
)
"""静态 Bearer Token仅在 `mode=\"bearer\"` 时使用"""
def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证 MCP 认证配置。
Args:
context: Pydantic 传入的上下文对象。
Returns:
None
"""
if self.mode == "bearer" and not self.bearer_token.strip():
raise ValueError("MCP 使用 bearer 认证时必须填写 bearer_token")
return super().model_post_init(context)
class MCPRootItemConfig(ConfigBase):
"""单个 MCP Root 配置。"""
enabled: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "power",
},
)
"""是否启用当前 Root"""
uri: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "folder",
},
)
"""Root URI通常为 `file://` 路径 URI"""
name: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "tag",
},
)
"""Root 的显示名称"""
def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证单个 Root 配置。
Args:
context: Pydantic 传入的上下文对象。
Returns:
None
"""
if self.enabled and not self.uri.strip():
raise ValueError("启用的 MCP Root 必须填写 uri")
return super().model_post_init(context)
class MCPRootsConfig(ConfigBase):
"""MCP Roots 能力配置。"""
enable: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "folder-tree",
},
)
"""是否向 MCP 服务器暴露 Roots 能力"""
items: list[MCPRootItemConfig] = Field(
default_factory=lambda: [],
json_schema_extra={
"x-widget": "custom",
"x-icon": "folder",
},
)
"""Roots 列表"""
class MCPSamplingConfig(ConfigBase):
"""MCP Sampling 能力配置。"""
enable: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "brain",
},
)
"""是否启用 Sampling 能力声明"""
task_name: str = Field(
default="planner",
json_schema_extra={
"x-widget": "input",
"x-icon": "sparkles",
},
)
"""执行 Sampling 请求时使用的主程序模型任务名"""
include_context_support: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "layers",
},
)
"""是否声明支持 `includeContext` 非 `none` 语义"""
tool_support: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "wrench",
},
)
"""是否声明支持在 Sampling 中继续使用工具"""
class MCPElicitationConfig(ConfigBase):
"""MCP Elicitation 能力配置。"""
enable: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "message-circle-question",
},
)
"""是否启用 Elicitation 能力声明"""
allow_form: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "form-input",
},
)
"""是否允许表单模式 Elicitation"""
allow_url: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "link",
},
)
"""是否允许 URL 模式 Elicitation"""
def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证 Elicitation 配置。
Args:
context: Pydantic 传入的上下文对象。
Returns:
None
"""
if self.enable and not (self.allow_form or self.allow_url):
raise ValueError("启用 MCP Elicitation 时至少需要允许一种模式")
return super().model_post_init(context)
class MCPClientConfig(ConfigBase):
"""MCP 客户端宿主能力配置。"""
client_name: str = Field(
default="MaiBot",
json_schema_extra={
"x-widget": "input",
"x-icon": "bot",
},
)
"""MCP 客户端实现名称"""
client_version: str = Field(
default="1.0.0",
json_schema_extra={
"x-widget": "input",
"x-icon": "info",
},
)
"""MCP 客户端实现版本"""
roots: MCPRootsConfig = Field(default_factory=MCPRootsConfig)
"""Roots 能力配置"""
sampling: MCPSamplingConfig = Field(default_factory=MCPSamplingConfig)
"""Sampling 能力配置"""
elicitation: MCPElicitationConfig = Field(default_factory=MCPElicitationConfig)
"""Elicitation 能力配置"""
class MCPServerItemConfig(ConfigBase): class MCPServerItemConfig(ConfigBase):
"""单个 MCP 服务器配置。""" """单个 MCP 服务器配置。"""
@@ -1590,14 +1810,14 @@ class MCPServerItemConfig(ConfigBase):
) )
"""是否启用当前 MCP 服务器""" """是否启用当前 MCP 服务器"""
transport: Literal["stdio", "sse"] = Field( transport: Literal["stdio", "streamable_http"] = Field(
default="stdio", default="stdio",
json_schema_extra={ json_schema_extra={
"x-widget": "select", "x-widget": "select",
"x-icon": "shuffle", "x-icon": "shuffle",
}, },
) )
"""传输方式,可选 stdio 或 sse""" """传输方式,可选 `stdio``streamable_http`"""
command: str = Field( command: str = Field(
default="", default="",
@@ -1633,7 +1853,7 @@ class MCPServerItemConfig(ConfigBase):
"x-icon": "link", "x-icon": "link",
}, },
) )
"""sse 模式下的服务地址""" """`streamable_http` 模式下的 MCP 端点地址"""
headers: dict[str, str] = Field( headers: dict[str, str] = Field(
default_factory=lambda: {}, default_factory=lambda: {},
@@ -1642,10 +1862,40 @@ class MCPServerItemConfig(ConfigBase):
"x-icon": "file-json", "x-icon": "file-json",
}, },
) )
"""sse 模式下附加的请求头""" """HTTP 模式下附加的请求头"""
http_timeout_seconds: float = Field(
default=30.0,
gt=0,
json_schema_extra={
"x-widget": "number",
"x-icon": "clock-3",
},
)
"""HTTP 请求超时时间,单位秒"""
read_timeout_seconds: float = Field(
default=300.0,
gt=0,
json_schema_extra={
"x-widget": "number",
"x-icon": "timer",
},
)
"""会话读取超时时间,单位秒"""
authorization: MCPAuthorizationConfig = Field(default_factory=MCPAuthorizationConfig)
"""HTTP 认证配置"""
def model_post_init(self, context: Optional[dict] = None) -> None: def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证 MCP 服务器配置。""" """验证 MCP 服务器配置。
Args:
context: Pydantic 传入的上下文对象。
Returns:
None
"""
if not self.name.strip(): if not self.name.strip():
raise ValueError("MCPServerItemConfig.name 不能为空") raise ValueError("MCPServerItemConfig.name 不能为空")
@@ -1653,8 +1903,8 @@ class MCPServerItemConfig(ConfigBase):
if self.transport == "stdio" and not self.command.strip(): if self.transport == "stdio" and not self.command.strip():
raise ValueError(f"MCP 服务器 {self.name} 使用 stdio 时必须填写 command") raise ValueError(f"MCP 服务器 {self.name} 使用 stdio 时必须填写 command")
if self.transport == "sse" and not self.url.strip(): if self.transport == "streamable_http" and not self.url.strip():
raise ValueError(f"MCP 服务器 {self.name} 使用 sse 时必须填写 url") raise ValueError(f"MCP 服务器 {self.name} 使用 streamable_http 时必须填写 url")
return super().model_post_init(context) return super().model_post_init(context)
@@ -1673,6 +1923,9 @@ class MCPConfig(ConfigBase):
) )
"""是否启用 MCPModel Context Protocol""" """是否启用 MCPModel Context Protocol"""
client: MCPClientConfig = Field(default_factory=MCPClientConfig)
"""MCP 客户端宿主能力配置"""
servers: list[MCPServerItemConfig] = Field( servers: list[MCPServerItemConfig] = Field(
default_factory=lambda: [], default_factory=lambda: [],
json_schema_extra={ json_schema_extra={
@@ -1683,7 +1936,14 @@ class MCPConfig(ConfigBase):
"""_wrap_MCP 服务器配置列表""" """_wrap_MCP 服务器配置列表"""
def model_post_init(self, context: Optional[dict] = None) -> None: def model_post_init(self, context: Optional[dict] = None) -> None:
"""验证 MCP 总配置。""" """验证 MCP 总配置。
Args:
context: Pydantic 传入的上下文对象。
Returns:
None
"""
server_names = [server.name.strip() for server in self.servers if server.name.strip()] server_names = [server.name.strip() for server in self.servers if server.name.strip()]
if len(server_names) != len(set(server_names)): if len(server_names) != len(set(server_names)):

View File

@@ -8,8 +8,8 @@ from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Protocol, runtime_checkable
import json import json
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolDefinitionInput from src.llm_models.payload_content.tool_option import ToolDefinitionInput
@@ -99,6 +99,64 @@ def build_tool_detailed_description(
return "\n".join(lines).strip() return "\n".join(lines).strip()
@dataclass(slots=True)
class ToolIcon:
"""统一工具图标信息。"""
src: str
mime_type: str = ""
sizes: list[str] = field(default_factory=list)
@dataclass(slots=True)
class ToolAnnotation:
"""统一工具注解信息。"""
audience: list[str] = field(default_factory=list)
priority: float | None = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class ToolContentItem:
"""统一工具内容项。"""
content_type: Literal["text", "image", "audio", "resource_link", "resource", "binary", "unknown"]
text: str = ""
data: str = ""
mime_type: str = ""
uri: str = ""
name: str = ""
description: str = ""
annotation: ToolAnnotation | None = None
metadata: Dict[str, Any] = field(default_factory=dict)
def build_history_text(self) -> str:
"""生成适合写入历史消息的文本摘要。
Returns:
str: 当前内容项对应的历史摘要文本。
"""
if self.content_type == "text" and self.text.strip():
return self.text.strip()
if self.content_type == "image":
return f"[图片内容 {self.mime_type or 'unknown'}]"
if self.content_type == "audio":
return f"[音频内容 {self.mime_type or 'unknown'}]"
if self.content_type == "resource_link":
label = self.name or self.uri or "资源链接"
return f"[资源链接] {label}"
if self.content_type == "resource":
if self.text.strip():
return self.text.strip()
label = self.name or self.uri or "嵌入资源"
return f"[嵌入资源] {label}"
if self.content_type == "binary":
return f"[二进制内容 {self.mime_type or 'unknown'}]"
return f"[{self.content_type} 内容]"
@dataclass(slots=True) @dataclass(slots=True)
class ToolSpec: class ToolSpec:
"""统一工具声明。""" """统一工具声明。"""
@@ -106,10 +164,14 @@ class ToolSpec:
name: str name: str
brief_description: str brief_description: str
detailed_description: str = "" detailed_description: str = ""
title: str = ""
parameters_schema: Dict[str, Any] | None = None parameters_schema: Dict[str, Any] | None = None
output_schema: Dict[str, Any] | None = None
provider_name: str = "" provider_name: str = ""
provider_type: str = "" provider_type: str = ""
enabled: bool = True enabled: bool = True
icons: list[ToolIcon] = field(default_factory=list)
annotation: ToolAnnotation | None = None
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
def build_llm_description(self) -> str: def build_llm_description(self) -> str:
@@ -172,6 +234,7 @@ class ToolExecutionResult:
content: str = "" content: str = ""
error_message: str = "" error_message: str = ""
structured_content: Any = None structured_content: Any = None
content_items: list[ToolContentItem] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
def get_history_content(self) -> str: def get_history_content(self) -> str:
@@ -183,6 +246,10 @@ class ToolExecutionResult:
if self.content.strip(): if self.content.strip():
return self.content.strip() return self.content.strip()
if self.content_items:
parts = [item.build_history_text() for item in self.content_items if item.build_history_text().strip()]
if parts:
return "\n".join(parts).strip()
if self.structured_content is not None: if self.structured_content is not None:
if isinstance(self.structured_content, str): if isinstance(self.structured_content, str):
return self.structured_content.strip() return self.structured_content.strip()
@@ -221,6 +288,8 @@ class ToolRegistry:
"""统一工具注册表。""" """统一工具注册表。"""
def __init__(self) -> None: def __init__(self) -> None:
"""初始化统一工具注册表。"""
self._providers: list[ToolProvider] = [] self._providers: list[ToolProvider] = []
def register_provider(self, provider: ToolProvider) -> None: def register_provider(self, provider: ToolProvider) -> None:

View File

@@ -17,6 +17,7 @@ from src.know_u.knowledge import KnowledgeLearner
from src.learners.expression_learner import ExpressionLearner from src.learners.expression_learner import ExpressionLearner
from src.learners.jargon_miner import JargonMiner from src.learners.jargon_miner import JargonMiner
from src.mcp_module import MCPManager from src.mcp_module import MCPManager
from src.mcp_module.host_llm_bridge import MCPHostLLMBridge
from src.mcp_module.provider import MCPToolProvider from src.mcp_module.provider import MCPToolProvider
from src.plugin_runtime.tool_provider import PluginToolProvider from src.plugin_runtime.tool_provider import PluginToolProvider
@@ -54,6 +55,7 @@ class MaisakaHeartFlowChatting:
self._internal_turn_queue: asyncio.Queue[Optional[list[SessionMessage]]] = asyncio.Queue() self._internal_turn_queue: asyncio.Queue[Optional[list[SessionMessage]]] = asyncio.Queue()
self._mcp_manager: Optional[MCPManager] = None self._mcp_manager: Optional[MCPManager] = None
self._mcp_host_bridge: Optional[MCPHostLLMBridge] = None
self._current_cycle_detail: Optional[CycleDetail] = None self._current_cycle_detail: Optional[CycleDetail] = None
self._source_messages_by_id: dict[str, SessionMessage] = {} self._source_messages_by_id: dict[str, SessionMessage] = {}
self._running = False self._running = False
@@ -127,6 +129,7 @@ class MaisakaHeartFlowChatting:
await self._tool_registry.close() await self._tool_registry.close()
self._mcp_manager = None self._mcp_manager = None
self._mcp_host_bridge = None
logger.info(f"{self.log_prefix} Maisaka 运行时已停止") logger.info(f"{self.log_prefix} Maisaka 运行时已停止")
@@ -385,7 +388,13 @@ class MaisakaHeartFlowChatting:
async def _init_mcp(self) -> None: async def _init_mcp(self) -> None:
"""初始化 MCP 工具并注册到统一工具层。""" """初始化 MCP 工具并注册到统一工具层。"""
self._mcp_manager = await MCPManager.from_app_config(global_config.mcp) self._mcp_host_bridge = MCPHostLLMBridge(
sampling_task_name=global_config.mcp.client.sampling.task_name,
)
self._mcp_manager = await MCPManager.from_app_config(
global_config.mcp,
host_callbacks=self._mcp_host_bridge.build_callbacks(),
)
if self._mcp_manager is None: if self._mcp_manager is None:
logger.info(f"{self.log_prefix} MCP 管理器不可用") logger.info(f"{self.log_prefix} MCP 管理器不可用")
return return
@@ -397,8 +406,8 @@ class MaisakaHeartFlowChatting:
self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager)) self._tool_registry.register_provider(MCPToolProvider(self._mcp_manager))
logger.info( logger.info(
f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tool_specs)} 个 MCP 工具:\n" f"{self.log_prefix} 已向 Maisaka 加载 {len(mcp_tool_specs)} 个 MCP 工具\n"
f"{self._mcp_manager.get_tool_summary()}" f"{self._mcp_manager.get_feature_summary()}"
) )
def _build_runtime_user_info(self) -> UserInfo: def _build_runtime_user_info(self) -> UserInfo:

View File

@@ -6,37 +6,120 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING: if TYPE_CHECKING:
from src.config.official_configs import MCPConfig from src.config.official_configs import MCPConfig
@dataclass(slots=True)
class MCPAuthorizationRuntimeConfig:
"""MCP HTTP 认证运行时配置。"""
mode: Literal["none", "bearer"] = "none"
bearer_token: str = ""
@dataclass(slots=True)
class MCPRootRuntimeConfig:
"""MCP Root 运行时配置。"""
uri: str
name: str = ""
@dataclass(slots=True)
class MCPClientRuntimeConfig:
"""MCP 客户端宿主能力运行时配置。"""
client_name: str = "MaiBot"
client_version: str = "1.0.0"
enable_roots: bool = False
roots: list[MCPRootRuntimeConfig] = field(default_factory=list)
enable_sampling: bool = False
sampling_task_name: str = "planner"
sampling_include_context_support: bool = False
sampling_tool_support: bool = False
enable_elicitation: bool = False
elicitation_allow_form: bool = True
elicitation_allow_url: bool = False
@dataclass(slots=True) @dataclass(slots=True)
class MCPServerRuntimeConfig: class MCPServerRuntimeConfig:
"""单个 MCP 服务器的运行时配置。""" """单个 MCP 服务器的运行时配置。"""
name: str name: str
transport: Literal["stdio", "streamable_http"] = "stdio"
command: str = "" command: str = ""
args: list[str] = field(default_factory=list) args: list[str] = field(default_factory=list)
env: dict[str, str] = field(default_factory=dict) env: dict[str, str] = field(default_factory=dict)
url: str = "" url: str = ""
headers: dict[str, str] = field(default_factory=dict) headers: dict[str, str] = field(default_factory=dict)
http_timeout_seconds: float = 30.0
read_timeout_seconds: float = 300.0
authorization: MCPAuthorizationRuntimeConfig = field(default_factory=MCPAuthorizationRuntimeConfig)
@property @property
def transport_type(self) -> str: def transport_type(self) -> str:
"""返回当前服务器的传输类型。 """返回当前服务器的传输类型。
Returns: Returns:
str: ``stdio``、``sse`` 或 ``unknown``。 str: ``stdio``、``streamable_http`` 或 ``unknown``。
""" """
if self.command: if self.transport == "stdio" and self.command:
return "stdio" return "stdio"
if self.url: if self.transport == "streamable_http" and self.url:
return "sse" return "streamable_http"
return "unknown" return "unknown"
def build_http_headers(self) -> dict[str, str]:
"""构建远程 HTTP 连接需要附加的请求头。
Returns:
dict[str, str]: 归一化后的请求头集合。
"""
headers = {str(key): str(value) for key, value in self.headers.items()}
if self.authorization.mode == "bearer" and self.authorization.bearer_token.strip():
headers["Authorization"] = f"Bearer {self.authorization.bearer_token.strip()}"
return headers
def build_mcp_client_runtime_config(mcp_config: "MCPConfig") -> MCPClientRuntimeConfig:
"""将官方 MCP 客户端配置转换为运行时结构。
Args:
mcp_config: 主程序中的 MCP 官方配置对象。
Returns:
MCPClientRuntimeConfig: MCP 客户端宿主能力运行时配置。
"""
roots = [
MCPRootRuntimeConfig(
uri=root.uri.strip(),
name=root.name.strip(),
)
for root in mcp_config.client.roots.items
if root.enabled and root.uri.strip()
]
return MCPClientRuntimeConfig(
client_name=mcp_config.client.client_name.strip() or "MaiBot",
client_version=mcp_config.client.client_version.strip() or "1.0.0",
enable_roots=mcp_config.client.roots.enable and bool(roots),
roots=roots,
enable_sampling=mcp_config.client.sampling.enable,
sampling_task_name=mcp_config.client.sampling.task_name.strip() or "planner",
sampling_include_context_support=mcp_config.client.sampling.include_context_support,
sampling_tool_support=mcp_config.client.sampling.tool_support,
enable_elicitation=mcp_config.client.elicitation.enable,
elicitation_allow_form=mcp_config.client.elicitation.allow_form,
elicitation_allow_url=mcp_config.client.elicitation.allow_url,
)
def build_mcp_server_runtime_configs(mcp_config: "MCPConfig") -> list[MCPServerRuntimeConfig]: def build_mcp_server_runtime_configs(mcp_config: "MCPConfig") -> list[MCPServerRuntimeConfig]:
"""将官方 MCP 配置转换为运行时配置列表。 """将官方 MCP 配置转换为运行时配置列表。
@@ -59,11 +142,18 @@ def build_mcp_server_runtime_configs(mcp_config: "MCPConfig") -> list[MCPServerR
runtime_configs.append( runtime_configs.append(
MCPServerRuntimeConfig( MCPServerRuntimeConfig(
name=server.name.strip(), name=server.name.strip(),
transport=server.transport,
command=server.command.strip(), command=server.command.strip(),
args=[str(argument) for argument in server.args], args=[str(argument) for argument in server.args],
env={str(key): str(value) for key, value in server.env.items()}, env={str(key): str(value) for key, value in server.env.items()},
url=server.url.strip(), url=server.url.strip(),
headers={str(key): str(value) for key, value in server.headers.items()}, headers={str(key): str(value) for key, value in server.headers.items()},
http_timeout_seconds=float(server.http_timeout_seconds),
read_timeout_seconds=float(server.read_timeout_seconds),
authorization=MCPAuthorizationRuntimeConfig(
mode=server.authorization.mode,
bearer_token=server.authorization.bearer_token.strip(),
),
) )
) )

View File

@@ -1,16 +1,31 @@
""" """
MaiSaka - 单个 MCP 服务器连接管理 MaiSaka - 单个 MCP 服务器连接管理
封装单个 MCP 服务器的连接生命周期:连接 → 发现工具 → 调用工具 → 断开。 封装单个 MCP 服务器的连接生命周期:连接 → 发现能力 → 调用工具/读取资源 → 断开。
""" """
from contextlib import AsyncExitStack from __future__ import annotations
from typing import Any, Optional
from src.core.tooling import ToolExecutionResult from contextlib import AsyncExitStack
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
import httpx
from src.cli.console import console from src.cli.console import console
from src.core.tooling import ToolExecutionResult
from .config import MCPServerRuntimeConfig from .config import MCPClientRuntimeConfig, MCPRootRuntimeConfig, MCPServerRuntimeConfig
from .hooks import MCPHostCallbacks
from .models import (
MCPPromptResult,
MCPResourceReadResult,
build_prompt_result,
build_resource_read_result,
build_tool_content_items,
)
if TYPE_CHECKING:
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
# ──────────────────── MCP SDK 可选导入 ──────────────────── # ──────────────────── MCP SDK 可选导入 ────────────────────
# #
@@ -18,7 +33,7 @@ from .config import MCPServerRuntimeConfig
# MCPManager.from_app_config() 会检测到并返回 None不影响主程序运行。 # MCPManager.from_app_config() 会检测到并返回 None不影响主程序运行。
try: try:
from mcp import ClientSession from mcp import ClientSession, types as mcp_types
try: try:
from mcp.client.stdio import StdioServerParameters from mcp.client.stdio import StdioServerParameters
@@ -26,84 +41,114 @@ try:
from mcp import StdioServerParameters # type: ignore[attr-defined] from mcp import StdioServerParameters # type: ignore[attr-defined]
from mcp.client.stdio import stdio_client from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
MCP_AVAILABLE = True MCP_AVAILABLE = True
STREAMABLE_HTTP_AVAILABLE = True
except ImportError: except ImportError:
MCP_AVAILABLE = False MCP_AVAILABLE = False
STREAMABLE_HTTP_AVAILABLE = False
ClientSession = None # type: ignore[assignment,misc] ClientSession = None # type: ignore[assignment,misc]
StdioServerParameters = None # type: ignore[assignment,misc] StdioServerParameters = None # type: ignore[assignment,misc]
mcp_types = None # type: ignore[assignment]
stdio_client = None # type: ignore[assignment] stdio_client = None # type: ignore[assignment]
streamable_http_client = None # type: ignore[assignment]
try:
from mcp.client.sse import sse_client
SSE_AVAILABLE = True
except ImportError:
SSE_AVAILABLE = False
sse_client = None # type: ignore[assignment]
class MCPConnection: class MCPConnection:
"""管理单个 MCP 服务器的连接生命周期。 """管理单个 MCP 服务器的连接生命周期。"""
支持两种传输方式: def __init__(
- Stdio: 启动子进程,通过 stdin/stdout 通信 self,
- SSE: 连接远程 HTTP SSE 端点 config: MCPServerRuntimeConfig,
""" client_config: MCPClientRuntimeConfig,
host_callbacks: Optional[MCPHostCallbacks] = None,
def __init__(self, config: MCPServerRuntimeConfig) -> None: ) -> None:
"""初始化单个 MCP 连接。 """初始化单个 MCP 连接。
Args: Args:
config: 当前服务器的运行时配置。 config: 当前服务器的运行时配置。
client_config: MCP 客户端宿主能力运行时配置。
host_callbacks: 宿主侧能力回调集合。
""" """
self.config = config self.config = config
self.session: Optional[Any] = None # mcp.ClientSession self.client_config = client_config
self.tools: list = [] # mcp Tool objects self.host_callbacks = host_callbacks or MCPHostCallbacks()
self.session: Optional[Any] = None
self.server_capabilities: Optional[Any] = None
self.tools: list[Any] = []
self.prompts: list[Any] = []
self.resources: list[Any] = []
self.resource_templates: list[Any] = []
self.protocol_version: str = ""
self._http_client: Optional[httpx.AsyncClient] = None
self._session_id_getter: Optional[Callable[[], str | None]] = None
self._exit_stack = AsyncExitStack() self._exit_stack = AsyncExitStack()
async def connect(self) -> bool: @property
""" def session_id(self) -> str:
连接到 MCP 服务器并发现可用工具 """返回当前连接协商得到的 MCP 会话标识
Returns: Returns:
True 表示连接成功False 表示失败 str: 当前会话 ID无会话时返回空字符串
""" """
if self._session_id_getter is None:
return ""
return self._session_id_getter() or ""
async def connect(self) -> bool:
"""连接到 MCP 服务器并发现可用能力。
Returns:
bool: `True` 表示连接成功,`False` 表示失败。
"""
if not MCP_AVAILABLE: if not MCP_AVAILABLE:
console.print("[warning]⚠️ 未安装 mcp SDK请运行: pip install mcp[/warning]") console.print("[warning]⚠️ 未安装 mcp SDK请运行: pip install mcp[/warning]")
return False return False
try: try:
await self._exit_stack.__aenter__() await self._exit_stack.__aenter__()
read_stream, write_stream = await self._connect_transport()
session = await self._create_client_session(read_stream, write_stream)
self.session = session
initialize_result = await session.initialize()
self.server_capabilities = getattr(initialize_result, "capabilities", None)
self.protocol_version = str(getattr(initialize_result, "protocolVersion", "") or "")
if self.config.transport_type == "stdio": await self._load_server_features()
read_stream, write_stream = await self._connect_stdio()
elif self.config.transport_type == "sse":
read_stream, write_stream = await self._connect_sse()
else:
console.print(f"[warning]MCP '{self.config.name}': 未知传输类型[/warning]")
return False
# 创建并初始化 MCP 会话
if ClientSession is None:
raise RuntimeError("当前环境未安装可用的 MCP ClientSession")
self.session = await self._exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
await self.session.initialize()
# 发现工具
result = await self.session.list_tools()
self.tools = result.tools if hasattr(result, "tools") else []
return True return True
except Exception as e: except Exception as exc:
console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {e}[/warning]") console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {exc}[/warning]")
await self.close() await self.close()
return False return False
async def _connect_stdio(self): async def _connect_transport(self) -> tuple[Any, Any]:
"""建立 Stdio 传输连接。""" """根据配置建立底层传输连接。
Returns:
tuple[Any, Any]: 读写流对象。
"""
if self.config.transport_type == "stdio":
return await self._connect_stdio()
if self.config.transport_type == "streamable_http":
return await self._connect_streamable_http()
raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}")
async def _connect_stdio(self) -> tuple[Any, Any]:
"""建立 stdio 传输连接。
Returns:
tuple[Any, Any]: 读写流对象。
"""
if StdioServerParameters is None or stdio_client is None: if StdioServerParameters is None or stdio_client is None:
raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端") raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端")
if not self.config.command: if not self.config.command:
@@ -116,15 +161,293 @@ class MCPConnection:
) )
return await self._exit_stack.enter_async_context(stdio_client(params)) return await self._exit_stack.enter_async_context(stdio_client(params))
async def _connect_sse(self): async def _connect_streamable_http(self) -> tuple[Any, Any]:
"""建立 SSE 传输连接。""" """建立 Streamable HTTP 传输连接。
if not SSE_AVAILABLE:
raise ImportError("SSE 传输需要额外依赖,请运行: pip install mcp[sse]") Returns:
if sse_client is None: tuple[Any, Any]: 读写流对象。
raise RuntimeError("当前环境未安装可用的 MCP SSE 客户端") """
if not STREAMABLE_HTTP_AVAILABLE or streamable_http_client is None:
raise ImportError("当前环境未安装可用的 MCP Streamable HTTP 客户端")
if not self.config.url: if not self.config.url:
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置") raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 Streamable HTTP url 配置")
return await self._exit_stack.enter_async_context(sse_client(url=self.config.url, headers=self.config.headers))
self._http_client = await self._exit_stack.enter_async_context(self._build_http_client())
read_stream, write_stream, session_id_getter = await self._exit_stack.enter_async_context(
streamable_http_client(
url=self.config.url,
http_client=self._http_client,
terminate_on_close=True,
)
)
self._session_id_getter = session_id_getter
return read_stream, write_stream
def _build_http_client(self) -> httpx.AsyncClient:
"""构建 Streamable HTTP 使用的 `httpx` 客户端。
Returns:
httpx.AsyncClient: 预配置的异步 HTTP 客户端。
"""
return httpx.AsyncClient(
headers=self.config.build_http_headers(),
timeout=httpx.Timeout(self.config.http_timeout_seconds),
)
async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any:
"""创建并返回 MCP `ClientSession`。
Args:
read_stream: 底层读取流。
write_stream: 底层写入流。
Returns:
Any: 已初始化的 MCP `ClientSession` 实例。
"""
if ClientSession is None:
raise RuntimeError("当前环境未安装可用的 MCP ClientSession")
list_roots_callback = self._build_list_roots_callback()
sampling_callback = (
self.host_callbacks.sampling_callback
if self.client_config.enable_sampling and self.host_callbacks.sampling_callback is not None
else None
)
elicitation_callback = (
self.host_callbacks.elicitation_callback
if self.client_config.enable_elicitation and self.host_callbacks.elicitation_callback is not None
else None
)
logging_callback = cast(Optional["LoggingFnT"], self.host_callbacks.logging_callback)
message_handler = cast(Optional["MessageHandlerFnT"], self.host_callbacks.message_handler)
if self.client_config.enable_sampling and sampling_callback is None:
console.print(
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 sampling 配置,但宿主未提供 sampling 回调,当前不会声明该能力[/warning]"
)
if self.client_config.enable_elicitation and elicitation_callback is None:
console.print(
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 elicitation 配置,但宿主未提供 elicitation 回调,当前不会声明该能力[/warning]"
)
session = await self._exit_stack.enter_async_context(
ClientSession(
read_stream,
write_stream,
read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds),
sampling_callback=cast(Optional["SamplingFnT"], sampling_callback),
elicitation_callback=cast(Optional["ElicitationFnT"], elicitation_callback),
list_roots_callback=cast(Optional["ListRootsFnT"], list_roots_callback),
logging_callback=logging_callback,
message_handler=message_handler,
client_info=self._build_client_info(),
sampling_capabilities=self._build_sampling_capabilities(sampling_callback),
)
)
return session
def _build_client_info(self) -> Any:
"""构建 MCP 客户端实现信息。
Returns:
Any: MCP SDK 的 `Implementation` 对象。
"""
if mcp_types is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
return mcp_types.Implementation(
name=self.client_config.client_name,
version=self.client_config.client_version,
)
def _build_sampling_capabilities(self, sampling_callback: Any) -> Any | None:
"""构建 Sampling 能力声明。
Args:
sampling_callback: 当前宿主侧的 Sampling 回调。
Returns:
Any | None: Sampling 能力对象;未启用时返回 ``None``。
"""
if mcp_types is None:
return None
if sampling_callback is None:
return None
context_capability = (
mcp_types.SamplingContextCapability()
if self.client_config.sampling_include_context_support
else None
)
tools_capability = (
mcp_types.SamplingToolsCapability()
if self.client_config.sampling_tool_support
else None
)
return mcp_types.SamplingCapability(
context=context_capability,
tools=tools_capability,
)
def _build_list_roots_callback(self) -> Any | None:
"""构建 Roots 列表回调。
Returns:
Any | None: 符合 MCP SDK 要求的回调;未启用时返回 ``None``。
"""
if mcp_types is None:
return None
if not self.client_config.enable_roots or not self.client_config.roots:
return None
async def _list_roots(context: Any) -> Any:
"""返回当前客户端声明的 Roots 列表。
Args:
context: MCP 请求上下文。
Returns:
Any: MCP `ListRootsResult` 对象。
"""
del context
types_module = mcp_types
if types_module is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
roots = [
types_module.Root(uri=cast(Any, root.uri), name=root.name or None)
for root in self.client_config.roots
]
return types_module.ListRootsResult(roots=roots)
return _list_roots
async def _load_server_features(self) -> None:
"""根据服务端能力声明加载工具、Prompt 与 Resource。"""
self.tools = await self._list_tools() if self.supports_tools() else []
self.prompts = await self._list_prompts() if self.supports_prompts() else []
self.resources = await self._list_resources() if self.supports_resources() else []
self.resource_templates = (
await self._list_resource_templates() if self.supports_resources() else []
)
def supports_tools(self) -> bool:
"""判断服务端是否声明支持 Tools。
Returns:
bool: 是否支持 Tools。
"""
return bool(self.server_capabilities is not None and getattr(self.server_capabilities, "tools", None) is not None)
def supports_prompts(self) -> bool:
"""判断服务端是否声明支持 Prompts。
Returns:
bool: 是否支持 Prompts。
"""
return bool(
self.server_capabilities is not None and getattr(self.server_capabilities, "prompts", None) is not None
)
def supports_resources(self) -> bool:
"""判断服务端是否声明支持 Resources。
Returns:
bool: 是否支持 Resources。
"""
return bool(
self.server_capabilities is not None and getattr(self.server_capabilities, "resources", None) is not None
)
async def _list_tools(self) -> list[Any]:
"""分页加载服务端暴露的全部工具。
Returns:
list[Any]: MCP SDK 的原始工具对象列表。
"""
if self.session is None:
return []
tools: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_tools(cursor=cursor)
tools.extend(list(getattr(result, "tools", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return tools
async def _list_prompts(self) -> list[Any]:
"""分页加载服务端暴露的全部 Prompt。
Returns:
list[Any]: MCP SDK 的原始 Prompt 对象列表。
"""
if self.session is None:
return []
prompts: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_prompts(cursor=cursor)
prompts.extend(list(getattr(result, "prompts", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return prompts
async def _list_resources(self) -> list[Any]:
"""分页加载服务端暴露的全部 Resource。
Returns:
list[Any]: MCP SDK 的原始 Resource 对象列表。
"""
if self.session is None:
return []
resources: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_resources(cursor=cursor)
resources.extend(list(getattr(result, "resources", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return resources
async def _list_resource_templates(self) -> list[Any]:
"""分页加载服务端暴露的全部 Resource Template。
Returns:
list[Any]: MCP SDK 的原始 Resource Template 对象列表。
"""
if self.session is None:
return []
resource_templates: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_resource_templates(cursor=cursor)
resource_templates.extend(list(getattr(result, "resourceTemplates", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return resource_templates
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult: async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult:
"""调用 MCP 工具并返回统一执行结果。 """调用 MCP 工具并返回统一执行结果。
@@ -137,15 +460,20 @@ class MCPConnection:
ToolExecutionResult: 统一执行结果。 ToolExecutionResult: 统一执行结果。
""" """
if not self.session: if self.session is None:
return ToolExecutionResult( return ToolExecutionResult(
tool_name=tool_name, tool_name=tool_name,
success=False, success=False,
error_message=f"MCP 服务器 '{self.config.name}' 未连接", error_message=f"MCP 服务器 '{self.config.name}' 未连接",
metadata={"server_name": self.config.name},
) )
try: try:
result = await self.session.call_tool(tool_name, arguments=arguments) result = await self.session.call_tool(
tool_name,
arguments=arguments,
read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds),
)
except Exception as exc: except Exception as exc:
return ToolExecutionResult( return ToolExecutionResult(
tool_name=tool_name, tool_name=tool_name,
@@ -154,33 +482,78 @@ class MCPConnection:
metadata={"server_name": self.config.name}, metadata={"server_name": self.config.name},
) )
text_parts: list[str] = [] content_items = build_tool_content_items(list(getattr(result, "content", []) or []))
binary_parts: list[dict[str, Any]] = [] text_parts = [item.text.strip() for item in content_items if item.content_type == "text" and item.text.strip()]
for content in result.content: structured_content = getattr(result, "structuredContent", None)
if hasattr(content, "text"): is_error = bool(getattr(result, "isError", False))
text_parts.append(str(content.text)) history_content = "\n".join(text_parts).strip()
elif hasattr(content, "data"): error_message = history_content if is_error else ""
content_type = getattr(content, "mimeType", "unknown")
binary_parts.append({"mime_type": content_type, "type": "binary"})
text_parts.append(f"[{content_type} 二进制内容]")
elif hasattr(content, "type"):
text_parts.append(f"[{content.type} 内容]")
return ToolExecutionResult( return ToolExecutionResult(
tool_name=tool_name, tool_name=tool_name,
success=True, success=not is_error,
content="\n".join(text_parts) if text_parts else "工具执行成功(无输出)", content=history_content if not is_error else "",
error_message=error_message,
structured_content=structured_content,
content_items=content_items,
metadata={ metadata={
"server_name": self.config.name, "server_name": self.config.name,
"binary_parts": binary_parts, "protocol_version": self.protocol_version,
"session_id": self.session_id,
}, },
) )
async def get_prompt(
self,
prompt_name: str,
arguments: Optional[dict[str, str]] = None,
) -> MCPPromptResult:
"""读取指定 MCP Prompt 的内容。
Args:
prompt_name: Prompt 名称。
arguments: Prompt 参数字典。
Returns:
MCPPromptResult: 统一 Prompt 结果。
"""
if self.session is None:
raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接")
result = await self.session.get_prompt(prompt_name, arguments=arguments)
return build_prompt_result(result, prompt_name=prompt_name, server_name=self.config.name)
async def read_resource(self, uri: str) -> MCPResourceReadResult:
"""读取指定 MCP Resource 的内容。
Args:
uri: 资源 URI。
Returns:
MCPResourceReadResult: 统一资源读取结果。
"""
if self.session is None:
raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接")
result = await self.session.read_resource(uri)
return build_resource_read_result(result, uri=uri, server_name=self.config.name)
async def close(self) -> None: async def close(self) -> None:
"""关闭连接并释放资源。""" """关闭连接并释放资源。"""
try: try:
await self._exit_stack.aclose() await self._exit_stack.aclose()
except Exception: except Exception:
pass pass
self.session = None self.session = None
self.server_capabilities = None
self.tools = [] self.tools = []
self.prompts = []
self.resources = []
self.resource_templates = []
self.protocol_version = ""
self._http_client = None
self._session_id_getter = None

20
src/mcp_module/hooks.py Normal file
View File

@@ -0,0 +1,20 @@
"""MCP 宿主回调声明。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
@dataclass(slots=True)
class MCPHostCallbacks:
"""MCP 宿主回调集合。
该对象用于向 `MCPConnection` 注入宿主侧可选能力,
例如 Sampling、Elicitation、日志消费和自定义消息处理。
"""
sampling_callback: Callable[..., Awaitable[Any]] | None = None
elicitation_callback: Callable[..., Awaitable[Any]] | None = None
logging_callback: Callable[..., Awaitable[None]] | None = None
message_handler: Callable[..., Awaitable[None]] | None = None

View File

@@ -0,0 +1,597 @@
"""MCP 宿主侧大模型桥接服务。"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
import json
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMResponseResult
from src.common.logger import get_logger
from src.core.tooling import build_tool_detailed_description
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
from src.services.llm_service import LLMServiceClient
from .hooks import MCPHostCallbacks
from .models import build_tool_content_items
if TYPE_CHECKING:
from src.llm_models.model_client.base_client import BaseClient
try:
from mcp import types as mcp_types
MCP_TYPES_AVAILABLE = True
except ImportError:
mcp_types = None # type: ignore[assignment]
MCP_TYPES_AVAILABLE = False
logger = get_logger("mcp_host_llm_bridge")
class MCPHostLLMBridge:
"""将 MCP Sampling 请求桥接到主程序大模型调用链。"""
def __init__(self, sampling_task_name: str = "planner") -> None:
"""初始化 MCP 宿主侧大模型桥接服务。
Args:
sampling_task_name: 执行 Sampling 请求时使用的模型任务名。
"""
self._sampling_task_name = sampling_task_name.strip() or "planner"
self._sampling_client = LLMServiceClient(
task_name=self._sampling_task_name,
request_type="mcp_sampling",
)
def build_callbacks(self) -> MCPHostCallbacks:
"""构建可注入给 MCP 连接层的宿主回调集合。
Returns:
MCPHostCallbacks: 包含 Sampling 回调的宿主回调集合。
"""
return MCPHostCallbacks(
sampling_callback=self.handle_sampling_request,
)
async def handle_sampling_request(self, context: Any, params: Any) -> Any:
"""处理服务端发起的 MCP Sampling 请求。
Args:
context: MCP SDK 传入的请求上下文。
params: `sampling/createMessage` 请求参数。
Returns:
Any: MCP `CreateMessageResult`、`CreateMessageResultWithTools` 或 `ErrorData`。
"""
del context
if not MCP_TYPES_AVAILABLE or mcp_types is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
try:
tool_choice_mode = self._get_tool_choice_mode(params)
tool_definitions = self._build_tool_definitions(
raw_tools=getattr(params, "tools", None),
tool_choice_mode=tool_choice_mode,
)
message_factory = self._build_message_factory(
raw_messages=list(getattr(params, "messages", []) or []),
system_prompt=self._build_system_prompt(
raw_system_prompt=str(getattr(params, "systemPrompt", "") or ""),
tool_choice_mode=tool_choice_mode,
tool_definitions=tool_definitions,
),
)
generation_result = await self._sampling_client.generate_response_with_messages(
message_factory=message_factory,
options=LLMGenerationOptions(
temperature=self._coerce_float(getattr(params, "temperature", None)),
max_tokens=int(getattr(params, "maxTokens", 1024) or 1024),
tool_options=tool_definitions,
),
)
if tool_choice_mode == "required" and tool_definitions and not generation_result.tool_calls:
return mcp_types.ErrorData(
code=mcp_types.INTERNAL_ERROR,
message="Sampling 要求必须调用工具,但模型未返回任何工具调用",
)
return self._build_sampling_result(
generation_result=generation_result,
tools_enabled=bool(tool_definitions),
)
except Exception as exc:
logger.exception(f"MCP Sampling 调用失败: {exc}")
return mcp_types.ErrorData(
code=mcp_types.INTERNAL_ERROR,
message=f"MCP Sampling 调用失败: {exc}",
)
@staticmethod
def _coerce_float(raw_value: Any) -> float | None:
"""将任意原始值转换为浮点数。
Args:
raw_value: 原始输入值。
Returns:
float | None: 转换后的浮点数;无法转换时返回 ``None``。
"""
if raw_value is None:
return None
if isinstance(raw_value, int | float):
return float(raw_value)
return None
@staticmethod
def _get_tool_choice_mode(params: Any) -> str:
"""读取 Sampling 请求中的工具选择模式。
Args:
params: Sampling 请求参数对象。
Returns:
str: `auto`、`required` 或 `none`;缺省时返回 `auto`。
"""
tool_choice = getattr(params, "toolChoice", None)
mode = str(getattr(tool_choice, "mode", "") or "").strip().lower()
if mode in {"required", "none"}:
return mode
return "auto"
def _build_system_prompt(
self,
raw_system_prompt: str,
tool_choice_mode: str,
tool_definitions: list[ToolDefinitionInput] | None,
) -> str:
"""构建发送给主程序大模型的系统提示词。
Args:
raw_system_prompt: 服务端请求中的系统提示词。
tool_choice_mode: 当前工具选择模式。
tool_definitions: 参与本次 Sampling 的工具定义。
Returns:
str: 最终系统提示词。
"""
prompt_parts: list[str] = []
if raw_system_prompt.strip():
prompt_parts.append(raw_system_prompt.strip())
if tool_choice_mode == "required" and tool_definitions:
prompt_parts.append("本轮回答必须至少调用一个工具;不要直接结束回答。")
return "\n\n".join(part for part in prompt_parts if part).strip()
def _build_message_factory(
self,
raw_messages: list[Any],
system_prompt: str,
) -> Any:
"""构建 MCP Sampling 使用的消息工厂。
Args:
raw_messages: MCP Sampling 原始消息列表。
system_prompt: 规范化后的系统提示词。
Returns:
Any: 供 `LLMServiceClient` 使用的消息工厂。
"""
def _message_factory(client: "BaseClient") -> list[Message]:
"""延迟构建内部消息列表。
Args:
client: 当前被选中的底层模型客户端。
Returns:
list[Message]: 内部统一消息列表。
"""
messages: list[Message] = []
if system_prompt.strip():
messages.append(
MessageBuilder()
.set_role(RoleType.System)
.add_text_content(system_prompt.strip())
.build()
)
for raw_message in raw_messages:
messages.extend(self._convert_sampling_message(raw_message, client))
return messages
return _message_factory
def _convert_sampling_message(self, raw_message: Any, client: "BaseClient") -> list[Message]:
"""将单条 MCP Sampling 消息转换为内部消息列表。
Args:
raw_message: MCP Sampling 原始消息对象。
client: 当前底层模型客户端。
Returns:
list[Message]: 转换后的内部消息列表。
"""
role = str(getattr(raw_message, "role", "") or "").strip().lower()
content_blocks = self._get_content_blocks(getattr(raw_message, "content", None))
if role == "assistant":
assistant_message = self._build_assistant_message(content_blocks, client)
return [assistant_message] if assistant_message is not None else []
if role == "user":
return self._build_user_messages(content_blocks, client)
raise ValueError(f"不支持的 MCP Sampling 消息角色: {role}")
@staticmethod
def _get_content_blocks(raw_content: Any) -> list[Any]:
"""将 MCP Sampling 消息内容统一为列表。
Args:
raw_content: 原始内容字段。
Returns:
list[Any]: 统一后的内容块列表。
"""
if raw_content is None:
return []
if isinstance(raw_content, list):
return list(raw_content)
return [raw_content]
def _build_assistant_message(self, content_blocks: list[Any], client: "BaseClient") -> Optional[Message]:
"""构建内部 assistant 消息。
Args:
content_blocks: MCP assistant 内容块列表。
client: 当前底层模型客户端。
Returns:
Optional[Message]: 转换后的内部 assistant 消息;无有效内容时返回 ``None``。
"""
message_builder = MessageBuilder().set_role(RoleType.Assistant)
tool_calls: list[ToolCall] = []
has_visible_content = False
for content_block in content_blocks:
content_type = self._get_content_type(content_block)
if content_type == "tool_use":
tool_calls.append(
ToolCall(
call_id=str(getattr(content_block, "id", "") or ""),
func_name=str(getattr(content_block, "name", "") or ""),
args=self._normalize_tool_call_arguments(getattr(content_block, "input", None)),
)
)
continue
has_visible_content = self._append_sampling_content_to_builder(
message_builder=message_builder,
content_block=content_block,
client=client,
) or has_visible_content
if tool_calls:
message_builder.set_tool_calls(tool_calls)
if not has_visible_content and not tool_calls:
return None
return message_builder.build()
def _build_user_messages(self, content_blocks: list[Any], client: "BaseClient") -> list[Message]:
"""构建内部 user/tool 消息序列。
Args:
content_blocks: MCP user 内容块列表。
client: 当前底层模型客户端。
Returns:
list[Message]: 转换后的内部消息序列。
"""
messages: list[Message] = []
message_builder = MessageBuilder().set_role(RoleType.User)
has_user_content = False
def flush_user_message() -> None:
"""在当前存在用户可见内容时落盘一条 user 消息。"""
nonlocal message_builder, has_user_content
if not has_user_content:
return
messages.append(message_builder.build())
message_builder = MessageBuilder().set_role(RoleType.User)
has_user_content = False
for content_block in content_blocks:
content_type = self._get_content_type(content_block)
if content_type == "tool_result":
flush_user_message()
messages.append(self._build_tool_result_message(content_block))
continue
has_user_content = self._append_sampling_content_to_builder(
message_builder=message_builder,
content_block=content_block,
client=client,
) or has_user_content
flush_user_message()
return messages
@staticmethod
def _get_content_type(content_block: Any) -> str:
"""读取 MCP 内容块类型。
Args:
content_block: MCP 内容块对象。
Returns:
str: 规范化后的内容块类型。
"""
return str(getattr(content_block, "type", "text") or "text").strip().lower()
def _append_sampling_content_to_builder(
self,
message_builder: MessageBuilder,
content_block: Any,
client: "BaseClient",
) -> bool:
"""将 MCP 普通内容块追加到内部消息构建器。
Args:
message_builder: 内部消息构建器。
content_block: MCP 内容块对象。
client: 当前底层模型客户端。
Returns:
bool: 是否成功追加了可见内容。
"""
content_type = self._get_content_type(content_block)
if content_type == "text":
text_content = str(getattr(content_block, "text", "") or "")
if text_content.strip():
message_builder.add_text_content(text_content)
return True
return False
if content_type == "image":
image_data = str(getattr(content_block, "data", "") or "")
image_mime_type = str(getattr(content_block, "mimeType", "") or "")
image_format = self._normalize_image_format(image_mime_type)
if image_data and image_format:
message_builder.add_image_content(
image_format=image_format,
image_base64=image_data,
support_formats=client.get_support_image_formats(),
)
return True
message_builder.add_text_content(
f"[图片内容mime_type={image_mime_type or 'unknown'},当前客户端无法直接透传]"
)
return True
if content_type == "audio":
audio_mime_type = str(getattr(content_block, "mimeType", "") or "")
message_builder.add_text_content(f"[音频内容mime_type={audio_mime_type or 'unknown'}]")
return True
return False
@staticmethod
def _normalize_image_format(mime_type: str) -> str:
"""将图片 MIME 类型转换为内部图片格式名称。
Args:
mime_type: MCP 图片 MIME 类型。
Returns:
str: 内部支持的图片格式名;不支持时返回空字符串。
"""
normalized_mime_type = mime_type.strip().lower()
if normalized_mime_type == "image/png":
return "png"
if normalized_mime_type in {"image/jpeg", "image/jpg"}:
return "jpeg"
if normalized_mime_type == "image/webp":
return "webp"
if normalized_mime_type == "image/gif":
return "gif"
return ""
def _build_tool_result_message(self, content_block: Any) -> Message:
"""将 MCP `tool_result` 内容块转换为内部 Tool 消息。
Args:
content_block: MCP `tool_result` 内容块对象。
Returns:
Message: 转换后的内部 Tool 消息。
"""
message_builder = MessageBuilder().set_role(RoleType.Tool)
message_builder.set_tool_call_id(str(getattr(content_block, "toolUseId", "") or "tool_result"))
summary_text = self._summarize_tool_result_content(content_block)
message_builder.add_text_content(summary_text or "工具执行完成。")
return message_builder.build()
def _summarize_tool_result_content(self, content_block: Any) -> str:
"""汇总 MCP `tool_result` 内容块中的结果文本。
Args:
content_block: MCP `tool_result` 内容块对象。
Returns:
str: 适合发送给主程序模型的工具结果摘要文本。
"""
raw_contents = list(getattr(content_block, "content", []) or [])
content_items = build_tool_content_items(raw_contents)
parts = [item.build_history_text().strip() for item in content_items if item.build_history_text().strip()]
structured_content = getattr(content_block, "structuredContent", None)
if structured_content is not None:
try:
parts.append(json.dumps(structured_content, ensure_ascii=False))
except (TypeError, ValueError):
parts.append(str(structured_content))
summary_text = "\n".join(part for part in parts if part).strip()
if bool(getattr(content_block, "isError", False)) and summary_text:
return f"工具执行失败:\n{summary_text}"
if bool(getattr(content_block, "isError", False)):
return "工具执行失败。"
return summary_text
@staticmethod
def _normalize_tool_call_arguments(raw_arguments: Any) -> dict[str, Any]:
"""将原始工具调用参数规范化为字典。
Args:
raw_arguments: 原始工具参数。
Returns:
dict[str, Any]: 规范化后的参数字典。
"""
if isinstance(raw_arguments, dict):
return dict(raw_arguments)
if raw_arguments is None:
return {}
return {"value": raw_arguments}
def _build_tool_definitions(
self,
raw_tools: Any,
tool_choice_mode: str,
) -> list[ToolDefinitionInput] | None:
"""将 MCP Sampling 工具定义转换为主程序内部工具定义。
Args:
raw_tools: MCP Sampling 请求中的工具列表。
tool_choice_mode: 当前工具选择模式。
Returns:
list[ToolDefinitionInput] | None: 可传给主程序模型层的工具定义列表。
"""
if tool_choice_mode == "none":
return None
if not isinstance(raw_tools, list) or not raw_tools:
return None
tool_definitions: list[ToolDefinitionInput] = []
for raw_tool in raw_tools:
tool_name = str(getattr(raw_tool, "name", "") or "").strip()
if not tool_name:
continue
parameters_schema = (
dict(getattr(raw_tool, "inputSchema", {}) or {}) if getattr(raw_tool, "inputSchema", None) else {}
)
if "$schema" in parameters_schema:
parameters_schema.pop("$schema")
title = str(getattr(raw_tool, "title", "") or "").strip()
description = str(getattr(raw_tool, "description", "") or "").strip()
brief_description = description or title or f"工具 {tool_name}"
detailed_description = build_tool_detailed_description(
parameters_schema,
fallback_description=f"工具名称:{tool_name}",
)
tool_definitions.append(
{
"name": tool_name,
"description": "\n\n".join(
part for part in [brief_description, detailed_description] if part.strip()
).strip(),
"parameters_schema": parameters_schema or {"type": "object", "properties": {}},
}
)
return tool_definitions or None
def _build_sampling_result(
self,
generation_result: LLMResponseResult,
tools_enabled: bool,
) -> Any:
"""将主程序模型响应转换为 MCP Sampling 结果。
Args:
generation_result: 主程序统一大模型响应结果。
tools_enabled: 当前是否允许模型使用工具。
Returns:
Any: MCP `CreateMessageResult` 或 `CreateMessageResultWithTools`。
"""
if not MCP_TYPES_AVAILABLE or mcp_types is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
text_content = str(generation_result.response or "")
tool_calls = list(generation_result.tool_calls or [])
model_name = generation_result.model_name or self._sampling_task_name
if tools_enabled:
content_blocks: list[Any] = []
if text_content.strip():
content_blocks.append(
mcp_types.TextContent(
type="text",
text=text_content,
)
)
for tool_call in tool_calls:
content_blocks.append(
mcp_types.ToolUseContent(
type="tool_use",
name=tool_call.func_name,
id=tool_call.call_id,
input=dict(tool_call.args or {}),
)
)
if not content_blocks:
content_blocks.append(
mcp_types.TextContent(
type="text",
text="",
)
)
return mcp_types.CreateMessageResultWithTools(
role="assistant",
content=content_blocks[0] if len(content_blocks) == 1 else content_blocks,
model=model_name,
stopReason="toolUse" if tool_calls else "endTurn",
)
return mcp_types.CreateMessageResult(
role="assistant",
content=mcp_types.TextContent(
type="text",
text=text_content,
),
model=model_name,
stopReason="endTurn",
)

View File

@@ -1,8 +1,10 @@
""" """
MaiSaka - MCP 管理器 MaiSaka - MCP 管理器
管理所有 MCP 服务器连接,提供统一的工具发现与调用接口。 管理所有 MCP 服务器连接,提供统一的工具、Prompt 与 Resource 访问入口。
""" """
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from src.cli.console import console from src.cli.console import console
@@ -13,8 +15,26 @@ from src.core.tooling import (
build_tool_detailed_description, build_tool_detailed_description,
) )
from .config import MCPServerRuntimeConfig, build_mcp_server_runtime_configs from .config import (
MCPClientRuntimeConfig,
MCPServerRuntimeConfig,
build_mcp_client_runtime_config,
build_mcp_server_runtime_configs,
)
from .connection import MCPConnection, MCP_AVAILABLE from .connection import MCPConnection, MCP_AVAILABLE
from .hooks import MCPHostCallbacks
from .models import (
MCPPromptResult,
MCPPromptSpec,
MCPResourceReadResult,
MCPResourceSpec,
MCPResourceTemplateSpec,
build_prompt_spec,
build_resource_spec,
build_resource_template_spec,
build_tool_annotation,
build_tool_icon,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from src.config.official_configs import MCPConfig from src.config.official_configs import MCPConfig
@@ -34,37 +54,44 @@ BUILTIN_TOOL_NAMES = frozenset(
class MCPManager: class MCPManager:
"""MCP 服务器连接管理器。 """MCP 服务器连接管理器。"""
职责: def __init__(
- 根据主程序官方配置连接所有 MCP 服务器 self,
- 将 MCP 工具转换为 OpenAI function calling 格式 client_config: MCPClientRuntimeConfig,
- 路由工具调用到正确的 MCP 服务器 host_callbacks: Optional[MCPHostCallbacks] = None,
- 统一管理连接生命周期 ) -> None:
""" """初始化 MCP 管理器。
def __init__(self) -> None: Args:
"""初始化 MCP 管理器。""" client_config: MCP 客户端宿主能力运行时配置。
host_callbacks: 宿主侧能力回调集合。
"""
self._connections: dict[str, MCPConnection] = {} # server_name → connection self._client_config = client_config
self._tool_to_server: dict[str, str] = {} # tool_name → server_name self._host_callbacks = host_callbacks or MCPHostCallbacks()
self._connections: dict[str, MCPConnection] = {}
# ──────── 工厂方法 ──────── self._tool_to_server: dict[str, str] = {}
self._prompt_to_server: dict[str, str] = {}
self._resource_to_server: dict[str, str] = {}
self._resource_template_to_server: dict[str, str] = {}
@classmethod @classmethod
async def from_app_config( async def from_app_config(
cls, cls,
mcp_config: "MCPConfig", mcp_config: "MCPConfig",
host_callbacks: Optional[MCPHostCallbacks] = None,
) -> Optional["MCPManager"]: ) -> Optional["MCPManager"]:
""" """从官方配置创建并初始化 `MCPManager`。
从官方配置创建并初始化 MCPManager。
Args: Args:
mcp_config: 主程序中的 MCP 配置对象。 mcp_config: 主程序中的 MCP 配置对象。
host_callbacks: 宿主侧能力回调集合。
Returns: Returns:
初始化完成的 MCPManager无可用配置或全部连接失败时返回 None。 Optional[MCPManager]: 初始化完成的管理器;无可用配置或全部连接失败时返回 ``None``
""" """
configs = build_mcp_server_runtime_configs(mcp_config) configs = build_mcp_server_runtime_configs(mcp_config)
if not configs: if not configs:
return None return None
@@ -73,7 +100,10 @@ class MCPManager:
console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK请运行: pip install mcp[/warning]") console.print("[warning]⚠️ 发现 MCP 配置但未安装 mcp SDK请运行: pip install mcp[/warning]")
return None return None
manager = cls() manager = cls(
client_config=build_mcp_client_runtime_config(mcp_config),
host_callbacks=host_callbacks,
)
await manager._connect_all(configs) await manager._connect_all(configs)
if not manager._connections: if not manager._connections:
@@ -82,48 +112,141 @@ class MCPManager:
return manager return manager
# ──────── 连接管理 ────────
async def _connect_all(self, configs: list[MCPServerRuntimeConfig]) -> None: async def _connect_all(self, configs: list[MCPServerRuntimeConfig]) -> None:
"""连接所有配置的 MCP 服务器,跳过失败的连接。""" """连接全部已配置的 MCP 服务器
for cfg in configs:
conn = MCPConnection(cfg) Args:
success = await conn.connect() configs: 服务器运行时配置列表。
Returns:
None
"""
for config in configs:
connection = MCPConnection(config, self._client_config, self._host_callbacks)
success = await connection.connect()
if not success: if not success:
continue continue
self._connections[cfg.name] = conn self._connections[config.name] = connection
registered_tool_count = self._register_tools(config.name, connection)
# 注册工具,检查冲突 registered_prompt_count = self._register_prompts(config.name, connection)
registered = 0 registered_resource_count = self._register_resources(config.name, connection)
for tool in conn.tools: registered_template_count = self._register_resource_templates(config.name, connection)
tool_name = tool.name
if tool_name in BUILTIN_TOOL_NAMES:
console.print(
f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {cfg.name}) 与内置工具冲突,已跳过[/warning]"
)
continue
if tool_name in self._tool_to_server:
existing_server = self._tool_to_server[tool_name]
console.print(
f"[warning]⚠️ MCP 工具 '{tool_name}' "
f"(来自 {cfg.name}) 与 {existing_server} 冲突,已跳过[/warning]"
)
continue
self._tool_to_server[tool_name] = cfg.name
registered += 1
console.print( console.print(
f"[success]✓ MCP 服务器 '{cfg.name}' 已连接[/success] [muted]({registered} 个工具已注册)[/muted]" "[success]✓ MCP 服务器 "
f"'{config.name}' 已连接[/success] "
f"[muted](工具 {registered_tool_count} / Prompt {registered_prompt_count} / "
f"资源 {registered_resource_count} / 模板 {registered_template_count})[/muted]"
) )
# ──────── 工具发现 ──────── def _register_tools(self, server_name: str, connection: MCPConnection) -> int:
"""注册单个服务器暴露的 MCP 工具。
Args:
server_name: 服务器名称。
connection: 对应连接对象。
Returns:
int: 成功注册的工具数量。
"""
registered_count = 0
for tool in connection.tools:
tool_name = str(tool.name)
if tool_name in BUILTIN_TOOL_NAMES:
console.print(
f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {server_name}) 与内置工具冲突,已跳过[/warning]"
)
continue
if tool_name in self._tool_to_server:
existing_server = self._tool_to_server[tool_name]
console.print(
f"[warning]⚠️ MCP 工具 '{tool_name}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]"
)
continue
self._tool_to_server[tool_name] = server_name
registered_count += 1
return registered_count
def _register_prompts(self, server_name: str, connection: MCPConnection) -> int:
"""注册单个服务器暴露的 MCP Prompt。
Args:
server_name: 服务器名称。
connection: 对应连接对象。
Returns:
int: 成功注册的 Prompt 数量。
"""
registered_count = 0
for prompt in connection.prompts:
prompt_name = str(prompt.name)
if prompt_name in self._prompt_to_server:
existing_server = self._prompt_to_server[prompt_name]
console.print(
f"[warning]⚠️ MCP Prompt '{prompt_name}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]"
)
continue
self._prompt_to_server[prompt_name] = server_name
registered_count += 1
return registered_count
def _register_resources(self, server_name: str, connection: MCPConnection) -> int:
"""注册单个服务器暴露的 MCP Resource。
Args:
server_name: 服务器名称。
connection: 对应连接对象。
Returns:
int: 成功注册的 Resource 数量。
"""
registered_count = 0
for resource in connection.resources:
resource_uri = str(resource.uri)
if resource_uri in self._resource_to_server:
existing_server = self._resource_to_server[resource_uri]
console.print(
f"[warning]⚠️ MCP Resource '{resource_uri}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]"
)
continue
self._resource_to_server[resource_uri] = server_name
registered_count += 1
return registered_count
def _register_resource_templates(self, server_name: str, connection: MCPConnection) -> int:
"""注册单个服务器暴露的 MCP Resource Template。
Args:
server_name: 服务器名称。
connection: 对应连接对象。
Returns:
int: 成功注册的模板数量。
"""
registered_count = 0
for resource_template in connection.resource_templates:
uri_template = str(resource_template.uriTemplate)
if uri_template in self._resource_template_to_server:
existing_server = self._resource_template_to_server[uri_template]
console.print(
"[warning]⚠️ MCP Resource Template "
f"'{uri_template}' (来自 {server_name}) 与 {existing_server} 冲突,已跳过[/warning]"
)
continue
self._resource_template_to_server[uri_template] = server_name
registered_count += 1
return registered_count
def _build_tool_parameters_schema(self, tool: Any) -> dict[str, Any] | None: def _build_tool_parameters_schema(self, tool: Any) -> dict[str, Any] | None:
"""构造单个 MCP 工具的对象级参数 Schema。 """构造单个 MCP 工具的参数 Schema。
Args: Args:
tool: MCP SDK 返回的原始工具对象。 tool: MCP SDK 返回的原始工具对象。
@@ -140,6 +263,21 @@ class MCPManager:
parameters_schema.pop("$schema", None) parameters_schema.pop("$schema", None)
return parameters_schema return parameters_schema
def _build_tool_output_schema(self, tool: Any) -> dict[str, Any] | None:
"""构造单个 MCP 工具的输出 Schema。
Args:
tool: MCP SDK 返回的原始工具对象。
Returns:
dict[str, Any] | None: 输出 Schema。
"""
output_schema = dict(tool.outputSchema) if hasattr(tool, "outputSchema") and tool.outputSchema else None
if isinstance(output_schema, dict):
output_schema.pop("$schema", None)
return output_schema
def get_tool_specs(self) -> list[ToolSpec]: def get_tool_specs(self) -> list[ToolSpec]:
"""获取全部已注册 MCP 工具的统一声明。 """获取全部已注册 MCP 工具的统一声明。
@@ -148,31 +286,79 @@ class MCPManager:
""" """
tool_specs: list[ToolSpec] = [] tool_specs: list[ToolSpec] = []
for server_name, conn in self._connections.items(): for server_name, connection in self._connections.items():
for tool in conn.tools: for tool in connection.tools:
if tool.name not in self._tool_to_server: if self._tool_to_server.get(tool.name) != server_name:
continue
if self._tool_to_server[tool.name] != server_name:
continue continue
parameters_schema = self._build_tool_parameters_schema(tool) parameters_schema = self._build_tool_parameters_schema(tool)
output_schema = self._build_tool_output_schema(tool)
brief_description = str(tool.description or f"来自 {server_name} 的 MCP 工具").strip() brief_description = str(tool.description or f"来自 {server_name} 的 MCP 工具").strip()
tool_specs.append( tool_specs.append(
ToolSpec( ToolSpec(
name=str(tool.name), name=str(tool.name),
title=str(getattr(tool, "title", "") or ""),
brief_description=brief_description, brief_description=brief_description,
detailed_description=build_tool_detailed_description( detailed_description=build_tool_detailed_description(
parameters_schema, parameters_schema,
fallback_description=f"工具来源MCP 服务 {server_name}", fallback_description=f"工具来源MCP 服务 {server_name}",
), ),
parameters_schema=parameters_schema, parameters_schema=parameters_schema,
output_schema=output_schema,
provider_name="mcp", provider_name="mcp",
provider_type="mcp", provider_type="mcp",
metadata={"server_name": server_name}, icons=[build_tool_icon(item) for item in getattr(tool, "icons", []) or []],
annotation=build_tool_annotation(getattr(tool, "annotations", None)),
metadata={"server_name": server_name} | getattr(tool, "meta", {}),
) )
) )
return tool_specs return tool_specs
def get_prompt_specs(self) -> list[MCPPromptSpec]:
"""获取全部已注册 MCP Prompt 声明。
Returns:
list[MCPPromptSpec]: Prompt 声明列表。
"""
prompt_specs: list[MCPPromptSpec] = []
for server_name, connection in self._connections.items():
for prompt in connection.prompts:
if self._prompt_to_server.get(prompt.name) != server_name:
continue
prompt_specs.append(build_prompt_spec(prompt, server_name))
return prompt_specs
def get_resource_specs(self) -> list[MCPResourceSpec]:
"""获取全部已注册 MCP Resource 声明。
Returns:
list[MCPResourceSpec]: Resource 声明列表。
"""
resource_specs: list[MCPResourceSpec] = []
for server_name, connection in self._connections.items():
for resource in connection.resources:
if self._resource_to_server.get(resource.uri) != server_name:
continue
resource_specs.append(build_resource_spec(resource, server_name))
return resource_specs
def get_resource_template_specs(self) -> list[MCPResourceTemplateSpec]:
"""获取全部已注册 MCP Resource Template 声明。
Returns:
list[MCPResourceTemplateSpec]: Resource Template 声明列表。
"""
resource_template_specs: list[MCPResourceTemplateSpec] = []
for server_name, connection in self._connections.items():
for resource_template in connection.resource_templates:
if self._resource_template_to_server.get(resource_template.uriTemplate) != server_name:
continue
resource_template_specs.append(build_resource_template_spec(resource_template, server_name))
return resource_template_specs
def get_openai_tools(self) -> list[dict[str, Any]]: def get_openai_tools(self) -> list[dict[str, Any]]:
"""获取兼容旧模型层的 MCP 工具定义。 """获取兼容旧模型层的 MCP 工具定义。
@@ -192,12 +378,42 @@ class MCPManager:
for tool_spec in self.get_tool_specs() for tool_spec in self.get_tool_specs()
] ]
# ──────── 工具调用 ────────
def is_mcp_tool(self, tool_name: str) -> bool: def is_mcp_tool(self, tool_name: str) -> bool:
"""判断工具名是否为已注册 MCP 工具。""" """判断给定名称是否为已注册 MCP 工具。
Args:
tool_name: 工具名称。
Returns:
bool: 是否存在。
"""
return tool_name in self._tool_to_server return tool_name in self._tool_to_server
def is_mcp_prompt(self, prompt_name: str) -> bool:
"""判断给定名称是否为已注册 MCP Prompt。
Args:
prompt_name: Prompt 名称。
Returns:
bool: 是否存在。
"""
return prompt_name in self._prompt_to_server
def is_mcp_resource(self, uri: str) -> bool:
"""判断给定 URI 是否为已注册 MCP Resource。
Args:
uri: 资源 URI。
Returns:
bool: 是否存在。
"""
return uri in self._resource_to_server
async def call_tool_invocation(self, invocation: ToolInvocation) -> ToolExecutionResult: async def call_tool_invocation(self, invocation: ToolInvocation) -> ToolExecutionResult:
"""执行统一的 MCP 工具调用。 """执行统一的 MCP 工具调用。
@@ -217,8 +433,8 @@ class MCPManager:
error_message=f"MCP 工具 '{tool_name}' 未找到", error_message=f"MCP 工具 '{tool_name}' 未找到",
) )
conn = self._connections[server_name] connection = self._connections[server_name]
return await conn.call_tool(tool_name, invocation.arguments) return await connection.call_tool(tool_name, invocation.arguments)
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> str: async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> str:
"""兼容旧接口,返回 MCP 工具的文本结果。 """兼容旧接口,返回 MCP 工具的文本结果。
@@ -239,36 +455,137 @@ class MCPManager:
) )
return result.get_history_content() return result.get_history_content()
# ──────── 信息展示 ──────── async def get_prompt(
self,
prompt_name: str,
arguments: Optional[dict[str, str]] = None,
) -> MCPPromptResult:
"""读取指定 Prompt 的内容。
Args:
prompt_name: Prompt 名称。
arguments: Prompt 参数字典。
Returns:
MCPPromptResult: Prompt 获取结果。
"""
server_name = self._prompt_to_server.get(prompt_name)
if not server_name or server_name not in self._connections:
raise KeyError(f"MCP Prompt '{prompt_name}' 未找到")
connection = self._connections[server_name]
return await connection.get_prompt(prompt_name, arguments=arguments)
async def read_resource(self, uri: str) -> MCPResourceReadResult:
"""读取指定 Resource 的内容。
Args:
uri: 资源 URI。
Returns:
MCPResourceReadResult: 资源读取结果。
"""
server_name = self._resource_to_server.get(uri)
if not server_name or server_name not in self._connections:
raise KeyError(f"MCP Resource '{uri}' 未找到")
connection = self._connections[server_name]
return await connection.read_resource(uri)
def get_tool_summary(self) -> str: def get_tool_summary(self) -> str:
"""获取所有已注册 MCP 工具的摘要信息。""" """获取所有已注册 MCP 工具的摘要信息。
Returns:
str: 工具摘要文本。
"""
parts: list[str] = [] parts: list[str] = []
for server_name, conn in self._connections.items(): for server_name, connection in self._connections.items():
tool_names = [ tool_names = [
t.name str(tool.name)
for t in conn.tools for tool in connection.tools
if t.name in self._tool_to_server and self._tool_to_server[t.name] == server_name if self._tool_to_server.get(tool.name) == server_name
] ]
if tool_names: if tool_names:
parts.append(f"{server_name}: {', '.join(tool_names)}") parts.append(f"{server_name}: {', '.join(tool_names)}")
return "\n".join(parts) return "\n".join(parts)
def get_feature_summary(self) -> str:
"""获取所有服务器能力的总体摘要。
Returns:
str: 多行摘要文本。
"""
parts: list[str] = []
for server_name, connection in self._connections.items():
tool_count = sum(1 for tool in connection.tools if self._tool_to_server.get(tool.name) == server_name)
prompt_count = sum(
1 for prompt in connection.prompts if self._prompt_to_server.get(prompt.name) == server_name
)
resource_count = sum(
1 for resource in connection.resources if self._resource_to_server.get(resource.uri) == server_name
)
template_count = sum(
1
for resource_template in connection.resource_templates
if self._resource_template_to_server.get(resource_template.uriTemplate) == server_name
)
parts.append(
f"{server_name}: 工具 {tool_count} / Prompt {prompt_count} / "
f"资源 {resource_count} / 模板 {template_count}"
)
return "\n".join(parts)
@property @property
def server_count(self) -> int: def server_count(self) -> int:
"""已连接 MCP 服务器数量。""" """返回已连接 MCP 服务器数量。
Returns:
int: 服务器数量。
"""
return len(self._connections) return len(self._connections)
@property @property
def tool_count(self) -> int: def tool_count(self) -> int:
"""已注册 MCP 工具总数。""" """返回已注册 MCP 工具总数。
Returns:
int: 工具数量。
"""
return len(self._tool_to_server) return len(self._tool_to_server)
# ──────── 生命周期 ──────── @property
def prompt_count(self) -> int:
"""返回已注册 MCP Prompt 总数。
Returns:
int: Prompt 数量。
"""
return len(self._prompt_to_server)
@property
def resource_count(self) -> int:
"""返回已注册 MCP Resource 总数。
Returns:
int: Resource 数量。
"""
return len(self._resource_to_server)
async def close(self) -> None: async def close(self) -> None:
"""关闭所有 MCP 服务器连接。""" """关闭所有 MCP 服务器连接。"""
for conn in self._connections.values():
await conn.close() for connection in self._connections.values():
await connection.close()
self._connections.clear() self._connections.clear()
self._tool_to_server.clear() self._tool_to_server.clear()
self._prompt_to_server.clear()
self._resource_to_server.clear()
self._resource_template_to_server.clear()

418
src/mcp_module/models.py Normal file
View File

@@ -0,0 +1,418 @@
"""MCP 结构化模型与转换工具。
负责在 MCP SDK 原始对象与主程序内部数据模型之间进行转换,
避免连接层和管理器层直接操作大量弱类型字段。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional
from src.core.tooling import ToolAnnotation, ToolContentItem, ToolIcon
def _dump_model_metadata(raw_value: Any) -> dict[str, Any]:
"""提取任意 MCP 模型对象中的元数据字典。
Args:
raw_value: MCP SDK 返回的原始对象。
Returns:
dict[str, Any]: 归一化后的元数据字典。
"""
metadata = getattr(raw_value, "meta", None)
if isinstance(metadata, dict):
return dict(metadata)
return {}
def build_tool_icon(raw_icon: Any) -> ToolIcon:
"""将 MCP 图标对象转换为统一图标模型。
Args:
raw_icon: MCP SDK 返回的图标对象。
Returns:
ToolIcon: 统一图标模型。
"""
sizes_value = getattr(raw_icon, "sizes", None)
sizes = [str(item) for item in sizes_value] if isinstance(sizes_value, list) else []
return ToolIcon(
src=str(getattr(raw_icon, "src", "") or ""),
mime_type=str(getattr(raw_icon, "mimeType", "") or ""),
sizes=sizes,
)
def build_tool_annotation(raw_annotation: Any) -> Optional[ToolAnnotation]:
"""将 MCP 注解对象转换为统一注解模型。
Args:
raw_annotation: MCP SDK 返回的注解对象。
Returns:
Optional[ToolAnnotation]: 统一注解模型;无有效内容时返回 ``None``。
"""
if raw_annotation is None:
return None
audience_value = getattr(raw_annotation, "audience", None)
audience = [str(item) for item in audience_value] if isinstance(audience_value, list) else []
priority_value = getattr(raw_annotation, "priority", None)
priority = float(priority_value) if isinstance(priority_value, int | float) else None
metadata = _dump_model_metadata(raw_annotation)
if not audience and priority is None and not metadata:
return None
return ToolAnnotation(
audience=audience,
priority=priority,
metadata=metadata,
)
def build_tool_content_item(raw_content: Any) -> ToolContentItem:
"""将 MCP 内容块转换为统一工具内容项。
Args:
raw_content: MCP SDK 返回的内容块对象。
Returns:
ToolContentItem: 统一工具内容项。
"""
content_type = str(getattr(raw_content, "type", "") or "").strip().lower()
annotation = build_tool_annotation(getattr(raw_content, "annotations", None))
metadata = _dump_model_metadata(raw_content)
if content_type == "text" or hasattr(raw_content, "text"):
return ToolContentItem(
content_type="text",
text=str(getattr(raw_content, "text", "") or ""),
annotation=annotation,
metadata=metadata,
)
if content_type == "image":
return ToolContentItem(
content_type="image",
data=str(getattr(raw_content, "data", "") or ""),
mime_type=str(getattr(raw_content, "mimeType", "") or ""),
annotation=annotation,
metadata=metadata,
)
if content_type == "audio":
return ToolContentItem(
content_type="audio",
data=str(getattr(raw_content, "data", "") or ""),
mime_type=str(getattr(raw_content, "mimeType", "") or ""),
annotation=annotation,
metadata=metadata,
)
if content_type == "resource_link":
return ToolContentItem(
content_type="resource_link",
uri=str(getattr(raw_content, "uri", "") or ""),
name=str(getattr(raw_content, "name", "") or ""),
description=str(getattr(raw_content, "description", "") or ""),
mime_type=str(getattr(raw_content, "mimeType", "") or ""),
annotation=annotation,
metadata=metadata,
)
if content_type == "resource" or hasattr(raw_content, "resource"):
resource = getattr(raw_content, "resource", None)
resource_metadata = metadata | _dump_model_metadata(resource)
return ToolContentItem(
content_type="resource",
text=str(getattr(resource, "text", "") or ""),
data=str(getattr(resource, "blob", "") or ""),
mime_type=str(getattr(resource, "mimeType", "") or ""),
uri=str(getattr(resource, "uri", "") or ""),
name=str(getattr(resource, "name", "") or ""),
annotation=annotation,
metadata=resource_metadata,
)
if hasattr(raw_content, "data"):
return ToolContentItem(
content_type="binary",
data=str(getattr(raw_content, "data", "") or ""),
mime_type=str(getattr(raw_content, "mimeType", "") or ""),
annotation=annotation,
metadata=metadata,
)
return ToolContentItem(
content_type="unknown",
text=str(raw_content),
annotation=annotation,
metadata=metadata,
)
def build_tool_content_items(raw_contents: list[Any] | None) -> list[ToolContentItem]:
"""批量转换 MCP 内容块列表。
Args:
raw_contents: MCP SDK 返回的内容块列表。
Returns:
list[ToolContentItem]: 转换后的统一内容项列表。
"""
if not raw_contents:
return []
return [build_tool_content_item(item) for item in raw_contents]
@dataclass(slots=True)
class MCPPromptArgumentSpec:
"""MCP Prompt 参数声明。"""
name: str
description: str = ""
required: bool = False
@dataclass(slots=True)
class MCPPromptSpec:
"""MCP Prompt 声明。"""
name: str
server_name: str
title: str = ""
description: str = ""
arguments: list[MCPPromptArgumentSpec] = field(default_factory=list)
icons: list[ToolIcon] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class MCPPromptMessage:
"""MCP Prompt 消息。"""
role: str
content: ToolContentItem
@dataclass(slots=True)
class MCPPromptResult:
"""MCP Prompt 获取结果。"""
prompt_name: str
server_name: str
description: str = ""
messages: list[MCPPromptMessage] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class MCPResourceSpec:
"""MCP Resource 声明。"""
uri: str
server_name: str
name: str
title: str = ""
description: str = ""
mime_type: str = ""
size: int | None = None
icons: list[ToolIcon] = field(default_factory=list)
annotation: ToolAnnotation | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class MCPResourceTemplateSpec:
"""MCP Resource Template 声明。"""
uri_template: str
server_name: str
name: str
title: str = ""
description: str = ""
mime_type: str = ""
icons: list[ToolIcon] = field(default_factory=list)
annotation: ToolAnnotation | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class MCPResourceReadResult:
"""MCP Resource 读取结果。"""
uri: str
server_name: str
contents: list[ToolContentItem] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
def build_prompt_argument_spec(raw_argument: Any) -> MCPPromptArgumentSpec:
"""将 MCP Prompt 参数对象转换为统一结构。
Args:
raw_argument: MCP SDK 返回的 Prompt 参数对象。
Returns:
MCPPromptArgumentSpec: 统一 Prompt 参数结构。
"""
return MCPPromptArgumentSpec(
name=str(getattr(raw_argument, "name", "") or ""),
description=str(getattr(raw_argument, "description", "") or ""),
required=bool(getattr(raw_argument, "required", False)),
)
def build_prompt_spec(raw_prompt: Any, server_name: str) -> MCPPromptSpec:
"""将 MCP Prompt 定义转换为统一结构。
Args:
raw_prompt: MCP SDK 返回的 Prompt 对象。
server_name: Prompt 所属的服务器名称。
Returns:
MCPPromptSpec: 统一 Prompt 定义。
"""
raw_arguments = getattr(raw_prompt, "arguments", None)
raw_icons = getattr(raw_prompt, "icons", None)
return MCPPromptSpec(
name=str(getattr(raw_prompt, "name", "") or ""),
server_name=server_name,
title=str(getattr(raw_prompt, "title", "") or ""),
description=str(getattr(raw_prompt, "description", "") or ""),
arguments=[build_prompt_argument_spec(item) for item in raw_arguments] if isinstance(raw_arguments, list) else [],
icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [],
metadata=_dump_model_metadata(raw_prompt),
)
def build_prompt_result(raw_result: Any, prompt_name: str, server_name: str) -> MCPPromptResult:
"""将 MCP Prompt 获取结果转换为统一结构。
Args:
raw_result: MCP SDK 返回的 Prompt 结果对象。
prompt_name: Prompt 名称。
server_name: Prompt 所属服务器名称。
Returns:
MCPPromptResult: 统一 Prompt 获取结果。
"""
messages: list[MCPPromptMessage] = []
raw_messages = getattr(raw_result, "messages", None)
if isinstance(raw_messages, list):
for raw_message in raw_messages:
messages.append(
MCPPromptMessage(
role=str(getattr(raw_message, "role", "") or ""),
content=build_tool_content_item(getattr(raw_message, "content", None)),
)
)
return MCPPromptResult(
prompt_name=prompt_name,
server_name=server_name,
description=str(getattr(raw_result, "description", "") or ""),
messages=messages,
metadata=_dump_model_metadata(raw_result),
)
def build_resource_spec(raw_resource: Any, server_name: str) -> MCPResourceSpec:
"""将 MCP Resource 定义转换为统一结构。
Args:
raw_resource: MCP SDK 返回的 Resource 对象。
server_name: Resource 所属服务器名称。
Returns:
MCPResourceSpec: 统一 Resource 定义。
"""
raw_icons = getattr(raw_resource, "icons", None)
size_value = getattr(raw_resource, "size", None)
size = int(size_value) if isinstance(size_value, int | float) else None
return MCPResourceSpec(
uri=str(getattr(raw_resource, "uri", "") or ""),
server_name=server_name,
name=str(getattr(raw_resource, "name", "") or ""),
title=str(getattr(raw_resource, "title", "") or ""),
description=str(getattr(raw_resource, "description", "") or ""),
mime_type=str(getattr(raw_resource, "mimeType", "") or ""),
size=size,
icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [],
annotation=build_tool_annotation(getattr(raw_resource, "annotations", None)),
metadata=_dump_model_metadata(raw_resource),
)
def build_resource_template_spec(raw_template: Any, server_name: str) -> MCPResourceTemplateSpec:
"""将 MCP Resource Template 定义转换为统一结构。
Args:
raw_template: MCP SDK 返回的 ResourceTemplate 对象。
server_name: 模板所属服务器名称。
Returns:
MCPResourceTemplateSpec: 统一模板定义。
"""
raw_icons = getattr(raw_template, "icons", None)
return MCPResourceTemplateSpec(
uri_template=str(getattr(raw_template, "uriTemplate", "") or ""),
server_name=server_name,
name=str(getattr(raw_template, "name", "") or ""),
title=str(getattr(raw_template, "title", "") or ""),
description=str(getattr(raw_template, "description", "") or ""),
mime_type=str(getattr(raw_template, "mimeType", "") or ""),
icons=[build_tool_icon(item) for item in raw_icons] if isinstance(raw_icons, list) else [],
annotation=build_tool_annotation(getattr(raw_template, "annotations", None)),
metadata=_dump_model_metadata(raw_template),
)
def build_resource_read_result(raw_result: Any, uri: str, server_name: str) -> MCPResourceReadResult:
"""将 MCP Resource 读取结果转换为统一结构。
Args:
raw_result: MCP SDK 返回的读取结果对象。
uri: 被读取的资源 URI。
server_name: 资源所属服务器名称。
Returns:
MCPResourceReadResult: 统一资源读取结果。
"""
contents: list[ToolContentItem] = []
raw_contents = getattr(raw_result, "contents", None)
if isinstance(raw_contents, list):
for raw_content in raw_contents:
metadata = _dump_model_metadata(raw_content)
contents.append(
ToolContentItem(
content_type="resource",
text=str(getattr(raw_content, "text", "") or ""),
data=str(getattr(raw_content, "blob", "") or ""),
mime_type=str(getattr(raw_content, "mimeType", "") or ""),
uri=str(getattr(raw_content, "uri", "") or uri),
annotation=None,
metadata=metadata,
)
)
return MCPResourceReadResult(
uri=uri,
server_name=server_name,
contents=contents,
metadata=_dump_model_metadata(raw_result),
)