feat: Enhance plugin loading and management

- Added module_name parameter to PluginMeta for better module tracking.
- Improved documentation for PluginMeta and PluginLoader methods.
- Introduced methods for managing loaded plugins: set_loaded_plugin, remove_loaded_plugin, and purge_plugin_modules.
- Enhanced dependency resolution in PluginLoader with resolve_dependencies method.
- Implemented candidate discovery and loading in PluginLoader.
- Added support for plugin reloading with _reload_plugin_by_id in PluginRunner.
- Improved error handling and logging throughout the RPCClient and PluginRunner.
- Added support for handling hook invocations in PluginRunner.
- Refactored plugin registration and unregistration processes for clarity and efficiency.
This commit is contained in:
DrSmoothl
2026-03-20 22:23:47 +08:00
parent 07256182fb
commit e4850c469f
9 changed files with 1351 additions and 333 deletions

View File

@@ -9,7 +9,7 @@
6. 转发插件的能力调用到 Host
"""
from typing import Any, Callable, List, Optional, Protocol, cast
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast
from pathlib import Path
@@ -32,8 +32,11 @@ from src.plugin_runtime.protocol.envelope import (
HealthPayload,
InvokePayload,
InvokeResultPayload,
RegisterComponentsPayload,
RegisterPluginPayload,
ReloadPluginPayload,
ReloadPluginResultPayload,
RunnerReadyPayload,
UnregisterPluginPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler
@@ -44,7 +47,8 @@ logger = get_logger("plugin_runtime.runner.main")
class _ContextAwarePlugin(Protocol):
def _set_context(self, context: Any) -> None: ...
def _set_context(self, context: Any) -> None:
"""为插件注入上下文对象。"""
def _install_shutdown_signal_handlers(
@@ -90,21 +94,29 @@ class PluginRunner:
session_token: str,
plugin_dirs: List[str],
) -> None:
"""初始化 Runner。
Args:
host_address: Host 的 IPC 地址。
session_token: 握手用会话令牌。
plugin_dirs: 当前 Runner 负责扫描的插件目录列表。
"""
self._host_address: str = host_address
self._session_token: str = session_token
self._plugin_dirs: list[str] = plugin_dirs
self._plugin_dirs: List[str] = plugin_dirs
self._rpc_client: RPCClient = RPCClient(host_address, session_token)
self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, ""))
self._start_time: float = time.monotonic()
self._shutting_down: bool = False
self._reload_lock: asyncio.Lock = asyncio.Lock()
# IPC 日志 Handler握手成功后安装将所有 stdlib logging 转发到 Host
self._log_handler: Optional[RunnerIPCLogHandler] = None
self._suspended_console_handlers: list[stdlib_logging.Handler] = []
self._suspended_console_handlers: List[stdlib_logging.Handler] = []
async def run(self) -> None:
"""Runner 主入口"""
"""运行 Runner 主循环。"""
# 1. 连接 Host
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
ok = await self._rpc_client.connect_and_handshake()
@@ -123,32 +135,11 @@ class PluginRunner:
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 注入 PluginContext + 调用 on_load 生命周期钩子
failed_plugins: set[str] = set()
failed_plugins: Set[str] = set(self._loader.failed_plugins.keys())
for meta in plugins:
instance = meta.instance
self._inject_context(meta.plugin_id, instance)
self._apply_plugin_config(meta)
if not await self._bootstrap_plugin(meta):
failed_plugins.add(meta.plugin_id)
continue
if hasattr(instance, "on_load"):
try:
ret = instance.on_load()
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True)
failed_plugins.add(meta.plugin_id)
await self._deactivate_plugin(meta)
# 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件)
for meta in plugins:
if meta.plugin_id in failed_plugins:
continue
ok = await self._register_plugin(meta)
ok = await self._activate_plugin(meta)
if not ok:
failed_plugins.add(meta.plugin_id)
await self._deactivate_plugin(meta)
successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins]
await self._notify_ready(successful_plugins, sorted(failed_plugins))
@@ -232,7 +223,9 @@ class PluginRunner:
bound_plugin_id = plugin_id
async def _rpc_call(
method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
method: str,
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
) -> Any:
"""桥接 PluginContext.call_capability → RPCClient.send_request。
@@ -257,7 +250,7 @@ class PluginRunner:
cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None:
"""在 Runner 侧为插件实例注入当前插件配置。"""
instance = meta.instance
if not hasattr(instance, "set_plugin_config"):
@@ -270,7 +263,7 @@ class PluginRunner:
logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}")
@staticmethod
def _load_plugin_config(plugin_dir: str) -> dict[str, Any]:
def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]:
"""从插件目录读取 config.toml。"""
config_path = Path(plugin_dir) / "config.toml"
if not config_path.exists():
@@ -286,16 +279,18 @@ class PluginRunner:
return loaded if isinstance(loaded, dict) else {}
def _register_handlers(self) -> None:
"""注册方法处理器"""
"""注册 Host -> Runner 的方法处理器"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin)
async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool:
"""向 Host 同步插件 bootstrap 能力令牌。"""
@@ -324,7 +319,14 @@ class PluginRunner:
await self._bootstrap_plugin(meta, capabilities_required=[])
async def _register_plugin(self, meta: PluginMeta) -> bool:
"""向 Host 注册单个插件"""
"""向 Host 注册单个插件
Args:
meta: 待注册的插件元数据。
Returns:
bool: 是否注册成功。
"""
# 收集插件组件声明
components: List[ComponentDeclaration] = []
instance = meta.instance
@@ -341,7 +343,7 @@ class PluginRunner:
for comp_info in instance.get_components()
)
reg_payload = RegisterComponentsPayload(
reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
@@ -361,8 +363,281 @@ class PluginRunner:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
return False
async def _unregister_plugin(self, plugin_id: str, reason: str) -> None:
"""通知 Host 注销指定插件。
Args:
plugin_id: 目标插件 ID。
reason: 注销原因。
"""
payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason)
try:
await self._rpc_client.send_request(
"plugin.unregister",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=10000,
)
except Exception as exc:
logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}")
async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool:
"""执行插件的 ``on_load`` 生命周期。
Args:
meta: 待初始化的插件元数据。
Returns:
bool: 生命周期是否执行成功。
"""
instance = meta.instance
if not hasattr(instance, "on_load"):
return True
try:
result = instance.on_load()
if asyncio.iscoroutine(result):
await result
return True
except Exception as exc:
logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True)
return False
async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None:
"""执行插件的 ``on_unload`` 生命周期。
Args:
meta: 待卸载的插件元数据。
"""
instance = meta.instance
if not hasattr(instance, "on_unload"):
return
try:
result = instance.on_unload()
if asyncio.iscoroutine(result):
await result
except Exception as exc:
logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True)
async def _activate_plugin(self, meta: PluginMeta) -> bool:
"""完成插件注入、授权、生命周期和组件注册。
Args:
meta: 待激活的插件元数据。
Returns:
bool: 是否激活成功。
"""
self._inject_context(meta.plugin_id, meta.instance)
self._apply_plugin_config(meta)
if not await self._bootstrap_plugin(meta):
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
if not await self._invoke_plugin_on_load(meta):
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
if not await self._register_plugin(meta):
await self._invoke_plugin_on_unload(meta)
await self._deactivate_plugin(meta)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
return False
self._loader.set_loaded_plugin(meta)
return True
async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None:
"""卸载单个插件并清理 Host/Runner 两侧状态。
Args:
meta: 待卸载的插件元数据。
reason: 卸载原因。
"""
await self._invoke_plugin_on_unload(meta)
await self._unregister_plugin(meta.plugin_id, reason)
await self._deactivate_plugin(meta)
self._loader.remove_loaded_plugin(meta.plugin_id)
self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir)
def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]:
"""收集依赖指定插件的所有已加载插件。
Args:
plugin_id: 根插件 ID。
Returns:
Set[str]: 目标插件及其所有反向依赖插件集合。
"""
impacted_plugins: Set[str] = {plugin_id}
changed = True
while changed:
changed = False
for loaded_plugin_id in self._loader.list_plugins():
if loaded_plugin_id in impacted_plugins:
continue
meta = self._loader.get_plugin(loaded_plugin_id)
if meta is None:
continue
if any(dependency in impacted_plugins for dependency in meta.dependencies):
impacted_plugins.add(loaded_plugin_id)
changed = True
return impacted_plugins
def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]:
"""构建受影响插件的卸载顺序。
Args:
plugin_ids: 需要卸载的插件集合。
Returns:
List[str]: 依赖方优先的卸载顺序。
"""
dependency_graph: Dict[str, Set[str]] = {}
for plugin_id in plugin_ids:
meta = self._loader.get_plugin(plugin_id)
if meta is None:
dependency_graph[plugin_id] = set()
continue
dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids}
indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()}
reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph}
for plugin_id, dependencies in dependency_graph.items():
for dependency in dependencies:
reverse_graph.setdefault(dependency, set()).add(plugin_id)
queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0)
load_order: List[str] = []
while queue:
current_plugin_id = queue.pop(0)
load_order.append(current_plugin_id)
for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())):
indegree[dependent_plugin_id] -= 1
if indegree[dependent_plugin_id] == 0:
queue.append(dependent_plugin_id)
queue.sort()
return list(reversed(load_order))
async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload:
"""按插件 ID 在 Runner 进程内执行精确重载。
Args:
plugin_id: 目标插件 ID。
reason: 重载原因。
Returns:
ReloadPluginResultPayload: 结构化重载结果。
"""
candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs)
failed_plugins: Dict[str, str] = {}
if plugin_id in duplicate_candidates:
conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id])
return ReloadPluginResultPayload(
success=False,
requested_plugin_id=plugin_id,
failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"},
)
loaded_plugin_ids = set(self._loader.list_plugins())
plugin_is_loaded = plugin_id in loaded_plugin_ids
plugin_has_candidate = plugin_id in candidates
if not plugin_is_loaded and not plugin_has_candidate:
return ReloadPluginResultPayload(
success=False,
requested_plugin_id=plugin_id,
failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"},
)
target_plugin_ids: Set[str] = {plugin_id}
if plugin_is_loaded:
target_plugin_ids = self._collect_reverse_dependents(plugin_id)
unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids)
unloaded_plugins: List[str] = []
retained_plugin_ids = loaded_plugin_ids - set(unload_order)
for unload_plugin_id in unload_order:
meta = self._loader.get_plugin(unload_plugin_id)
if meta is None:
continue
await self._unload_plugin(meta, reason=reason)
unloaded_plugins.append(unload_plugin_id)
reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {}
for target_plugin_id in target_plugin_ids:
candidate = candidates.get(target_plugin_id)
if candidate is None:
failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态"
continue
reload_candidates[target_plugin_id] = candidate
load_order, dependency_failures = self._loader.resolve_dependencies(
reload_candidates,
extra_available=retained_plugin_ids,
)
failed_plugins.update(dependency_failures)
available_plugins = set(retained_plugin_ids)
reloaded_plugins: List[str] = []
for load_plugin_id in load_order:
if load_plugin_id in failed_plugins:
continue
candidate = reload_candidates.get(load_plugin_id)
if candidate is None:
continue
_, manifest, _ = candidate
dependencies = PluginMeta._extract_dependencies(manifest)
missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins]
if missing_dependencies:
failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}"
continue
meta = self._loader.load_candidate(load_plugin_id, candidate)
if meta is None:
failed_plugins[load_plugin_id] = "插件模块加载失败"
continue
activated = await self._activate_plugin(meta)
if not activated:
failed_plugins[load_plugin_id] = "插件初始化失败"
continue
available_plugins.add(load_plugin_id)
reloaded_plugins.append(load_plugin_id)
requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins
return ReloadPluginResultPayload(
success=requested_plugin_success,
requested_plugin_id=plugin_id,
reloaded_plugins=reloaded_plugins,
unloaded_plugins=unloaded_plugins,
failed_plugins=failed_plugins,
)
async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None:
"""通知 Host 当前 generation 已完成插件初始化。"""
"""通知 Host 当前 Runner 已完成插件初始化。
Args:
loaded_plugins: 成功初始化的插件列表。
failed_plugins: 初始化失败的插件列表。
"""
payload = RunnerReadyPayload(
loaded_plugins=loaded_plugins,
failed_plugins=failed_plugins,
@@ -487,6 +762,61 @@ class PluginRunner:
logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True)
return envelope.make_response(payload={"success": False, "continue_processing": True})
async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope:
"""处理 HookHandler 调用请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: 标准化后的 Hook 调用结果。
"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
instance = meta.instance
component_name = invoke.component_name
handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
except Exception as exc:
logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
return envelope.make_response(payload={"success": False, "continue_processing": True})
if raw is None:
result = {"success": True, "continue_processing": True}
elif isinstance(raw, dict):
result = {
"success": True,
"continue_processing": raw.get("continue_processing", True),
"modified_kwargs": raw.get("modified_kwargs"),
"custom_result": raw.get("custom_result"),
}
else:
result = {"success": True, "continue_processing": True, "custom_result": raw}
return envelope.make_response(payload=result)
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
@@ -557,15 +887,10 @@ class PluginRunner:
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
"""处理关停 — 调用所有插件的 on_unload 后退出"""
logger.info("收到 shutdown 信号,开始调用 on_unload")
for plugin_id in self._loader.list_plugins():
for plugin_id in list(self._loader.list_plugins()):
meta = self._loader.get_plugin(plugin_id)
if meta and hasattr(meta.instance, "on_unload"):
try:
ret = meta.instance.on_unload()
if asyncio.iscoroutine(ret):
await ret
except Exception as e:
logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True)
if meta is not None:
await self._unload_plugin(meta, reason="runner_shutdown")
self._shutting_down = True
return envelope.make_response(payload={"acknowledged": True})
@@ -587,6 +912,30 @@ class PluginRunner:
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope:
"""处理按插件 ID 的精确重载请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: 结构化重载结果。
"""
try:
payload = ReloadPluginPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
if self._reload_lock.locked():
return envelope.make_error_response(
ErrorCode.E_RELOAD_IN_PROGRESS.value,
f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行",
)
async with self._reload_lock:
result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason)
return envelope.make_response(payload=result.model_dump())
def request_capability(self) -> RPCClient:
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
return self._rpc_client
@@ -652,13 +1001,16 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
_ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common")
def find_module(self, fullname, path=None):
def find_module(self, fullname: str, path: Any = None) -> Any:
"""决定是否拦截指定模块导入。"""
return self if self._should_block(fullname) else None
def load_module(self, fullname):
def load_module(self, fullname: str) -> None:
"""阻止被拦截模块继续导入。"""
raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}")
def _should_block(self, fullname: str) -> bool:
"""判断给定模块名是否应被阻止导入。"""
# 放行非 src.* 的导入、以及 "src" 本身
if not fullname.startswith("src.") or fullname == "src":
return False
@@ -692,6 +1044,7 @@ async def _async_main() -> None:
# 注册信号处理
def _mark_runner_shutting_down() -> None:
"""标记 Runner 即将进入关停流程。"""
runner._shutting_down = True
_install_shutdown_signal_handlers(_mark_runner_shutting_down)