feat: Add LLM Provider support in plugin runtime

- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities.
- Implemented validation for LLM Provider declarations to prevent duplicates and conflicts.
- Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly.
- Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins.
- Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime.
- Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
DrSmoothl
2026-04-27 16:49:44 +08:00
parent 1fe9dc8786
commit 742e21a727
11 changed files with 903 additions and 13 deletions

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type
import asyncio
@@ -204,19 +204,48 @@ class BaseClient(ABC):
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
ClientFactory = Callable[[APIProvider], BaseClient]
"""根据 APIProvider 创建客户端实例的工厂函数。"""
@dataclass(slots=True)
class ClientProviderRegistration:
"""LLM Provider 客户端类型注册信息。"""
client_type: str
"""客户端类型标识,对应模型配置中的 `api_providers[].client_type`。"""
factory: ClientFactory
"""客户端实例工厂。"""
owner_plugin_id: str | None = None
"""拥有该客户端类型的插件 ID主程序内置类型为 ``None``。"""
version: str = "1.0.0"
"""Provider 实现版本。"""
description: str = ""
"""Provider 描述文本。"""
builtin: bool = False
"""是否为主程序内置 Provider。"""
class ClientRegistry:
"""客户端注册表。"""
def __init__(self) -> None:
"""初始化注册表并绑定配置重载回调。"""
self.client_registry: Dict[str, Type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_registry: Dict[str, ClientProviderRegistration] = {}
"""APIProvider.client_type -> Provider 注册信息映射表"""
self.client_instance_cache: Dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
self._owner_client_types: Dict[str, Set[str]] = {}
"""插件 ID -> 该插件拥有的 client_type 集合。"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
"""注册 API 客户端类。
"""注册主程序内置 API 客户端类。
Args:
client_type: 客户端类型标识。
@@ -226,13 +255,180 @@ class ClientRegistry:
"""
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
"""将内置客户端类注册到全局客户端注册表。
Args:
cls: 待注册的客户端类。
Returns:
Type[BaseClient]: 原始客户端类。
"""
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
self.register_provider(
ClientProviderRegistration(
client_type=client_type,
factory=cls,
builtin=True,
)
)
return cls
return decorator
@staticmethod
def _normalize_client_type(client_type: str) -> str:
"""规范化客户端类型标识。
Args:
client_type: 原始客户端类型标识。
Returns:
str: 去除首尾空白后的客户端类型标识。
Raises:
ValueError: 当客户端类型为空时抛出。
"""
normalized_client_type = str(client_type or "").strip()
if not normalized_client_type:
raise ValueError("client_type 不能为空")
return normalized_client_type
def register_provider(self, registration: ClientProviderRegistration) -> None:
"""注册单个客户端类型。
Args:
registration: Provider 注册信息。
Raises:
ValueError: 当客户端类型冲突时抛出。
"""
client_type = self._normalize_client_type(registration.client_type)
existing = self.client_registry.get(client_type)
if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id:
raise ValueError(
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
)
self.client_registry[client_type] = ClientProviderRegistration(
client_type=client_type,
factory=registration.factory,
owner_plugin_id=registration.owner_plugin_id,
version=registration.version,
description=registration.description,
builtin=registration.builtin,
)
if registration.owner_plugin_id:
self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type)
self.clear_client_instance_cache_by_client_type(client_type)
def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None:
"""校验插件 Provider 替换是否会造成运行时冲突。
Args:
plugin_id: 目标插件 ID。
client_types: 插件即将注册的客户端类型列表。
Raises:
ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
raise ValueError("plugin_id 不能为空")
normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types]
duplicate_client_types = sorted(
{
client_type
for client_type in normalized_client_types
if normalized_client_types.count(client_type) > 1
}
)
if duplicate_client_types:
raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}")
for client_type in normalized_client_types:
existing = self.client_registry.get(client_type)
if existing is None or existing.owner_plugin_id == normalized_plugin_id:
continue
raise ValueError(
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
)
def replace_plugin_providers(
self,
plugin_id: str,
registrations: List[ClientProviderRegistration],
) -> None:
"""原子替换一个插件拥有的全部 Provider 注册。
Args:
plugin_id: 目标插件 ID。
registrations: 插件当前上报的 Provider 注册列表。
Raises:
ValueError: 当注册信息不合法或存在冲突时抛出。
"""
normalized_plugin_id = str(plugin_id or "").strip()
self.validate_plugin_provider_replacement(
normalized_plugin_id,
[registration.client_type for registration in registrations],
)
self.unregister_plugin_providers(normalized_plugin_id)
for registration in registrations:
self.register_provider(
ClientProviderRegistration(
client_type=registration.client_type,
factory=registration.factory,
owner_plugin_id=normalized_plugin_id,
version=registration.version,
description=registration.description,
builtin=False,
)
)
def unregister_plugin_providers(self, plugin_id: str) -> int:
"""注销一个插件拥有的全部 Provider 注册。
Args:
plugin_id: 目标插件 ID。
Returns:
int: 被注销的客户端类型数量。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return 0
client_types = self._owner_client_types.pop(normalized_plugin_id, set())
removed_count = 0
for client_type in client_types:
registration = self.client_registry.get(client_type)
if registration is None or registration.owner_plugin_id != normalized_plugin_id:
continue
self.client_registry.pop(client_type, None)
self.clear_client_instance_cache_by_client_type(client_type)
removed_count += 1
return removed_count
def clear_client_instance_cache_by_client_type(self, client_type: str) -> None:
"""清理指定客户端类型对应的客户端实例缓存。
Args:
client_type: 需要清理缓存的客户端类型。
"""
normalized_client_type = str(client_type or "").strip()
if not normalized_client_type:
return
stale_provider_names = [
provider_name
for provider_name, client in self.client_instance_cache.items()
if client.api_provider.client_type == normalized_client_type
]
for provider_name in stale_provider_names:
self.client_instance_cache.pop(provider_name, None)
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
"""获取注册的 API 客户端实例。
@@ -249,15 +445,14 @@ class ClientRegistry:
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
if registration := self.client_registry.get(api_provider.client_type):
return registration.factory(api_provider)
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)
if registration := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = registration.factory(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]