diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index aa7ceb46..4223525f 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -174,7 +174,10 @@ class RuntimeComponentCapabilityMixin: if registered_supervisor is not None: try: - reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}") + reloaded = await registered_supervisor.reload_plugins( + plugin_ids=[plugin_name], + reason=f"load {plugin_name}", + ) if reloaded: return {"success": True, "count": 1} return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} @@ -186,7 +189,10 @@ class RuntimeComponentCapabilityMixin: for pdir in sv._plugin_dirs: if (pdir / plugin_name).is_dir(): try: - reloaded = await sv.reload_plugins(reason=f"load {plugin_name}") + reloaded = await sv.reload_plugins( + plugin_ids=[plugin_name], + reason=f"load {plugin_name}", + ) if reloaded: return {"success": True, "count": 1} return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} @@ -222,7 +228,10 @@ class RuntimeComponentCapabilityMixin: if sv is not None: try: - reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}") + reloaded = await sv.reload_plugins( + plugin_ids=[plugin_name], + reason=f"reload {plugin_name}", + ) if reloaded: return {"success": True} return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"} diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index abce97dc..ead5876a 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from src.common.logger import get_logger from src.plugin_runtime.host.supervisor import PluginSupervisor @@ -12,67 +12,78 @@ logger = get_logger("plugin_runtime.integration") def register_capability_impls(manager: "PluginRuntimeManager", supervisor: PluginSupervisor) -> None: """向指定 Supervisor 注册主程序提供的能力实现。""" cap_service = supervisor.capability_service + rpc_server = supervisor.rpc_server - cap_service.register_capability("send.text", manager._cap_send_text) - cap_service.register_capability("send.emoji", manager._cap_send_emoji) - cap_service.register_capability("send.image", manager._cap_send_image) - cap_service.register_capability("send.command", manager._cap_send_command) - cap_service.register_capability("send.custom", manager._cap_send_custom) + def _register(name: str, impl: Any) -> None: + """注册单个能力实现及其 RPC 入口。 - cap_service.register_capability("llm.generate", manager._cap_llm_generate) - cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools) - cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models) + Args: + name: 能力名称。 + impl: 能力实现函数。 + """ + cap_service.register_capability(name, impl) + rpc_server.register_method(name, cap_service.handle_capability_request) - cap_service.register_capability("config.get", manager._cap_config_get) - cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin) - cap_service.register_capability("config.get_all", manager._cap_config_get_all) + _register("send.text", manager._cap_send_text) + _register("send.emoji", manager._cap_send_emoji) + _register("send.image", manager._cap_send_image) + _register("send.command", manager._cap_send_command) + _register("send.custom", manager._cap_send_custom) - cap_service.register_capability("database.query", manager._cap_database_query) - cap_service.register_capability("database.save", manager._cap_database_save) - cap_service.register_capability("database.get", manager._cap_database_get) - cap_service.register_capability("database.delete", manager._cap_database_delete) - cap_service.register_capability("database.count", manager._cap_database_count) + _register("llm.generate", manager._cap_llm_generate) + _register("llm.generate_with_tools", manager._cap_llm_generate_with_tools) + _register("llm.get_available_models", manager._cap_llm_get_available_models) - cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams) - cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams) - cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams) - cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) - cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) + _register("config.get", manager._cap_config_get) + _register("config.get_plugin", manager._cap_config_get_plugin) + _register("config.get_all", manager._cap_config_get_all) - cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time) - cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) - cap_service.register_capability("message.get_recent", manager._cap_message_get_recent) - cap_service.register_capability("message.count_new", manager._cap_message_count_new) - cap_service.register_capability("message.build_readable", manager._cap_message_build_readable) + _register("database.query", manager._cap_database_query) + _register("database.save", manager._cap_database_save) + _register("database.get", manager._cap_database_get) + _register("database.delete", manager._cap_database_delete) + _register("database.count", manager._cap_database_count) - cap_service.register_capability("person.get_id", manager._cap_person_get_id) - cap_service.register_capability("person.get_value", manager._cap_person_get_value) - cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name) + _register("chat.get_all_streams", manager._cap_chat_get_all_streams) + _register("chat.get_group_streams", manager._cap_chat_get_group_streams) + _register("chat.get_private_streams", manager._cap_chat_get_private_streams) + _register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id) + _register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id) - cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description) - cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random) - cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count) - cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions) - cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all) - cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info) - cap_service.register_capability("emoji.register", manager._cap_emoji_register) - cap_service.register_capability("emoji.delete", manager._cap_emoji_delete) + _register("message.get_by_time", manager._cap_message_get_by_time) + _register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat) + _register("message.get_recent", manager._cap_message_get_recent) + _register("message.count_new", manager._cap_message_count_new) + _register("message.build_readable", manager._cap_message_build_readable) - cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) - cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust) - cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust) + _register("person.get_id", manager._cap_person_get_id) + _register("person.get_value", manager._cap_person_get_value) + _register("person.get_id_by_name", manager._cap_person_get_id_by_name) - cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions) + _register("emoji.get_by_description", manager._cap_emoji_get_by_description) + _register("emoji.get_random", manager._cap_emoji_get_random) + _register("emoji.get_count", manager._cap_emoji_get_count) + _register("emoji.get_emotions", manager._cap_emoji_get_emotions) + _register("emoji.get_all", manager._cap_emoji_get_all) + _register("emoji.get_info", manager._cap_emoji_get_info) + _register("emoji.register", manager._cap_emoji_register) + _register("emoji.delete", manager._cap_emoji_delete) - cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins) - cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info) - cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) - cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins) - cap_service.register_capability("component.enable", manager._cap_component_enable) - cap_service.register_capability("component.disable", manager._cap_component_disable) - cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin) - cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin) - cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin) + _register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value) + _register("frequency.set_adjust", manager._cap_frequency_set_adjust) + _register("frequency.get_adjust", manager._cap_frequency_get_adjust) - cap_service.register_capability("knowledge.search", manager._cap_knowledge_search) + _register("tool.get_definitions", manager._cap_tool_get_definitions) + + _register("component.get_all_plugins", manager._cap_component_get_all_plugins) + _register("component.get_plugin_info", manager._cap_component_get_plugin_info) + _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) + _register("component.list_registered_plugins", manager._cap_component_list_registered_plugins) + _register("component.enable", manager._cap_component_enable) + _register("component.disable", manager._cap_component_disable) + _register("component.load_plugin", manager._cap_component_load_plugin) + _register("component.unload_plugin", manager._cap_component_unload_plugin) + _register("component.reload_plugin", manager._cap_component_reload_plugin) + + _register("knowledge.search", manager._cap_knowledge_search) logger.debug("已注册全部主程序能力实现") diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 98366a07..761b20ca 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -30,6 +30,11 @@ class CapabilityService: """ def __init__(self, authorization: "AuthorizationManager") -> None: + """初始化能力服务。 + + Args: + authorization: 能力授权管理器。 + """ self._authorization = authorization # capability_name -> implementation self._implementations: Dict[str, CapabilityImpl] = {} @@ -51,13 +56,19 @@ class CapabilityService: 校验权限后调用对应实现。 """ plugin_id = envelope.plugin_id + payload = envelope.payload if isinstance(envelope.payload, dict) else {} try: - req = CapabilityRequestPayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}") + req = CapabilityRequestPayload.model_validate(payload) + capability = req.capability + args = req.args + except Exception: + capability = envelope.method + raw_args = payload.get("args", payload) + args = raw_args if isinstance(raw_args, dict) else {} - capability = req.capability + if not capability: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, "能力调用缺少 capability") # 1. 权限校验 allowed, reason = self._authorization.check_capability(plugin_id, capability) @@ -71,7 +82,7 @@ class CapabilityService: # 3. 执行 try: - result = await impl(plugin_id, capability, req.args) + result = await impl(plugin_id, capability, args) resp_payload = CapabilityResponsePayload(success=True, result=result) return envelope.make_response(payload=resp_payload.model_dump()) except RPCError as e: diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 82a5970b..5ae3bdee 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -1,32 +1,38 @@ from pathlib import Path -from typing import Optional, List, Dict, Any, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import asyncio - +import contextlib +import os +import sys from src.common.logger import get_logger from src.config.config import global_config -from src.plugin_runtime.transport.factory import create_transport_server +from src.plugin_runtime import ENV_HOST_VERSION, ENV_IPC_ADDRESS, ENV_PLUGIN_DIRS, ENV_SESSION_TOKEN from src.plugin_runtime.protocol.envelope import ( BootstrapPluginPayload, ConfigUpdatedPayload, Envelope, HealthPayload, - LogBatchPayload, + PROTOCOL_VERSION, RegisterPluginPayload, + ReloadPluginResultPayload, RunnerReadyPayload, ShutdownPayload, + UnregisterPluginPayload, ) +from src.plugin_runtime.protocol.codec import MsgPackCodec +from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +from src.plugin_runtime.transport.factory import create_transport_server from .authorization import AuthorizationManager from .capability_service import CapabilityService -from .rpc_server import RPCServer -from .logger_bridge import RunnerLogBridge from .component_registry import ComponentRegistry from .event_dispatcher import EventDispatcher from .hook_dispatcher import HookDispatcher +from .logger_bridge import RunnerLogBridge from .message_gateway import MessageGateway -from .message_utils import PluginMessageUtils +from .rpc_server import RPCServer if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -35,7 +41,11 @@ logger = get_logger("plugin_runtime.host.runner_manager") class PluginRunnerSupervisor: - """插件的Runner管理器,负责管理Runner的生命周期""" + """插件 Runner 监督器。 + + 负责 Host 侧与单个 Runner 子进程之间的生命周期、内部 RPC、 + 健康检查和插件级重载协调。 + """ def __init__( self, @@ -44,13 +54,24 @@ class PluginRunnerSupervisor: health_check_interval_sec: Optional[float] = None, max_restart_attempts: Optional[int] = None, runner_spawn_timeout_sec: Optional[float] = None, - ): - _cfg = global_config.plugin_runtime - self._plugin_dirs: List[Path] = plugin_dirs or [] - self._health_interval = health_check_interval_sec or _cfg.health_check_interval_sec or 30.0 - self._runner_spawn_timeout = runner_spawn_timeout_sec or _cfg.runner_spawn_timeout_sec or 30.0 + ) -> None: + """初始化 Supervisor。 + + Args: + plugin_dirs: 由当前 Runner 负责加载的插件目录列表。 + socket_path: 自定义 IPC 地址;留空时由传输层自动生成。 + health_check_interval_sec: 健康检查间隔,单位秒。 + max_restart_attempts: 自动重启 Runner 的最大次数。 + runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。 + """ + runtime_config = global_config.plugin_runtime + self._plugin_dirs: List[Path] = plugin_dirs or [] + self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0 + self._runner_spawn_timeout: float = ( + runner_spawn_timeout_sec or runtime_config.runner_spawn_timeout_sec or 30.0 + ) + self._max_restart_attempts: int = max_restart_attempts or runtime_config.max_restart_attempts or 3 - # 基础设施 self._transport = create_transport_server(socket_path=socket_path) self._authorization = AuthorizationManager() self._capability_service = CapabilityService(self._authorization) @@ -58,61 +79,55 @@ class PluginRunnerSupervisor: self._event_dispatcher = EventDispatcher(self._component_registry) self._hook_dispatcher = HookDispatcher(self._component_registry) self._message_gateway = MessageGateway(self._component_registry) - - # 编解码和服务器 - from src.plugin_runtime.protocol.codec import MsgPackCodec + self._log_bridge = RunnerLogBridge() codec = MsgPackCodec() self._rpc_server = RPCServer(transport=self._transport, codec=codec) - # Runner 子进程 self._runner_process: Optional[asyncio.subprocess.Process] = None - self._max_restart_attempts: int = max_restart_attempts or _cfg.max_restart_attempts or 3 - self._restart_count: int = 0 - - # 已注册的插件组件信息 self._registered_plugins: Dict[str, RegisterPluginPayload] = {} self._runner_ready_events: asyncio.Event = asyncio.Event() self._runner_ready_payloads: RunnerReadyPayload = RunnerReadyPayload() + self._health_task: Optional[asyncio.Task[None]] = None + self._stderr_drain_task: Optional[asyncio.Task[None]] = None + self._restart_count: int = 0 + self._running: bool = False - # 后台任务 - self._health_task: Optional[asyncio.Task] = None - # Runner stderr 流排空任务(仅保留 stderr,用于 IPC 建立前的启动日志倒空、致命错误输出等场景) - self._stderr_drain_task: Optional[asyncio.Task] = None - self._running = False - - # Runner 日志桥(将 Runner 上报的批量日志重放到主进程 Logger) - self._log_bridge: RunnerLogBridge = RunnerLogBridge() - - # 注册内部 RPC 方法 - self._register_internal_methods() # TODO: 完成内部方法注册 + self._register_internal_methods() @property def authorization_manager(self) -> AuthorizationManager: + """返回授权管理器。""" return self._authorization @property def capability_service(self) -> CapabilityService: + """返回能力服务。""" return self._capability_service @property def component_registry(self) -> ComponentRegistry: + """返回组件注册表。""" return self._component_registry @property def event_dispatcher(self) -> EventDispatcher: + """返回事件分发器。""" return self._event_dispatcher @property def hook_dispatcher(self) -> HookDispatcher: + """返回 Hook 分发器。""" return self._hook_dispatcher @property def message_gateway(self) -> MessageGateway: + """返回消息网关。""" return self._message_gateway @property def rpc_server(self) -> RPCServer: + """返回底层 RPC 服务端。""" return self._rpc_server async def dispatch_event( @@ -121,11 +136,28 @@ class PluginRunnerSupervisor: message: Optional["SessionMessage"] = None, extra_args: Optional[Dict[str, Any]] = None, ) -> Tuple[bool, Optional["SessionMessage"]]: - """分发事件到所有对应 handler 的快捷方法。""" + """分发事件到已注册的事件处理器。 + + Args: + event_type: 事件类型。 + message: 可选的消息对象。 + extra_args: 附加参数。 + + Returns: + Tuple[bool, Optional[SessionMessage]]: 是否继续处理,以及插件可能修改后的消息。 + """ return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) - async def dispatch_hook(self, stage: str, **kwargs): - """分发Hook事件到所有对应 handler 的快捷方法。""" + async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]: + """分发 Hook 到已注册的 Hook 处理器。 + + Args: + stage: Hook 阶段名称。 + **kwargs: 传递给 Hook 的关键字参数。 + + Returns: + Dict[str, Any]: 经 Hook 修改后的参数字典。 + """ return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs) async def send_message_to_external( @@ -135,60 +167,68 @@ class PluginRunnerSupervisor: enabled_only: bool = True, save_to_db: bool = True, ) -> bool: - """发送系统内部消息到外部平台的快捷方法。""" + """通过插件消息网关发送外部消息。 + + Args: + internal_message: 系统内部消息对象。 + enabled_only: 是否仅使用启用的网关组件。 + save_to_db: 发送成功后是否写入数据库。 + + Returns: + bool: 是否发送成功。 + """ return await self._message_gateway.send_message_to_external( - internal_message, self, enabled_only=enabled_only, save_to_db=save_to_db + internal_message, + self, + enabled_only=enabled_only, + save_to_db=save_to_db, ) async def start(self) -> None: - """启动 Supervisor + """启动 Supervisor。""" + if self._running: + logger.warning("PluginRunnerSupervisor 已在运行,跳过重复启动") + return - 1. 启动 RPC Server - 2. 拉起 Runner 子进程 - 3. 启动健康检查 - """ self._running = True + self._restart_count = 0 + self._clear_runner_state() - # 启动 RPC Server await self._rpc_server.start() - # 拉起 Runner 进程 await self._spawn_runner() - # 等待 Runner 完成连接和初始化,避免 start() 返回时 Runner 尚未就绪 try: await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) except TimeoutError: if not self._rpc_server.is_connected: - logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成连接,后续操作可能失败") + logger.warning("Runner 未在限定时间内完成连接,后续操作可能失败") else: - logger.warning(f"Runner 未在 {self._runner_spawn_timeout}s 内完成初始化,后续操作可能失败") + logger.warning("Runner 未在限定时间内完成初始化,后续操作可能失败") - # 启动健康检查 - self._health_task = asyncio.create_task(self._health_check_loop()) - - logger.info("PluginSupervisor 已启动") + self._health_task = asyncio.create_task(self._health_check_loop(), name="PluginRunnerSupervisor.health") + logger.info("PluginRunnerSupervisor 已启动") async def stop(self) -> None: - """停止 Supervisor""" + """停止 Supervisor。""" + if not self._running: + return + self._running = False - # 停止组件 - await self._event_dispatcher.stop() - await self._hook_dispatcher.stop() - - # 停止健康检查 - if self._health_task: + if self._health_task is not None: self._health_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._health_task self._health_task = None - # 优雅关停 Runner - await self._shutdown_runner() - - # 停止 RPC Server + await self._event_dispatcher.stop() + await self._hook_dispatcher.stop() + await self._shutdown_runner(reason="host_stop") await self._rpc_server.stop() + self._clear_runner_state() - logger.info("PluginSupervisor 已停止") + logger.info("PluginRunnerSupervisor 已停止") async def invoke_plugin( self, @@ -198,9 +238,17 @@ class PluginRunnerSupervisor: args: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """调用插件组件 + """调用 Runner 内的插件组件。 - 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + component_name: 组件名。 + args: 调用参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: RPC 响应信封。 """ return await self._rpc_server.send_request( method, @@ -210,27 +258,421 @@ class PluginRunnerSupervisor: ) async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: - raise NotImplementedError("等待SDK完成") # TODO: 完成对应的调用和请求逻辑 + """按插件 ID 触发精确重载。 + + Args: + plugin_id: 目标插件 ID。 + reason: 重载原因。 + + Returns: + bool: 是否重载成功。 + """ + try: + response = await self._rpc_server.send_request( + "plugin.reload", + plugin_id=plugin_id, + payload={"plugin_id": plugin_id, "reason": reason}, + timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), + ) + except Exception as exc: + logger.error(f"插件 {plugin_id} 重载请求失败: {exc}") + return False + + result = ReloadPluginResultPayload.model_validate(response.payload) + if not result.success: + logger.warning(f"插件 {plugin_id} 重载失败: {result.failed_plugins}") + return result.success + + async def reload_plugins( + self, + plugin_ids: Optional[List[str]] = None, + reason: str = "manual", + ) -> bool: + """批量重载插件。 + + Args: + plugin_ids: 目标插件 ID 列表;为空时重载当前已注册的全部插件。 + reason: 重载原因。 + + Returns: + bool: 是否全部重载成功。 + """ + target_plugin_ids = plugin_ids or list(self._registered_plugins.keys()) + ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids)) + success = True + + for plugin_id in ordered_plugin_ids: + reloaded = await self.reload_plugin(plugin_id=plugin_id, reason=reason) + success = success and reloaded + + return success + + async def notify_plugin_config_updated( + self, + plugin_id: str, + config_data: Optional[Dict[str, Any]] = None, + config_version: str = "", + ) -> bool: + """向 Runner 推送插件配置更新。 + + Args: + plugin_id: 目标插件 ID。 + config_data: 配置内容。 + config_version: 配置版本号。 + + Returns: + bool: 请求是否成功送达并被 Runner 接受。 + """ + payload = ConfigUpdatedPayload( + plugin_id=plugin_id, + config_version=config_version, + config_data=config_data or {}, + ) + try: + response = await self._rpc_server.send_request( + "plugin.config_updated", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=10000, + ) + except Exception as exc: + logger.warning(f"插件 {plugin_id} 配置更新通知失败: {exc}") + return False + + return bool(response.payload.get("acknowledged", False)) async def _wait_for_runner_connection(self, timeout_sec: float) -> None: - """等待 Runner 连接上 RPC Server""" + """等待 Runner 建立 RPC 连接。 - async def wait_for_connection(): + Args: + timeout_sec: 超时时间,单位秒。 + + Raises: + TimeoutError: 在超时时间内 Runner 未完成连接。 + """ + + async def wait_for_connection() -> None: + """轮询等待 RPC 连接建立。""" while self._running and not self._rpc_server.is_connected: await asyncio.sleep(0.1) try: await asyncio.wait_for(wait_for_connection(), timeout=timeout_sec) logger.info("Runner 已连接到 RPC Server") - except asyncio.TimeoutError as e: - raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from e + except asyncio.TimeoutError as exc: + raise TimeoutError(f"等待 Runner 连接超时({timeout_sec}s)") from exc async def _wait_for_runner_ready(self, timeout_sec: float = 30.0) -> RunnerReadyPayload: - """等待 Runner 完成初始化并上报就绪""" + """等待 Runner 完成启动初始化。 + Args: + timeout_sec: 超时时间,单位秒。 + + Returns: + RunnerReadyPayload: Runner 上报的就绪信息。 + + Raises: + TimeoutError: 在超时时间内 Runner 未完成初始化。 + """ try: await asyncio.wait_for(self._runner_ready_events.wait(), timeout=timeout_sec) logger.info("Runner 已完成初始化并上报就绪") return self._runner_ready_payloads - except asyncio.TimeoutError as e: - raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from e + except asyncio.TimeoutError as exc: + raise TimeoutError(f"等待 Runner 就绪超时({timeout_sec}s)") from exc + + def _register_internal_methods(self) -> None: + """注册 Host 侧内部 RPC 方法。""" + self._rpc_server.register_method("cap.call", self._capability_service.handle_capability_request) + self._rpc_server.register_method("plugin.bootstrap", self._handle_bootstrap_plugin) + self._rpc_server.register_method("plugin.register_components", self._handle_register_plugin) + self._rpc_server.register_method("plugin.register_plugin", self._handle_register_plugin) + self._rpc_server.register_method("plugin.unregister", self._handle_unregister_plugin) + self._rpc_server.register_method("runner.log_batch", self._log_bridge.handle_log_batch) + self._rpc_server.register_method("runner.ready", self._handle_runner_ready) + + async def _handle_bootstrap_plugin(self, envelope: Envelope) -> Envelope: + """处理插件 bootstrap 请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = BootstrapPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if payload.capabilities_required: + self._authorization.register_plugin(payload.plugin_id, payload.capabilities_required) + else: + self._authorization.revoke_permission_token(payload.plugin_id) + + return envelope.make_response(payload={"accepted": True, "plugin_id": payload.plugin_id}) + + async def _handle_register_plugin(self, envelope: Envelope) -> Envelope: + """处理插件组件注册请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = RegisterPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + self._component_registry.remove_components_by_plugin(payload.plugin_id) + registered_count = self._component_registry.register_plugin_components( + payload.plugin_id, + [component.model_dump() for component in payload.components], + ) + self._registered_plugins[payload.plugin_id] = payload + + return envelope.make_response( + payload={ + "accepted": True, + "plugin_id": payload.plugin_id, + "registered_components": registered_count, + } + ) + + async def _handle_unregister_plugin(self, envelope: Envelope) -> Envelope: + """处理插件注销请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = UnregisterPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) + self._authorization.revoke_permission_token(payload.plugin_id) + removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None + + return envelope.make_response( + payload={ + "accepted": True, + "plugin_id": payload.plugin_id, + "reason": payload.reason, + "removed_components": removed_components, + "removed_registration": removed_registration, + } + ) + + async def _handle_runner_ready(self, envelope: Envelope) -> Envelope: + """处理 Runner 就绪通知。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: RPC 响应信封。 + """ + try: + payload = RunnerReadyPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + self._runner_ready_payloads = payload + self._runner_ready_events.set() + return envelope.make_response(payload={"accepted": True}) + + def _build_runner_environment(self) -> Dict[str, str]: + """构建拉起 Runner 所需的环境变量。 + + Returns: + Dict[str, str]: 传递给 Runner 进程的环境变量映射。 + """ + return { + ENV_HOST_VERSION: PROTOCOL_VERSION, + ENV_IPC_ADDRESS: self._transport.get_address(), + ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs), + ENV_SESSION_TOKEN: self._rpc_server.session_token, + } + + async def _spawn_runner(self) -> None: + """拉起 Runner 子进程。""" + if self._runner_process is not None and self._runner_process.returncode is None: + logger.warning("Runner 已在运行,跳过重复拉起") + return + + self._clear_runner_state() + + env = os.environ.copy() + env.update(self._build_runner_environment()) + + self._runner_process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "src.plugin_runtime.runner.runner_main", + env=env, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + + if self._runner_process.stderr is not None: + self._stderr_drain_task = asyncio.create_task( + self._drain_runner_stderr(self._runner_process.stderr), + name="PluginRunnerSupervisor.stderr", + ) + + logger.info(f"Runner 已拉起,pid={self._runner_process.pid}") + + async def _drain_runner_stderr(self, stream: asyncio.StreamReader) -> None: + """持续排空 Runner 的 stderr。 + + Args: + stream: Runner 的 stderr 流。 + """ + try: + while True: + line = await stream.readline() + if not line: + return + message = line.decode("utf-8", errors="replace").rstrip() + if message: + logger.warning(f"[runner-stderr] {message}") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"排空 Runner stderr 失败: {exc}") + + async def _shutdown_runner(self, reason: str = "normal") -> None: + """优雅关闭 Runner 子进程。 + + Args: + reason: 关停原因。 + """ + process = self._runner_process + if process is None: + return + + payload = ShutdownPayload(reason=reason) + + if process.returncode is None and self._rpc_server.is_connected: + with contextlib.suppress(Exception): + await self._rpc_server.send_request( + "plugin.prepare_shutdown", + payload=payload.model_dump(), + timeout_ms=payload.drain_timeout_ms, + ) + with contextlib.suppress(Exception): + await self._rpc_server.send_request( + "plugin.shutdown", + payload=payload.model_dump(), + timeout_ms=payload.drain_timeout_ms, + ) + + if process.returncode is None: + try: + await asyncio.wait_for(process.wait(), timeout=max(payload.drain_timeout_ms / 1000.0, 1.0)) + except asyncio.TimeoutError: + logger.warning("Runner 优雅退出超时,尝试 terminate") + process.terminate() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning("Runner terminate 超时,尝试 kill") + process.kill() + with contextlib.suppress(Exception): + await asyncio.wait_for(process.wait(), timeout=5.0) + + self._runner_process = None + + if self._stderr_drain_task is not None: + self._stderr_drain_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._stderr_drain_task + self._stderr_drain_task = None + + self._clear_runner_state() + + async def _health_check_loop(self) -> None: + """周期性检查 Runner 健康状态,并在必要时重启。""" + timeout_ms = max(int(self._health_interval * 1000), 1000) + + while self._running: + try: + await asyncio.sleep(self._health_interval) + except asyncio.CancelledError: + return + + if not self._running: + return + + process = self._runner_process + if process is None or process.returncode is not None: + reason = "runner_process_exited" if process is not None else "runner_process_missing" + restarted = await self._restart_runner(reason=reason) + if not restarted: + return + continue + + try: + response = await self._rpc_server.send_request("plugin.health", timeout_ms=timeout_ms) + health = HealthPayload.model_validate(response.payload) + if not health.healthy: + restarted = await self._restart_runner(reason="health_check_unhealthy") + if not restarted: + return + except asyncio.CancelledError: + return + except (RPCError, Exception) as exc: + logger.warning(f"Runner 健康检查失败: {exc}") + restarted = await self._restart_runner(reason="health_check_failed") + if not restarted: + return + + async def _restart_runner(self, reason: str) -> bool: + """在 Runner 异常时执行整进程级重启。 + + Args: + reason: 触发重启的原因。 + + Returns: + bool: 是否重启成功。 + """ + if not self._running: + return False + + if self._restart_count >= self._max_restart_attempts: + logger.error(f"Runner 自动重启次数已达上限,停止重启。reason={reason}") + return False + + self._restart_count += 1 + logger.warning(f"准备重启 Runner,第 {self._restart_count} 次,reason={reason}") + + await self._shutdown_runner(reason=reason) + + try: + await self._spawn_runner() + await self._wait_for_runner_connection(timeout_sec=self._runner_spawn_timeout) + await self._wait_for_runner_ready(timeout_sec=self._runner_spawn_timeout) + except Exception as exc: + logger.error(f"Runner 重启失败: {exc}", exc_info=True) + return False + + self._restart_count = 0 + logger.info("Runner 已成功重启") + return True + + def _clear_runner_state(self) -> None: + """清理当前 Runner 对应的 Host 侧注册状态。""" + self._authorization.clear() + self._component_registry.clear() + self._registered_plugins.clear() + self._runner_ready_events = asyncio.Event() + self._runner_ready_payloads = RunnerReadyPayload() + + +PluginSupervisor = PluginRunnerSupervisor diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 04c8e324..730da3e1 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -23,8 +23,10 @@ from src.plugin_runtime.capabilities import ( RuntimeDataCapabilityMixin, ) from src.plugin_runtime.capabilities.registry import register_capability_impls +from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage from src.plugin_runtime.host.supervisor import PluginSupervisor logger = get_logger("plugin_runtime.integration") @@ -223,9 +225,9 @@ class PluginRuntimeManager( async def bridge_event( self, event_type_value: str, - message_dict: Optional[Dict[str, Any]] = None, + message_dict: Optional[MessageDict] = None, extra_args: Optional[Dict[str, Any]] = None, - ) -> Tuple[bool, Optional[Dict[str, Any]]]: + ) -> Tuple[bool, Optional[MessageDict]]: """将事件分发到所有 Supervisor Returns: @@ -235,17 +237,23 @@ class PluginRuntimeManager( return True, None new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value) - modified: Optional[Dict[str, Any]] = None + modified: Optional[MessageDict] = None + current_message: Optional["SessionMessage"] = ( + PluginMessageUtils._build_session_message_from_dict(dict(message_dict)) + if message_dict is not None + else None + ) for sv in self.supervisors: try: cont, mod = await sv.dispatch_event( event_type=new_event_type, - message=modified or message_dict, + message=current_message, extra_args=extra_args, ) if mod is not None: - modified = mod + current_message = mod + modified = PluginMessageUtils._session_message_to_dict(mod) if not cont: return False, modified except Exception as e: @@ -477,7 +485,7 @@ class PluginRuntimeManager( logger.error(f"检测到重复插件 ID,跳过本次插件热重载: {details}") return - reload_supervisors: List[Any] = [] + reload_supervisors: Dict[Any, List[str]] = {} changed_paths = [change.path.resolve() for change in changes] for supervisor in self.supervisors: @@ -485,11 +493,13 @@ class PluginRuntimeManager( plugin_id = self._match_plugin_id_for_supervisor(supervisor, path) if plugin_id is None: continue - if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors: - reload_supervisors.append(supervisor) + if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py": + reload_supervisors.setdefault(supervisor, []) + if plugin_id not in reload_supervisors[supervisor]: + reload_supervisors[supervisor].append(plugin_id) - for supervisor in reload_supervisors: - await supervisor.reload_plugins(reason="file_watcher") + for supervisor, plugin_ids in reload_supervisors.items(): + await supervisor.reload_plugins(plugin_ids=plugin_ids, reason="file_watcher") if reload_supervisors: self._refresh_plugin_config_watch_subscriptions() diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index e81df019..6f95f97f 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -144,7 +144,11 @@ class ComponentDeclaration(BaseModel): class RegisterPluginPayload(BaseModel): - """plugin.register_plugin 请求 payload""" + """插件组件注册请求载荷。 + + 该模型同时用于 ``plugin.register_components`` 与兼容旧命名的 + ``plugin.register_plugin`` 请求。 + """ plugin_id: str = Field(description="插件 ID") """插件 ID""" @@ -248,6 +252,39 @@ class ShutdownPayload(BaseModel): """排空超时 (ms)""" +class UnregisterPluginPayload(BaseModel): + """插件注销请求载荷。""" + + plugin_id: str = Field(description="插件 ID") + """插件 ID""" + reason: str = Field(default="manual", description="注销原因") + """注销原因""" + + +class ReloadPluginPayload(BaseModel): + """插件重载请求载荷。""" + + plugin_id: str = Field(description="目标插件 ID") + """目标插件 ID""" + reason: str = Field(default="manual", description="重载原因") + """重载原因""" + + +class ReloadPluginResultPayload(BaseModel): + """插件重载结果载荷。""" + + success: bool = Field(description="是否重载成功") + """是否重载成功""" + requested_plugin_id: str = Field(description="请求重载的插件 ID") + """请求重载的插件 ID""" + reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表") + """成功完成重载的插件列表""" + unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表") + """本次已卸载的插件列表""" + failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因") + """重载失败的插件及原因""" + + # ====== 日志传输 ====== diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 11ba45e7..90c8bf47 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -32,11 +32,22 @@ class PluginMeta: self, plugin_id: str, plugin_dir: str, + module_name: str, plugin_instance: Any, manifest: Dict[str, Any], ) -> None: + """初始化插件元数据。 + + Args: + plugin_id: 插件 ID。 + plugin_dir: 插件目录绝对路径。 + module_name: 插件入口模块名。 + plugin_instance: 插件实例对象。 + manifest: 解析后的 manifest 内容。 + """ self.plugin_id = plugin_id self.plugin_dir = plugin_dir + self.module_name = module_name self.instance = plugin_instance self.manifest = manifest self.version = manifest.get("version", "1.0.0") @@ -45,6 +56,14 @@ class PluginMeta: @staticmethod def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]: + """从 manifest 中提取依赖列表。 + + Args: + manifest: 插件 manifest。 + + Returns: + List[str]: 规范化后的依赖插件 ID 列表。 + """ raw = manifest.get("dependencies", []) result: List[str] = [] for dep in raw: @@ -66,19 +85,24 @@ class PluginLoader: """ def __init__(self, host_version: str = "") -> None: + """初始化插件加载器。 + + Args: + host_version: Host 版本号,用于 manifest 兼容性校验。 + """ self._loaded_plugins: Dict[str, PluginMeta] = {} self._failed_plugins: Dict[str, str] = {} self._manifest_validator = ManifestValidator(host_version=host_version) self._compat_hook_installed = False def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]: - """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验) + """扫描多个目录并加载所有插件。 Args: - plugin_dirs: 插件目录列表 + plugin_dirs: 插件目录列表。 Returns: - 成功加载的插件元数据列表(按依赖顺序) + List[PluginMeta]: 成功加载的插件元数据列表,按依赖顺序排列。 """ candidates, duplicate_candidates = self._discover_candidates(plugin_dirs) self._record_duplicate_candidates(duplicate_candidates) @@ -90,6 +114,18 @@ class PluginLoader: # 第三阶段:按依赖顺序加载 return self._load_plugins_in_order(load_order, candidates) + def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: + """扫描插件目录并返回候选插件。 + + Args: + plugin_dirs: 需要扫描的插件根目录列表。 + + Returns: + Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: + 候选插件映射和重复插件 ID 冲突映射。 + """ + return self._discover_candidates(plugin_dirs) + def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]: """扫描插件目录并收集候选插件。""" candidates: Dict[str, PluginCandidate] = {} @@ -170,7 +206,6 @@ class PluginLoader: plugin_dir, manifest, plugin_path = candidates[plugin_id] try: if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path): - self._loaded_plugins[meta.plugin_id] = meta results.append(meta) except Exception as e: self._failed_plugins[plugin_id] = str(e) @@ -182,22 +217,109 @@ class PluginLoader: """获取已加载的插件""" return self._loaded_plugins.get(plugin_id) + def set_loaded_plugin(self, meta: PluginMeta) -> None: + """登记一个已经完成初始化的插件。 + + Args: + meta: 待登记的插件元数据。 + """ + self._loaded_plugins[meta.plugin_id] = meta + + def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]: + """移除一个已加载插件的元数据。 + + Args: + plugin_id: 待移除的插件 ID。 + + Returns: + Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。 + """ + return self._loaded_plugins.pop(plugin_id, None) + + def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]: + """清理指定插件目录下的模块缓存。 + + Args: + plugin_id: 插件 ID。 + plugin_dir: 插件目录绝对路径。 + + Returns: + List[str]: 已从 ``sys.modules`` 中移除的模块名列表。 + """ + removed_modules: List[str] = [] + plugin_path = Path(plugin_dir).resolve() + synthetic_module_name = f"_maibot_plugin_{plugin_id}" + + for module_name, module in list(sys.modules.items()): + if module_name == synthetic_module_name: + removed_modules.append(module_name) + sys.modules.pop(module_name, None) + continue + + module_file = getattr(module, "__file__", None) + if module_file is None: + continue + + try: + module_path = Path(module_file).resolve() + except Exception: + continue + + if module_path.is_relative_to(plugin_path): + removed_modules.append(module_name) + sys.modules.pop(module_name, None) + + importlib.invalidate_caches() + return removed_modules + def list_plugins(self) -> List[str]: """列出所有已加载的插件 ID""" return list(self._loaded_plugins.keys()) @property def failed_plugins(self) -> Dict[str, str]: + """返回当前记录的失败插件原因映射。""" return dict(self._failed_plugins) # ──── 依赖解析 ──────────────────────────────────────────── + def resolve_dependencies( + self, + candidates: Dict[str, PluginCandidate], + extra_available: Optional[Set[str]] = None, + ) -> Tuple[List[str], Dict[str, str]]: + """解析候选插件的依赖顺序。 + + Args: + candidates: 待加载的候选插件集合。 + extra_available: 视为已满足的外部依赖插件 ID 集合。 + + Returns: + Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。 + """ + return self._resolve_dependencies(candidates, extra_available=extra_available) + + def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]: + """加载单个候选插件模块。 + + Args: + plugin_id: 插件 ID。 + candidate: 候选插件三元组。 + + Returns: + Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。 + """ + plugin_dir, manifest, plugin_path = candidate + return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path) + def _resolve_dependencies( self, candidates: Dict[str, PluginCandidate], + extra_available: Optional[Set[str]] = None, ) -> Tuple[List[str], Dict[str, str]]: """拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。""" available = set(candidates.keys()) + satisfied_dependencies = set(extra_available or set()) dep_graph: Dict[str, Set[str]] = {} failed: Dict[str, str] = {} @@ -212,6 +334,8 @@ class PluginLoader: continue if dep_name in available: resolved.add(dep_name) + elif dep_name in satisfied_dependencies: + continue else: missing.append(dep_name) if missing: @@ -271,33 +395,39 @@ class PluginLoader: sys.modules[module_name] = module plugin_parent_dir = plugin_dir.parent - with self._temporary_sys_path_entry(plugin_parent_dir): - spec.loader.exec_module(module) + try: + with self._temporary_sys_path_entry(plugin_parent_dir): + spec.loader.exec_module(module) - # 优先使用新版 create_plugin 工厂函数 - create_plugin = getattr(module, "create_plugin", None) - if create_plugin is not None: - instance = create_plugin() - logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=str(plugin_dir), - plugin_instance=instance, - manifest=manifest, - ) + # 优先使用新版 create_plugin 工厂函数 + create_plugin = getattr(module, "create_plugin", None) + if create_plugin is not None: + instance = create_plugin() + logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=str(plugin_dir), + module_name=module_name, + plugin_instance=instance, + manifest=manifest, + ) - # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 - instance = self._try_load_legacy_plugin(module, plugin_id) - if instance is not None: - logger.info( - f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" - ) - return PluginMeta( - plugin_id=plugin_id, - plugin_dir=str(plugin_dir), - plugin_instance=instance, - manifest=manifest, - ) + # 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类 + instance = self._try_load_legacy_plugin(module, plugin_id) + if instance is not None: + logger.info( + f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk)" + ) + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=str(plugin_dir), + module_name=module_name, + plugin_instance=instance, + manifest=manifest, + ) + except Exception: + sys.modules.pop(module_name, None) + raise logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin") return None diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index 6a1d59d5..dc917cc8 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -1,14 +1,6 @@ -"""Runner 端 RPC Client +"""Runner 端 RPC 客户端。""" -负责: -1. 连接 Host RPC Server -2. 发送握手(runner.hello) -3. 发送组件注册请求 -4. 接收并分发 Host 的调用请求 -5. 发送能力调用请求到 Host -""" - -from typing import Any, Awaitable, Callable, Dict, Optional, cast +from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast import asyncio import contextlib @@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client logger = get_logger("plugin_runtime.runner.rpc_client") -# RPC 方法处理器类型 MethodHandler = Callable[[Envelope], Awaitable[Envelope]] def _get_sdk_version() -> str: - """从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。""" + """读取 SDK 版本号。 + + Returns: + str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。 + """ try: from importlib.metadata import version @@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version() class RPCClient: - """Runner 端 RPC 客户端 - - 管理与 Host 的 IPC 连接,支持双向 RPC 调用。 - """ + """Runner 端 RPC 客户端。""" def __init__( self, host_address: str, session_token: str, codec: Optional[Codec] = None, - ): - self._host_address = host_address - self._session_token = session_token - self._codec = codec or MsgPackCodec() + ) -> None: + """初始化 RPC 客户端。 + + Args: + host_address: Host 的 IPC 地址。 + session_token: 握手用会话令牌。 + codec: 可选的编解码器实现。 + """ + self._host_address: str = host_address + self._session_token: str = session_token + self._codec: Codec = codec or MsgPackCodec() self._id_gen = RequestIdGenerator() self._connection: Optional[Connection] = None - self._runner_id = str(uuid.uuid4()) - self._generation: int = 0 - - # 方法处理器注册表(Host 发来的调用) + self._runner_id: str = str(uuid.uuid4()) self._method_handlers: Dict[str, MethodHandler] = {} - - # 等待响应的 pending 请求: request_id -> Future - self._pending_requests: Dict[int, asyncio.Future] = {} - - # 运行状态 - self._running = False - self._recv_task: Optional[asyncio.Task] = None - self._background_tasks: set[asyncio.Task] = set() - - @property - def generation(self) -> int: - return self._generation + self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {} + self._running: bool = False + self._recv_task: Optional[asyncio.Task[None]] = None + self._background_tasks: Set[asyncio.Task[Any]] = set() @property def is_connected(self) -> bool: + """返回当前连接是否可用。""" return self._connection is not None and not self._connection.is_closed def register_method(self, method: str, handler: MethodHandler) -> None: - """注册方法处理器(处理 Host 发来的请求)""" + """注册 Host -> Runner 的 RPC 处理器。 + + Args: + method: RPC 方法名。 + handler: 方法处理函数。 + """ self._method_handlers[method] = handler def _require_connection(self) -> Connection: - """返回当前可用连接;若连接不可用则抛出 RPCError。""" + """返回当前可用连接。 + + Returns: + Connection: 当前连接对象。 + + Raises: + RPCError: 当前未连接到 Host。 + """ connection = self._connection if connection is None or connection.is_closed: raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host") return cast(Connection, connection) async def connect_and_handshake(self) -> bool: - """连接 Host 并完成握手 + """连接 Host 并完成握手。 Returns: - 是否握手成功 + bool: 是否握手成功。 """ client = create_transport_client(self._host_address) self._connection = await client.connect() connection = self._require_connection() - # 发送 runner.hello hello = HelloPayload( runner_id=self._runner_id, sdk_version=SDK_VERSION, session_token=self._session_token, ) - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, @@ -121,33 +121,27 @@ class RPCClient: payload=hello.model_dump(), ) - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) + await connection.send_frame(self._codec.encode_envelope(envelope)) - # 接收握手响应 resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0) - resp = self._codec.decode_envelope(resp_data) + response = self._codec.decode_envelope(resp_data) + resp_payload = HelloResponsePayload.model_validate(response.payload) - resp_payload = HelloResponsePayload.model_validate(resp.payload) if not resp_payload.accepted: logger.error(f"握手被拒绝: {resp_payload.reason}") - await self._connection.close() - self._connection = None + await self.disconnect() return False - self._generation = resp_payload.assigned_generation - logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}") - - # 启动消息接收循环 + logger.info(f"握手成功: host_version={resp_payload.host_version}") self._running = True - self._recv_task = asyncio.create_task(self._recv_loop()) - + self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv") return True async def disconnect(self) -> None: - """断开连接""" + """断开与 Host 的连接并清理状态。""" self._running = False - if self._recv_task: + + if self._recv_task is not None: self._recv_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._recv_task @@ -160,13 +154,12 @@ class RPCClient: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - # 取消所有 pending 请求 for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭")) self._pending_requests.clear() - if self._connection: + if self._connection is not None: await self._connection.close() self._connection = None @@ -177,16 +170,27 @@ class RPCClient: payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> Envelope: - """向 Host 发送 RPC 请求并等待响应""" - connection = self._require_connection() + """向 Host 发送 RPC 请求并等待响应。 - request_id = self._id_gen.next() + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 请求载荷。 + timeout_ms: 超时时间,单位毫秒。 + + Returns: + Envelope: Host 返回的响应信封。 + + Raises: + RPCError: 发送失败、超时或连接异常。 + """ + connection = self._require_connection() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=self._generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -196,21 +200,16 @@ class RPCClient: self._pending_requests[request_id] = future try: - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) - - timeout_sec = timeout_ms / 1000.0 - return await asyncio.wait_for(future, timeout=timeout_sec) + await connection.send_frame(self._codec.encode_envelope(envelope)) + return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0) except asyncio.TimeoutError: self._pending_requests.pop(request_id, None) raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None - except Exception as e: + except Exception as exc: self._pending_requests.pop(request_id, None) - if isinstance(e, RPCError): + if isinstance(exc, RPCError): raise - raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e - - # ─── 内部方法 ────────────────────────────────────────────── + raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc async def send_event( self, @@ -218,33 +217,30 @@ class RPCClient: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, ) -> None: - """向 Host 发送单向事件(fire-and-forget,不等待响应)。 + """向 Host 发送单向广播消息。 Args: - method: RPC 方法名,如 "runner.log_batch"。 - plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)。 - payload: 事件数据。 + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + payload: 广播载荷。 """ if not self.is_connected: return connection = self._require_connection() - - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, - message_type=MessageType.EVENT, + message_type=MessageType.BROADCAST, method=method, plugin_id=plugin_id, - generation=self._generation, payload=payload or {}, ) - data = self._codec.encode_envelope(envelope) - await connection.send_frame(data) + await connection.send_frame(self._codec.encode_envelope(envelope)) async def _recv_loop(self) -> None: - """消息接收主循环""" - while self._running and self._connection and not self._connection.is_closed: + """持续接收 Host 发来的消息并分发。""" + while self._running and self._connection is not None and not self._connection.is_closed: try: data = await self._connection.recv_frame() except (asyncio.IncompleteReadError, ConnectionError): @@ -252,39 +248,47 @@ class RPCClient: break except asyncio.CancelledError: break - except Exception as e: - logger.error(f"接收帧失败: {e}") + except Exception as exc: + logger.error(f"接收帧失败: {exc}") break try: envelope = self._codec.decode_envelope(data) - except Exception as e: - logger.error(f"解码消息失败: {e}") + except Exception as exc: + logger.error(f"解码消息失败: {exc}") continue if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): self._track_background_task(asyncio.create_task(self._handle_request(envelope))) - elif envelope.is_event(): - self._track_background_task(asyncio.create_task(self._handle_event(envelope))) + elif envelope.is_broadcast(): + self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope))) def _handle_response(self, envelope: Envelope) -> None: - """处理来自 Host 的响应""" + """处理 Host 返回的响应。 + + Args: + envelope: 响应信封。 + """ future = self._pending_requests.pop(envelope.request_id, None) - if future and not future.done(): - if envelope.error: - future.set_exception(RPCError.from_dict(envelope.error)) - else: - future.set_result(envelope) + if future is None or future.done(): + return + if envelope.error: + future.set_exception(RPCError.from_dict(envelope.error)) + else: + future.set_result(envelope) async def _handle_request(self, envelope: Envelope) -> None: - """处理来自 Host 的请求(调用插件组件)""" + """处理 Host 发来的请求。 + + Args: + envelope: 请求信封。 + """ connection = self._connection if connection is None or connection.is_closed: logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应") return - connection = cast(Connection, connection) handler = self._method_handlers.get(envelope.method) if handler is None: @@ -298,23 +302,34 @@ class RPCClient: try: response = await handler(envelope) await connection.send_frame(self._codec.encode_envelope(response)) - except RPCError as e: - error_resp = envelope.make_error_response(e.code.value, e.message, e.details) + except RPCError as exc: + error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details) await connection.send_frame(self._codec.encode_envelope(error_resp)) - except Exception as e: - logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) - error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) + except Exception as exc: + logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True) + error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc)) await connection.send_frame(self._codec.encode_envelope(error_resp)) - async def _handle_event(self, envelope: Envelope) -> None: - """处理来自 Host 的事件""" - if handler := self._method_handlers.get(envelope.method): - try: - await handler(envelope) - except Exception as e: - logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) + async def _handle_broadcast(self, envelope: Envelope) -> None: + """处理 Host 发来的广播事件。 - def _track_background_task(self, task: asyncio.Task) -> None: - """保持后台任务强引用,直到其完成或被取消。""" + Args: + envelope: 广播信封。 + """ + handler = self._method_handlers.get(envelope.method) + if handler is None: + return + + try: + await handler(envelope) + except Exception as exc: + logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True) + + def _track_background_task(self, task: asyncio.Task[Any]) -> None: + """持有后台任务强引用直到其结束。 + + Args: + task: 后台任务。 + """ self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index dae1cfa1..771e685f 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,7 +9,7 @@ 6. 转发插件的能力调用到 Host """ -from typing import Any, Callable, List, Optional, Protocol, cast +from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, cast from pathlib import Path @@ -32,8 +32,11 @@ from src.plugin_runtime.protocol.envelope import ( HealthPayload, InvokePayload, InvokeResultPayload, - RegisterComponentsPayload, + RegisterPluginPayload, + ReloadPluginPayload, + ReloadPluginResultPayload, RunnerReadyPayload, + UnregisterPluginPayload, ) from src.plugin_runtime.protocol.errors import ErrorCode from src.plugin_runtime.runner.log_handler import RunnerIPCLogHandler @@ -44,7 +47,8 @@ logger = get_logger("plugin_runtime.runner.main") class _ContextAwarePlugin(Protocol): - def _set_context(self, context: Any) -> None: ... + def _set_context(self, context: Any) -> None: + """为插件注入上下文对象。""" def _install_shutdown_signal_handlers( @@ -90,21 +94,29 @@ class PluginRunner: session_token: str, plugin_dirs: List[str], ) -> None: + """初始化 Runner。 + + Args: + host_address: Host 的 IPC 地址。 + session_token: 握手用会话令牌。 + plugin_dirs: 当前 Runner 负责扫描的插件目录列表。 + """ self._host_address: str = host_address self._session_token: str = session_token - self._plugin_dirs: list[str] = plugin_dirs + self._plugin_dirs: List[str] = plugin_dirs self._rpc_client: RPCClient = RPCClient(host_address, session_token) self._loader: PluginLoader = PluginLoader(host_version=os.getenv(ENV_HOST_VERSION, "")) self._start_time: float = time.monotonic() self._shutting_down: bool = False + self._reload_lock: asyncio.Lock = asyncio.Lock() # IPC 日志 Handler:握手成功后安装,将所有 stdlib logging 转发到 Host self._log_handler: Optional[RunnerIPCLogHandler] = None - self._suspended_console_handlers: list[stdlib_logging.Handler] = [] + self._suspended_console_handlers: List[stdlib_logging.Handler] = [] async def run(self) -> None: - """Runner 主入口""" + """运行 Runner 主循环。""" # 1. 连接 Host logger.info(f"Runner 启动,连接 Host: {self._host_address}") ok = await self._rpc_client.connect_and_handshake() @@ -123,32 +135,11 @@ class PluginRunner: logger.info(f"已加载 {len(plugins)} 个插件") # 4. 注入 PluginContext + 调用 on_load 生命周期钩子 - failed_plugins: set[str] = set() + failed_plugins: Set[str] = set(self._loader.failed_plugins.keys()) for meta in plugins: - instance = meta.instance - self._inject_context(meta.plugin_id, instance) - self._apply_plugin_config(meta) - if not await self._bootstrap_plugin(meta): - failed_plugins.add(meta.plugin_id) - continue - if hasattr(instance, "on_load"): - try: - ret = instance.on_load() - if asyncio.iscoroutine(ret): - await ret - except Exception as e: - logger.error(f"插件 {meta.plugin_id} on_load 失败,跳过注册: {e}", exc_info=True) - failed_plugins.add(meta.plugin_id) - await self._deactivate_plugin(meta) - - # 5. 向 Host 注册所有插件的组件(跳过 on_load 失败的插件) - for meta in plugins: - if meta.plugin_id in failed_plugins: - continue - ok = await self._register_plugin(meta) + ok = await self._activate_plugin(meta) if not ok: failed_plugins.add(meta.plugin_id) - await self._deactivate_plugin(meta) successful_plugins = [meta.plugin_id for meta in plugins if meta.plugin_id not in failed_plugins] await self._notify_ready(successful_plugins, sorted(failed_plugins)) @@ -232,7 +223,9 @@ class PluginRunner: bound_plugin_id = plugin_id async def _rpc_call( - method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None + method: str, + plugin_id: str = "", + payload: Optional[Dict[str, Any]] = None, ) -> Any: """桥接 PluginContext.call_capability → RPCClient.send_request。 @@ -257,7 +250,7 @@ class PluginRunner: cast(_ContextAwarePlugin, instance)._set_context(ctx) logger.debug(f"已为插件 {plugin_id} 注入 PluginContext") - def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None: + def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[Dict[str, Any]] = None) -> None: """在 Runner 侧为插件实例注入当前插件配置。""" instance = meta.instance if not hasattr(instance, "set_plugin_config"): @@ -270,7 +263,7 @@ class PluginRunner: logger.warning(f"插件 {meta.plugin_id} 配置注入失败: {exc}") @staticmethod - def _load_plugin_config(plugin_dir: str) -> dict[str, Any]: + def _load_plugin_config(plugin_dir: str) -> Dict[str, Any]: """从插件目录读取 config.toml。""" config_path = Path(plugin_dir) / "config.toml" if not config_path.exists(): @@ -286,16 +279,18 @@ class PluginRunner: return loaded if isinstance(loaded, dict) else {} def _register_handlers(self) -> None: - """注册方法处理器""" + """注册 Host -> Runner 的方法处理器。""" self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) + self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) + self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin) async def _bootstrap_plugin(self, meta: PluginMeta, capabilities_required: Optional[List[str]] = None) -> bool: """向 Host 同步插件 bootstrap 能力令牌。""" @@ -324,7 +319,14 @@ class PluginRunner: await self._bootstrap_plugin(meta, capabilities_required=[]) async def _register_plugin(self, meta: PluginMeta) -> bool: - """向 Host 注册单个插件""" + """向 Host 注册单个插件。 + + Args: + meta: 待注册的插件元数据。 + + Returns: + bool: 是否注册成功。 + """ # 收集插件组件声明 components: List[ComponentDeclaration] = [] instance = meta.instance @@ -341,7 +343,7 @@ class PluginRunner: for comp_info in instance.get_components() ) - reg_payload = RegisterComponentsPayload( + reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, @@ -361,8 +363,281 @@ class PluginRunner: logger.error(f"插件 {meta.plugin_id} 注册失败: {e}") return False + async def _unregister_plugin(self, plugin_id: str, reason: str) -> None: + """通知 Host 注销指定插件。 + + Args: + plugin_id: 目标插件 ID。 + reason: 注销原因。 + """ + payload = UnregisterPluginPayload(plugin_id=plugin_id, reason=reason) + try: + await self._rpc_client.send_request( + "plugin.unregister", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=10000, + ) + except Exception as exc: + logger.warning(f"插件 {plugin_id} 注销通知失败: {exc}") + + async def _invoke_plugin_on_load(self, meta: PluginMeta) -> bool: + """执行插件的 ``on_load`` 生命周期。 + + Args: + meta: 待初始化的插件元数据。 + + Returns: + bool: 生命周期是否执行成功。 + """ + instance = meta.instance + if not hasattr(instance, "on_load"): + return True + + try: + result = instance.on_load() + if asyncio.iscoroutine(result): + await result + return True + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} on_load 失败: {exc}", exc_info=True) + return False + + async def _invoke_plugin_on_unload(self, meta: PluginMeta) -> None: + """执行插件的 ``on_unload`` 生命周期。 + + Args: + meta: 待卸载的插件元数据。 + """ + instance = meta.instance + if not hasattr(instance, "on_unload"): + return + + try: + result = instance.on_unload() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.error(f"插件 {meta.plugin_id} on_unload 失败: {exc}", exc_info=True) + + async def _activate_plugin(self, meta: PluginMeta) -> bool: + """完成插件注入、授权、生命周期和组件注册。 + + Args: + meta: 待激活的插件元数据。 + + Returns: + bool: 是否激活成功。 + """ + self._inject_context(meta.plugin_id, meta.instance) + self._apply_plugin_config(meta) + + if not await self._bootstrap_plugin(meta): + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + if not await self._invoke_plugin_on_load(meta): + await self._deactivate_plugin(meta) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + if not await self._register_plugin(meta): + await self._invoke_plugin_on_unload(meta) + await self._deactivate_plugin(meta) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + return False + + self._loader.set_loaded_plugin(meta) + return True + + async def _unload_plugin(self, meta: PluginMeta, reason: str) -> None: + """卸载单个插件并清理 Host/Runner 两侧状态。 + + Args: + meta: 待卸载的插件元数据。 + reason: 卸载原因。 + """ + await self._invoke_plugin_on_unload(meta) + await self._unregister_plugin(meta.plugin_id, reason) + await self._deactivate_plugin(meta) + self._loader.remove_loaded_plugin(meta.plugin_id) + self._loader.purge_plugin_modules(meta.plugin_id, meta.plugin_dir) + + def _collect_reverse_dependents(self, plugin_id: str) -> Set[str]: + """收集依赖指定插件的所有已加载插件。 + + Args: + plugin_id: 根插件 ID。 + + Returns: + Set[str]: 目标插件及其所有反向依赖插件集合。 + """ + impacted_plugins: Set[str] = {plugin_id} + changed = True + + while changed: + changed = False + for loaded_plugin_id in self._loader.list_plugins(): + if loaded_plugin_id in impacted_plugins: + continue + + meta = self._loader.get_plugin(loaded_plugin_id) + if meta is None: + continue + + if any(dependency in impacted_plugins for dependency in meta.dependencies): + impacted_plugins.add(loaded_plugin_id) + changed = True + + return impacted_plugins + + def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]: + """构建受影响插件的卸载顺序。 + + Args: + plugin_ids: 需要卸载的插件集合。 + + Returns: + List[str]: 依赖方优先的卸载顺序。 + """ + dependency_graph: Dict[str, Set[str]] = {} + for plugin_id in plugin_ids: + meta = self._loader.get_plugin(plugin_id) + if meta is None: + dependency_graph[plugin_id] = set() + continue + dependency_graph[plugin_id] = {dependency for dependency in meta.dependencies if dependency in plugin_ids} + + indegree: Dict[str, int] = {plugin_id: len(dependencies) for plugin_id, dependencies in dependency_graph.items()} + reverse_graph: Dict[str, Set[str]] = {plugin_id: set() for plugin_id in dependency_graph} + + for plugin_id, dependencies in dependency_graph.items(): + for dependency in dependencies: + reverse_graph.setdefault(dependency, set()).add(plugin_id) + + queue: List[str] = sorted(plugin_id for plugin_id, degree in indegree.items() if degree == 0) + load_order: List[str] = [] + + while queue: + current_plugin_id = queue.pop(0) + load_order.append(current_plugin_id) + for dependent_plugin_id in sorted(reverse_graph.get(current_plugin_id, set())): + indegree[dependent_plugin_id] -= 1 + if indegree[dependent_plugin_id] == 0: + queue.append(dependent_plugin_id) + queue.sort() + + return list(reversed(load_order)) + + async def _reload_plugin_by_id(self, plugin_id: str, reason: str) -> ReloadPluginResultPayload: + """按插件 ID 在 Runner 进程内执行精确重载。 + + Args: + plugin_id: 目标插件 ID。 + reason: 重载原因。 + + Returns: + ReloadPluginResultPayload: 结构化重载结果。 + """ + candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) + failed_plugins: Dict[str, str] = {} + + if plugin_id in duplicate_candidates: + conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) + return ReloadPluginResultPayload( + success=False, + requested_plugin_id=plugin_id, + failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"}, + ) + + loaded_plugin_ids = set(self._loader.list_plugins()) + plugin_is_loaded = plugin_id in loaded_plugin_ids + plugin_has_candidate = plugin_id in candidates + + if not plugin_is_loaded and not plugin_has_candidate: + return ReloadPluginResultPayload( + success=False, + requested_plugin_id=plugin_id, + failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"}, + ) + + target_plugin_ids: Set[str] = {plugin_id} + if plugin_is_loaded: + target_plugin_ids = self._collect_reverse_dependents(plugin_id) + + unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) + unloaded_plugins: List[str] = [] + retained_plugin_ids = loaded_plugin_ids - set(unload_order) + + for unload_plugin_id in unload_order: + meta = self._loader.get_plugin(unload_plugin_id) + if meta is None: + continue + await self._unload_plugin(meta, reason=reason) + unloaded_plugins.append(unload_plugin_id) + + reload_candidates: Dict[str, Tuple[Path, Dict[str, Any], Path]] = {} + for target_plugin_id in target_plugin_ids: + candidate = candidates.get(target_plugin_id) + if candidate is None: + failed_plugins[target_plugin_id] = "插件目录已不存在,已保持卸载状态" + continue + reload_candidates[target_plugin_id] = candidate + + load_order, dependency_failures = self._loader.resolve_dependencies( + reload_candidates, + extra_available=retained_plugin_ids, + ) + failed_plugins.update(dependency_failures) + + available_plugins = set(retained_plugin_ids) + reloaded_plugins: List[str] = [] + + for load_plugin_id in load_order: + if load_plugin_id in failed_plugins: + continue + + candidate = reload_candidates.get(load_plugin_id) + if candidate is None: + continue + + _, manifest, _ = candidate + dependencies = PluginMeta._extract_dependencies(manifest) + missing_dependencies = [dependency for dependency in dependencies if dependency not in available_plugins] + if missing_dependencies: + failed_plugins[load_plugin_id] = f"依赖未满足: {', '.join(missing_dependencies)}" + continue + + meta = self._loader.load_candidate(load_plugin_id, candidate) + if meta is None: + failed_plugins[load_plugin_id] = "插件模块加载失败" + continue + + activated = await self._activate_plugin(meta) + if not activated: + failed_plugins[load_plugin_id] = "插件初始化失败" + continue + + available_plugins.add(load_plugin_id) + reloaded_plugins.append(load_plugin_id) + + requested_plugin_success = plugin_id in reloaded_plugins and not failed_plugins + + return ReloadPluginResultPayload( + success=requested_plugin_success, + requested_plugin_id=plugin_id, + reloaded_plugins=reloaded_plugins, + unloaded_plugins=unloaded_plugins, + failed_plugins=failed_plugins, + ) + async def _notify_ready(self, loaded_plugins: List[str], failed_plugins: List[str]) -> None: - """通知 Host 当前 generation 已完成插件初始化。""" + """通知 Host 当前 Runner 已完成插件初始化。 + + Args: + loaded_plugins: 成功初始化的插件列表。 + failed_plugins: 初始化失败的插件列表。 + """ payload = RunnerReadyPayload( loaded_plugins=loaded_plugins, failed_plugins=failed_plugins, @@ -487,6 +762,61 @@ class PluginRunner: logger.error(f"插件 {plugin_id} event_handler {component_name} 执行异常: {e}", exc_info=True) return envelope.make_response(payload={"success": False, "continue_processing": True}) + async def _handle_hook_invoke(self, envelope: Envelope) -> Envelope: + """处理 HookHandler 调用请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 标准化后的 Hook 调用结果。 + """ + try: + invoke = InvokePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta is None: + return envelope.make_error_response( + ErrorCode.E_PLUGIN_NOT_FOUND.value, + f"插件 {plugin_id} 未加载", + ) + + instance = meta.instance + component_name = invoke.component_name + handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None) + if handler_method is None or not callable(handler_method): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {plugin_id} 无组件: {component_name}", + ) + + try: + raw = ( + await handler_method(**invoke.args) + if inspect.iscoroutinefunction(handler_method) + else handler_method(**invoke.args) + ) + except Exception as exc: + logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True) + return envelope.make_response(payload={"success": False, "continue_processing": True}) + + if raw is None: + result = {"success": True, "continue_processing": True} + elif isinstance(raw, dict): + result = { + "success": True, + "continue_processing": raw.get("continue_processing", True), + "modified_kwargs": raw.get("modified_kwargs"), + "custom_result": raw.get("custom_result"), + } + else: + result = {"success": True, "continue_processing": True, "custom_result": raw} + + return envelope.make_response(payload=result) + async def _handle_workflow_step(self, envelope: Envelope) -> Envelope: """处理 WorkflowStep 调用请求 @@ -557,15 +887,10 @@ class PluginRunner: async def _handle_shutdown(self, envelope: Envelope) -> Envelope: """处理关停 — 调用所有插件的 on_unload 后退出""" logger.info("收到 shutdown 信号,开始调用 on_unload") - for plugin_id in self._loader.list_plugins(): + for plugin_id in list(self._loader.list_plugins()): meta = self._loader.get_plugin(plugin_id) - if meta and hasattr(meta.instance, "on_unload"): - try: - ret = meta.instance.on_unload() - if asyncio.iscoroutine(ret): - await ret - except Exception as e: - logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True) + if meta is not None: + await self._unload_plugin(meta, reason="runner_shutdown") self._shutting_down = True return envelope.make_response(payload={"acknowledged": True}) @@ -587,6 +912,30 @@ class PluginRunner: return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) return envelope.make_response(payload={"acknowledged": True}) + async def _handle_reload_plugin(self, envelope: Envelope) -> Envelope: + """处理按插件 ID 的精确重载请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 结构化重载结果。 + """ + try: + payload = ReloadPluginPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if self._reload_lock.locked(): + return envelope.make_error_response( + ErrorCode.E_RELOAD_IN_PROGRESS.value, + f"插件 {payload.plugin_id} 重载请求被拒绝:已有重载任务正在执行", + ) + + async with self._reload_lock: + result = await self._reload_plugin_by_id(payload.plugin_id, payload.reason) + return envelope.make_response(payload=result.model_dump()) + def request_capability(self) -> RPCClient: """获取 RPC 客户端(供 SDK 使用,发起能力调用)""" return self._rpc_client @@ -652,13 +1001,16 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") - def find_module(self, fullname, path=None): + def find_module(self, fullname: str, path: Any = None) -> Any: + """决定是否拦截指定模块导入。""" return self if self._should_block(fullname) else None - def load_module(self, fullname): + def load_module(self, fullname: str) -> None: + """阻止被拦截模块继续导入。""" raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}") def _should_block(self, fullname: str) -> bool: + """判断给定模块名是否应被阻止导入。""" # 放行非 src.* 的导入、以及 "src" 本身 if not fullname.startswith("src.") or fullname == "src": return False @@ -692,6 +1044,7 @@ async def _async_main() -> None: # 注册信号处理 def _mark_runner_shutting_down() -> None: + """标记 Runner 即将进入关停流程。""" runner._shutting_down = True _install_shutdown_signal_handlers(_mark_runner_shutting_down)