Files
mai-bot/src/mcp_module/connection.py
2026-05-06 13:02:17 +08:00

609 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MaiSaka - 单个 MCP 服务器连接管理
封装单个 MCP 服务器的连接生命周期:连接 → 发现能力 → 调用工具/读取资源 → 断开。
"""
from __future__ import annotations
from contextlib import AsyncExitStack
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
import httpx
from src.cli.console import console
from src.core.tooling import ToolExecutionResult
from .config import MCPClientRuntimeConfig, MCPServerRuntimeConfig
from .hooks import MCPHostCallbacks
from .models import (
MCPPromptResult,
MCPResourceReadResult,
build_prompt_result,
build_resource_read_result,
build_tool_content_items,
)
if TYPE_CHECKING:
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
# ──────────────────── MCP SDK 可选导入 ────────────────────
#
# mcp 是可选依赖。如果未安装MCP_AVAILABLE = False
# MCPManager.from_app_config() 会检测到并返回 None不影响主程序运行。
try:
from mcp import ClientSession, types as mcp_types
try:
from mcp.client.stdio import StdioServerParameters
except ImportError:
from mcp import StdioServerParameters # type: ignore[attr-defined]
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
try:
from mcp.client.sse import sse_client
SSE_AVAILABLE = True
except ImportError:
SSE_AVAILABLE = False
sse_client = None # type: ignore[assignment]
MCP_AVAILABLE = True
STREAMABLE_HTTP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
STREAMABLE_HTTP_AVAILABLE = False
SSE_AVAILABLE = False
ClientSession = None # type: ignore[assignment,misc]
StdioServerParameters = None # type: ignore[assignment,misc]
mcp_types = None # type: ignore[assignment]
stdio_client = None # type: ignore[assignment]
streamable_http_client = None # type: ignore[assignment]
sse_client = None # type: ignore[assignment]
class MCPConnection:
"""管理单个 MCP 服务器的连接生命周期。"""
def __init__(
self,
config: MCPServerRuntimeConfig,
client_config: MCPClientRuntimeConfig,
host_callbacks: Optional[MCPHostCallbacks] = None,
) -> None:
"""初始化单个 MCP 连接。
Args:
config: 当前服务器的运行时配置。
client_config: MCP 客户端宿主能力运行时配置。
host_callbacks: 宿主侧能力回调集合。
"""
self.config = config
self.client_config = client_config
self.host_callbacks = host_callbacks or MCPHostCallbacks()
self.session: Optional[Any] = None
self.server_capabilities: Optional[Any] = None
self.tools: list[Any] = []
self.prompts: list[Any] = []
self.resources: list[Any] = []
self.resource_templates: list[Any] = []
self.protocol_version: str = ""
self._http_client: Optional[httpx.AsyncClient] = None
self._session_id_getter: Optional[Callable[[], str | None]] = None
self._exit_stack = AsyncExitStack()
@property
def session_id(self) -> str:
"""返回当前连接协商得到的 MCP 会话标识。
Returns:
str: 当前会话 ID无会话时返回空字符串。
"""
if self._session_id_getter is None:
return ""
return self._session_id_getter() or ""
async def connect(self) -> bool:
"""连接到 MCP 服务器并发现可用能力。
Returns:
bool: `True` 表示连接成功,`False` 表示失败。
"""
if not MCP_AVAILABLE:
console.print("[warning]⚠️ 未安装 mcp SDK请运行: pip install mcp[/warning]")
return False
try:
await self._exit_stack.__aenter__()
read_stream, write_stream = await self._connect_transport()
session = await self._create_client_session(read_stream, write_stream)
self.session = session
initialize_result = await session.initialize()
self.server_capabilities = getattr(initialize_result, "capabilities", None)
self.protocol_version = str(getattr(initialize_result, "protocolVersion", "") or "")
await self._load_server_features()
return True
except Exception as exc:
console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {exc}[/warning]")
await self.close()
return False
async def _connect_transport(self) -> tuple[Any, Any]:
"""根据配置建立底层传输连接。
Returns:
tuple[Any, Any]: 读写流对象。
"""
if self.config.transport_type == "stdio":
return await self._connect_stdio()
if self.config.transport_type == "streamable_http":
return await self._connect_streamable_http()
if self.config.transport_type == "sse":
return await self._connect_sse()
raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}")
async def _connect_stdio(self) -> tuple[Any, Any]:
"""建立 stdio 传输连接。
Returns:
tuple[Any, Any]: 读写流对象。
"""
if StdioServerParameters is None or stdio_client is None:
raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端")
if not self.config.command:
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 stdio command 配置")
params = StdioServerParameters(
command=self.config.command,
args=self.config.args,
env=self.config.env,
)
return await self._exit_stack.enter_async_context(stdio_client(params))
async def _connect_streamable_http(self) -> tuple[Any, Any]:
"""建立 Streamable HTTP 传输连接。
Returns:
tuple[Any, Any]: 读写流对象。
"""
if not STREAMABLE_HTTP_AVAILABLE or streamable_http_client is None:
raise ImportError("当前环境未安装可用的 MCP Streamable HTTP 客户端")
if not self.config.url:
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 Streamable HTTP url 配置")
self._http_client = await self._exit_stack.enter_async_context(self._build_http_client())
read_stream, write_stream, session_id_getter = await self._exit_stack.enter_async_context(
streamable_http_client(
url=self.config.url,
http_client=self._http_client,
terminate_on_close=True,
)
)
self._session_id_getter = session_id_getter
return read_stream, write_stream
async def _connect_sse(self) -> tuple[Any, Any]:
"""建立 SSE 传输连接。
Returns:
tuple[Any, Any]: 读写流对象。
"""
if not SSE_AVAILABLE or sse_client is None:
raise ImportError("当前环境未安装可用的 MCP SSE 客户端")
if not self.config.url:
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置")
read_stream, write_stream = await self._exit_stack.enter_async_context(
sse_client(
url=self.config.url,
headers=self.config.build_http_headers(),
timeout=self.config.http_timeout_seconds,
sse_read_timeout=self.config.read_timeout_seconds,
httpx_client_factory=self._build_http_client,
)
)
return read_stream, write_stream
def _build_http_client(
self,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
"""构建 httpx 客户端。
Args:
headers: 合并到配置请求头的额外请求头。
timeout: 覆盖的 httpx 超时配置。
auth: 附加认证。
Returns:
httpx.AsyncClient: 预配置的异步 HTTP 客户端。
"""
del auth
merged_headers = self.config.build_http_headers()
if headers:
merged_headers.update(headers)
return httpx.AsyncClient(
headers=merged_headers,
timeout=timeout or httpx.Timeout(self.config.http_timeout_seconds),
)
async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any:
"""创建并返回 MCP `ClientSession`。
Args:
read_stream: 底层读取流。
write_stream: 底层写入流。
Returns:
Any: 已初始化的 MCP `ClientSession` 实例。
"""
if ClientSession is None:
raise RuntimeError("当前环境未安装可用的 MCP ClientSession")
list_roots_callback = self._build_list_roots_callback()
sampling_callback = (
self.host_callbacks.sampling_callback
if self.client_config.enable_sampling and self.host_callbacks.sampling_callback is not None
else None
)
elicitation_callback = (
self.host_callbacks.elicitation_callback
if self.client_config.enable_elicitation and self.host_callbacks.elicitation_callback is not None
else None
)
logging_callback = cast(Optional["LoggingFnT"], self.host_callbacks.logging_callback)
message_handler = cast(Optional["MessageHandlerFnT"], self.host_callbacks.message_handler)
if self.client_config.enable_sampling and sampling_callback is None:
console.print(
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 sampling 配置,但宿主未提供 sampling 回调,当前不会声明该能力[/warning]"
)
if self.client_config.enable_elicitation and elicitation_callback is None:
console.print(
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 elicitation 配置,但宿主未提供 elicitation 回调,当前不会声明该能力[/warning]"
)
session = await self._exit_stack.enter_async_context(
ClientSession(
read_stream,
write_stream,
read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds),
sampling_callback=cast(Optional["SamplingFnT"], sampling_callback),
elicitation_callback=cast(Optional["ElicitationFnT"], elicitation_callback),
list_roots_callback=cast(Optional["ListRootsFnT"], list_roots_callback),
logging_callback=logging_callback,
message_handler=message_handler,
client_info=self._build_client_info(),
sampling_capabilities=self._build_sampling_capabilities(sampling_callback),
)
)
return session
def _build_client_info(self) -> Any:
"""构建 MCP 客户端实现信息。
Returns:
Any: MCP SDK 的 `Implementation` 对象。
"""
if mcp_types is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
return mcp_types.Implementation(
name=self.client_config.client_name,
version=self.client_config.client_version,
)
def _build_sampling_capabilities(self, sampling_callback: Any) -> Any | None:
"""构建 Sampling 能力声明。
Args:
sampling_callback: 当前宿主侧的 Sampling 回调。
Returns:
Any | None: Sampling 能力对象;未启用时返回 ``None``。
"""
if mcp_types is None:
return None
if sampling_callback is None:
return None
context_capability = (
mcp_types.SamplingContextCapability()
if self.client_config.sampling_include_context_support
else None
)
tools_capability = (
mcp_types.SamplingToolsCapability()
if self.client_config.sampling_tool_support
else None
)
return mcp_types.SamplingCapability(
context=context_capability,
tools=tools_capability,
)
def _build_list_roots_callback(self) -> Any | None:
"""构建 Roots 列表回调。
Returns:
Any | None: 符合 MCP SDK 要求的回调;未启用时返回 ``None``。
"""
if mcp_types is None:
return None
if not self.client_config.enable_roots or not self.client_config.roots:
return None
async def _list_roots(context: Any) -> Any:
"""返回当前客户端声明的 Roots 列表。
Args:
context: MCP 请求上下文。
Returns:
Any: MCP `ListRootsResult` 对象。
"""
del context
types_module = mcp_types
if types_module is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
roots = [
types_module.Root(uri=cast(Any, root.uri), name=root.name or None)
for root in self.client_config.roots
]
return types_module.ListRootsResult(roots=roots)
return _list_roots
async def _load_server_features(self) -> None:
"""根据服务端能力声明加载工具、Prompt 与 Resource。"""
self.tools = await self._list_tools() if self.supports_tools() else []
self.prompts = await self._list_prompts() if self.supports_prompts() else []
self.resources = await self._list_resources() if self.supports_resources() else []
self.resource_templates = (
await self._list_resource_templates() if self.supports_resources() else []
)
def supports_tools(self) -> bool:
"""判断服务端是否声明支持 Tools。
Returns:
bool: 是否支持 Tools。
"""
return bool(self.server_capabilities is not None and getattr(self.server_capabilities, "tools", None) is not None)
def supports_prompts(self) -> bool:
"""判断服务端是否声明支持 Prompts。
Returns:
bool: 是否支持 Prompts。
"""
return bool(
self.server_capabilities is not None and getattr(self.server_capabilities, "prompts", None) is not None
)
def supports_resources(self) -> bool:
"""判断服务端是否声明支持 Resources。
Returns:
bool: 是否支持 Resources。
"""
return bool(
self.server_capabilities is not None and getattr(self.server_capabilities, "resources", None) is not None
)
async def _list_tools(self) -> list[Any]:
"""分页加载服务端暴露的全部工具。
Returns:
list[Any]: MCP SDK 的原始工具对象列表。
"""
if self.session is None:
return []
tools: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_tools(cursor=cursor)
tools.extend(list(getattr(result, "tools", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return tools
async def _list_prompts(self) -> list[Any]:
"""分页加载服务端暴露的全部 Prompt。
Returns:
list[Any]: MCP SDK 的原始 Prompt 对象列表。
"""
if self.session is None:
return []
prompts: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_prompts(cursor=cursor)
prompts.extend(list(getattr(result, "prompts", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return prompts
async def _list_resources(self) -> list[Any]:
"""分页加载服务端暴露的全部 Resource。
Returns:
list[Any]: MCP SDK 的原始 Resource 对象列表。
"""
if self.session is None:
return []
resources: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_resources(cursor=cursor)
resources.extend(list(getattr(result, "resources", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return resources
async def _list_resource_templates(self) -> list[Any]:
"""分页加载服务端暴露的全部 Resource Template。
Returns:
list[Any]: MCP SDK 的原始 Resource Template 对象列表。
"""
if self.session is None:
return []
resource_templates: list[Any] = []
cursor: Optional[str] = None
while True:
result = await self.session.list_resource_templates(cursor=cursor)
resource_templates.extend(list(getattr(result, "resourceTemplates", []) or []))
cursor = getattr(result, "nextCursor", None)
if not cursor:
break
return resource_templates
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult:
"""调用 MCP 工具并返回统一执行结果。
Args:
tool_name: 工具名称。
arguments: 工具参数字典。
Returns:
ToolExecutionResult: 统一执行结果。
"""
if self.session is None:
return ToolExecutionResult(
tool_name=tool_name,
success=False,
error_message=f"MCP 服务器 '{self.config.name}' 未连接",
metadata={"server_name": self.config.name},
)
try:
result = await self.session.call_tool(
tool_name,
arguments=arguments,
read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds),
)
except Exception as exc:
return ToolExecutionResult(
tool_name=tool_name,
success=False,
error_message=f"MCP 工具 '{tool_name}' 执行失败: {exc}",
metadata={"server_name": self.config.name},
)
content_items = build_tool_content_items(list(getattr(result, "content", []) or []))
text_parts = [item.text.strip() for item in content_items if item.content_type == "text" and item.text.strip()]
structured_content = getattr(result, "structuredContent", None)
is_error = bool(getattr(result, "isError", False))
history_content = "\n".join(text_parts).strip()
error_message = history_content if is_error else ""
return ToolExecutionResult(
tool_name=tool_name,
success=not is_error,
content=history_content if not is_error else "",
error_message=error_message,
structured_content=structured_content,
content_items=content_items,
metadata={
"server_name": self.config.name,
"protocol_version": self.protocol_version,
"session_id": self.session_id,
},
)
async def get_prompt(
self,
prompt_name: str,
arguments: Optional[dict[str, str]] = None,
) -> MCPPromptResult:
"""读取指定 MCP Prompt 的内容。
Args:
prompt_name: Prompt 名称。
arguments: Prompt 参数字典。
Returns:
MCPPromptResult: 统一 Prompt 结果。
"""
if self.session is None:
raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接")
result = await self.session.get_prompt(prompt_name, arguments=arguments)
return build_prompt_result(result, prompt_name=prompt_name, server_name=self.config.name)
async def read_resource(self, uri: str) -> MCPResourceReadResult:
"""读取指定 MCP Resource 的内容。
Args:
uri: 资源 URI。
Returns:
MCPResourceReadResult: 统一资源读取结果。
"""
if self.session is None:
raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接")
result = await self.session.read_resource(uri)
return build_resource_read_result(result, uri=uri, server_name=self.config.name)
async def close(self) -> None:
"""关闭连接并释放资源。"""
try:
await self._exit_stack.aclose()
except Exception:
pass
self.session = None
self.server_capabilities = None
self.tools = []
self.prompts = []
self.resources = []
self.resource_templates = []
self.protocol_version = ""
self._http_client = None
self._session_id_getter = None