Files
mai-bot/src/llm_models/model_client/plugin_client.py
DrSmoothl 742e21a727 feat: Add LLM Provider support in plugin runtime
- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities.
- Implemented validation for LLM Provider declarations to prevent duplicates and conflicts.
- Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly.
- Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins.
- Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime.
- Developed tests to ensure proper registration and conflict handling of LLM Providers.
2026-04-27 16:49:44 +08:00

192 lines
7.5 KiB
Python

from typing import Any, Dict, List
from src.config.model_configs import APIProvider
from src.llm_models.exceptions import RespParseException
from src.llm_models.model_client.base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
EmbeddingRequest,
ResponseRequest,
UsageRecord,
)
from src.llm_models.request_snapshot import (
deserialize_tool_calls_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
class PluginLLMClient(BaseClient):
"""通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。"""
def __init__(
self,
api_provider: APIProvider,
supervisor: Any,
plugin_id: str,
client_type: str,
) -> None:
"""初始化插件 LLM Provider 客户端代理。
Args:
api_provider: 当前 API Provider 配置。
supervisor: 拥有目标插件的 Supervisor。
plugin_id: 目标插件 ID。
client_type: 目标客户端类型。
"""
super().__init__(api_provider)
self._supervisor = supervisor
self._plugin_id = plugin_id
self._client_type = client_type
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 统一响应对象。
Raises:
RespParseException: 插件返回内容无法转换为统一响应时抛出。
"""
if request.stream_response_handler is not None or request.async_response_parser is not None:
raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器")
payload = serialize_response_request_snapshot(request)
result = await self._invoke_provider("response", payload)
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 嵌入响应。
"""
result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request))
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 音频转录响应。
"""
result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request))
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式。
Returns:
List[str]: 插件 Provider 默认接收的图片格式列表。
"""
return ["jpeg", "jpg", "png", "webp"]
def _build_api_provider_snapshot(self) -> Dict[str, Any]:
"""构建可传给插件 Provider 的 API Provider 配置快照。
Returns:
Dict[str, Any]: 包含认证信息的 API Provider 配置字典。
"""
return self.api_provider.model_dump(mode="json")
async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""调用插件 Provider。
Args:
operation: 请求操作类型。
request_payload: 已序列化的内部请求。
Returns:
Dict[str, Any]: 插件返回的响应字典。
Raises:
RespParseException: 插件调用失败或返回格式不合法时抛出。
"""
try:
response = await self._supervisor.invoke_llm_provider(
plugin_id=self._plugin_id,
client_type=self._client_type,
operation=operation,
request={
**request_payload,
"api_provider": self._build_api_provider_snapshot(),
},
timeout_ms=max(1000, int(self.api_provider.timeout) * 1000),
)
except Exception as exc:
raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc
if response.error:
raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败")))
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
if not success:
raise RespParseException(message=str(result or "插件 LLM Provider 返回失败"))
if not isinstance(result, dict):
raise RespParseException(message="插件 LLM Provider 返回值必须是字典")
return result
@staticmethod
def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None:
"""从插件返回值恢复使用量记录。
Args:
raw_usage: 插件返回的使用量字段。
model_name: 当前模型名称。
provider_name: 当前 Provider 名称。
Returns:
UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。
"""
if not isinstance(raw_usage, dict):
return None
return UsageRecord(
model_name=str(raw_usage.get("model_name") or model_name),
provider_name=str(raw_usage.get("provider_name") or provider_name),
prompt_tokens=int(raw_usage.get("prompt_tokens") or 0),
completion_tokens=int(raw_usage.get("completion_tokens") or 0),
total_tokens=int(raw_usage.get("total_tokens") or 0),
prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0),
prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0),
)
@staticmethod
def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse:
"""从插件返回值恢复统一 APIResponse。
Args:
result: 插件返回的响应字典。
model_name: 当前模型名称。
provider_name: 当前 Provider 名称。
Returns:
APIResponse: 统一响应对象。
"""
raw_embedding = result.get("embedding")
embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None
content = result.get("content")
if not isinstance(content, str):
content = result.get("response")
reasoning_content = result.get("reasoning_content")
if not isinstance(reasoning_content, str):
reasoning_content = result.get("reasoning")
return APIResponse(
content=content if isinstance(content, str) else None,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None,
embedding=embedding,
usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name),
raw_data=result.get("raw_data", result),
)