feat: Add LLM Provider support in plugin runtime

- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities.
- Implemented validation for LLM Provider declarations to prevent duplicates and conflicts.
- Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly.
- Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins.
- Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime.
- Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
DrSmoothl
2026-04-27 16:49:44 +08:00
parent 1fe9dc8786
commit 742e21a727
11 changed files with 903 additions and 13 deletions

View File

@@ -10,6 +10,8 @@ import sys
from src.common.logger import get_logger
from src.config.config import config_manager, global_config
from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry
from src.llm_models.model_client.plugin_client import PluginLLMClient
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
@@ -30,6 +32,7 @@ from src.plugin_runtime.protocol.envelope import (
HealthPayload,
InspectPluginConfigPayload,
InspectPluginConfigResultPayload,
LLMProviderInvokePayload,
MessageGatewayStateUpdatePayload,
MessageGatewayStateUpdateResultPayload,
PROTOCOL_VERSION,
@@ -417,6 +420,38 @@ class PluginRunnerSupervisor:
timeout_ms=timeout_ms,
)
async def invoke_llm_provider(
self,
plugin_id: str,
client_type: str,
operation: str,
request: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
"""调用插件声明的 LLM Provider。
Args:
plugin_id: 目标插件 ID。
client_type: 目标客户端类型。
operation: 请求操作类型。
request: 已序列化的 LLM 请求。
timeout_ms: RPC 超时时间,单位毫秒。
Returns:
Envelope: Runner 返回的响应信封。
"""
payload = LLMProviderInvokePayload(
client_type=client_type,
operation=operation,
request=request or {},
)
return await self._rpc_server.send_request(
"plugin.invoke_llm_provider",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=timeout_ms,
)
async def invoke_api(
self,
plugin_id: str,
@@ -779,6 +814,22 @@ class PluginRunnerSupervisor:
component_declarations = [component.model_dump() for component in payload.components]
runtime_components, api_components = self._split_component_declarations(component_declarations)
try:
client_registry.validate_plugin_provider_replacement(
payload.plugin_id,
[provider.client_type for provider in payload.llm_providers],
)
except Exception as exc:
logger.error(f"插件 {payload.plugin_id} LLM Provider 注册校验失败: {exc}")
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
str(exc),
details={
"plugin_id": payload.plugin_id,
"llm_provider_count": len(payload.llm_providers),
},
)
try:
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
@@ -798,6 +849,24 @@ class PluginRunnerSupervisor:
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
client_registry.replace_plugin_providers(
payload.plugin_id,
[
ClientProviderRegistration(
client_type=provider.client_type,
factory=lambda api_provider, provider_client_type=provider.client_type: PluginLLMClient(
api_provider=api_provider,
supervisor=self,
plugin_id=payload.plugin_id,
client_type=provider_client_type,
),
owner_plugin_id=payload.plugin_id,
version=provider.version,
description=provider.description or provider.name,
)
for provider in payload.llm_providers
],
)
self._registered_plugins[payload.plugin_id] = payload
self._message_gateway_states[payload.plugin_id] = {}
@@ -810,6 +879,7 @@ class PluginRunnerSupervisor:
"message_gateways": len(
self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
),
"llm_providers": len(payload.llm_providers),
}
)
@@ -829,6 +899,7 @@ class PluginRunnerSupervisor:
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
removed_llm_providers = client_registry.unregister_plugin_providers(payload.plugin_id)
self._authorization.revoke_permission_token(payload.plugin_id)
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
@@ -841,6 +912,7 @@ class PluginRunnerSupervisor:
"reason": payload.reason,
"removed_components": removed_components,
"removed_apis": removed_apis,
"removed_llm_providers": removed_llm_providers,
"removed_registration": removed_registration,
}
)
@@ -1505,6 +1577,8 @@ class PluginRunnerSupervisor:
def _clear_runner_state(self) -> None:
"""清理当前 Runner 对应的 Host 侧注册状态。"""
for plugin_id in list(self._registered_plugins):
client_registry.unregister_plugin_providers(plugin_id)
self._authorization.clear()
self._api_registry.clear()
self._component_registry.clear()

View File

