feat: Add LLM Provider support in plugin runtime

- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities.
- Implemented validation for LLM Provider declarations to prevent duplicates and conflicts.
- Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly.
- Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins.
- Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime.
- Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
DrSmoothl
2026-04-27 16:49:44 +08:00
parent 1fe9dc8786
commit 742e21a727
11 changed files with 903 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type
import asyncio
@@ -204,19 +204,48 @@ class BaseClient(ABC):
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
ClientFactory = Callable[[APIProvider], BaseClient]
"""根据 APIProvider 创建客户端实例的工厂函数。"""
@dataclass(slots=True)
class ClientProviderRegistration:
"""LLM Provider 客户端类型注册信息。"""
client_type: str
"""客户端类型标识,对应模型配置中的 `api_providers[].client_type`。"""
factory: ClientFactory
"""客户端实例工厂。"""
owner_plugin_id: str | None = None
"""拥有该客户端类型的插件 ID主程序内置类型为 ``None``。"""
version: str = "1.0.0"
"""Provider 实现版本。"""
description: str = ""
"""Provider 描述文本。"""
builtin: bool = False
"""是否为主程序内置 Provider。"""
class ClientRegistry:
"""客户端注册表。"""
def __init__(self) -> None:
"""初始化注册表并绑定配置重载回调。"""
self.client_registry: Dict[str, Type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_registry: Dict[str, ClientProviderRegistration] = {}
"""APIProvider.client_type -> Provider 注册信息映射表"""
self.client_instance_cache: Dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
self._owner_client_types: Dict[str, Set[str]] = {}
"""插件 ID -> 该插件拥有的 client_type 集合。"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
"""注册 API 客户端类。
"""注册主程序内置 API 客户端类。
Args:
client_type: 客户端类型标识。
@@ -226,13 +255,180 @@ class ClientRegistry:
"""
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
"""将内置客户端类注册到全局客户端注册表。
Args:
cls: 待注册的客户端类。
Returns:
Type[BaseClient]: 原始客户端类。
"""
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
self.register_provider(
ClientProviderRegistration(
client_type=client_type,
factory=cls,
builtin=True,
)
)
return cls
return decorator
@staticmethod
def _normalize_client_type(client_type: str) -> str:
"""规范化客户端类型标识。
Args:
client_type: 原始客户端类型标识。
Returns:
str: 去除首尾空白后的客户端类型标识。
Raises:
ValueError: 当客户端类型为空时抛出。
"""
normalized_client_type = str(client_type or "").strip()
if not normalized_client_type:
raise ValueError("client_type 不能为空")
return normalized_client_type
def register_provider(self, registration: ClientProviderRegistration) -> None:
"""注册单个客户端类型。
Args:
registration: Provider 注册信息。
Raises:
ValueError: 当客户端类型冲突时抛出。
"""
client_type = self._normalize_client_type(registration.client_type)
existing = self.client_registry.get(client_type)
if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id:
raise ValueError(
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
)
self.client_registry[client_type] = ClientProviderRegistration(
client_type=client_type,
factory=registration.factory,
owner_plugin_id=registration.owner_plugin_id,
version=registration.version,
description=registration.description,
builtin=registration.builtin,
)
if registration.owner_plugin_id:
self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type)
self.clear_client_instance_cache_by_client_type(client_type)
def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None:
"""校验插件 Provider 替换是否会造成运行时冲突。
Args:
plugin_id: 目标插件 ID。
client_types: 插件即将注册的客户端类型列表。
Raises:
ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
raise ValueError("plugin_id 不能为空")
normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types]
duplicate_client_types = sorted(
{
client_type
for client_type in normalized_client_types
if normalized_client_types.count(client_type) > 1
}
)
if duplicate_client_types:
raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}")
for client_type in normalized_client_types:
existing = self.client_registry.get(client_type)
if existing is None or existing.owner_plugin_id == normalized_plugin_id:
continue
raise ValueError(
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
)
def replace_plugin_providers(
self,
plugin_id: str,
registrations: List[ClientProviderRegistration],
) -> None:
"""原子替换一个插件拥有的全部 Provider 注册。
Args:
plugin_id: 目标插件 ID。
registrations: 插件当前上报的 Provider 注册列表。
Raises:
ValueError: 当注册信息不合法或存在冲突时抛出。
"""
normalized_plugin_id = str(plugin_id or "").strip()
self.validate_plugin_provider_replacement(
normalized_plugin_id,
[registration.client_type for registration in registrations],
)
self.unregister_plugin_providers(normalized_plugin_id)
for registration in registrations:
self.register_provider(
ClientProviderRegistration(
client_type=registration.client_type,
factory=registration.factory,
owner_plugin_id=normalized_plugin_id,
version=registration.version,
description=registration.description,
builtin=False,
)
)
def unregister_plugin_providers(self, plugin_id: str) -> int:
"""注销一个插件拥有的全部 Provider 注册。
Args:
plugin_id: 目标插件 ID。
Returns:
int: 被注销的客户端类型数量。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return 0
client_types = self._owner_client_types.pop(normalized_plugin_id, set())
removed_count = 0
for client_type in client_types:
registration = self.client_registry.get(client_type)
if registration is None or registration.owner_plugin_id != normalized_plugin_id:
continue
self.client_registry.pop(client_type, None)
self.clear_client_instance_cache_by_client_type(client_type)
removed_count += 1
return removed_count
def clear_client_instance_cache_by_client_type(self, client_type: str) -> None:
"""清理指定客户端类型对应的客户端实例缓存。
Args:
client_type: 需要清理缓存的客户端类型。
"""
normalized_client_type = str(client_type or "").strip()
if not normalized_client_type:
return
stale_provider_names = [
provider_name
for provider_name, client in self.client_instance_cache.items()
if client.api_provider.client_type == normalized_client_type
]
for provider_name in stale_provider_names:
self.client_instance_cache.pop(provider_name, None)
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
"""获取注册的 API 客户端实例。
@@ -249,15 +445,14 @@ class ClientRegistry:
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
if registration := self.client_registry.get(api_provider.client_type):
return registration.factory(api_provider)
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)
if registration := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = registration.factory(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]

View File

@@ -0,0 +1,191 @@
from typing import Any, Dict, List
from src.config.model_configs import APIProvider
from src.llm_models.exceptions import RespParseException
from src.llm_models.model_client.base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
EmbeddingRequest,
ResponseRequest,
UsageRecord,
)
from src.llm_models.request_snapshot import (
deserialize_tool_calls_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
class PluginLLMClient(BaseClient):
"""通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。"""
def __init__(
self,
api_provider: APIProvider,
supervisor: Any,
plugin_id: str,
client_type: str,
) -> None:
"""初始化插件 LLM Provider 客户端代理。
Args:
api_provider: 当前 API Provider 配置。
supervisor: 拥有目标插件的 Supervisor。
plugin_id: 目标插件 ID。
client_type: 目标客户端类型。
"""
super().__init__(api_provider)
self._supervisor = supervisor
self._plugin_id = plugin_id
self._client_type = client_type
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 统一响应对象。
Raises:
RespParseException: 插件返回内容无法转换为统一响应时抛出。
"""
if request.stream_response_handler is not None or request.async_response_parser is not None:
raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器")
payload = serialize_response_request_snapshot(request)
result = await self._invoke_provider("response", payload)
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 嵌入响应。
"""
result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request))
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 音频转录响应。
"""
result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request))
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式。
Returns:
List[str]: 插件 Provider 默认接收的图片格式列表。
"""
return ["jpeg", "jpg", "png", "webp"]
def _build_api_provider_snapshot(self) -> Dict[str, Any]:
"""构建可传给插件 Provider 的 API Provider 配置快照。
Returns:
Dict[str, Any]: 包含认证信息的 API Provider 配置字典。
"""
return self.api_provider.model_dump(mode="json")
async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""调用插件 Provider。
Args:
operation: 请求操作类型。
request_payload: 已序列化的内部请求。
Returns:
Dict[str, Any]: 插件返回的响应字典。
Raises:
RespParseException: 插件调用失败或返回格式不合法时抛出。
"""
try:
response = await self._supervisor.invoke_llm_provider(
plugin_id=self._plugin_id,
client_type=self._client_type,
operation=operation,
request={
**request_payload,
"api_provider": self._build_api_provider_snapshot(),
},
timeout_ms=max(1000, int(self.api_provider.timeout) * 1000),
)
except Exception as exc:
raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc
if response.error:
raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败")))
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
if not success:
raise RespParseException(message=str(result or "插件 LLM Provider 返回失败"))
if not isinstance(result, dict):
raise RespParseException(message="插件 LLM Provider 返回值必须是字典")
return result
@staticmethod
def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None:
"""从插件返回值恢复使用量记录。
Args:
raw_usage: 插件返回的使用量字段。
model_name: 当前模型名称。
provider_name: 当前 Provider 名称。
Returns:
UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。
"""
if not isinstance(raw_usage, dict):
return None
return UsageRecord(
model_name=str(raw_usage.get("model_name") or model_name),
provider_name=str(raw_usage.get("provider_name") or provider_name),
prompt_tokens=int(raw_usage.get("prompt_tokens") or 0),
completion_tokens=int(raw_usage.get("completion_tokens") or 0),
total_tokens=int(raw_usage.get("total_tokens") or 0),
prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0),
prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0),
)
@staticmethod
def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse:
"""从插件返回值恢复统一 APIResponse。
Args:
result: 插件返回的响应字典。
model_name: 当前模型名称。
provider_name: 当前 Provider 名称。
Returns:
APIResponse: 统一响应对象。
"""
raw_embedding = result.get("embedding")
embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None
content = result.get("content")
if not isinstance(content, str):
content = result.get("response")
reasoning_content = result.get("reasoning_content")
if not isinstance(reasoning_content, str):
reasoning_content = result.get("reasoning")
return APIResponse(
content=content if isinstance(content, str) else None,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None,
embedding=embedding,
usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name),
raw_data=result.get("raw_data", result),
)

