feat: Add LLM Provider support in plugin runtime
- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities. - Implemented validation for LLM Provider declarations to prevent duplicates and conflicts. - Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly. - Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins. - Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime. - Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
@@ -10,6 +10,8 @@ import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry
|
||||
from src.llm_models.model_client.plugin_client import PluginLLMClient
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||
from src.platform_io.drivers import PluginPlatformDriver
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
@@ -30,6 +32,7 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
HealthPayload,
|
||||
InspectPluginConfigPayload,
|
||||
InspectPluginConfigResultPayload,
|
||||
LLMProviderInvokePayload,
|
||||
MessageGatewayStateUpdatePayload,
|
||||
MessageGatewayStateUpdateResultPayload,
|
||||
PROTOCOL_VERSION,
|
||||
@@ -417,6 +420,38 @@ class PluginRunnerSupervisor:
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
async def invoke_llm_provider(
|
||||
self,
|
||||
plugin_id: str,
|
||||
client_type: str,
|
||||
operation: str,
|
||||
request: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""调用插件声明的 LLM Provider。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
client_type: 目标客户端类型。
|
||||
operation: 请求操作类型。
|
||||
request: 已序列化的 LLM 请求。
|
||||
timeout_ms: RPC 超时时间,单位毫秒。
|
||||
|
||||
Returns:
|
||||
Envelope: Runner 返回的响应信封。
|
||||
"""
|
||||
payload = LLMProviderInvokePayload(
|
||||
client_type=client_type,
|
||||
operation=operation,
|
||||
request=request or {},
|
||||
)
|
||||
return await self._rpc_server.send_request(
|
||||
"plugin.invoke_llm_provider",
|
||||
plugin_id=plugin_id,
|
||||
payload=payload.model_dump(),
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
|
||||
async def invoke_api(
|
||||
self,
|
||||
plugin_id: str,
|
||||
@@ -779,6 +814,22 @@ class PluginRunnerSupervisor:
|
||||
|
||||
component_declarations = [component.model_dump() for component in payload.components]
|
||||
runtime_components, api_components = self._split_component_declarations(component_declarations)
|
||||
try:
|
||||
client_registry.validate_plugin_provider_replacement(
|
||||
payload.plugin_id,
|
||||
[provider.client_type for provider in payload.llm_providers],
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件 {payload.plugin_id} LLM Provider 注册校验失败: {exc}")
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
str(exc),
|
||||
details={
|
||||
"plugin_id": payload.plugin_id,
|
||||
"llm_provider_count": len(payload.llm_providers),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
registered_count = self._component_registry.register_plugin_components(
|
||||
payload.plugin_id,
|
||||
@@ -798,6 +849,24 @@ class PluginRunnerSupervisor:
|
||||
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
|
||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||
client_registry.replace_plugin_providers(
|
||||
payload.plugin_id,
|
||||
[
|
||||
ClientProviderRegistration(
|
||||
client_type=provider.client_type,
|
||||
factory=lambda api_provider, provider_client_type=provider.client_type: PluginLLMClient(
|
||||
api_provider=api_provider,
|
||||
supervisor=self,
|
||||
plugin_id=payload.plugin_id,
|
||||
client_type=provider_client_type,
|
||||
),
|
||||
owner_plugin_id=payload.plugin_id,
|
||||
version=provider.version,
|
||||
description=provider.description or provider.name,
|
||||
)
|
||||
for provider in payload.llm_providers
|
||||
],
|
||||
)
|
||||
self._registered_plugins[payload.plugin_id] = payload
|
||||
self._message_gateway_states[payload.plugin_id] = {}
|
||||
|
||||
@@ -810,6 +879,7 @@ class PluginRunnerSupervisor:
|
||||
"message_gateways": len(
|
||||
self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
|
||||
),
|
||||
"llm_providers": len(payload.llm_providers),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -829,6 +899,7 @@ class PluginRunnerSupervisor:
|
||||
|
||||
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||
removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||
removed_llm_providers = client_registry.unregister_plugin_providers(payload.plugin_id)
|
||||
self._authorization.revoke_permission_token(payload.plugin_id)
|
||||
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
|
||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||
@@ -841,6 +912,7 @@ class PluginRunnerSupervisor:
|
||||
"reason": payload.reason,
|
||||
"removed_components": removed_components,
|
||||
"removed_apis": removed_apis,
|
||||
"removed_llm_providers": removed_llm_providers,
|
||||
"removed_registration": removed_registration,
|
||||
}
|
||||
)
|
||||
@@ -1505,6 +1577,8 @@ class PluginRunnerSupervisor:
|
||||
|
||||
def _clear_runner_state(self) -> None:
|
||||
"""清理当前 Runner 对应的 Host 侧注册状态。"""
|
||||
for plugin_id in list(self._registered_plugins):
|
||||
client_registry.unregister_plugin_providers(plugin_id)
|
||||
self._authorization.clear()
|
||||
self._api_registry.clear()
|
||||
self._component_registry.clear()
|
||||
|
||||
@@ -148,6 +148,35 @@ class PluginRuntimeManager(
|
||||
validator = ManifestValidator(validate_python_package_dependencies=False)
|
||||
return validator.build_plugin_dependency_map(plugin_dirs)
|
||||
|
||||
@classmethod
|
||||
def _discover_llm_provider_conflicts(cls, plugin_dirs: Iterable[Path]) -> Dict[str, str]:
|
||||
"""扫描插件 Manifest,发现 LLM Provider client_type 冲突。
|
||||
|
||||
Args:
|
||||
plugin_dirs: 需要扫描的插件根目录集合。
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 需要阻止加载的插件 ID 与原因映射。
|
||||
"""
|
||||
validator = ManifestValidator(validate_python_package_dependencies=False)
|
||||
provider_owners: Dict[str, List[str]] = {}
|
||||
for _plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs, require_entrypoint=True):
|
||||
for client_type in manifest.llm_provider_client_types:
|
||||
provider_owners.setdefault(client_type, []).append(manifest.id)
|
||||
|
||||
blocked_reasons: Dict[str, str] = {}
|
||||
for client_type, plugin_ids in provider_owners.items():
|
||||
unique_plugin_ids = sorted(set(plugin_ids))
|
||||
if len(unique_plugin_ids) <= 1:
|
||||
continue
|
||||
reason = (
|
||||
f"LLM Provider client_type 冲突: {client_type} 被以下插件重复声明: "
|
||||
f"{', '.join(unique_plugin_ids)}"
|
||||
)
|
||||
for plugin_id in unique_plugin_ids:
|
||||
blocked_reasons[plugin_id] = reason
|
||||
return blocked_reasons
|
||||
|
||||
@classmethod
|
||||
def _build_group_start_order(
|
||||
cls,
|
||||
@@ -271,7 +300,11 @@ class PluginRuntimeManager(
|
||||
"""
|
||||
|
||||
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
|
||||
changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons)
|
||||
blocked_plugin_reasons = {
|
||||
**result.blocked_plugin_reasons,
|
||||
**self._discover_llm_provider_conflicts(plugin_dirs),
|
||||
}
|
||||
changed_plugin_ids = self._set_blocked_plugin_reasons(blocked_plugin_reasons)
|
||||
return DependencySyncState(
|
||||
blocked_changed_plugin_ids=changed_plugin_ids,
|
||||
environment_changed=result.environment_changed,
|
||||
|
||||
@@ -199,6 +199,21 @@ class ComponentDeclaration(BaseModel):
|
||||
"""组件元数据"""
|
||||
|
||||
|
||||
class LLMProviderDeclaration(BaseModel):
|
||||
"""单个 LLM Provider 声明。"""
|
||||
|
||||
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
|
||||
"""客户端类型标识。"""
|
||||
name: str = Field(default="", description="Provider 展示名称")
|
||||
"""Provider 展示名称。"""
|
||||
description: str = Field(default="", description="Provider 描述")
|
||||
"""Provider 描述。"""
|
||||
version: str = Field(default="1.0.0", description="Provider 实现版本")
|
||||
"""Provider 实现版本。"""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider 元数据")
|
||||
"""Provider 元数据。"""
|
||||
|
||||
|
||||
class RegisterPluginPayload(BaseModel):
|
||||
"""插件组件注册请求载荷。
|
||||
|
||||
@@ -212,6 +227,8 @@ class RegisterPluginPayload(BaseModel):
|
||||
"""插件版本"""
|
||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
"""组件列表"""
|
||||
llm_providers: List[LLMProviderDeclaration] = Field(default_factory=list, description="LLM Provider 声明列表")
|
||||
"""LLM Provider 声明列表。"""
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
"""所需能力列表"""
|
||||
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
|
||||
@@ -254,6 +271,17 @@ class InvokeResultPayload(BaseModel):
|
||||
"""返回值"""
|
||||
|
||||
|
||||
class LLMProviderInvokePayload(BaseModel):
|
||||
"""plugin.invoke_llm_provider 请求 payload。"""
|
||||
|
||||
client_type: str = Field(description="目标 LLM Provider 客户端类型")
|
||||
"""目标 LLM Provider 客户端类型。"""
|
||||
operation: str = Field(description="请求操作类型")
|
||||
"""请求操作类型,如 response、embedding、audio_transcription。"""
|
||||
request: Dict[str, Any] = Field(default_factory=dict, description="已序列化的 LLM 请求")
|
||||
"""已序列化的 LLM 请求。"""
|
||||
|
||||
|
||||
# ====== 能力调用消息 ======
|
||||
class CapabilityRequestPayload(BaseModel):
|
||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
from functools import lru_cache
|
||||
from importlib import metadata as importlib_metadata
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
import json
|
||||
import re
|
||||
@@ -453,6 +453,38 @@ class PythonPackageDependencyDefinition(_StrictManifestModel):
|
||||
return value
|
||||
|
||||
|
||||
class LLMProviderManifestDeclaration(_StrictManifestModel):
|
||||
"""插件 Manifest 中声明的 LLM Provider。"""
|
||||
|
||||
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
|
||||
"""客户端类型标识。"""
|
||||
name: str = Field(default="", description="Provider 展示名称")
|
||||
"""Provider 展示名称。"""
|
||||
description: str = Field(default="", description="Provider 描述")
|
||||
"""Provider 描述。"""
|
||||
version: str = Field(default="1.0.0", description="Provider 实现版本")
|
||||
"""Provider 实现版本。"""
|
||||
|
||||
@field_validator("client_type")
|
||||
@classmethod
|
||||
def _validate_client_type(cls, value: str) -> str:
|
||||
"""校验客户端类型标识。
|
||||
|
||||
Args:
|
||||
value: 原始客户端类型标识。
|
||||
|
||||
Returns:
|
||||
str: 合法的客户端类型标识。
|
||||
|
||||
Raises:
|
||||
ValueError: 当客户端类型为空时抛出。
|
||||
"""
|
||||
normalized_value = str(value or "").strip()
|
||||
if not normalized_value:
|
||||
raise ValueError("client_type 不能为空")
|
||||
return normalized_value
|
||||
|
||||
|
||||
ManifestDependencyDefinition = Annotated[
|
||||
Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
|
||||
Field(discriminator="type"),
|
||||
@@ -472,6 +504,10 @@ class PluginManifest(_StrictManifestModel):
|
||||
host_application: ManifestVersionRange = Field(description="Host 兼容区间")
|
||||
sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
|
||||
dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
|
||||
llm_providers: List[LLMProviderManifestDeclaration] = Field(
|
||||
default_factory=list,
|
||||
description="插件静态声明的 LLM Provider 列表",
|
||||
)
|
||||
capabilities: List[str] = Field(description="插件声明的能力请求")
|
||||
i18n: ManifestI18n = Field(description="国际化配置")
|
||||
id: str = Field(description="稳定插件 ID")
|
||||
@@ -567,6 +603,23 @@ class PluginManifest(_StrictManifestModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_llm_providers(self) -> "PluginManifest":
|
||||
"""校验 LLM Provider 静态声明集合。
|
||||
|
||||
Returns:
|
||||
PluginManifest: 当前对象本身。
|
||||
|
||||
Raises:
|
||||
ValueError: 当同一 Manifest 内重复声明 client_type 时抛出。
|
||||
"""
|
||||
client_types: Set[str] = set()
|
||||
for provider in self.llm_providers:
|
||||
if provider.client_type in client_types:
|
||||
raise ValueError(f"存在重复的 LLM Provider 声明: {provider.client_type}")
|
||||
client_types.add(provider.client_type)
|
||||
return self
|
||||
|
||||
@property
|
||||
def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
|
||||
"""返回插件级依赖列表。
|
||||
@@ -598,6 +651,15 @@ class PluginManifest(_StrictManifestModel):
|
||||
"""
|
||||
return [dependency.id for dependency in self.plugin_dependencies]
|
||||
|
||||
@property
|
||||
def llm_provider_client_types(self) -> List[str]:
|
||||
"""返回 Manifest 静态声明的 LLM Provider client_type 列表。
|
||||
|
||||
Returns:
|
||||
List[str]: 当前插件声明的 LLM Provider client_type。
|
||||
"""
|
||||
return [provider.client_type for provider in self.llm_providers]
|
||||
|
||||
|
||||
class ManifestValidator:
|
||||
"""严格的插件 Manifest v2 校验器。"""
|
||||
|
||||
@@ -54,6 +54,7 @@ class PluginMeta:
|
||||
self.capabilities_required = list(manifest.capabilities)
|
||||
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
|
||||
self.component_handlers: Dict[str, str] = {}
|
||||
self.llm_provider_handlers: Dict[str, str] = {}
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
|
||||
@@ -48,6 +48,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
InspectPluginConfigResultPayload,
|
||||
InvokePayload,
|
||||
InvokeResultPayload,
|
||||
LLMProviderDeclaration,
|
||||
LLMProviderInvokePayload,
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
@@ -891,6 +893,7 @@ class PluginRunner:
|
||||
self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_llm_provider", self._handle_llm_provider_invoke)
|
||||
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
|
||||
self._rpc_client.register_method("plugin.health", self._handle_health)
|
||||
@@ -980,6 +983,7 @@ class PluginRunner:
|
||||
"""
|
||||
# 收集插件组件声明
|
||||
components: List[ComponentDeclaration] = []
|
||||
llm_providers: List[LLMProviderDeclaration] = []
|
||||
config_reload_subscriptions: List[str] = []
|
||||
instance = meta.instance
|
||||
|
||||
@@ -1016,11 +1020,43 @@ class PluginRunner:
|
||||
)
|
||||
if hasattr(instance, "get_config_reload_subscriptions"):
|
||||
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
|
||||
if hasattr(instance, "get_llm_providers"):
|
||||
meta.llm_provider_handlers.clear()
|
||||
for provider_info in instance.get_llm_providers():
|
||||
if not isinstance(provider_info, dict):
|
||||
continue
|
||||
|
||||
client_type = str(provider_info.get("client_type", "") or "").strip()
|
||||
raw_metadata = provider_info.get("metadata", {})
|
||||
provider_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
|
||||
if client_type:
|
||||
handler_name = str(provider_metadata.get("handler_name", client_type) or client_type).strip()
|
||||
meta.llm_provider_handlers[client_type] = handler_name or client_type
|
||||
|
||||
llm_providers.append(
|
||||
LLMProviderDeclaration(
|
||||
client_type=client_type,
|
||||
name=str(provider_info.get("name", "") or "").strip(),
|
||||
description=str(provider_info.get("description", "") or "").strip(),
|
||||
version=str(provider_info.get("version", "1.0.0") or "1.0.0").strip() or "1.0.0",
|
||||
metadata=provider_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
declared_client_types = sorted(meta.manifest.llm_provider_client_types)
|
||||
registered_client_types = sorted(provider.client_type for provider in llm_providers)
|
||||
if declared_client_types != registered_client_types:
|
||||
logger.error(
|
||||
f"插件 {meta.plugin_id} LLM Provider 声明不一致: "
|
||||
f"manifest={declared_client_types}, code={registered_client_types}"
|
||||
)
|
||||
return False
|
||||
|
||||
reg_payload = RegisterPluginPayload(
|
||||
plugin_id=meta.plugin_id,
|
||||
plugin_version=meta.version,
|
||||
components=components,
|
||||
llm_providers=llm_providers,
|
||||
capabilities_required=meta.capabilities_required,
|
||||
dependencies=meta.dependencies,
|
||||
config_reload_subscriptions=config_reload_subscriptions,
|
||||
@@ -1629,6 +1665,50 @@ class PluginRunner:
|
||||
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
|
||||
async def _handle_llm_provider_invoke(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 LLM Provider 调用请求。
|
||||
|
||||
Args:
|
||||
envelope: RPC 请求信封。
|
||||
|
||||
Returns:
|
||||
Envelope: 标准化后的 Provider 调用结果。
|
||||
"""
|
||||
try:
|
||||
invoke = LLMProviderInvokePayload.model_validate(envelope.payload)
|
||||
except Exception as exc:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
f"插件 {plugin_id} 未加载",
|
||||
)
|
||||
|
||||
handler_name = meta.llm_provider_handlers.get(invoke.client_type, "")
|
||||
handler_method = getattr(meta.instance, handler_name, None) if handler_name else None
|
||||
if handler_method is None or not callable(handler_method):
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {plugin_id} 未注册 LLM Provider: {invoke.client_type}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = handler_method(operation=invoke.operation, request=invoke.request)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"插件 {plugin_id} LLM Provider {invoke.client_type} 执行异常: {exc}",
|
||||
exc_info=True,
|
||||
)
|
||||
resp_payload = InvokeResultPayload(success=False, result=str(exc))
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
|
||||
async def _handle_event_invoke(self, envelope: Envelope) -> Envelope:
|
||||
"""处理 EventHandler 调用请求
|
||||
|
||||
|
||||
Reference in New Issue
Block a user