feat(mcp_module): add hooks, host LLM bridge, and models for MCP integration
- Introduced MCPHostCallbacks for optional host capabilities like sampling and logging. - Implemented MCPHostLLMBridge to handle MCP Sampling requests and bridge to LLM service. - Created models for structured data conversion between MCP SDK and internal data models, including tool content items, prompts, and resources. - Enhanced error handling and logging for better traceability during sampling operations.
This commit is contained in:
597
src/mcp_module/host_llm_bridge.py
Normal file
597
src/mcp_module/host_llm_bridge.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""MCP 宿主侧大模型桥接服务。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import json
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMResponseResult
|
||||
from src.common.logger import get_logger
|
||||
from src.core.tooling import build_tool_detailed_description
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .hooks import MCPHostCallbacks
|
||||
from .models import build_tool_content_items
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
|
||||
try:
|
||||
from mcp import types as mcp_types
|
||||
|
||||
MCP_TYPES_AVAILABLE = True
|
||||
except ImportError:
|
||||
mcp_types = None # type: ignore[assignment]
|
||||
MCP_TYPES_AVAILABLE = False
|
||||
|
||||
logger = get_logger("mcp_host_llm_bridge")
|
||||
|
||||
|
||||
class MCPHostLLMBridge:
|
||||
"""将 MCP Sampling 请求桥接到主程序大模型调用链。"""
|
||||
|
||||
def __init__(self, sampling_task_name: str = "planner") -> None:
|
||||
"""初始化 MCP 宿主侧大模型桥接服务。
|
||||
|
||||
Args:
|
||||
sampling_task_name: 执行 Sampling 请求时使用的模型任务名。
|
||||
"""
|
||||
|
||||
self._sampling_task_name = sampling_task_name.strip() or "planner"
|
||||
self._sampling_client = LLMServiceClient(
|
||||
task_name=self._sampling_task_name,
|
||||
request_type="mcp_sampling",
|
||||
)
|
||||
|
||||
def build_callbacks(self) -> MCPHostCallbacks:
|
||||
"""构建可注入给 MCP 连接层的宿主回调集合。
|
||||
|
||||
Returns:
|
||||
MCPHostCallbacks: 包含 Sampling 回调的宿主回调集合。
|
||||
"""
|
||||
|
||||
return MCPHostCallbacks(
|
||||
sampling_callback=self.handle_sampling_request,
|
||||
)
|
||||
|
||||
async def handle_sampling_request(self, context: Any, params: Any) -> Any:
|
||||
"""处理服务端发起的 MCP Sampling 请求。
|
||||
|
||||
Args:
|
||||
context: MCP SDK 传入的请求上下文。
|
||||
params: `sampling/createMessage` 请求参数。
|
||||
|
||||
Returns:
|
||||
Any: MCP `CreateMessageResult`、`CreateMessageResultWithTools` 或 `ErrorData`。
|
||||
"""
|
||||
|
||||
del context
|
||||
if not MCP_TYPES_AVAILABLE or mcp_types is None:
|
||||
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
|
||||
|
||||
try:
|
||||
tool_choice_mode = self._get_tool_choice_mode(params)
|
||||
tool_definitions = self._build_tool_definitions(
|
||||
raw_tools=getattr(params, "tools", None),
|
||||
tool_choice_mode=tool_choice_mode,
|
||||
)
|
||||
message_factory = self._build_message_factory(
|
||||
raw_messages=list(getattr(params, "messages", []) or []),
|
||||
system_prompt=self._build_system_prompt(
|
||||
raw_system_prompt=str(getattr(params, "systemPrompt", "") or ""),
|
||||
tool_choice_mode=tool_choice_mode,
|
||||
tool_definitions=tool_definitions,
|
||||
),
|
||||
)
|
||||
|
||||
generation_result = await self._sampling_client.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
temperature=self._coerce_float(getattr(params, "temperature", None)),
|
||||
max_tokens=int(getattr(params, "maxTokens", 1024) or 1024),
|
||||
tool_options=tool_definitions,
|
||||
),
|
||||
)
|
||||
|
||||
if tool_choice_mode == "required" and tool_definitions and not generation_result.tool_calls:
|
||||
return mcp_types.ErrorData(
|
||||
code=mcp_types.INTERNAL_ERROR,
|
||||
message="Sampling 要求必须调用工具,但模型未返回任何工具调用",
|
||||
)
|
||||
|
||||
return self._build_sampling_result(
|
||||
generation_result=generation_result,
|
||||
tools_enabled=bool(tool_definitions),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"MCP Sampling 调用失败: {exc}")
|
||||
return mcp_types.ErrorData(
|
||||
code=mcp_types.INTERNAL_ERROR,
|
||||
message=f"MCP Sampling 调用失败: {exc}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_float(raw_value: Any) -> float | None:
|
||||
"""将任意原始值转换为浮点数。
|
||||
|
||||
Args:
|
||||
raw_value: 原始输入值。
|
||||
|
||||
Returns:
|
||||
float | None: 转换后的浮点数;无法转换时返回 ``None``。
|
||||
"""
|
||||
|
||||
if raw_value is None:
|
||||
return None
|
||||
if isinstance(raw_value, int | float):
|
||||
return float(raw_value)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_choice_mode(params: Any) -> str:
|
||||
"""读取 Sampling 请求中的工具选择模式。
|
||||
|
||||
Args:
|
||||
params: Sampling 请求参数对象。
|
||||
|
||||
Returns:
|
||||
str: `auto`、`required` 或 `none`;缺省时返回 `auto`。
|
||||
"""
|
||||
|
||||
tool_choice = getattr(params, "toolChoice", None)
|
||||
mode = str(getattr(tool_choice, "mode", "") or "").strip().lower()
|
||||
if mode in {"required", "none"}:
|
||||
return mode
|
||||
return "auto"
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
raw_system_prompt: str,
|
||||
tool_choice_mode: str,
|
||||
tool_definitions: list[ToolDefinitionInput] | None,
|
||||
) -> str:
|
||||
"""构建发送给主程序大模型的系统提示词。
|
||||
|
||||
Args:
|
||||
raw_system_prompt: 服务端请求中的系统提示词。
|
||||
tool_choice_mode: 当前工具选择模式。
|
||||
tool_definitions: 参与本次 Sampling 的工具定义。
|
||||
|
||||
Returns:
|
||||
str: 最终系统提示词。
|
||||
"""
|
||||
|
||||
prompt_parts: list[str] = []
|
||||
if raw_system_prompt.strip():
|
||||
prompt_parts.append(raw_system_prompt.strip())
|
||||
if tool_choice_mode == "required" and tool_definitions:
|
||||
prompt_parts.append("本轮回答必须至少调用一个工具;不要直接结束回答。")
|
||||
return "\n\n".join(part for part in prompt_parts if part).strip()
|
||||
|
||||
def _build_message_factory(
|
||||
self,
|
||||
raw_messages: list[Any],
|
||||
system_prompt: str,
|
||||
) -> Any:
|
||||
"""构建 MCP Sampling 使用的消息工厂。
|
||||
|
||||
Args:
|
||||
raw_messages: MCP Sampling 原始消息列表。
|
||||
system_prompt: 规范化后的系统提示词。
|
||||
|
||||
Returns:
|
||||
Any: 供 `LLMServiceClient` 使用的消息工厂。
|
||||
"""
|
||||
|
||||
def _message_factory(client: "BaseClient") -> list[Message]:
|
||||
"""延迟构建内部消息列表。
|
||||
|
||||
Args:
|
||||
client: 当前被选中的底层模型客户端。
|
||||
|
||||
Returns:
|
||||
list[Message]: 内部统一消息列表。
|
||||
"""
|
||||
|
||||
messages: list[Message] = []
|
||||
if system_prompt.strip():
|
||||
messages.append(
|
||||
MessageBuilder()
|
||||
.set_role(RoleType.System)
|
||||
.add_text_content(system_prompt.strip())
|
||||
.build()
|
||||
)
|
||||
|
||||
for raw_message in raw_messages:
|
||||
messages.extend(self._convert_sampling_message(raw_message, client))
|
||||
return messages
|
||||
|
||||
return _message_factory
|
||||
|
||||
def _convert_sampling_message(self, raw_message: Any, client: "BaseClient") -> list[Message]:
|
||||
"""将单条 MCP Sampling 消息转换为内部消息列表。
|
||||
|
||||
Args:
|
||||
raw_message: MCP Sampling 原始消息对象。
|
||||
client: 当前底层模型客户端。
|
||||
|
||||
Returns:
|
||||
list[Message]: 转换后的内部消息列表。
|
||||
"""
|
||||
|
||||
role = str(getattr(raw_message, "role", "") or "").strip().lower()
|
||||
content_blocks = self._get_content_blocks(getattr(raw_message, "content", None))
|
||||
|
||||
if role == "assistant":
|
||||
assistant_message = self._build_assistant_message(content_blocks, client)
|
||||
return [assistant_message] if assistant_message is not None else []
|
||||
|
||||
if role == "user":
|
||||
return self._build_user_messages(content_blocks, client)
|
||||
|
||||
raise ValueError(f"不支持的 MCP Sampling 消息角色: {role}")
|
||||
|
||||
@staticmethod
|
||||
def _get_content_blocks(raw_content: Any) -> list[Any]:
|
||||
"""将 MCP Sampling 消息内容统一为列表。
|
||||
|
||||
Args:
|
||||
raw_content: 原始内容字段。
|
||||
|
||||
Returns:
|
||||
list[Any]: 统一后的内容块列表。
|
||||
"""
|
||||
|
||||
if raw_content is None:
|
||||
return []
|
||||
if isinstance(raw_content, list):
|
||||
return list(raw_content)
|
||||
return [raw_content]
|
||||
|
||||
def _build_assistant_message(self, content_blocks: list[Any], client: "BaseClient") -> Optional[Message]:
|
||||
"""构建内部 assistant 消息。
|
||||
|
||||
Args:
|
||||
content_blocks: MCP assistant 内容块列表。
|
||||
client: 当前底层模型客户端。
|
||||
|
||||
Returns:
|
||||
Optional[Message]: 转换后的内部 assistant 消息;无有效内容时返回 ``None``。
|
||||
"""
|
||||
|
||||
message_builder = MessageBuilder().set_role(RoleType.Assistant)
|
||||
tool_calls: list[ToolCall] = []
|
||||
has_visible_content = False
|
||||
|
||||
for content_block in content_blocks:
|
||||
content_type = self._get_content_type(content_block)
|
||||
if content_type == "tool_use":
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=str(getattr(content_block, "id", "") or ""),
|
||||
func_name=str(getattr(content_block, "name", "") or ""),
|
||||
args=self._normalize_tool_call_arguments(getattr(content_block, "input", None)),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
has_visible_content = self._append_sampling_content_to_builder(
|
||||
message_builder=message_builder,
|
||||
content_block=content_block,
|
||||
client=client,
|
||||
) or has_visible_content
|
||||
|
||||
if tool_calls:
|
||||
message_builder.set_tool_calls(tool_calls)
|
||||
|
||||
if not has_visible_content and not tool_calls:
|
||||
return None
|
||||
return message_builder.build()
|
||||
|
||||
def _build_user_messages(self, content_blocks: list[Any], client: "BaseClient") -> list[Message]:
|
||||
"""构建内部 user/tool 消息序列。
|
||||
|
||||
Args:
|
||||
content_blocks: MCP user 内容块列表。
|
||||
client: 当前底层模型客户端。
|
||||
|
||||
Returns:
|
||||
list[Message]: 转换后的内部消息序列。
|
||||
"""
|
||||
|
||||
messages: list[Message] = []
|
||||
message_builder = MessageBuilder().set_role(RoleType.User)
|
||||
has_user_content = False
|
||||
|
||||
def flush_user_message() -> None:
|
||||
"""在当前存在用户可见内容时落盘一条 user 消息。"""
|
||||
|
||||
nonlocal message_builder, has_user_content
|
||||
if not has_user_content:
|
||||
return
|
||||
messages.append(message_builder.build())
|
||||
message_builder = MessageBuilder().set_role(RoleType.User)
|
||||
has_user_content = False
|
||||
|
||||
for content_block in content_blocks:
|
||||
content_type = self._get_content_type(content_block)
|
||||
if content_type == "tool_result":
|
||||
flush_user_message()
|
||||
messages.append(self._build_tool_result_message(content_block))
|
||||
continue
|
||||
|
||||
has_user_content = self._append_sampling_content_to_builder(
|
||||
message_builder=message_builder,
|
||||
content_block=content_block,
|
||||
client=client,
|
||||
) or has_user_content
|
||||
|
||||
flush_user_message()
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _get_content_type(content_block: Any) -> str:
|
||||
"""读取 MCP 内容块类型。
|
||||
|
||||
Args:
|
||||
content_block: MCP 内容块对象。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的内容块类型。
|
||||
"""
|
||||
|
||||
return str(getattr(content_block, "type", "text") or "text").strip().lower()
|
||||
|
||||
def _append_sampling_content_to_builder(
|
||||
self,
|
||||
message_builder: MessageBuilder,
|
||||
content_block: Any,
|
||||
client: "BaseClient",
|
||||
) -> bool:
|
||||
"""将 MCP 普通内容块追加到内部消息构建器。
|
||||
|
||||
Args:
|
||||
message_builder: 内部消息构建器。
|
||||
content_block: MCP 内容块对象。
|
||||
client: 当前底层模型客户端。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功追加了可见内容。
|
||||
"""
|
||||
|
||||
content_type = self._get_content_type(content_block)
|
||||
if content_type == "text":
|
||||
text_content = str(getattr(content_block, "text", "") or "")
|
||||
if text_content.strip():
|
||||
message_builder.add_text_content(text_content)
|
||||
return True
|
||||
return False
|
||||
|
||||
if content_type == "image":
|
||||
image_data = str(getattr(content_block, "data", "") or "")
|
||||
image_mime_type = str(getattr(content_block, "mimeType", "") or "")
|
||||
image_format = self._normalize_image_format(image_mime_type)
|
||||
if image_data and image_format:
|
||||
message_builder.add_image_content(
|
||||
image_format=image_format,
|
||||
image_base64=image_data,
|
||||
support_formats=client.get_support_image_formats(),
|
||||
)
|
||||
return True
|
||||
|
||||
message_builder.add_text_content(
|
||||
f"[图片内容:mime_type={image_mime_type or 'unknown'},当前客户端无法直接透传]"
|
||||
)
|
||||
return True
|
||||
|
||||
if content_type == "audio":
|
||||
audio_mime_type = str(getattr(content_block, "mimeType", "") or "")
|
||||
message_builder.add_text_content(f"[音频内容:mime_type={audio_mime_type or 'unknown'}]")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _normalize_image_format(mime_type: str) -> str:
|
||||
"""将图片 MIME 类型转换为内部图片格式名称。
|
||||
|
||||
Args:
|
||||
mime_type: MCP 图片 MIME 类型。
|
||||
|
||||
Returns:
|
||||
str: 内部支持的图片格式名;不支持时返回空字符串。
|
||||
"""
|
||||
|
||||
normalized_mime_type = mime_type.strip().lower()
|
||||
if normalized_mime_type == "image/png":
|
||||
return "png"
|
||||
if normalized_mime_type in {"image/jpeg", "image/jpg"}:
|
||||
return "jpeg"
|
||||
if normalized_mime_type == "image/webp":
|
||||
return "webp"
|
||||
if normalized_mime_type == "image/gif":
|
||||
return "gif"
|
||||
return ""
|
||||
|
||||
def _build_tool_result_message(self, content_block: Any) -> Message:
|
||||
"""将 MCP `tool_result` 内容块转换为内部 Tool 消息。
|
||||
|
||||
Args:
|
||||
content_block: MCP `tool_result` 内容块对象。
|
||||
|
||||
Returns:
|
||||
Message: 转换后的内部 Tool 消息。
|
||||
"""
|
||||
|
||||
message_builder = MessageBuilder().set_role(RoleType.Tool)
|
||||
message_builder.set_tool_call_id(str(getattr(content_block, "toolUseId", "") or "tool_result"))
|
||||
summary_text = self._summarize_tool_result_content(content_block)
|
||||
message_builder.add_text_content(summary_text or "工具执行完成。")
|
||||
return message_builder.build()
|
||||
|
||||
def _summarize_tool_result_content(self, content_block: Any) -> str:
|
||||
"""汇总 MCP `tool_result` 内容块中的结果文本。
|
||||
|
||||
Args:
|
||||
content_block: MCP `tool_result` 内容块对象。
|
||||
|
||||
Returns:
|
||||
str: 适合发送给主程序模型的工具结果摘要文本。
|
||||
"""
|
||||
|
||||
raw_contents = list(getattr(content_block, "content", []) or [])
|
||||
content_items = build_tool_content_items(raw_contents)
|
||||
parts = [item.build_history_text().strip() for item in content_items if item.build_history_text().strip()]
|
||||
|
||||
structured_content = getattr(content_block, "structuredContent", None)
|
||||
if structured_content is not None:
|
||||
try:
|
||||
parts.append(json.dumps(structured_content, ensure_ascii=False))
|
||||
except (TypeError, ValueError):
|
||||
parts.append(str(structured_content))
|
||||
|
||||
summary_text = "\n".join(part for part in parts if part).strip()
|
||||
if bool(getattr(content_block, "isError", False)) and summary_text:
|
||||
return f"工具执行失败:\n{summary_text}"
|
||||
if bool(getattr(content_block, "isError", False)):
|
||||
return "工具执行失败。"
|
||||
return summary_text
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_arguments(raw_arguments: Any) -> dict[str, Any]:
|
||||
"""将原始工具调用参数规范化为字典。
|
||||
|
||||
Args:
|
||||
raw_arguments: 原始工具参数。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 规范化后的参数字典。
|
||||
"""
|
||||
|
||||
if isinstance(raw_arguments, dict):
|
||||
return dict(raw_arguments)
|
||||
if raw_arguments is None:
|
||||
return {}
|
||||
return {"value": raw_arguments}
|
||||
|
||||
def _build_tool_definitions(
|
||||
self,
|
||||
raw_tools: Any,
|
||||
tool_choice_mode: str,
|
||||
) -> list[ToolDefinitionInput] | None:
|
||||
"""将 MCP Sampling 工具定义转换为主程序内部工具定义。
|
||||
|
||||
Args:
|
||||
raw_tools: MCP Sampling 请求中的工具列表。
|
||||
tool_choice_mode: 当前工具选择模式。
|
||||
|
||||
Returns:
|
||||
list[ToolDefinitionInput] | None: 可传给主程序模型层的工具定义列表。
|
||||
"""
|
||||
|
||||
if tool_choice_mode == "none":
|
||||
return None
|
||||
if not isinstance(raw_tools, list) or not raw_tools:
|
||||
return None
|
||||
|
||||
tool_definitions: list[ToolDefinitionInput] = []
|
||||
for raw_tool in raw_tools:
|
||||
tool_name = str(getattr(raw_tool, "name", "") or "").strip()
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
parameters_schema = (
|
||||
dict(getattr(raw_tool, "inputSchema", {}) or {}) if getattr(raw_tool, "inputSchema", None) else {}
|
||||
)
|
||||
if "$schema" in parameters_schema:
|
||||
parameters_schema.pop("$schema")
|
||||
|
||||
title = str(getattr(raw_tool, "title", "") or "").strip()
|
||||
description = str(getattr(raw_tool, "description", "") or "").strip()
|
||||
brief_description = description or title or f"工具 {tool_name}"
|
||||
detailed_description = build_tool_detailed_description(
|
||||
parameters_schema,
|
||||
fallback_description=f"工具名称:{tool_name}",
|
||||
)
|
||||
|
||||
tool_definitions.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"description": "\n\n".join(
|
||||
part for part in [brief_description, detailed_description] if part.strip()
|
||||
).strip(),
|
||||
"parameters_schema": parameters_schema or {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
|
||||
return tool_definitions or None
|
||||
|
||||
def _build_sampling_result(
|
||||
self,
|
||||
generation_result: LLMResponseResult,
|
||||
tools_enabled: bool,
|
||||
) -> Any:
|
||||
"""将主程序模型响应转换为 MCP Sampling 结果。
|
||||
|
||||
Args:
|
||||
generation_result: 主程序统一大模型响应结果。
|
||||
tools_enabled: 当前是否允许模型使用工具。
|
||||
|
||||
Returns:
|
||||
Any: MCP `CreateMessageResult` 或 `CreateMessageResultWithTools`。
|
||||
"""
|
||||
|
||||
if not MCP_TYPES_AVAILABLE or mcp_types is None:
|
||||
raise RuntimeError("当前环境未安装可用的 MCP types 模块")
|
||||
|
||||
text_content = str(generation_result.response or "")
|
||||
tool_calls = list(generation_result.tool_calls or [])
|
||||
model_name = generation_result.model_name or self._sampling_task_name
|
||||
|
||||
if tools_enabled:
|
||||
content_blocks: list[Any] = []
|
||||
if text_content.strip():
|
||||
content_blocks.append(
|
||||
mcp_types.TextContent(
|
||||
type="text",
|
||||
text=text_content,
|
||||
)
|
||||
)
|
||||
for tool_call in tool_calls:
|
||||
content_blocks.append(
|
||||
mcp_types.ToolUseContent(
|
||||
type="tool_use",
|
||||
name=tool_call.func_name,
|
||||
id=tool_call.call_id,
|
||||
input=dict(tool_call.args or {}),
|
||||
)
|
||||
)
|
||||
|
||||
if not content_blocks:
|
||||
content_blocks.append(
|
||||
mcp_types.TextContent(
|
||||
type="text",
|
||||
text="",
|
||||
)
|
||||
)
|
||||
|
||||
return mcp_types.CreateMessageResultWithTools(
|
||||
role="assistant",
|
||||
content=content_blocks[0] if len(content_blocks) == 1 else content_blocks,
|
||||
model=model_name,
|
||||
stopReason="toolUse" if tool_calls else "endTurn",
|
||||
)
|
||||
|
||||
return mcp_types.CreateMessageResult(
|
||||
role="assistant",
|
||||
content=mcp_types.TextContent(
|
||||
type="text",
|
||||
text=text_content,
|
||||
),
|
||||
model=model_name,
|
||||
stopReason="endTurn",
|
||||
)
|
||||
Reference in New Issue
Block a user