feat: 添加批量插件重载功能及相关测试
This commit is contained in:
@@ -33,6 +33,8 @@ from src.plugin_runtime.protocol.envelope import (
|
||||
ReceiveExternalMessageResultPayload,
|
||||
RegisterPluginPayload,
|
||||
ReloadPluginResultPayload,
|
||||
ReloadPluginsPayload,
|
||||
ReloadPluginsResultPayload,
|
||||
RouteMessagePayload,
|
||||
RunnerReadyPayload,
|
||||
ShutdownPayload,
|
||||
@@ -194,6 +196,35 @@ class PluginRunnerSupervisor:
|
||||
for plugin_id, registration in self._registered_plugins.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]:
|
||||
"""规范化批量重载入参。
|
||||
|
||||
Args:
|
||||
plugin_ids: 原始插件 ID 列表或单个插件 ID。
|
||||
|
||||
Returns:
|
||||
List[str]: 去重且去空白后的插件 ID 列表。
|
||||
"""
|
||||
|
||||
raw_plugin_ids: List[str]
|
||||
if plugin_ids is None:
|
||||
raw_plugin_ids = []
|
||||
elif isinstance(plugin_ids, str):
|
||||
raw_plugin_ids = [plugin_ids]
|
||||
else:
|
||||
raw_plugin_ids = list(plugin_ids)
|
||||
|
||||
normalized_plugin_ids: List[str] = []
|
||||
seen_plugin_ids: set[str] = set()
|
||||
for plugin_id in raw_plugin_ids:
|
||||
normalized_plugin_id = str(plugin_id or "").strip()
|
||||
if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids:
|
||||
continue
|
||||
seen_plugin_ids.add(normalized_plugin_id)
|
||||
normalized_plugin_ids.append(normalized_plugin_id)
|
||||
return normalized_plugin_ids
|
||||
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
@@ -420,7 +451,7 @@ class PluginRunnerSupervisor:
|
||||
|
||||
async def reload_plugins(
|
||||
self,
|
||||
plugin_ids: Optional[List[str]] = None,
|
||||
plugin_ids: Optional[List[str] | str] = None,
|
||||
reason: str = "manual",
|
||||
external_available_plugins: Optional[Dict[str, str]] = None,
|
||||
) -> bool:
|
||||
@@ -434,19 +465,37 @@ class PluginRunnerSupervisor:
|
||||
Returns:
|
||||
bool: 是否全部重载成功。
|
||||
"""
|
||||
target_plugin_ids = plugin_ids or list(self._registered_plugins.keys())
|
||||
ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids))
|
||||
success = True
|
||||
ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids)
|
||||
if not ordered_plugin_ids:
|
||||
ordered_plugin_ids = list(self._registered_plugins.keys())
|
||||
if not ordered_plugin_ids:
|
||||
return True
|
||||
|
||||
for plugin_id in ordered_plugin_ids:
|
||||
reloaded = await self.reload_plugin(
|
||||
plugin_id=plugin_id,
|
||||
if len(ordered_plugin_ids) == 1:
|
||||
return await self.reload_plugin(
|
||||
plugin_id=ordered_plugin_ids[0],
|
||||
reason=reason,
|
||||
external_available_plugins=external_available_plugins,
|
||||
)
|
||||
success = success and reloaded
|
||||
|
||||
return success
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
"plugin.reload_batch",
|
||||
payload=ReloadPluginsPayload(
|
||||
plugin_ids=ordered_plugin_ids,
|
||||
reason=reason,
|
||||
external_available_plugins=external_available_plugins or self._external_available_plugins,
|
||||
).model_dump(),
|
||||
timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"插件批量重载请求失败: {exc}")
|
||||
return False
|
||||
|
||||
result = ReloadPluginsResultPayload.model_validate(response.payload)
|
||||
if not result.success:
|
||||
logger.warning(f"插件批量重载失败: {result.failed_plugins}")
|
||||
return result.success
|
||||
|
||||
async def notify_plugin_config_updated(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user