feat: add runtime validation for plugin configurations

- Introduced a new method `validate_plugin_config` in `PluginRuntimeManager` to validate plugin configurations at runtime.
- Implemented the `_normalize_plugin_config` method in `PluginRunner` to normalize and persist plugin configurations.
- Enhanced the `PluginRunner` to handle configuration validation requests and return normalized configurations.
- Updated the WebUI routes to utilize runtime validation for plugin configurations, ensuring that configurations are validated and normalized before saving.
- Added tests for runtime configuration validation and normalization processes, including handling of invalid configurations.
This commit is contained in:
DrSmoothl
2026-04-01 19:39:55 +08:00
parent efb84df768
commit 7b3c12ba02
9 changed files with 946 additions and 65 deletions

View File

@@ -458,6 +458,17 @@ class RuntimeComponentCapabilityMixin:
async def _cap_component_get_plugin_info(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
"""获取指定插件的基础信息。
Args:
plugin_id: 当前调用方插件 ID。
capability: 当前能力名称。
args: 能力调用参数。
Returns:
Any: 插件基础信息响应。
"""
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
@@ -473,10 +484,46 @@ class RuntimeComponentCapabilityMixin:
"description": "",
"author": "",
"enabled": True,
"default_config": reg.default_config,
"config_schema": reg.config_schema,
},
}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_get_plugin_config_schema(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
"""获取指定插件注册时上报的配置 Schema。
Args:
plugin_id: 当前调用方插件 ID。
capability: 当前能力名称。
args: 能力调用参数。
Returns:
Any: 包含配置 Schema 与默认配置的响应。
"""
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is None:
return {"success": False, "error": f"未找到插件: {plugin_name}"}
registration = sv._registered_plugins.get(plugin_name)
if registration is None:
return {"success": False, "error": f"未找到插件: {plugin_name}"}
return {
"success": True,
"plugin_id": plugin_name,
"schema": registration.config_schema,
"default_config": registration.default_config,
}
async def _cap_component_list_loaded_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:

View File

@@ -81,6 +81,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
_register("component.get_plugin_config_schema", manager._cap_component_get_plugin_config_schema)
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
_register("component.enable", manager._cap_component_enable)

View File

@@ -858,5 +858,55 @@ class ComponentQueryService:
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
return None
def get_plugin_default_config(self, plugin_name: str) -> Optional[dict]:
"""获取指定插件注册时上报的默认配置。
Args:
plugin_name: 插件名称。
Returns:
Optional[dict]: 默认配置字典;未找到时返回 ``None``。
"""
runtime_manager = self._get_runtime_manager()
try:
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
logger.error(f"读取插件默认配置失败: {exc}")
return None
if supervisor is None:
return None
registration = supervisor._registered_plugins.get(plugin_name)
if registration is None:
return None
return dict(registration.default_config)
def get_plugin_config_schema(self, plugin_name: str) -> Optional[dict]:
"""获取指定插件注册时上报的配置 Schema。
Args:
plugin_name: 插件名称。
Returns:
Optional[dict]: 配置 Schema未找到时返回 ``None``。
"""
runtime_manager = self._get_runtime_manager()
try:
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
logger.error(f"读取插件配置 Schema 失败: {exc}")
return None
if supervisor is None:
return None
registration = supervisor._registered_plugins.get(plugin_name)
if registration is None:
return None
return dict(registration.config_schema)
component_query_service = ComponentQueryService()

View File

@@ -39,6 +39,8 @@ from src.plugin_runtime.protocol.envelope import (
RunnerReadyPayload,
ShutdownPayload,
UnregisterPluginPayload,
ValidatePluginConfigPayload,
ValidatePluginConfigResultPayload,
)
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
@@ -59,6 +61,7 @@ if TYPE_CHECKING:
logger = get_logger("plugin_runtime.host.runner_manager")
@dataclass(slots=True)
class _MessageGatewayRuntimeState:
"""保存消息网关当前的运行时连接状态。"""
@@ -100,9 +103,7 @@ class PluginRunnerSupervisor:
self._group_name: str = str(group_name or "third_party").strip() or "third_party"
self._plugin_dirs: List[Path] = plugin_dirs or []
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
self._runner_spawn_timeout: float = (
runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
)
self._runner_spawn_timeout: float = runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0
self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3
self._transport = create_transport_server(socket_path=socket_path)
@@ -200,10 +201,7 @@ class PluginRunnerSupervisor:
Returns:
Dict[str, str]: 已注册插件版本映射,键为插件 ID值为插件版本。
"""
return {
plugin_id: registration.plugin_version
for plugin_id, registration in self._registered_plugins.items()
}
return {plugin_id: registration.plugin_version for plugin_id, registration in self._registered_plugins.items()}
@staticmethod
def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]:
@@ -550,6 +548,39 @@ class PluginRunnerSupervisor:
return bool(response.payload.get("acknowledged", False))
async def validate_plugin_config(self, plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any]:
"""请求 Runner 使用插件自身配置模型校验配置。
Args:
plugin_id: 目标插件 ID。
config_data: 待校验的配置内容。
Returns:
Dict[str, Any]: 插件模型归一化后的配置字典。
Raises:
ValueError: 插件拒绝该配置或校验失败时抛出。
"""
payload = ValidatePluginConfigPayload(config_data=config_data)
try:
response = await self._rpc_server.send_request(
"plugin.validate_config",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
except Exception as exc:
raise ValueError(f"插件配置校验请求失败: {exc}") from exc
if response.error:
raise ValueError(str(response.error.get("message", "插件配置校验失败")))
result = ValidatePluginConfigResultPayload.model_validate(response.payload)
if not result.success:
raise ValueError("插件配置校验失败")
return dict(result.normalized_config)
def get_config_reload_subscribers(self, scope: str) -> List[str]:
"""返回订阅指定全局配置广播的插件列表。
@@ -608,6 +639,7 @@ class PluginRunnerSupervisor:
Raises:
TimeoutError: 在超时时间内 Runner 未完成初始化。
"""
async def wait_for_ready() -> RunnerReadyPayload:
"""轮询等待 Runner 上报就绪。"""
while True:
@@ -1058,7 +1090,9 @@ class PluginRunnerSupervisor:
route_key = RouteKey(platform=platform)
route_account_id, route_scope = RouteKeyFactory.extract_components(route_metadata)
account_id = route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
account_id = (
route_key.account_id or route_account_id or runtime_state.account_id or gateway_entry.account_id or None
)
scope = route_key.scope or route_scope or runtime_state.scope or gateway_entry.scope or None
return RouteKey(
platform=platform,

View File

@@ -9,7 +9,20 @@
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
import asyncio
@@ -364,9 +377,7 @@ class PluginRuntimeManager(
"""构建当前已注册插件到所属 Supervisor 的映射。"""
return {
plugin_id: supervisor
for supervisor in self.supervisors
for plugin_id in supervisor.get_loaded_plugin_ids()
plugin_id: supervisor for supervisor in self.supervisors for plugin_id in supervisor.get_loaded_plugin_ids()
}
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
@@ -411,9 +422,7 @@ class PluginRuntimeManager(
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
local_dependency_map = {
plugin_id: {
dependency
for dependency in dependency_map.get(plugin_id, set())
if dependency in local_plugin_ids
dependency for dependency in dependency_map.get(plugin_id, set()) if dependency in local_plugin_ids
}
for plugin_id in local_plugin_ids
}
@@ -440,9 +449,7 @@ class PluginRuntimeManager(
"""
normalized_plugin_ids = [
normalized_plugin_id
for plugin_id in plugin_ids
if (normalized_plugin_id := str(plugin_id or "").strip())
normalized_plugin_id for plugin_id in plugin_ids if (normalized_plugin_id := str(plugin_id or "").strip())
]
if not normalized_plugin_ids:
return True
@@ -518,9 +525,7 @@ class PluginRuntimeManager(
return False
config_payload = (
config_data
if config_data is not None
else self._load_plugin_config_for_supervisor(sv, plugin_id)
config_data if config_data is not None else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
return await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
@@ -529,6 +534,41 @@ class PluginRuntimeManager(
config_scope=config_scope,
)
async def validate_plugin_config(self, plugin_id: str, config_data: Dict[str, Any]) -> Dict[str, Any] | None:
"""请求运行时按插件自身配置模型校验配置。
Args:
plugin_id: 目标插件 ID。
config_data: 待校验的配置内容。
Returns:
Dict[str, Any] | None: 校验成功时返回规范化后的配置;若插件当前未加载
或运行时不可用,则返回 ``None`` 以便调用方回退到静态 Schema 方案。
Raises:
ValueError: 插件已加载,但配置校验失败时抛出。
"""
if not self._started:
return None
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
logger.warning(f"插件 {plugin_id} 配置校验路由失败,将回退到静态 Schema: {exc}")
return None
if supervisor is None:
return None
try:
return await supervisor.validate_plugin_config(plugin_id, config_data)
except ValueError:
raise
except Exception as exc:
logger.warning(f"插件 {plugin_id} 运行时配置校验不可用,将回退到静态 Schema: {exc}")
return None
@staticmethod
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
"""规范化配置热重载范围列表。
@@ -869,7 +909,9 @@ class PluginRuntimeManager(
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
return cached_path
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(
getattr(supervisor, "_plugin_dirs", [])
):
if candidate_plugin_id != plugin_id:
continue
self._plugin_path_cache[plugin_id] = plugin_path
@@ -908,9 +950,7 @@ class PluginRuntimeManager(
)
self._plugin_config_watcher_subscriptions[plugin_id] = (config_path, subscription_id)
def _build_plugin_config_change_callback(
self, plugin_id: str
) -> Callable[[Sequence[FileChange]], Awaitable[None]]:
def _build_plugin_config_change_callback(self, plugin_id: str) -> Callable[[Sequence[FileChange]], Awaitable[None]]:
"""为指定插件生成配置文件变更回调。"""
async def _callback(changes: Sequence[FileChange]) -> None:
@@ -1018,7 +1058,10 @@ class PluginRuntimeManager(
return plugin_id
for plugin_id, plugin_path in self._plugin_path_cache.items():
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
if not any(
self._plugin_dir_matches(plugin_path, Path(plugin_dir))
for plugin_dir in getattr(supervisor, "_plugin_dirs", [])
):
continue
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
return plugin_id

View File

@@ -1,7 +1,7 @@
"""RPC Envelope 消息模型
"""RPC Envelope 消息模型
定义 Host 与 Runner 之间所有 RPC 消息的统一信封格式。
使用 Pydantic 进行 schema 定义与校验。
使用 Pydantic 进行 Schema 定义与校验。
"""
from enum import Enum
@@ -39,12 +39,23 @@ class ConfigReloadScope(str, Enum):
# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器"""
"""单调递增 int64 请求 ID 生成器"""
def __init__(self, start: int = 1) -> None:
"""初始化请求 ID 生成器。
Args:
start: 起始请求 ID。
"""
self._counter = start
async def next(self) -> int:
"""返回下一个请求 ID。
Returns:
int: 下一个可用的请求 ID。
"""
current = self._counter
self._counter += 1
return current
@@ -52,7 +63,7 @@ class RequestIdGenerator:
# ====== Envelope 模型 ======
class Envelope(BaseModel):
"""RPC 统一消息封装
"""RPC 统一消息封装
所有 Host <-> Runner 消息均封装为此格式。
序列化流程Envelope -> .model_dump() -> MsgPack encode
@@ -79,18 +90,44 @@ class Envelope(BaseModel):
"""错误信息 (仅 response)"""
def is_request(self) -> bool:
"""判断当前信封是否为请求消息。
Returns:
bool: 当前消息类型是否为 ``REQUEST``。
"""
return self.message_type == MessageType.REQUEST
def is_response(self) -> bool:
"""判断当前信封是否为响应消息。
Returns:
bool: 当前消息类型是否为 ``RESPONSE``。
"""
return self.message_type == MessageType.RESPONSE
def is_broadcast(self) -> bool:
"""判断当前信封是否为广播消息。
Returns:
bool: 当前消息类型是否为 ``BROADCAST``。
"""
return self.message_type == MessageType.BROADCAST
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
) -> "Envelope":
"""基于当前请求创建对应的响应信封"""
"""基于当前请求创建对应的响应信封
Args:
payload: 响应业务载荷。
error: 响应错误信息。
Returns:
Envelope: 对应的响应信封。
"""
return Envelope(
protocol_version=self.protocol_version,
request_id=self.request_id,
@@ -102,7 +139,16 @@ class Envelope(BaseModel):
)
def make_error_response(self, code: str, message: str = "", details: Optional[Dict[str, Any]] = None) -> "Envelope":
"""基于当前请求创建错误响应"""
"""基于当前请求创建错误响应
Args:
code: 错误码。
message: 错误描述。
details: 详细错误信息。
Returns:
Envelope: 错误响应信封。
"""
return self.make_response(
error={
"code": code,
@@ -141,9 +187,7 @@ class ComponentDeclaration(BaseModel):
name: str = Field(description="组件名称")
"""组件名称"""
component_type: str = Field(
description="组件类型action/command/tool/event_handler/hook_handler/message_gateway"
)
component_type: str = Field(description="组件类型action/command/tool/event_handler/hook_handler/message_gateway")
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
"""所属插件 ID"""
@@ -170,6 +214,10 @@ class RegisterPluginPayload(BaseModel):
"""插件级依赖插件 ID 列表"""
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
"""订阅的全局配置热重载范围"""
default_config: Dict[str, Any] = Field(default_factory=dict, description="插件默认配置")
"""插件默认配置"""
config_schema: Dict[str, Any] = Field(default_factory=dict, description="插件配置 Schema")
"""插件配置 Schema"""
class BootstrapPluginPayload(BaseModel):
@@ -256,6 +304,24 @@ class ConfigUpdatedPayload(BaseModel):
"""配置内容"""
class ValidatePluginConfigPayload(BaseModel):
"""plugin.validate_config 请求 payload。"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="待校验的配置内容")
"""待校验的配置内容"""
class ValidatePluginConfigResultPayload(BaseModel):
"""plugin.validate_config 响应 payload。"""
success: bool = Field(description="是否校验成功")
"""是否校验成功"""
normalized_config: Dict[str, Any] = Field(default_factory=dict, description="校验后的规范化配置")
"""校验后的规范化配置"""
changed: bool = Field(default=False, description="是否在校验过程中自动补齐或归一化")
"""是否在校验过程中自动补齐或归一化"""
# ====== 关停 ======
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""

View File

@@ -10,7 +10,7 @@
"""
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Protocol, Set, Tuple, cast
import asyncio
import contextlib
@@ -23,6 +23,8 @@ import sys
import time
import tomllib
import tomlkit
from src.common.logger import get_console_handler, get_logger, initialize_logging
from src.plugin_runtime import (
ENV_EXTERNAL_PLUGIN_IDS,
@@ -46,6 +48,8 @@ from src.plugin_runtime.protocol.envelope import (
ReloadPluginsResultPayload,
RunnerReadyPayload,
UnregisterPluginPayload,
ValidatePluginConfigPayload,
ValidatePluginConfigResultPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
@@ -79,6 +83,64 @@ class _ContextAwarePlugin(Protocol):
"""
class _ConfigAwarePlugin(Protocol):
"""支持声明式插件配置能力的插件协议。"""
def normalize_plugin_config(self, config_data: Optional[Mapping[str, Any]]) -> Tuple[Dict[str, Any], bool]:
"""对插件配置进行归一化与补齐。
Args:
config_data: 原始配置数据。
Returns:
Tuple[Dict[str, Any], bool]: 归一化后的配置,以及是否发生自动变更。
"""
...
def set_plugin_config(self, config: Dict[str, Any]) -> None:
"""注入插件当前配置。
Args:
config: 当前最新插件配置。
"""
...
def get_default_config(self) -> Dict[str, Any]:
"""返回插件默认配置。
Returns:
Dict[str, Any]: 默认配置字典。
"""
...
def get_webui_config_schema(
self,
*,
plugin_id: str = "",
plugin_name: str = "",
plugin_version: str = "",
plugin_description: str = "",
plugin_author: str = "",
) -> Dict[str, Any]:
"""返回插件配置 Schema。
Args:
plugin_id: 插件 ID。
plugin_name: 插件名称。
plugin_version: 插件版本。
plugin_description: 插件描述。
plugin_author: 插件作者。
Returns:
Dict[str, Any]: WebUI 配置 Schema。
"""
...
def _install_shutdown_signal_handlers(
mark_runner_shutting_down: Callable[[], None],
loop: Optional[asyncio.AbstractEventLoop] = None,
@@ -271,14 +333,11 @@ class PluginRunner:
始终绑定为当前插件实例,避免伪造其他插件身份申请能力。
"""
if plugin_id and plugin_id != bound_plugin_id:
logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份"
)
logger.warning(f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份")
normalized_method = str(method or "").strip()
if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
raise PermissionError(
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
f"{normalized_method or '<empty>'}"
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: {normalized_method or '<empty>'}"
)
resp = await rpc_client.send_request(
method=normalized_method,
@@ -294,17 +353,72 @@ class PluginRunner:
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
"""在 Runner 侧为插件实例注入当前插件配置。"""
"""在 Runner 侧为插件实例注入当前插件配置。
Args:
meta: 插件元数据。
config_data: 可选的配置数据;留空时自动从插件目录读取。
"""
instance = meta.instance
if not hasattr(instance, "set_plugin_config"):
return
plugin_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir)
raw_config = config_data if config_data is not None else self._load_plugin_config(meta.plugin_dir)
plugin_config, should_persist = self._normalize_plugin_config(instance, raw_config)
config_path = Path(meta.plugin_dir) / "config.toml"
default_config = self._get_plugin_default_config(instance)
should_initialize_file = not config_path.exists() and bool(default_config)
if should_persist or should_initialize_file:
self._save_plugin_config(meta.plugin_dir, plugin_config)
try:
instance.set_plugin_config(plugin_config)
cast(_ConfigAwarePlugin, instance).set_plugin_config(plugin_config)
except Exception as exc:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
def _normalize_plugin_config(
self,
instance: object,
config_data: Optional[Dict[str, Any]],
*,
suppress_errors: bool = True,
) -> Tuple[Dict[str, Any], bool]:
"""对插件配置做统一归一化处理。
Args:
instance: 插件实例。
config_data: 原始配置数据。
suppress_errors: 是否在归一化失败时吞掉异常并回退原始配置。
Returns:
Tuple[Dict[str, Any], bool]: 归一化后的配置,以及是否需要回写文件。
"""
normalized_config = dict(config_data or {})
if not hasattr(instance, "normalize_plugin_config"):
return normalized_config, False
try:
return cast(_ConfigAwarePlugin, instance).normalize_plugin_config(normalized_config)
except Exception as exc:
if not suppress_errors:
raise
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
return normalized_config, False
@staticmethod
def _save_plugin_config(plugin_dir: str, config_data: Dict[str, Any]) -> None:
"""将插件配置写回到 ``config.toml``。
Args:
plugin_dir: 插件目录。
config_data: 需要写回的配置字典。
"""
config_path = Path(plugin_dir) / "config.toml"
config_path.parent.mkdir(parents=True, exist_ok=True)
with config_path.open("w", encoding="utf-8") as handle:
handle.write(tomlkit.dumps(config_data))
@staticmethod
def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]:
"""从插件目录读取 config.toml。"""
@@ -334,6 +448,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
self._rpc_client.register_method("plugin.validate_config", self._handle_validate_plugin_config)
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins)
@@ -451,6 +566,8 @@ class PluginRunner:
capabilities_required=meta.capabilities_required,
dependencies=meta.dependencies,
config_reload_subscriptions=config_reload_subscriptions,
default_config=self._get_plugin_default_config(instance),
config_schema=self._get_plugin_config_schema(meta),
)
try:
@@ -468,6 +585,53 @@ class PluginRunner:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
return False
@staticmethod
def _get_plugin_default_config(instance: object) -> Dict[str, Any]:
"""获取插件默认配置。
Args:
instance: 插件实例。
Returns:
Dict[str, Any]: 默认配置;插件未声明时返回空字典。
"""
if not hasattr(instance, "get_default_config"):
return {}
try:
default_config = cast(_ConfigAwarePlugin, instance).get_default_config()
except Exception as exc:
logger.warning(f"读取插件默认配置失败: {exc}")
return {}
return default_config if isinstance(default_config, dict) else {}
@staticmethod
def _get_plugin_config_schema(meta: PluginMeta) -> Dict[str, Any]:
"""获取插件 WebUI 配置 Schema。
Args:
meta: 插件元数据。
Returns:
Dict[str, Any]: 插件配置 Schema插件未声明时返回空字典。
"""
instance = meta.instance
if not hasattr(instance, "get_webui_config_schema"):
return {}
try:
schema = cast(_ConfigAwarePlugin, instance).get_webui_config_schema(
plugin_id=meta.plugin_id,
plugin_name=meta.manifest.name,
plugin_version=meta.version,
plugin_description=meta.manifest.description,
plugin_author=meta.manifest.author.name,
)
except Exception as exc:
logger.warning(f"构造插件配置 Schema 失败: {exc}")
return {}
return schema if isinstance(schema, dict) else {}
async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
"""通知 Host 注销指定插件。
@@ -631,7 +795,9 @@ class PluginRunner:
continue
dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
indegree: Dict[str, int] = {
plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()
}
reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
for plugin_id, dependencies in dependency_graph.items():
@@ -677,9 +843,7 @@ class PluginRunner:
for failed_plugin_id, failure_reason in failed_plugins.items():
rollback_failure = rollback_failures.get(failed_plugin_id)
if rollback_failure:
finalized_failures[failed_plugin_id] = (
f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
)
finalized_failures[failed_plugin_id] = f"{failure_reason};且旧版本恢复失败: {rollback_failure}"
else:
finalized_failures[failed_plugin_id] = f"{failure_reason}(已恢复旧版本)"
@@ -761,9 +925,7 @@ class PluginRunner:
failed_plugins=failed_plugins,
)
target_plugin_ids: Set[str] = {
plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids
}
target_plugin_ids: Set[str] = {plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids}
if loaded_root_plugin_ids := reload_root_ids & loaded_plugin_ids:
target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids))
@@ -1127,6 +1289,42 @@ class PluginRunner:
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
async def _handle_validate_plugin_config(self, envelope: Envelope) -> Envelope:
"""处理插件配置校验请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: RPC 响应信封。
"""
try:
payload = ValidatePluginConfigPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(ErrorCode.E_PLUGIN_NOT_FOUND.value, f"未找到插件: {plugin_id}")
try:
normalized_config, changed = self._normalize_plugin_config(
meta.instance,
payload.config_data,
suppress_errors=False,
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
result = ValidatePluginConfigResultPayload(
success=True,
normalized_config=normalized_config,
changed=changed,
)
return envelope.make_response(payload=result.model_dump())
async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
"""处理按插件 ID 的精确重载请求。
@@ -1212,8 +1410,7 @@ async def _async_main() -> None:
session_token,
plugin_dirs,
external_available_plugins={
str(plugin_id): str(plugin_version)
for plugin_id, plugin_version in external_plugin_ids.items()
str(plugin_id): str(plugin_version) for plugin_id, plugin_version in external_plugin_ids.items()
},
)