feat: Add LLM Provider support in plugin runtime
- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities. - Implemented validation for LLM Provider declarations to prevent duplicates and conflicts. - Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly. - Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins. - Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime. - Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -204,19 +204,48 @@ class BaseClient(ABC):
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
ClientFactory = Callable[[APIProvider], BaseClient]
|
||||
"""根据 APIProvider 创建客户端实例的工厂函数。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ClientProviderRegistration:
|
||||
"""LLM Provider 客户端类型注册信息。"""
|
||||
|
||||
client_type: str
|
||||
"""客户端类型标识,对应模型配置中的 `api_providers[].client_type`。"""
|
||||
|
||||
factory: ClientFactory
|
||||
"""客户端实例工厂。"""
|
||||
|
||||
owner_plugin_id: str | None = None
|
||||
"""拥有该客户端类型的插件 ID;主程序内置类型为 ``None``。"""
|
||||
|
||||
version: str = "1.0.0"
|
||||
"""Provider 实现版本。"""
|
||||
|
||||
description: str = ""
|
||||
"""Provider 描述文本。"""
|
||||
|
||||
builtin: bool = False
|
||||
"""是否为主程序内置 Provider。"""
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
"""客户端注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化注册表并绑定配置重载回调。"""
|
||||
self.client_registry: Dict[str, Type[BaseClient]] = {}
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_registry: Dict[str, ClientProviderRegistration] = {}
|
||||
"""APIProvider.client_type -> Provider 注册信息映射表。"""
|
||||
self.client_instance_cache: Dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
self._owner_client_types: Dict[str, Set[str]] = {}
|
||||
"""插件 ID -> 该插件拥有的 client_type 集合。"""
|
||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||
|
||||
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
||||
"""注册 API 客户端类。
|
||||
"""注册主程序内置 API 客户端类。
|
||||
|
||||
Args:
|
||||
client_type: 客户端类型标识。
|
||||
@@ -226,13 +255,180 @@ class ClientRegistry:
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
||||
"""将内置客户端类注册到全局客户端注册表。
|
||||
|
||||
Args:
|
||||
cls: 待注册的客户端类。
|
||||
|
||||
Returns:
|
||||
Type[BaseClient]: 原始客户端类。
|
||||
"""
|
||||
if not issubclass(cls, BaseClient):
|
||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||
self.client_registry[client_type] = cls
|
||||
self.register_provider(
|
||||
ClientProviderRegistration(
|
||||
client_type=client_type,
|
||||
factory=cls,
|
||||
builtin=True,
|
||||
)
|
||||
)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _normalize_client_type(client_type: str) -> str:
|
||||
"""规范化客户端类型标识。
|
||||
|
||||
Args:
|
||||
client_type: 原始客户端类型标识。
|
||||
|
||||
Returns:
|
||||
str: 去除首尾空白后的客户端类型标识。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型为空时抛出。
|
||||
"""
|
||||
normalized_client_type = str(client_type or "").strip()
|
||||
if not normalized_client_type:
|
||||
raise ValueError("client_type 不能为空")
|
||||
return normalized_client_type
|
||||
|
||||
def register_provider(self, registration: ClientProviderRegistration) -> None:
|
||||
"""注册单个客户端类型。
|
||||
|
||||
Args:
|
||||
registration: Provider 注册信息。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型冲突时抛出。
|
||||
"""
|
||||
client_type = self._normalize_client_type(registration.client_type)
|
||||
existing = self.client_registry.get(client_type)
|
||||
if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id:
|
||||
raise ValueError(
|
||||
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
|
||||
)
|
||||
|
||||
self.client_registry[client_type] = ClientProviderRegistration(
|
||||
client_type=client_type,
|
||||
factory=registration.factory,
|
||||
owner_plugin_id=registration.owner_plugin_id,
|
||||
version=registration.version,
|
||||
description=registration.description,
|
||||
builtin=registration.builtin,
|
||||
)
|
||||
if registration.owner_plugin_id:
|
||||
self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type)
|
||||
self.clear_client_instance_cache_by_client_type(client_type)
|
||||
|
||||
def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None:
|
||||
"""校验插件 Provider 替换是否会造成运行时冲突。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
client_types: 插件即将注册的客户端类型列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。
|
||||
"""
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
raise ValueError("plugin_id 不能为空")
|
||||
|
||||
normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types]
|
||||
duplicate_client_types = sorted(
|
||||
{
|
||||
client_type
|
||||
for client_type in normalized_client_types
|
||||
if normalized_client_types.count(client_type) > 1
|
||||
}
|
||||
)
|
||||
if duplicate_client_types:
|
||||
raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}")
|
||||
|
||||
for client_type in normalized_client_types:
|
||||
existing = self.client_registry.get(client_type)
|
||||
if existing is None or existing.owner_plugin_id == normalized_plugin_id:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
|
||||
)
|
||||
|
||||
def replace_plugin_providers(
|
||||
self,
|
||||
plugin_id: str,
|
||||
registrations: List[ClientProviderRegistration],
|
||||
) -> None:
|
||||
"""原子替换一个插件拥有的全部 Provider 注册。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
registrations: 插件当前上报的 Provider 注册列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 当注册信息不合法或存在冲突时抛出。
|
||||
"""
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
self.validate_plugin_provider_replacement(
|
||||
normalized_plugin_id,
|
||||
[registration.client_type for registration in registrations],
|
||||
)
|
||||
self.unregister_plugin_providers(normalized_plugin_id)
|
||||
for registration in registrations:
|
||||
self.register_provider(
|
||||
ClientProviderRegistration(
|
||||
client_type=registration.client_type,
|
||||
factory=registration.factory,
|
||||
owner_plugin_id=normalized_plugin_id,
|
||||
version=registration.version,
|
||||
description=registration.description,
|
||||
builtin=False,
|
||||
)
|
||||
)
|
||||
|
||||
def unregister_plugin_providers(self, plugin_id: str) -> int:
|
||||
"""注销一个插件拥有的全部 Provider 注册。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
int: 被注销的客户端类型数量。
|
||||
"""
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
return 0
|
||||
|
||||
client_types = self._owner_client_types.pop(normalized_plugin_id, set())
|
||||
removed_count = 0
|
||||
for client_type in client_types:
|
||||
registration = self.client_registry.get(client_type)
|
||||
if registration is None or registration.owner_plugin_id != normalized_plugin_id:
|
||||
continue
|
||||
self.client_registry.pop(client_type, None)
|
||||
self.clear_client_instance_cache_by_client_type(client_type)
|
||||
removed_count += 1
|
||||
return removed_count
|
||||
|
||||
def clear_client_instance_cache_by_client_type(self, client_type: str) -> None:
|
||||
"""清理指定客户端类型对应的客户端实例缓存。
|
||||
|
||||
Args:
|
||||
client_type: 需要清理缓存的客户端类型。
|
||||
"""
|
||||
normalized_client_type = str(client_type or "").strip()
|
||||
if not normalized_client_type:
|
||||
return
|
||||
|
||||
stale_provider_names = [
|
||||
provider_name
|
||||
for provider_name, client in self.client_instance_cache.items()
|
||||
if client.api_provider.client_type == normalized_client_type
|
||||
]
|
||||
for provider_name in stale_provider_names:
|
||||
self.client_instance_cache.pop(provider_name, None)
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
||||
"""获取注册的 API 客户端实例。
|
||||
|
||||
@@ -249,15 +445,14 @@ class ClientRegistry:
|
||||
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
return client_class(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
if registration := self.client_registry.get(api_provider.client_type):
|
||||
return registration.factory(api_provider)
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||
if registration := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = registration.factory(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
191
src/llm_models/model_client/plugin_client.py
Normal file
191
src/llm_models/model_client/plugin_client.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from src.config.model_configs import APIProvider
|
||||
from src.llm_models.exceptions import RespParseException
|
||||
from src.llm_models.model_client.base_client import (
|
||||
APIResponse,
|
||||
AudioTranscriptionRequest,
|
||||
BaseClient,
|
||||
EmbeddingRequest,
|
||||
ResponseRequest,
|
||||
UsageRecord,
|
||||
)
|
||||
from src.llm_models.request_snapshot import (
|
||||
deserialize_tool_calls_snapshot,
|
||||
serialize_audio_request_snapshot,
|
||||
serialize_embedding_request_snapshot,
|
||||
serialize_response_request_snapshot,
|
||||
)
|
||||
|
||||
|
||||
class PluginLLMClient(BaseClient):
|
||||
"""通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_provider: APIProvider,
|
||||
supervisor: Any,
|
||||
plugin_id: str,
|
||||
client_type: str,
|
||||
) -> None:
|
||||
"""初始化插件 LLM Provider 客户端代理。
|
||||
|
||||
Args:
|
||||
api_provider: 当前 API Provider 配置。
|
||||
supervisor: 拥有目标插件的 Supervisor。
|
||||
plugin_id: 目标插件 ID。
|
||||
client_type: 目标客户端类型。
|
||||
"""
|
||||
super().__init__(api_provider)
|
||||
self._supervisor = supervisor
|
||||
self._plugin_id = plugin_id
|
||||
self._client_type = client_type
|
||||
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取对话响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 统一响应对象。
|
||||
|
||||
Raises:
|
||||
RespParseException: 插件返回内容无法转换为统一响应时抛出。
|
||||
"""
|
||||
if request.stream_response_handler is not None or request.async_response_parser is not None:
|
||||
raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器")
|
||||
payload = serialize_response_request_snapshot(request)
|
||||
result = await self._invoke_provider("response", payload)
|
||||
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
||||
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取文本嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 嵌入响应。
|
||||
"""
|
||||
result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request))
|
||||
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
||||
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取音频转录。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 音频转录响应。
|
||||
"""
|
||||
result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request))
|
||||
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
||||
|
||||
def get_support_image_formats(self) -> List[str]:
|
||||
"""获取支持的图片格式。
|
||||
|
||||
Returns:
|
||||
List[str]: 插件 Provider 默认接收的图片格式列表。
|
||||
"""
|
||||
return ["jpeg", "jpg", "png", "webp"]
|
||||
|
||||
def _build_api_provider_snapshot(self) -> Dict[str, Any]:
|
||||
"""构建可传给插件 Provider 的 API Provider 配置快照。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含认证信息的 API Provider 配置字典。
|
||||
"""
|
||||
return self.api_provider.model_dump(mode="json")
|
||||
|
||||
async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用插件 Provider。
|
||||
|
||||
Args:
|
||||
operation: 请求操作类型。
|
||||
request_payload: 已序列化的内部请求。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 插件返回的响应字典。
|
||||
|
||||
Raises:
|
||||
RespParseException: 插件调用失败或返回格式不合法时抛出。
|
||||
"""
|
||||
try:
|
||||
response = await self._supervisor.invoke_llm_provider(
|
||||
plugin_id=self._plugin_id,
|
||||
client_type=self._client_type,
|
||||
operation=operation,
|
||||
request={
|
||||
**request_payload,
|
||||
"api_provider": self._build_api_provider_snapshot(),
|
||||
},
|
||||
timeout_ms=max(1000, int(self.api_provider.timeout) * 1000),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc
|
||||
if response.error:
|
||||
raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败")))
|
||||
|
||||
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||
success = bool(payload.get("success", False))
|
||||
result = payload.get("result")
|
||||
if not success:
|
||||
raise RespParseException(message=str(result or "插件 LLM Provider 返回失败"))
|
||||
if not isinstance(result, dict):
|
||||
raise RespParseException(message="插件 LLM Provider 返回值必须是字典")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None:
|
||||
"""从插件返回值恢复使用量记录。
|
||||
|
||||
Args:
|
||||
raw_usage: 插件返回的使用量字段。
|
||||
model_name: 当前模型名称。
|
||||
provider_name: 当前 Provider 名称。
|
||||
|
||||
Returns:
|
||||
UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。
|
||||
"""
|
||||
if not isinstance(raw_usage, dict):
|
||||
return None
|
||||
return UsageRecord(
|
||||
model_name=str(raw_usage.get("model_name") or model_name),
|
||||
provider_name=str(raw_usage.get("provider_name") or provider_name),
|
||||
prompt_tokens=int(raw_usage.get("prompt_tokens") or 0),
|
||||
completion_tokens=int(raw_usage.get("completion_tokens") or 0),
|
||||
total_tokens=int(raw_usage.get("total_tokens") or 0),
|
||||
prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0),
|
||||
prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse:
|
||||
"""从插件返回值恢复统一 APIResponse。
|
||||
|
||||
Args:
|
||||
result: 插件返回的响应字典。
|
||||
model_name: 当前模型名称。
|
||||
provider_name: 当前 Provider 名称。
|
||||
|
||||
Returns:
|
||||
APIResponse: 统一响应对象。
|
||||
"""
|
||||
raw_embedding = result.get("embedding")
|
||||
embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None
|
||||
content = result.get("content")
|
||||
if not isinstance(content, str):
|
||||
content = result.get("response")
|
||||
reasoning_content = result.get("reasoning_content")
|
||||
if not isinstance(reasoning_content, str):
|
||||
reasoning_content = result.get("reasoning")
|
||||
return APIResponse(
|
||||
content=content if isinstance(content, str) else None,
|
||||
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||
tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None,
|
||||
embedding=embedding,
|
||||
usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name),
|
||||
raw_data=result.get("raw_data", result),
|
||||
)
|
||||
@@ -1072,3 +1072,65 @@ class TempMethodsLLMUtils:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||
|
||||
|
||||
class LLMRequest(LLMOrchestrator):
|
||||
"""兼容旧调用方的 LLM 请求入口。
|
||||
|
||||
新代码应优先使用 ``LLMOrchestrator`` 或服务层 ``LLMServiceClient``;
|
||||
该类保留旧版 ``model_set=TaskConfig`` 的构造方式,并在运行时解析
|
||||
对应的最新任务配置名称。
|
||||
"""
|
||||
|
||||
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
|
||||
"""初始化旧版 LLM 请求对象。
|
||||
|
||||
Args:
|
||||
model_set: 旧调用方传入的任务配置对象。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self._task_config_name = self._resolve_task_config_name(model_set)
|
||||
super().__init__(task_name=self._task_config_name, request_type=request_type)
|
||||
|
||||
@staticmethod
|
||||
def _build_task_config_signature(task_config: TaskConfig) -> Tuple[Any, ...]:
|
||||
"""构造任务配置签名。
|
||||
|
||||
Args:
|
||||
task_config: 任务配置对象。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, ...]: 可用于匹配热重载前后任务配置的签名。
|
||||
"""
|
||||
return (
|
||||
tuple(task_config.model_list),
|
||||
task_config.max_tokens,
|
||||
task_config.temperature,
|
||||
task_config.slow_threshold,
|
||||
task_config.selection_strategy,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resolve_task_config_name(cls, model_set: TaskConfig) -> str:
|
||||
"""根据旧版 TaskConfig 对象解析任务配置名称。
|
||||
|
||||
Args:
|
||||
model_set: 旧调用方传入的任务配置对象。
|
||||
|
||||
Returns:
|
||||
str: 对应 ``model_task_config`` 下的字段名。
|
||||
|
||||
Raises:
|
||||
ValueError: 未能找到匹配任务配置时抛出。
|
||||
"""
|
||||
target_signature = cls._build_task_config_signature(model_set)
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
for attr_name in dir(model_task_config):
|
||||
if attr_name.startswith("_"):
|
||||
continue
|
||||
attr_value = getattr(model_task_config, attr_name)
|
||||
if not isinstance(attr_value, TaskConfig):
|
||||
continue
|
||||
if attr_value is model_set or cls._build_task_config_signature(attr_value) == target_signature:
|
||||
return attr_name
|
||||
raise ValueError("无法根据旧版 model_set 解析任务配置名称")
|
||||
|
||||
Reference in New Issue
Block a user