feat: Add LLM Provider support in plugin runtime
- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities. - Implemented validation for LLM Provider declarations to prevent duplicates and conflicts. - Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly. - Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins. - Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime. - Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -204,19 +204,48 @@ class BaseClient(ABC):
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
ClientFactory = Callable[[APIProvider], BaseClient]
|
||||
"""根据 APIProvider 创建客户端实例的工厂函数。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ClientProviderRegistration:
|
||||
"""LLM Provider 客户端类型注册信息。"""
|
||||
|
||||
client_type: str
|
||||
"""客户端类型标识,对应模型配置中的 `api_providers[].client_type`。"""
|
||||
|
||||
factory: ClientFactory
|
||||
"""客户端实例工厂。"""
|
||||
|
||||
owner_plugin_id: str | None = None
|
||||
"""拥有该客户端类型的插件 ID;主程序内置类型为 ``None``。"""
|
||||
|
||||
version: str = "1.0.0"
|
||||
"""Provider 实现版本。"""
|
||||
|
||||
description: str = ""
|
||||
"""Provider 描述文本。"""
|
||||
|
||||
builtin: bool = False
|
||||
"""是否为主程序内置 Provider。"""
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
"""客户端注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化注册表并绑定配置重载回调。"""
|
||||
self.client_registry: Dict[str, Type[BaseClient]] = {}
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_registry: Dict[str, ClientProviderRegistration] = {}
|
||||
"""APIProvider.client_type -> Provider 注册信息映射表。"""
|
||||
self.client_instance_cache: Dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
self._owner_client_types: Dict[str, Set[str]] = {}
|
||||
"""插件 ID -> 该插件拥有的 client_type 集合。"""
|
||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||
|
||||
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
||||
"""注册 API 客户端类。
|
||||
"""注册主程序内置 API 客户端类。
|
||||
|
||||
Args:
|
||||
client_type: 客户端类型标识。
|
||||
@@ -226,13 +255,180 @@ class ClientRegistry:
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
||||
"""将内置客户端类注册到全局客户端注册表。
|
||||
|
||||
Args:
|
||||
cls: 待注册的客户端类。
|
||||
|
||||
Returns:
|
||||
Type[BaseClient]: 原始客户端类。
|
||||
"""
|
||||
if not issubclass(cls, BaseClient):
|
||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||
self.client_registry[client_type] = cls
|
||||
self.register_provider(
|
||||
ClientProviderRegistration(
|
||||
client_type=client_type,
|
||||
factory=cls,
|
||||
builtin=True,
|
||||
)
|
||||
)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _normalize_client_type(client_type: str) -> str:
|
||||
"""规范化客户端类型标识。
|
||||
|
||||
Args:
|
||||
client_type: 原始客户端类型标识。
|
||||
|
||||
Returns:
|
||||
str: 去除首尾空白后的客户端类型标识。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型为空时抛出。
|
||||
"""
|
||||
normalized_client_type = str(client_type or "").strip()
|
||||
if not normalized_client_type:
|
||||
raise ValueError("client_type 不能为空")
|
||||
return normalized_client_type
|
||||
|
||||
def register_provider(self, registration: ClientProviderRegistration) -> None:
|
||||
"""注册单个客户端类型。
|
||||
|
||||
Args:
|
||||
registration: Provider 注册信息。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型冲突时抛出。
|
||||
"""
|
||||
client_type = self._normalize_client_type(registration.client_type)
|
||||
existing = self.client_registry.get(client_type)
|
||||
if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id:
|
||||
raise ValueError(
|
||||
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
|
||||
)
|
||||
|
||||
self.client_registry[client_type] = ClientProviderRegistration(
|
||||
client_type=client_type,
|
||||
factory=registration.factory,
|
||||
owner_plugin_id=registration.owner_plugin_id,
|
||||
version=registration.version,
|
||||
description=registration.description,
|
||||
builtin=registration.builtin,
|
||||
)
|
||||
if registration.owner_plugin_id:
|
||||
self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type)
|
||||
self.clear_client_instance_cache_by_client_type(client_type)
|
||||
|
||||
def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None:
|
||||
"""校验插件 Provider 替换是否会造成运行时冲突。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
client_types: 插件即将注册的客户端类型列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。
|
||||
"""
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
raise ValueError("plugin_id 不能为空")
|
||||
|
||||
normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types]
|
||||
duplicate_client_types = sorted(
|
||||
{
|
||||
client_type
|
||||
for client_type in normalized_client_types
|
||||
if normalized_client_types.count(client_type) > 1
|
||||
}
|
||||
)
|
||||
if duplicate_client_types:
|
||||
raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}")
|
||||
|
||||
for client_type in normalized_client_types:
|
||||
existing = self.client_registry.get(client_type)
|
||||
if existing is None or existing.owner_plugin_id == normalized_plugin_id:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
|
||||
)
|
||||
|
||||
def replace_plugin_providers(
|
||||
self,
|
||||
plugin_id: str,
|
||||
registrations: List[ClientProviderRegistration],
|
||||
) -> None:
|
||||
"""原子替换一个插件拥有的全部 Provider 注册。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
registrations: 插件当前上报的 Provider 注册列表。
|
||||
|
||||
Raises:
|
||||
ValueError: 当注册信息不合法或存在冲突时抛出。
|
||||
"""
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
self.validate_plugin_provider_replacement(
|
||||
normalized_plugin_id,
|
||||
[registration.client_type for registration in registrations],
|
||||
)
|
||||
self.unregister_plugin_providers(normalized_plugin_id)
|
||||
for registration in registrations:
|
||||
self.register_provider(
|
||||
ClientProviderRegistration(
|
||||
client_type=registration.client_type,
|
||||
factory=registration.factory,
|
||||
owner_plugin_id=normalized_plugin_id,
|
||||
version=registration.version,
|
||||
description=registration.description,
|
||||
builtin=False,
|
||||
)
|
||||
)
|
||||
|
||||
def unregister_plugin_providers(self, plugin_id: str) -> int:
|
||||
"""注销一个插件拥有的全部 Provider 注册。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
int: 被注销的客户端类型数量。
|
||||
"""
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id:
|
||||
return 0
|
||||
|
||||
client_types = self._owner_client_types.pop(normalized_plugin_id, set())
|
||||
removed_count = 0
|
||||
for client_type in client_types:
|
||||
registration = self.client_registry.get(client_type)
|
||||
if registration is None or registration.owner_plugin_id != normalized_plugin_id:
|
||||
continue
|
||||
self.client_registry.pop(client_type, None)
|
||||
self.clear_client_instance_cache_by_client_type(client_type)
|
||||
removed_count += 1
|
||||
return removed_count
|
||||
|
||||
def clear_client_instance_cache_by_client_type(self, client_type: str) -> None:
|
||||
"""清理指定客户端类型对应的客户端实例缓存。
|
||||
|
||||
Args:
|
||||
client_type: 需要清理缓存的客户端类型。
|
||||
"""
|
||||
normalized_client_type = str(client_type or "").strip()
|
||||
if not normalized_client_type:
|
||||
return
|
||||
|
||||
stale_provider_names = [
|
||||
provider_name
|
||||
for provider_name, client in self.client_instance_cache.items()
|
||||
if client.api_provider.client_type == normalized_client_type
|
||||
]
|
||||
for provider_name in stale_provider_names:
|
||||
self.client_instance_cache.pop(provider_name, None)
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
||||
"""获取注册的 API 客户端实例。
|
||||
|
||||
@@ -249,15 +445,14 @@ class ClientRegistry:
|
||||
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
return client_class(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
if registration := self.client_registry.get(api_provider.client_type):
|
||||
return registration.factory(api_provider)
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||
if registration := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = registration.factory(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
Reference in New Issue
Block a user