feat(mcp_module): add hooks, host LLM bridge, and models for MCP integration

- Introduced MCPHostCallbacks for optional host capabilities like sampling and logging.
- Implemented MCPHostLLMBridge to handle MCP Sampling requests and bridge to LLM service.
- Created models for structured data conversion between MCP SDK and internal data models, including tool content items, prompts, and resources.
- Enhanced error handling and logging for better traceability during sampling operations.
This commit is contained in:
DrSmoothl
2026-03-30 23:51:05 +08:00
parent abb1d071b1
commit 42dbd5462a
11 changed files with 2332 additions and 170 deletions

View File

@@ -0,0 +1,597 @@
"""MCP 宿主侧大模型桥接服务。"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
import json
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMResponseResult
from src.common.logger import get_logger
from src.core.tooling import build_tool_detailed_description
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
from src.services.llm_service import LLMServiceClient
from .hooks import MCPHostCallbacks
from .models import build_tool_content_items
if TYPE_CHECKING:
from src.llm_models.model_client.base_client import BaseClient
try:
from mcp import types as mcp_types
MCP_TYPES_AVAILABLE = True
except ImportError:
mcp_types = None # type: ignore[assignment]
MCP_TYPES_AVAILABLE = False
logger = get_logger("mcp_host_llm_bridge")
class MCPHostLLMBridge:
"""将 MCP Sampling 请求桥接到主程序大模型调用链。"""
def __init__(self, sampling_task_name: str = "planner") -> None:
"""初始化 MCP 宿主侧大模型桥接服务。
Args:
sampling_task_name: 执行 Sampling 请求时使用的模型任务名。
"""
self._sampling_task_name = sampling_task_name.strip() or "planner"
self._sampling_client = LLMServiceClient(
task_name=self._sampling_task_name,
request_type="mcp_sampling",
)
def build_callbacks(self) -> MCPHostCallbacks:
"""构建可注入给 MCP 连接层的宿主回调集合。
Returns:
MCPHostCallbacks: 包含 Sampling 回调的宿主回调集合。
"""
return MCPHostCallbacks(
sampling_callback=self.handle_sampling_request,
)
async def handle_sampling_request(self, context: Any, params: Any) -> Any:
"""处理服务端发起的 MCP Sampling 请求。
Args:
context: MCP SDK 传入的请求上下文。
params: `sampling/createMessage` 请求参数。
Returns:
Any: MCP `CreateMessageResult`、`CreateMessageResultWithTools` 或 `ErrorData`。
"""
del context
if not MCP_TYPES_AVAILABLE or mcp_types is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
try:
tool_choice_mode = self._get_tool_choice_mode(params)
tool_definitions = self._build_tool_definitions(
raw_tools=getattr(params, "tools", None),
tool_choice_mode=tool_choice_mode,
)
message_factory = self._build_message_factory(
raw_messages=list(getattr(params, "messages", []) or []),
system_prompt=self._build_system_prompt(
raw_system_prompt=str(getattr(params, "systemPrompt", "") or ""),
tool_choice_mode=tool_choice_mode,
tool_definitions=tool_definitions,
),
)
generation_result = await self._sampling_client.generate_response_with_messages(
message_factory=message_factory,
options=LLMGenerationOptions(
temperature=self._coerce_float(getattr(params, "temperature", None)),
max_tokens=int(getattr(params, "maxTokens", 1024) or 1024),
tool_options=tool_definitions,
),
)
if tool_choice_mode == "required" and tool_definitions and not generation_result.tool_calls:
return mcp_types.ErrorData(
code=mcp_types.INTERNAL_ERROR,
message="Sampling 要求必须调用工具,但模型未返回任何工具调用",
)
return self._build_sampling_result(
generation_result=generation_result,
tools_enabled=bool(tool_definitions),
)
except Exception as exc:
logger.exception(f"MCP Sampling 调用失败: {exc}")
return mcp_types.ErrorData(
code=mcp_types.INTERNAL_ERROR,
message=f"MCP Sampling 调用失败: {exc}",
)
@staticmethod
def _coerce_float(raw_value: Any) -> float | None:
"""将任意原始值转换为浮点数。
Args:
raw_value: 原始输入值。
Returns:
float | None: 转换后的浮点数;无法转换时返回 ``None``。
"""
if raw_value is None:
return None
if isinstance(raw_value, int | float):
return float(raw_value)
return None
@staticmethod
def _get_tool_choice_mode(params: Any) -> str:
"""读取 Sampling 请求中的工具选择模式。
Args:
params: Sampling 请求参数对象。
Returns:
str: `auto`、`required` 或 `none`;缺省时返回 `auto`。
"""
tool_choice = getattr(params, "toolChoice", None)
mode = str(getattr(tool_choice, "mode", "") or "").strip().lower()
if mode in {"required", "none"}:
return mode
return "auto"
def _build_system_prompt(
self,
raw_system_prompt: str,
tool_choice_mode: str,
tool_definitions: list[ToolDefinitionInput] | None,
) -> str:
"""构建发送给主程序大模型的系统提示词。
Args:
raw_system_prompt: 服务端请求中的系统提示词。
tool_choice_mode: 当前工具选择模式。
tool_definitions: 参与本次 Sampling 的工具定义。
Returns:
str: 最终系统提示词。
"""
prompt_parts: list[str] = []
if raw_system_prompt.strip():
prompt_parts.append(raw_system_prompt.strip())
if tool_choice_mode == "required" and tool_definitions:
prompt_parts.append("本轮回答必须至少调用一个工具;不要直接结束回答。")
return "\n\n".join(part for part in prompt_parts if part).strip()
def _build_message_factory(
self,
raw_messages: list[Any],
system_prompt: str,
) -> Any:
"""构建 MCP Sampling 使用的消息工厂。
Args:
raw_messages: MCP Sampling 原始消息列表。
system_prompt: 规范化后的系统提示词。
Returns:
Any: 供 `LLMServiceClient` 使用的消息工厂。
"""
def _message_factory(client: "BaseClient") -> list[Message]:
"""延迟构建内部消息列表。
Args:
client: 当前被选中的底层模型客户端。
Returns:
list[Message]: 内部统一消息列表。
"""
messages: list[Message] = []
if system_prompt.strip():
messages.append(
MessageBuilder()
.set_role(RoleType.System)
.add_text_content(system_prompt.strip())
.build()
)
for raw_message in raw_messages:
messages.extend(self._convert_sampling_message(raw_message, client))
return messages
return _message_factory
def _convert_sampling_message(self, raw_message: Any, client: "BaseClient") -> list[Message]:
"""将单条 MCP Sampling 消息转换为内部消息列表。
Args:
raw_message: MCP Sampling 原始消息对象。
client: 当前底层模型客户端。
Returns:
list[Message]: 转换后的内部消息列表。
"""
role = str(getattr(raw_message, "role", "") or "").strip().lower()
content_blocks = self._get_content_blocks(getattr(raw_message, "content", None))
if role == "assistant":
assistant_message = self._build_assistant_message(content_blocks, client)
return [assistant_message] if assistant_message is not None else []
if role == "user":
return self._build_user_messages(content_blocks, client)
raise ValueError(f"不支持的 MCP Sampling 消息角色: {role}")
@staticmethod
def _get_content_blocks(raw_content: Any) -> list[Any]:
"""将 MCP Sampling 消息内容统一为列表。
Args:
raw_content: 原始内容字段。
Returns:
list[Any]: 统一后的内容块列表。
"""
if raw_content is None:
return []
if isinstance(raw_content, list):
return list(raw_content)
return [raw_content]
def _build_assistant_message(self, content_blocks: list[Any], client: "BaseClient") -> Optional[Message]:
"""构建内部 assistant 消息。
Args:
content_blocks: MCP assistant 内容块列表。
client: 当前底层模型客户端。
Returns:
Optional[Message]: 转换后的内部 assistant 消息;无有效内容时返回 ``None``。
"""
message_builder = MessageBuilder().set_role(RoleType.Assistant)
tool_calls: list[ToolCall] = []
has_visible_content = False
for content_block in content_blocks:
content_type = self._get_content_type(content_block)
if content_type == "tool_use":
tool_calls.append(
ToolCall(
call_id=str(getattr(content_block, "id", "") or ""),
func_name=str(getattr(content_block, "name", "") or ""),
args=self._normalize_tool_call_arguments(getattr(content_block, "input", None)),
)
)
continue
has_visible_content = self._append_sampling_content_to_builder(
message_builder=message_builder,
content_block=content_block,
client=client,
) or has_visible_content
if tool_calls:
message_builder.set_tool_calls(tool_calls)
if not has_visible_content and not tool_calls:
return None
return message_builder.build()
def _build_user_messages(self, content_blocks: list[Any], client: "BaseClient") -> list[Message]:
"""构建内部 user/tool 消息序列。
Args:
content_blocks: MCP user 内容块列表。
client: 当前底层模型客户端。
Returns:
list[Message]: 转换后的内部消息序列。
"""
messages: list[Message] = []
message_builder = MessageBuilder().set_role(RoleType.User)
has_user_content = False
def flush_user_message() -> None:
"""在当前存在用户可见内容时落盘一条 user 消息。"""
nonlocal message_builder, has_user_content
if not has_user_content:
return
messages.append(message_builder.build())
message_builder = MessageBuilder().set_role(RoleType.User)
has_user_content = False
for content_block in content_blocks:
content_type = self._get_content_type(content_block)
if content_type == "tool_result":
flush_user_message()
messages.append(self._build_tool_result_message(content_block))
continue
has_user_content = self._append_sampling_content_to_builder(
message_builder=message_builder,
content_block=content_block,
client=client,
) or has_user_content
flush_user_message()
return messages
@staticmethod
def _get_content_type(content_block: Any) -> str:
"""读取 MCP 内容块类型。
Args:
content_block: MCP 内容块对象。
Returns:
str: 规范化后的内容块类型。
"""
return str(getattr(content_block, "type", "text") or "text").strip().lower()
def _append_sampling_content_to_builder(
self,
message_builder: MessageBuilder,
content_block: Any,
client: "BaseClient",
) -> bool:
"""将 MCP 普通内容块追加到内部消息构建器。
Args:
message_builder: 内部消息构建器。
content_block: MCP 内容块对象。
client: 当前底层模型客户端。
Returns:
bool: 是否成功追加了可见内容。
"""
content_type = self._get_content_type(content_block)
if content_type == "text":
text_content = str(getattr(content_block, "text", "") or "")
if text_content.strip():
message_builder.add_text_content(text_content)
return True
return False
if content_type == "image":
image_data = str(getattr(content_block, "data", "") or "")
image_mime_type = str(getattr(content_block, "mimeType", "") or "")
image_format = self._normalize_image_format(image_mime_type)
if image_data and image_format:
message_builder.add_image_content(
image_format=image_format,
image_base64=image_data,
support_formats=client.get_support_image_formats(),
)
return True
message_builder.add_text_content(
f"[图片内容mime_type={image_mime_type or 'unknown'},当前客户端无法直接透传]"
)
return True
if content_type == "audio":
audio_mime_type = str(getattr(content_block, "mimeType", "") or "")
message_builder.add_text_content(f"[音频内容mime_type={audio_mime_type or 'unknown'}]")
return True
return False
@staticmethod
def _normalize_image_format(mime_type: str) -> str:
"""将图片 MIME 类型转换为内部图片格式名称。
Args:
mime_type: MCP 图片 MIME 类型。
Returns:
str: 内部支持的图片格式名;不支持时返回空字符串。
"""
normalized_mime_type = mime_type.strip().lower()
if normalized_mime_type == "image/png":
return "png"
if normalized_mime_type in {"image/jpeg", "image/jpg"}:
return "jpeg"
if normalized_mime_type == "image/webp":
return "webp"
if normalized_mime_type == "image/gif":
return "gif"
return ""
def _build_tool_result_message(self, content_block: Any) -> Message:
"""将 MCP `tool_result` 内容块转换为内部 Tool 消息。
Args:
content_block: MCP `tool_result` 内容块对象。
Returns:
Message: 转换后的内部 Tool 消息。
"""
message_builder = MessageBuilder().set_role(RoleType.Tool)
message_builder.set_tool_call_id(str(getattr(content_block, "toolUseId", "") or "tool_result"))
summary_text = self._summarize_tool_result_content(content_block)
message_builder.add_text_content(summary_text or "工具执行完成。")
return message_builder.build()
def _summarize_tool_result_content(self, content_block: Any) -> str:
"""汇总 MCP `tool_result` 内容块中的结果文本。
Args:
content_block: MCP `tool_result` 内容块对象。
Returns:
str: 适合发送给主程序模型的工具结果摘要文本。
"""
raw_contents = list(getattr(content_block, "content", []) or [])
content_items = build_tool_content_items(raw_contents)
parts = [item.build_history_text().strip() for item in content_items if item.build_history_text().strip()]
structured_content = getattr(content_block, "structuredContent", None)
if structured_content is not None:
try:
parts.append(json.dumps(structured_content, ensure_ascii=False))
except (TypeError, ValueError):
parts.append(str(structured_content))
summary_text = "\n".join(part for part in parts if part).strip()
if bool(getattr(content_block, "isError", False)) and summary_text:
return f"工具执行失败:\n{summary_text}"
if bool(getattr(content_block, "isError", False)):
return "工具执行失败。"
return summary_text
@staticmethod
def _normalize_tool_call_arguments(raw_arguments: Any) -> dict[str, Any]:
"""将原始工具调用参数规范化为字典。
Args:
raw_arguments: 原始工具参数。
Returns:
dict[str, Any]: 规范化后的参数字典。
"""
if isinstance(raw_arguments, dict):
return dict(raw_arguments)
if raw_arguments is None:
return {}
return {"value": raw_arguments}
def _build_tool_definitions(
self,
raw_tools: Any,
tool_choice_mode: str,
) -> list[ToolDefinitionInput] | None:
"""将 MCP Sampling 工具定义转换为主程序内部工具定义。
Args:
raw_tools: MCP Sampling 请求中的工具列表。
tool_choice_mode: 当前工具选择模式。
Returns:
list[ToolDefinitionInput] | None: 可传给主程序模型层的工具定义列表。
"""
if tool_choice_mode == "none":
return None
if not isinstance(raw_tools, list) or not raw_tools:
return None
tool_definitions: list[ToolDefinitionInput] = []
for raw_tool in raw_tools:
tool_name = str(getattr(raw_tool, "name", "") or "").strip()
if not tool_name:
continue
parameters_schema = (
dict(getattr(raw_tool, "inputSchema", {}) or {}) if getattr(raw_tool, "inputSchema", None) else {}
)
if "$schema" in parameters_schema:
parameters_schema.pop("$schema")
title = str(getattr(raw_tool, "title", "") or "").strip()
description = str(getattr(raw_tool, "description", "") or "").strip()
brief_description = description or title or f"工具 {tool_name}"
detailed_description = build_tool_detailed_description(
parameters_schema,
fallback_description=f"工具名称:{tool_name}",
)
tool_definitions.append(
{
"name": tool_name,
"description": "\n\n".join(
part for part in [brief_description, detailed_description] if part.strip()
).strip(),
"parameters_schema": parameters_schema or {"type": "object", "properties": {}},
}
)
return tool_definitions or None
def _build_sampling_result(
self,
generation_result: LLMResponseResult,
tools_enabled: bool,
) -> Any:
"""将主程序模型响应转换为 MCP Sampling 结果。
Args:
generation_result: 主程序统一大模型响应结果。
tools_enabled: 当前是否允许模型使用工具。
Returns:
Any: MCP `CreateMessageResult` 或 `CreateMessageResultWithTools`。
"""
if not MCP_TYPES_AVAILABLE or mcp_types is None:
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
text_content = str(generation_result.response or "")
tool_calls = list(generation_result.tool_calls or [])
model_name = generation_result.model_name or self._sampling_task_name
if tools_enabled:
content_blocks: list[Any] = []
if text_content.strip():
content_blocks.append(
mcp_types.TextContent(
type="text",
text=text_content,
)
)
for tool_call in tool_calls:
content_blocks.append(
mcp_types.ToolUseContent(
type="tool_use",
name=tool_call.func_name,
id=tool_call.call_id,
input=dict(tool_call.args or {}),
)
)
if not content_blocks:
content_blocks.append(
mcp_types.TextContent(
type="text",
text="",
)
)
return mcp_types.CreateMessageResultWithTools(
role="assistant",
content=content_blocks[0] if len(content_blocks) == 1 else content_blocks,
model=model_name,
stopReason="toolUse" if tool_calls else "endTurn",
)
return mcp_types.CreateMessageResult(
role="assistant",
content=mcp_types.TextContent(
type="text",
text=text_content,
),
model=model_name,
stopReason="endTurn",
)