609 lines
21 KiB
Python
609 lines
21 KiB
Python
"""
|
||
MaiSaka - 单个 MCP 服务器连接管理
|
||
封装单个 MCP 服务器的连接生命周期:连接 → 发现能力 → 调用工具/读取资源 → 断开。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from contextlib import AsyncExitStack
|
||
from datetime import timedelta
|
||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||
|
||
import httpx
|
||
|
||
from src.cli.console import console
|
||
from src.core.tooling import ToolExecutionResult
|
||
|
||
from .config import MCPClientRuntimeConfig, MCPServerRuntimeConfig
|
||
from .hooks import MCPHostCallbacks
|
||
from .models import (
|
||
MCPPromptResult,
|
||
MCPResourceReadResult,
|
||
build_prompt_result,
|
||
build_resource_read_result,
|
||
build_tool_content_items,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
||
|
||
# ──────────────────── MCP SDK 可选导入 ────────────────────
|
||
#
|
||
# mcp 是可选依赖。如果未安装,MCP_AVAILABLE = False,
|
||
# MCPManager.from_app_config() 会检测到并返回 None,不影响主程序运行。
|
||
|
||
try:
|
||
from mcp import ClientSession, types as mcp_types
|
||
|
||
try:
|
||
from mcp.client.stdio import StdioServerParameters
|
||
except ImportError:
|
||
from mcp import StdioServerParameters # type: ignore[attr-defined]
|
||
|
||
from mcp.client.stdio import stdio_client
|
||
from mcp.client.streamable_http import streamable_http_client
|
||
|
||
try:
|
||
from mcp.client.sse import sse_client
|
||
|
||
SSE_AVAILABLE = True
|
||
except ImportError:
|
||
SSE_AVAILABLE = False
|
||
sse_client = None # type: ignore[assignment]
|
||
|
||
MCP_AVAILABLE = True
|
||
STREAMABLE_HTTP_AVAILABLE = True
|
||
except ImportError:
|
||
MCP_AVAILABLE = False
|
||
STREAMABLE_HTTP_AVAILABLE = False
|
||
SSE_AVAILABLE = False
|
||
ClientSession = None # type: ignore[assignment,misc]
|
||
StdioServerParameters = None # type: ignore[assignment,misc]
|
||
mcp_types = None # type: ignore[assignment]
|
||
stdio_client = None # type: ignore[assignment]
|
||
streamable_http_client = None # type: ignore[assignment]
|
||
sse_client = None # type: ignore[assignment]
|
||
|
||
|
||
class MCPConnection:
|
||
"""管理单个 MCP 服务器的连接生命周期。"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: MCPServerRuntimeConfig,
|
||
client_config: MCPClientRuntimeConfig,
|
||
host_callbacks: Optional[MCPHostCallbacks] = None,
|
||
) -> None:
|
||
"""初始化单个 MCP 连接。
|
||
|
||
Args:
|
||
config: 当前服务器的运行时配置。
|
||
client_config: MCP 客户端宿主能力运行时配置。
|
||
host_callbacks: 宿主侧能力回调集合。
|
||
"""
|
||
|
||
self.config = config
|
||
self.client_config = client_config
|
||
self.host_callbacks = host_callbacks or MCPHostCallbacks()
|
||
|
||
self.session: Optional[Any] = None
|
||
self.server_capabilities: Optional[Any] = None
|
||
self.tools: list[Any] = []
|
||
self.prompts: list[Any] = []
|
||
self.resources: list[Any] = []
|
||
self.resource_templates: list[Any] = []
|
||
self.protocol_version: str = ""
|
||
|
||
self._http_client: Optional[httpx.AsyncClient] = None
|
||
self._session_id_getter: Optional[Callable[[], str | None]] = None
|
||
self._exit_stack = AsyncExitStack()
|
||
|
||
@property
|
||
def session_id(self) -> str:
|
||
"""返回当前连接协商得到的 MCP 会话标识。
|
||
|
||
Returns:
|
||
str: 当前会话 ID;无会话时返回空字符串。
|
||
"""
|
||
|
||
if self._session_id_getter is None:
|
||
return ""
|
||
return self._session_id_getter() or ""
|
||
|
||
async def connect(self) -> bool:
|
||
"""连接到 MCP 服务器并发现可用能力。
|
||
|
||
Returns:
|
||
bool: `True` 表示连接成功,`False` 表示失败。
|
||
"""
|
||
|
||
if not MCP_AVAILABLE:
|
||
console.print("[warning]⚠️ 未安装 mcp SDK,请运行: pip install mcp[/warning]")
|
||
return False
|
||
|
||
try:
|
||
await self._exit_stack.__aenter__()
|
||
read_stream, write_stream = await self._connect_transport()
|
||
session = await self._create_client_session(read_stream, write_stream)
|
||
self.session = session
|
||
initialize_result = await session.initialize()
|
||
self.server_capabilities = getattr(initialize_result, "capabilities", None)
|
||
self.protocol_version = str(getattr(initialize_result, "protocolVersion", "") or "")
|
||
|
||
await self._load_server_features()
|
||
return True
|
||
|
||
except Exception as exc:
|
||
console.print(f"[warning]⚠️ MCP 服务器 '{self.config.name}' 连接失败: {exc}[/warning]")
|
||
await self.close()
|
||
return False
|
||
|
||
async def _connect_transport(self) -> tuple[Any, Any]:
|
||
"""根据配置建立底层传输连接。
|
||
|
||
Returns:
|
||
tuple[Any, Any]: 读写流对象。
|
||
"""
|
||
|
||
if self.config.transport_type == "stdio":
|
||
return await self._connect_stdio()
|
||
if self.config.transport_type == "streamable_http":
|
||
return await self._connect_streamable_http()
|
||
if self.config.transport_type == "sse":
|
||
return await self._connect_sse()
|
||
|
||
raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}")
|
||
|
||
async def _connect_stdio(self) -> tuple[Any, Any]:
|
||
"""建立 stdio 传输连接。
|
||
|
||
Returns:
|
||
tuple[Any, Any]: 读写流对象。
|
||
"""
|
||
|
||
if StdioServerParameters is None or stdio_client is None:
|
||
raise RuntimeError("当前环境未安装可用的 MCP stdio 客户端")
|
||
if not self.config.command:
|
||
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 stdio command 配置")
|
||
|
||
params = StdioServerParameters(
|
||
command=self.config.command,
|
||
args=self.config.args,
|
||
env=self.config.env,
|
||
)
|
||
return await self._exit_stack.enter_async_context(stdio_client(params))
|
||
|
||
async def _connect_streamable_http(self) -> tuple[Any, Any]:
|
||
"""建立 Streamable HTTP 传输连接。
|
||
|
||
Returns:
|
||
tuple[Any, Any]: 读写流对象。
|
||
"""
|
||
|
||
if not STREAMABLE_HTTP_AVAILABLE or streamable_http_client is None:
|
||
raise ImportError("当前环境未安装可用的 MCP Streamable HTTP 客户端")
|
||
if not self.config.url:
|
||
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 Streamable HTTP url 配置")
|
||
|
||
self._http_client = await self._exit_stack.enter_async_context(self._build_http_client())
|
||
read_stream, write_stream, session_id_getter = await self._exit_stack.enter_async_context(
|
||
streamable_http_client(
|
||
url=self.config.url,
|
||
http_client=self._http_client,
|
||
terminate_on_close=True,
|
||
)
|
||
)
|
||
self._session_id_getter = session_id_getter
|
||
return read_stream, write_stream
|
||
|
||
async def _connect_sse(self) -> tuple[Any, Any]:
|
||
"""建立 SSE 传输连接。
|
||
|
||
Returns:
|
||
tuple[Any, Any]: 读写流对象。
|
||
"""
|
||
|
||
if not SSE_AVAILABLE or sse_client is None:
|
||
raise ImportError("当前环境未安装可用的 MCP SSE 客户端")
|
||
if not self.config.url:
|
||
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置")
|
||
|
||
read_stream, write_stream = await self._exit_stack.enter_async_context(
|
||
sse_client(
|
||
url=self.config.url,
|
||
headers=self.config.build_http_headers(),
|
||
timeout=self.config.http_timeout_seconds,
|
||
sse_read_timeout=self.config.read_timeout_seconds,
|
||
httpx_client_factory=self._build_http_client,
|
||
)
|
||
)
|
||
return read_stream, write_stream
|
||
|
||
def _build_http_client(
|
||
self,
|
||
headers: dict[str, str] | None = None,
|
||
timeout: httpx.Timeout | None = None,
|
||
auth: httpx.Auth | None = None,
|
||
) -> httpx.AsyncClient:
|
||
"""构建 httpx 客户端。
|
||
|
||
Args:
|
||
headers: 合并到配置请求头的额外请求头。
|
||
timeout: 覆盖的 httpx 超时配置。
|
||
auth: 附加认证。
|
||
|
||
Returns:
|
||
httpx.AsyncClient: 预配置的异步 HTTP 客户端。
|
||
"""
|
||
|
||
del auth
|
||
merged_headers = self.config.build_http_headers()
|
||
if headers:
|
||
merged_headers.update(headers)
|
||
return httpx.AsyncClient(
|
||
headers=merged_headers,
|
||
timeout=timeout or httpx.Timeout(self.config.http_timeout_seconds),
|
||
)
|
||
|
||
async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any:
|
||
"""创建并返回 MCP `ClientSession`。
|
||
|
||
Args:
|
||
read_stream: 底层读取流。
|
||
write_stream: 底层写入流。
|
||
|
||
Returns:
|
||
Any: 已初始化的 MCP `ClientSession` 实例。
|
||
"""
|
||
|
||
if ClientSession is None:
|
||
raise RuntimeError("当前环境未安装可用的 MCP ClientSession")
|
||
|
||
list_roots_callback = self._build_list_roots_callback()
|
||
sampling_callback = (
|
||
self.host_callbacks.sampling_callback
|
||
if self.client_config.enable_sampling and self.host_callbacks.sampling_callback is not None
|
||
else None
|
||
)
|
||
elicitation_callback = (
|
||
self.host_callbacks.elicitation_callback
|
||
if self.client_config.enable_elicitation and self.host_callbacks.elicitation_callback is not None
|
||
else None
|
||
)
|
||
logging_callback = cast(Optional["LoggingFnT"], self.host_callbacks.logging_callback)
|
||
message_handler = cast(Optional["MessageHandlerFnT"], self.host_callbacks.message_handler)
|
||
|
||
if self.client_config.enable_sampling and sampling_callback is None:
|
||
console.print(
|
||
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 sampling 配置,但宿主未提供 sampling 回调,当前不会声明该能力[/warning]"
|
||
)
|
||
if self.client_config.enable_elicitation and elicitation_callback is None:
|
||
console.print(
|
||
f"[warning]⚠️ MCP 服务器 '{self.config.name}' 已启用 elicitation 配置,但宿主未提供 elicitation 回调,当前不会声明该能力[/warning]"
|
||
)
|
||
|
||
session = await self._exit_stack.enter_async_context(
|
||
ClientSession(
|
||
read_stream,
|
||
write_stream,
|
||
read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds),
|
||
sampling_callback=cast(Optional["SamplingFnT"], sampling_callback),
|
||
elicitation_callback=cast(Optional["ElicitationFnT"], elicitation_callback),
|
||
list_roots_callback=cast(Optional["ListRootsFnT"], list_roots_callback),
|
||
logging_callback=logging_callback,
|
||
message_handler=message_handler,
|
||
client_info=self._build_client_info(),
|
||
sampling_capabilities=self._build_sampling_capabilities(sampling_callback),
|
||
)
|
||
)
|
||
return session
|
||
|
||
def _build_client_info(self) -> Any:
|
||
"""构建 MCP 客户端实现信息。
|
||
|
||
Returns:
|
||
Any: MCP SDK 的 `Implementation` 对象。
|
||
"""
|
||
|
||
if mcp_types is None:
|
||
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
|
||
|
||
return mcp_types.Implementation(
|
||
name=self.client_config.client_name,
|
||
version=self.client_config.client_version,
|
||
)
|
||
|
||
def _build_sampling_capabilities(self, sampling_callback: Any) -> Any | None:
|
||
"""构建 Sampling 能力声明。
|
||
|
||
Args:
|
||
sampling_callback: 当前宿主侧的 Sampling 回调。
|
||
|
||
Returns:
|
||
Any | None: Sampling 能力对象;未启用时返回 ``None``。
|
||
"""
|
||
|
||
if mcp_types is None:
|
||
return None
|
||
if sampling_callback is None:
|
||
return None
|
||
|
||
context_capability = (
|
||
mcp_types.SamplingContextCapability()
|
||
if self.client_config.sampling_include_context_support
|
||
else None
|
||
)
|
||
tools_capability = (
|
||
mcp_types.SamplingToolsCapability()
|
||
if self.client_config.sampling_tool_support
|
||
else None
|
||
)
|
||
return mcp_types.SamplingCapability(
|
||
context=context_capability,
|
||
tools=tools_capability,
|
||
)
|
||
|
||
def _build_list_roots_callback(self) -> Any | None:
|
||
"""构建 Roots 列表回调。
|
||
|
||
Returns:
|
||
Any | None: 符合 MCP SDK 要求的回调;未启用时返回 ``None``。
|
||
"""
|
||
|
||
if mcp_types is None:
|
||
return None
|
||
if not self.client_config.enable_roots or not self.client_config.roots:
|
||
return None
|
||
|
||
async def _list_roots(context: Any) -> Any:
|
||
"""返回当前客户端声明的 Roots 列表。
|
||
|
||
Args:
|
||
context: MCP 请求上下文。
|
||
|
||
Returns:
|
||
Any: MCP `ListRootsResult` 对象。
|
||
"""
|
||
|
||
del context
|
||
types_module = mcp_types
|
||
if types_module is None:
|
||
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
|
||
roots = [
|
||
types_module.Root(uri=cast(Any, root.uri), name=root.name or None)
|
||
for root in self.client_config.roots
|
||
]
|
||
return types_module.ListRootsResult(roots=roots)
|
||
|
||
return _list_roots
|
||
|
||
async def _load_server_features(self) -> None:
|
||
"""根据服务端能力声明加载工具、Prompt 与 Resource。"""
|
||
|
||
self.tools = await self._list_tools() if self.supports_tools() else []
|
||
self.prompts = await self._list_prompts() if self.supports_prompts() else []
|
||
self.resources = await self._list_resources() if self.supports_resources() else []
|
||
self.resource_templates = (
|
||
await self._list_resource_templates() if self.supports_resources() else []
|
||
)
|
||
|
||
def supports_tools(self) -> bool:
|
||
"""判断服务端是否声明支持 Tools。
|
||
|
||
Returns:
|
||
bool: 是否支持 Tools。
|
||
"""
|
||
|
||
return bool(self.server_capabilities is not None and getattr(self.server_capabilities, "tools", None) is not None)
|
||
|
||
def supports_prompts(self) -> bool:
|
||
"""判断服务端是否声明支持 Prompts。
|
||
|
||
Returns:
|
||
bool: 是否支持 Prompts。
|
||
"""
|
||
|
||
return bool(
|
||
self.server_capabilities is not None and getattr(self.server_capabilities, "prompts", None) is not None
|
||
)
|
||
|
||
def supports_resources(self) -> bool:
|
||
"""判断服务端是否声明支持 Resources。
|
||
|
||
Returns:
|
||
bool: 是否支持 Resources。
|
||
"""
|
||
|
||
return bool(
|
||
self.server_capabilities is not None and getattr(self.server_capabilities, "resources", None) is not None
|
||
)
|
||
|
||
async def _list_tools(self) -> list[Any]:
|
||
"""分页加载服务端暴露的全部工具。
|
||
|
||
Returns:
|
||
list[Any]: MCP SDK 的原始工具对象列表。
|
||
"""
|
||
|
||
if self.session is None:
|
||
return []
|
||
|
||
tools: list[Any] = []
|
||
cursor: Optional[str] = None
|
||
while True:
|
||
result = await self.session.list_tools(cursor=cursor)
|
||
tools.extend(list(getattr(result, "tools", []) or []))
|
||
cursor = getattr(result, "nextCursor", None)
|
||
if not cursor:
|
||
break
|
||
return tools
|
||
|
||
async def _list_prompts(self) -> list[Any]:
|
||
"""分页加载服务端暴露的全部 Prompt。
|
||
|
||
Returns:
|
||
list[Any]: MCP SDK 的原始 Prompt 对象列表。
|
||
"""
|
||
|
||
if self.session is None:
|
||
return []
|
||
|
||
prompts: list[Any] = []
|
||
cursor: Optional[str] = None
|
||
while True:
|
||
result = await self.session.list_prompts(cursor=cursor)
|
||
prompts.extend(list(getattr(result, "prompts", []) or []))
|
||
cursor = getattr(result, "nextCursor", None)
|
||
if not cursor:
|
||
break
|
||
return prompts
|
||
|
||
async def _list_resources(self) -> list[Any]:
|
||
"""分页加载服务端暴露的全部 Resource。
|
||
|
||
Returns:
|
||
list[Any]: MCP SDK 的原始 Resource 对象列表。
|
||
"""
|
||
|
||
if self.session is None:
|
||
return []
|
||
|
||
resources: list[Any] = []
|
||
cursor: Optional[str] = None
|
||
while True:
|
||
result = await self.session.list_resources(cursor=cursor)
|
||
resources.extend(list(getattr(result, "resources", []) or []))
|
||
cursor = getattr(result, "nextCursor", None)
|
||
if not cursor:
|
||
break
|
||
return resources
|
||
|
||
async def _list_resource_templates(self) -> list[Any]:
|
||
"""分页加载服务端暴露的全部 Resource Template。
|
||
|
||
Returns:
|
||
list[Any]: MCP SDK 的原始 Resource Template 对象列表。
|
||
"""
|
||
|
||
if self.session is None:
|
||
return []
|
||
|
||
resource_templates: list[Any] = []
|
||
cursor: Optional[str] = None
|
||
while True:
|
||
result = await self.session.list_resource_templates(cursor=cursor)
|
||
resource_templates.extend(list(getattr(result, "resourceTemplates", []) or []))
|
||
cursor = getattr(result, "nextCursor", None)
|
||
if not cursor:
|
||
break
|
||
return resource_templates
|
||
|
||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> ToolExecutionResult:
|
||
"""调用 MCP 工具并返回统一执行结果。
|
||
|
||
Args:
|
||
tool_name: 工具名称。
|
||
arguments: 工具参数字典。
|
||
|
||
Returns:
|
||
ToolExecutionResult: 统一执行结果。
|
||
"""
|
||
|
||
if self.session is None:
|
||
return ToolExecutionResult(
|
||
tool_name=tool_name,
|
||
success=False,
|
||
error_message=f"MCP 服务器 '{self.config.name}' 未连接",
|
||
metadata={"server_name": self.config.name},
|
||
)
|
||
|
||
try:
|
||
result = await self.session.call_tool(
|
||
tool_name,
|
||
arguments=arguments,
|
||
read_timeout_seconds=timedelta(seconds=self.config.read_timeout_seconds),
|
||
)
|
||
except Exception as exc:
|
||
return ToolExecutionResult(
|
||
tool_name=tool_name,
|
||
success=False,
|
||
error_message=f"MCP 工具 '{tool_name}' 执行失败: {exc}",
|
||
metadata={"server_name": self.config.name},
|
||
)
|
||
|
||
content_items = build_tool_content_items(list(getattr(result, "content", []) or []))
|
||
text_parts = [item.text.strip() for item in content_items if item.content_type == "text" and item.text.strip()]
|
||
structured_content = getattr(result, "structuredContent", None)
|
||
is_error = bool(getattr(result, "isError", False))
|
||
history_content = "\n".join(text_parts).strip()
|
||
error_message = history_content if is_error else ""
|
||
|
||
return ToolExecutionResult(
|
||
tool_name=tool_name,
|
||
success=not is_error,
|
||
content=history_content if not is_error else "",
|
||
error_message=error_message,
|
||
structured_content=structured_content,
|
||
content_items=content_items,
|
||
metadata={
|
||
"server_name": self.config.name,
|
||
"protocol_version": self.protocol_version,
|
||
"session_id": self.session_id,
|
||
},
|
||
)
|
||
|
||
async def get_prompt(
|
||
self,
|
||
prompt_name: str,
|
||
arguments: Optional[dict[str, str]] = None,
|
||
) -> MCPPromptResult:
|
||
"""读取指定 MCP Prompt 的内容。
|
||
|
||
Args:
|
||
prompt_name: Prompt 名称。
|
||
arguments: Prompt 参数字典。
|
||
|
||
Returns:
|
||
MCPPromptResult: 统一 Prompt 结果。
|
||
"""
|
||
|
||
if self.session is None:
|
||
raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接")
|
||
|
||
result = await self.session.get_prompt(prompt_name, arguments=arguments)
|
||
return build_prompt_result(result, prompt_name=prompt_name, server_name=self.config.name)
|
||
|
||
async def read_resource(self, uri: str) -> MCPResourceReadResult:
|
||
"""读取指定 MCP Resource 的内容。
|
||
|
||
Args:
|
||
uri: 资源 URI。
|
||
|
||
Returns:
|
||
MCPResourceReadResult: 统一资源读取结果。
|
||
"""
|
||
|
||
if self.session is None:
|
||
raise RuntimeError(f"MCP 服务器 '{self.config.name}' 未连接")
|
||
|
||
result = await self.session.read_resource(uri)
|
||
return build_resource_read_result(result, uri=uri, server_name=self.config.name)
|
||
|
||
async def close(self) -> None:
|
||
"""关闭连接并释放资源。"""
|
||
|
||
try:
|
||
await self._exit_stack.aclose()
|
||
except Exception:
|
||
pass
|
||
|
||
self.session = None
|
||
self.server_capabilities = None
|
||
self.tools = []
|
||
self.prompts = []
|
||
self.resources = []
|
||
self.resource_templates = []
|
||
self.protocol_version = ""
|
||
self._http_client = None
|
||
self._session_id_getter = None
|