- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities. - Implemented validation for LLM Provider declarations to prevent duplicates and conflicts. - Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly. - Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins. - Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime. - Developed tests to ensure proper registration and conflict handling of LLM Providers.
192 lines
7.5 KiB
Python
192 lines
7.5 KiB
Python
from typing import Any, Dict, List
|
|
|
|
from src.config.model_configs import APIProvider
|
|
from src.llm_models.exceptions import RespParseException
|
|
from src.llm_models.model_client.base_client import (
|
|
APIResponse,
|
|
AudioTranscriptionRequest,
|
|
BaseClient,
|
|
EmbeddingRequest,
|
|
ResponseRequest,
|
|
UsageRecord,
|
|
)
|
|
from src.llm_models.request_snapshot import (
|
|
deserialize_tool_calls_snapshot,
|
|
serialize_audio_request_snapshot,
|
|
serialize_embedding_request_snapshot,
|
|
serialize_response_request_snapshot,
|
|
)
|
|
|
|
|
|
class PluginLLMClient(BaseClient):
|
|
"""通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_provider: APIProvider,
|
|
supervisor: Any,
|
|
plugin_id: str,
|
|
client_type: str,
|
|
) -> None:
|
|
"""初始化插件 LLM Provider 客户端代理。
|
|
|
|
Args:
|
|
api_provider: 当前 API Provider 配置。
|
|
supervisor: 拥有目标插件的 Supervisor。
|
|
plugin_id: 目标插件 ID。
|
|
client_type: 目标客户端类型。
|
|
"""
|
|
super().__init__(api_provider)
|
|
self._supervisor = supervisor
|
|
self._plugin_id = plugin_id
|
|
self._client_type = client_type
|
|
|
|
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
|
"""获取对话响应。
|
|
|
|
Args:
|
|
request: 统一响应请求对象。
|
|
|
|
Returns:
|
|
APIResponse: 统一响应对象。
|
|
|
|
Raises:
|
|
RespParseException: 插件返回内容无法转换为统一响应时抛出。
|
|
"""
|
|
if request.stream_response_handler is not None or request.async_response_parser is not None:
|
|
raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器")
|
|
payload = serialize_response_request_snapshot(request)
|
|
result = await self._invoke_provider("response", payload)
|
|
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
|
|
|
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
|
"""获取文本嵌入。
|
|
|
|
Args:
|
|
request: 统一嵌入请求对象。
|
|
|
|
Returns:
|
|
APIResponse: 嵌入响应。
|
|
"""
|
|
result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request))
|
|
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
|
|
|
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
|
"""获取音频转录。
|
|
|
|
Args:
|
|
request: 统一音频转录请求对象。
|
|
|
|
Returns:
|
|
APIResponse: 音频转录响应。
|
|
"""
|
|
result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request))
|
|
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
|
|
|
def get_support_image_formats(self) -> List[str]:
|
|
"""获取支持的图片格式。
|
|
|
|
Returns:
|
|
List[str]: 插件 Provider 默认接收的图片格式列表。
|
|
"""
|
|
return ["jpeg", "jpg", "png", "webp"]
|
|
|
|
def _build_api_provider_snapshot(self) -> Dict[str, Any]:
|
|
"""构建可传给插件 Provider 的 API Provider 配置快照。
|
|
|
|
Returns:
|
|
Dict[str, Any]: 包含认证信息的 API Provider 配置字典。
|
|
"""
|
|
return self.api_provider.model_dump(mode="json")
|
|
|
|
async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""调用插件 Provider。
|
|
|
|
Args:
|
|
operation: 请求操作类型。
|
|
request_payload: 已序列化的内部请求。
|
|
|
|
Returns:
|
|
Dict[str, Any]: 插件返回的响应字典。
|
|
|
|
Raises:
|
|
RespParseException: 插件调用失败或返回格式不合法时抛出。
|
|
"""
|
|
try:
|
|
response = await self._supervisor.invoke_llm_provider(
|
|
plugin_id=self._plugin_id,
|
|
client_type=self._client_type,
|
|
operation=operation,
|
|
request={
|
|
**request_payload,
|
|
"api_provider": self._build_api_provider_snapshot(),
|
|
},
|
|
timeout_ms=max(1000, int(self.api_provider.timeout) * 1000),
|
|
)
|
|
except Exception as exc:
|
|
raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc
|
|
if response.error:
|
|
raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败")))
|
|
|
|
payload = response.payload if isinstance(response.payload, dict) else {}
|
|
success = bool(payload.get("success", False))
|
|
result = payload.get("result")
|
|
if not success:
|
|
raise RespParseException(message=str(result or "插件 LLM Provider 返回失败"))
|
|
if not isinstance(result, dict):
|
|
raise RespParseException(message="插件 LLM Provider 返回值必须是字典")
|
|
return result
|
|
|
|
@staticmethod
|
|
def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None:
|
|
"""从插件返回值恢复使用量记录。
|
|
|
|
Args:
|
|
raw_usage: 插件返回的使用量字段。
|
|
model_name: 当前模型名称。
|
|
provider_name: 当前 Provider 名称。
|
|
|
|
Returns:
|
|
UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。
|
|
"""
|
|
if not isinstance(raw_usage, dict):
|
|
return None
|
|
return UsageRecord(
|
|
model_name=str(raw_usage.get("model_name") or model_name),
|
|
provider_name=str(raw_usage.get("provider_name") or provider_name),
|
|
prompt_tokens=int(raw_usage.get("prompt_tokens") or 0),
|
|
completion_tokens=int(raw_usage.get("completion_tokens") or 0),
|
|
total_tokens=int(raw_usage.get("total_tokens") or 0),
|
|
prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0),
|
|
prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0),
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse:
|
|
"""从插件返回值恢复统一 APIResponse。
|
|
|
|
Args:
|
|
result: 插件返回的响应字典。
|
|
model_name: 当前模型名称。
|
|
provider_name: 当前 Provider 名称。
|
|
|
|
Returns:
|
|
APIResponse: 统一响应对象。
|
|
"""
|
|
raw_embedding = result.get("embedding")
|
|
embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None
|
|
content = result.get("content")
|
|
if not isinstance(content, str):
|
|
content = result.get("response")
|
|
reasoning_content = result.get("reasoning_content")
|
|
if not isinstance(reasoning_content, str):
|
|
reasoning_content = result.get("reasoning")
|
|
return APIResponse(
|
|
content=content if isinstance(content, str) else None,
|
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
|
tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None,
|
|
embedding=embedding,
|
|
usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name),
|
|
raw_data=result.get("raw_data", result),
|
|
)
|