@@ -148,6 +148,35 @@ class PluginRuntimeManager(
validator = ManifestValidator(validate_python_package_dependencies=False)
return validator.build_plugin_dependency_map(plugin_dirs)
@classmethod
def _discover_llm_provider_conflicts(cls, plugin_dirs: Iterable[Path]) -> Dict[str, str]:
"""扫描插件 Manifest发现 LLM Provider client_type 冲突。
Args:
plugin_dirs: 需要扫描的插件根目录集合。
Returns:
Dict[str, str]: 需要阻止加载的插件 ID 与原因映射。
"""
validator = ManifestValidator(validate_python_package_dependencies=False)
provider_owners: Dict[str, List[str]] = {}
for _plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs, require_entrypoint=True):
for client_type in manifest.llm_provider_client_types:
provider_owners.setdefault(client_type, []).append(manifest.id)
blocked_reasons: Dict[str, str] = {}
for client_type, plugin_ids in provider_owners.items():
unique_plugin_ids = sorted(set(plugin_ids))
if len(unique_plugin_ids) <= 1:
continue
reason = (
f"LLM Provider client_type 冲突: {client_type} 被以下插件重复声明: "
f"{', '.join(unique_plugin_ids)}"
)
for plugin_id in unique_plugin_ids:
blocked_reasons[plugin_id] = reason
return blocked_reasons
@classmethod
def _build_group_start_order(
cls,
@@ -271,7 +300,11 @@ class PluginRuntimeManager(
"""
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons)
blocked_plugin_reasons = {
**result.blocked_plugin_reasons,
**self._discover_llm_provider_conflicts(plugin_dirs),
}
changed_plugin_ids = self._set_blocked_plugin_reasons(blocked_plugin_reasons)
return DependencySyncState(
blocked_changed_plugin_ids=changed_plugin_ids,
environment_changed=result.environment_changed,

View File

@@ -199,6 +199,21 @@ class ComponentDeclaration(BaseModel):
"""组件元数据"""
class LLMProviderDeclaration(BaseModel):
"""单个 LLM Provider 声明。"""
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
"""客户端类型标识。"""
name: str = Field(default="", description="Provider 展示名称")
"""Provider 展示名称。"""
description: str = Field(default="", description="Provider 描述")
"""Provider 描述。"""
version: str = Field(default="1.0.0", description="Provider 实现版本")
"""Provider 实现版本。"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider 元数据")
"""Provider 元数据。"""
class RegisterPluginPayload(BaseModel):
"""插件组件注册请求载荷。
@@ -212,6 +227,8 @@ class RegisterPluginPayload(BaseModel):
"""插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
"""组件列表"""
llm_providers: List[LLMProviderDeclaration] = Field(default_factory=list, description="LLM Provider 声明列表")
"""LLM Provider 声明列表。"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
@@ -254,6 +271,17 @@ class InvokeResultPayload(BaseModel):
"""返回值"""
class LLMProviderInvokePayload(BaseModel):
"""plugin.invoke_llm_provider 请求 payload。"""
client_type: str = Field(description="目标 LLM Provider 客户端类型")
"""目标 LLM Provider 客户端类型。"""
operation: str = Field(description="请求操作类型")
"""请求操作类型,如 response、embedding、audio_transcription。"""
request: Dict[str, Any] = Field(default_factory=dict, description="已序列化的 LLM 请求")
"""已序列化的 LLM 请求。"""
# ====== 能力调用消息 ======
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)"""

View File

@@ -7,7 +7,7 @@
from functools import lru_cache
from importlib import metadata as importlib_metadata
from pathlib import Path
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
import json
import re
@@ -453,6 +453,38 @@ class PythonPackageDependencyDefinition(_StrictManifestModel):
return value
class LLMProviderManifestDeclaration(_StrictManifestModel):
"""插件 Manifest 中声明的 LLM Provider。"""
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
"""客户端类型标识。"""
name: str = Field(default="", description="Provider 展示名称")
"""Provider 展示名称。"""
description: str = Field(default="", description="Provider 描述")
"""Provider 描述。"""
version: str = Field(default="1.0.0", description="Provider 实现版本")
"""Provider 实现版本。"""
@field_validator("client_type")
@classmethod
def _validate_client_type(cls, value: str) -> str:
"""校验客户端类型标识。
Args:
value: 原始客户端类型标识。
Returns:
str: 合法的客户端类型标识。
Raises:
ValueError: 当客户端类型为空时抛出。
"""
normalized_value = str(value or "").strip()
if not normalized_value:
raise ValueError("client_type 不能为空")
return normalized_value
ManifestDependencyDefinition = Annotated[
Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
Field(discriminator="type"),
@@ -472,6 +504,10 @@ class PluginManifest(_StrictManifestModel):
host_application: ManifestVersionRange = Field(description="Host 兼容区间")
sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
llm_providers: List[LLMProviderManifestDeclaration] = Field(
default_factory=list,
description="插件静态声明的 LLM Provider 列表",
)
capabilities: List[str] = Field(description="插件声明的能力请求")
i18n: ManifestI18n = Field(description="国际化配置")
id: str = Field(description="稳定插件 ID")
@@ -567,6 +603,23 @@ class PluginManifest(_StrictManifestModel):
return self
@model_validator(mode="after")
def _validate_llm_providers(self) -> "PluginManifest":
"""校验 LLM Provider 静态声明集合。
Returns:
PluginManifest: 当前对象本身。
Raises:
ValueError: 当同一 Manifest 内重复声明 client_type 时抛出。
"""
client_types: Set[str] = set()
for provider in self.llm_providers:
if provider.client_type in client_types:
raise ValueError(f"存在重复的 LLM Provider 声明: {provider.client_type}")
client_types.add(provider.client_type)
return self
@property
def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
"""返回插件级依赖列表。
@@ -598,6 +651,15 @@ class PluginManifest(_StrictManifestModel):
"""
return [dependency.id for dependency in self.plugin_dependencies]
@property
def llm_provider_client_types(self) -> List[str]:
"""返回 Manifest 静态声明的 LLM Provider client_type 列表。
Returns:
List[str]: 当前插件声明的 LLM Provider client_type。
"""
return [provider.client_type for provider in self.llm_providers]
class ManifestValidator:
"""严格的插件 Manifest v2 校验器。"""

View File

@@ -54,6 +54,7 @@ class PluginMeta:
self.capabilities_required = list(manifest.capabilities)
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
self.component_handlers: Dict[str, str] = {}
self.llm_provider_handlers: Dict[str, str] = {}
class PluginLoader:

View File

@@ -48,6 +48,8 @@ from src.plugin_runtime.protocol.envelope import (
InspectPluginConfigResultPayload,
InvokePayload,
InvokeResultPayload,
LLMProviderDeclaration,
LLMProviderInvokePayload,
RegisterPluginPayload,
ReloadPluginPayload,
ReloadPluginResultPayload,
@@ -891,6 +893,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_llm_provider", self._handle_llm_provider_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.health", self._handle_health)
@@ -980,6 +983,7 @@ class PluginRunner:
"""
# 收集插件组件声明
components: List[ComponentDeclaration] = []
llm_providers: List[LLMProviderDeclaration] = []
config_reload_subscriptions: List[str] = []
instance = meta.instance
@@ -1016,11 +1020,43 @@ class PluginRunner:
)
if hasattr(instance, "get_config_reload_subscriptions"):
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
if hasattr(instance, "get_llm_providers"):
meta.llm_provider_handlers.clear()
for provider_info in instance.get_llm_providers():
if not isinstance(provider_info, dict):
continue
client_type = str(provider_info.get("client_type", "") or "").strip()
raw_metadata = provider_info.get("metadata", {})
provider_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
if client_type:
handler_name = str(provider_metadata.get("handler_name", client_type) or client_type).strip()
meta.llm_provider_handlers[client_type] = handler_name or client_type
llm_providers.append(
LLMProviderDeclaration(
client_type=client_type,
name=str(provider_info.get("name", "") or "").strip(),
description=str(provider_info.get("description", "") or "").strip(),
version=str(provider_info.get("version", "1.0.0") or "1.0.0").strip() or "1.0.0",
metadata=provider_metadata,
)
)
declared_client_types = sorted(meta.manifest.llm_provider_client_types)
registered_client_types = sorted(provider.client_type for provider in llm_providers)
if declared_client_types != registered_client_types:
logger.error(
f"插件 {meta.plugin_id} LLM Provider 声明不一致: "
f"manifest={declared_client_types}, code={registered_client_types}"
)
return False
reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
llm_providers=llm_providers,
capabilities_required=meta.capabilities_required,
dependencies=meta.dependencies,
config_reload_subscriptions=config_reload_subscriptions,
@@ -1629,6 +1665,50 @@ class PluginRunner:
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_llm_provider_invoke(self, envelope: Envelope) -> Envelope:
"""处理 LLM Provider 调用请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: 标准化后的 Provider 调用结果。
"""
try:
invoke = LLMProviderInvokePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
handler_name = meta.llm_provider_handlers.get(invoke.client_type, "")
handler_method = getattr(meta.instance, handler_name, None) if handler_name else None
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 未注册 LLM Provider: {invoke.client_type}",
)
try:
result = handler_method(operation=invoke.operation, request=invoke.request)
if inspect.isawaitable(result):
result = await result
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as exc:
logger.error(
f"插件 {plugin_id} LLM Provider {invoke.client_type} 执行异常: {exc}",
exc_info=True,
)
resp_payload = InvokeResultPayload(success=False, result=str(exc))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_event_invoke(self, envelope: Envelope) -> Envelope:
"""处理 EventHandler 调用请求