View File

@@ -1072,3 +1072,65 @@ class TempMethodsLLMUtils:
if provider.name == provider_name:
return provider
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
class LLMRequest(LLMOrchestrator):
"""兼容旧调用方的 LLM 请求入口。
新代码应优先使用 ``LLMOrchestrator`` 或服务层 ``LLMServiceClient``
该类保留旧版 ``model_set=TaskConfig`` 的构造方式,并在运行时解析
对应的最新任务配置名称。
"""
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
"""初始化旧版 LLM 请求对象。
Args:
model_set: 旧调用方传入的任务配置对象。
request_type: 当前请求的业务类型标识。
"""
self._task_config_name = self._resolve_task_config_name(model_set)
super().__init__(task_name=self._task_config_name, request_type=request_type)
@staticmethod
def _build_task_config_signature(task_config: TaskConfig) -> Tuple[Any, ...]:
"""构造任务配置签名。
Args:
task_config: 任务配置对象。
Returns:
Tuple[Any, ...]: 可用于匹配热重载前后任务配置的签名。
"""
return (
tuple(task_config.model_list),
task_config.max_tokens,
task_config.temperature,
task_config.slow_threshold,
task_config.selection_strategy,
)
@classmethod
def _resolve_task_config_name(cls, model_set: TaskConfig) -> str:
"""根据旧版 TaskConfig 对象解析任务配置名称。
Args:
model_set: 旧调用方传入的任务配置对象。
Returns:
str: 对应 ``model_task_config`` 下的字段名。
Raises:
ValueError: 未能找到匹配任务配置时抛出。
"""
target_signature = cls._build_task_config_signature(model_set)
model_task_config = config_manager.get_model_config().model_task_config
for attr_name in dir(model_task_config):
if attr_name.startswith("_"):
continue
attr_value = getattr(model_task_config, attr_name)
if not isinstance(attr_value, TaskConfig):
continue
if attr_value is model_set or cls._build_task_config_signature(attr_value) == target_signature:
return attr_name
raise ValueError("无法根据旧版 model_set 解析任务配置名称")