diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 1d93ae24..f3e1e7ce 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -676,6 +676,101 @@ class TestSDK: methods = [call["method"] for call in runner._rpc_client.calls] assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] + @pytest.mark.asyncio + async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch): + """批量重载应只对重叠依赖闭包执行一次 unload/load。""" + from src.plugin_runtime.runner.runner_main import PluginRunner + + runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) + plugin_a_id = "test.plugin-a" + plugin_b_id = "test.plugin-b" + plugin_c_id = "test.plugin-c" + + def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace: + return SimpleNamespace( + plugin_id=plugin_id, + dependencies=dependencies, + plugin_dir=f"/tmp/{plugin_id}", + version="1.0.0", + instance=SimpleNamespace(), + ) + + loaded_metas = { + plugin_a_id: build_meta(plugin_a_id, []), + plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]), + plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]), + } + reloaded_metas = { + plugin_id: build_meta(plugin_id, list(meta.dependencies)) + for plugin_id, meta in loaded_metas.items() + } + candidates = { + plugin_a_id: ( + "dir_plugin_a", + build_test_manifest_model(plugin_a_id), + "plugin_a/plugin.py", + ), + plugin_b_id: ( + "dir_plugin_b", + build_test_manifest_model( + plugin_b_id, + dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}], + ), + "plugin_b/plugin.py", + ), + plugin_c_id: ( + "dir_plugin_c", + build_test_manifest_model( + plugin_c_id, + dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}], + ), + "plugin_c/plugin.py", + ), + } + unloaded_plugins: list[str] = [] + activated_plugins: list[str] = [] + + monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {})) + monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys())) + monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id)) + monkeypatch.setattr( + runner._loader, + "remove_loaded_plugin", + lambda plugin_id: loaded_metas.pop(plugin_id, None), + ) + monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: []) + monkeypatch.setattr( + runner._loader, + "resolve_dependencies", + lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}), + ) + monkeypatch.setattr( + runner._loader, + "load_candidate", + lambda plugin_id, candidate: reloaded_metas[plugin_id], + ) + + async def fake_unload_plugin(meta, reason, purge_modules=False): + del reason, purge_modules + unloaded_plugins.append(meta.plugin_id) + loaded_metas.pop(meta.plugin_id, None) + + async def fake_activate_plugin(meta): + activated_plugins.append(meta.plugin_id) + loaded_metas[meta.plugin_id] = meta + return True + + monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin) + monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin) + + result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual") + + assert result.success is True + assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id] + assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id] + assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] + assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] + class TestPluginSdkUsage: """验证仓库内插件按新 SDK 归一化返回值工作。""" @@ -1220,6 +1315,25 @@ class TestDependencyResolution: sys.path[:] = original_path sys.meta_path[:] = original_meta_path + def test_isolate_sys_path_blocks_disallowed_src_imports(self): + import importlib + + from src.plugin_runtime.runner import runner_main + + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + sys.modules.pop("src.forbidden_demo", None) + + try: + runner_main._isolate_sys_path([]) + + with pytest.raises(ImportError, match="不允许导入主程序模块"): + importlib.import_module("src.forbidden_demo") + finally: + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + sys.modules.pop("src.forbidden_demo", None) + # ─── Host-side ComponentRegistry 测试 ────────────────────── @@ -1264,6 +1378,30 @@ class TestComponentRegistry: assert stats["command"] == 1 assert stats["tool"] == 1 + def test_register_command_with_invalid_regex_only_warns(self, monkeypatch): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + warnings: list[str] = [] + monkeypatch.setattr( + "src.plugin_runtime.host.component_registry.logger.warning", + lambda message: warnings.append(str(message)), + ) + + success = reg.register_component( + "broken", + "command", + "plugin_a", + { + "command_pattern": "[", + }, + ) + + assert success is True + assert reg.get_component("plugin_a.broken") is not None + assert warnings + assert "plugin_a.broken" in warnings[0] + def test_query_by_type(self): from src.plugin_runtime.host.component_registry import ComponentRegistry @@ -2303,6 +2441,39 @@ class TestSupervisor: assert supervisor.component_registry.get_component("plugin_a.handler") is not None assert supervisor.component_registry.get_component("plugin_a.obsolete") is None + @pytest.mark.asyncio + async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self): + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload + + supervisor = PluginSupervisor(plugin_dirs=[]) + sent_requests: list[tuple[str, dict[str, object], int]] = [] + + class FakeRPCServer: + async def send_request(self, method, payload, timeout_ms=5000, **kwargs): + del kwargs + sent_requests.append((method, payload, timeout_ms)) + return SimpleNamespace( + payload=ReloadPluginsResultPayload( + success=True, + requested_plugin_ids=["plugin_a", "plugin_b"], + reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"], + unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"], + ).model_dump() + ) + + supervisor._rpc_server = FakeRPCServer() + + reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual") + + assert reloaded is True + assert len(sent_requests) == 1 + method, payload, timeout_ms = sent_requests[0] + assert method == "plugin.reload_batch" + assert payload["plugin_ids"] == ["plugin_a", "plugin_b"] + assert payload["reason"] == "manual" + assert timeout_ms >= 10000 + @pytest.mark.asyncio async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 1c073490..97fdca30 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -75,14 +75,14 @@ class CommandEntry(ComponentEntry): """Command 组件条目""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - self.compiled_pattern: Optional[re.Pattern] = None + super().__init__(name, component_type, plugin_id, metadata) self.aliases: List[str] = metadata.get("aliases", []) + self.compiled_pattern: Optional[re.Pattern] = None if pattern := metadata.get("command_pattern", ""): try: self.compiled_pattern = re.compile(pattern) - except re.error as e: + except (re.error, TypeError) as e: logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") - super().__init__(name, component_type, plugin_id, metadata) class ToolEntry(ComponentEntry): diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index d12014f6..08638d16 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -33,6 +33,8 @@ from src.plugin_runtime.protocol.envelope import ( ReceiveExternalMessageResultPayload, RegisterPluginPayload, ReloadPluginResultPayload, + ReloadPluginsPayload, + ReloadPluginsResultPayload, RouteMessagePayload, RunnerReadyPayload, ShutdownPayload, @@ -194,6 +196,35 @@ class PluginRunnerSupervisor: for plugin_id, registration in self._registered_plugins.items() } + @staticmethod + def _normalize_reload_plugin_ids(plugin_ids: Optional[List[str] | str]) -> List[str]: + """规范化批量重载入参。 + + Args: + plugin_ids: 原始插件 ID 列表或单个插件 ID。 + + Returns: + List[str]: 去重且去空白后的插件 ID 列表。 + """ + + raw_plugin_ids: List[str] + if plugin_ids is None: + raw_plugin_ids = [] + elif isinstance(plugin_ids, str): + raw_plugin_ids = [plugin_ids] + else: + raw_plugin_ids = list(plugin_ids) + + normalized_plugin_ids: List[str] = [] + seen_plugin_ids: set[str] = set() + for plugin_id in raw_plugin_ids: + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(normalized_plugin_id) + normalized_plugin_ids.append(normalized_plugin_id) + return normalized_plugin_ids + async def dispatch_event( self, event_type: str, @@ -420,7 +451,7 @@ class PluginRunnerSupervisor: async def reload_plugins( self, - plugin_ids: Optional[List[str]] = None, + plugin_ids: Optional[List[str] | str] = None, reason: str = "manual", external_available_plugins: Optional[Dict[str, str]] = None, ) -> bool: @@ -434,19 +465,37 @@ class PluginRunnerSupervisor: Returns: bool: 是否全部重载成功。 """ - target_plugin_ids = plugin_ids or list(self._registered_plugins.keys()) - ordered_plugin_ids = list(dict.fromkeys(target_plugin_ids)) - success = True + ordered_plugin_ids = self._normalize_reload_plugin_ids(plugin_ids) + if not ordered_plugin_ids: + ordered_plugin_ids = list(self._registered_plugins.keys()) + if not ordered_plugin_ids: + return True - for plugin_id in ordered_plugin_ids: - reloaded = await self.reload_plugin( - plugin_id=plugin_id, + if len(ordered_plugin_ids) == 1: + return await self.reload_plugin( + plugin_id=ordered_plugin_ids[0], reason=reason, external_available_plugins=external_available_plugins, ) - success = success and reloaded - return success + try: + response = await self._rpc_server.send_request( + "plugin.reload_batch", + payload=ReloadPluginsPayload( + plugin_ids=ordered_plugin_ids, + reason=reason, + external_available_plugins=external_available_plugins or self._external_available_plugins, + ).model_dump(), + timeout_ms=max(int(self._runner_spawn_timeout * 1000), 10000), + ) + except Exception as exc: + logger.error(f"插件批量重载请求失败: {exc}") + return False + + result = ReloadPluginsResultPayload.model_validate(response.payload) + if not result.success: + logger.warning(f"插件批量重载失败: {result.failed_plugins}") + return result.success async def notify_plugin_config_updated( self, diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index c2c89a0f..e738d019 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -289,6 +289,20 @@ class ReloadPluginPayload(BaseModel): """可视为已满足的外部依赖插件版本映射""" +class ReloadPluginsPayload(BaseModel): + """批量插件重载请求载荷。""" + + plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表") + """目标插件 ID 列表""" + reason: str = Field(default="manual", description="重载原因") + """重载原因""" + external_available_plugins: Dict[str, str] = Field( + default_factory=dict, + description="可视为已满足的外部依赖插件版本映射", + ) + """可视为已满足的外部依赖插件版本映射""" + + class ReloadPluginResultPayload(BaseModel): """插件重载结果载荷。""" @@ -304,6 +318,21 @@ class ReloadPluginResultPayload(BaseModel): """重载失败的插件及原因""" +class ReloadPluginsResultPayload(BaseModel): + """批量插件重载结果载荷。""" + + success: bool = Field(description="是否重载成功") + """是否重载成功""" + requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表") + """请求重载的插件 ID 列表""" + reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表") + """成功完成重载的插件列表""" + unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表") + """本次已卸载的插件列表""" + failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因") + """重载失败的插件及原因""" + + class MessageGatewayStateUpdatePayload(BaseModel): """消息网关运行时状态更新载荷。""" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 39c741bd..e66d2fab 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -42,6 +42,8 @@ from src.plugin_runtime.protocol.envelope import ( RegisterPluginPayload, ReloadPluginPayload, ReloadPluginResultPayload, + ReloadPluginsPayload, + ReloadPluginsResultPayload, RunnerReadyPayload, UnregisterPluginPayload, ) @@ -334,6 +336,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) self._rpc_client.register_method("plugin.reload", self._handle_reload_plugin) + self._rpc_client.register_method("plugin.reload_batch", self._handle_reload_plugins) @staticmethod def _resolve_component_handler_name(meta: PluginMeta, component_name: str) -> str: @@ -597,6 +600,21 @@ class PluginRunner: return impacted_plugins + def _collect_reverse_dependents_for_roots(self, plugin_ids: Set[str]) -> Set[str]: + """收集多个根插件对应的反向依赖并集。 + + Args: + plugin_ids: 根插件 ID 集合。 + + Returns: + Set[str]: 所有根插件及其反向依赖并集。 + """ + + impacted_plugins: Set[str] = set() + for plugin_id in sorted(plugin_ids): + impacted_plugins.update(self._collect_reverse_dependents(plugin_id)) + return impacted_plugins + def _build_unload_order(self, plugin_ids: Set[str]) -> List[str]: """构建受影响插件的卸载顺序。 @@ -635,6 +653,20 @@ class PluginRunner: return list(reversed(load_order)) + @staticmethod + def _normalize_requested_plugin_ids(plugin_ids: List[str]) -> List[str]: + """规范化批量重载请求中的插件 ID 列表。""" + + normalized_plugin_ids: List[str] = [] + seen_plugin_ids: Set[str] = set() + for plugin_id in plugin_ids: + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id or normalized_plugin_id in seen_plugin_ids: + continue + seen_plugin_ids.add(normalized_plugin_id) + normalized_plugin_ids.append(normalized_plugin_id) + return normalized_plugin_ids + @staticmethod def _finalize_failed_reload_messages( failed_plugins: Dict[str, str], @@ -674,6 +706,31 @@ class PluginRunner: Returns: ReloadPluginResultPayload: 结构化重载结果。 """ + batch_result = await self._reload_plugins_by_ids( + [plugin_id], + reason, + external_available_plugins=external_available_plugins, + ) + return ReloadPluginResultPayload( + success=batch_result.success, + requested_plugin_id=plugin_id, + reloaded_plugins=batch_result.reloaded_plugins, + unloaded_plugins=batch_result.unloaded_plugins, + failed_plugins=batch_result.failed_plugins, + ) + + async def _reload_plugins_by_ids( + self, + plugin_ids: List[str], + reason: str, + external_available_plugins: Optional[Dict[str, str]] = None, + ) -> ReloadPluginsResultPayload: + """按插件 ID 列表在 Runner 进程内执行一次批量重载。""" + + normalized_plugin_ids = self._normalize_requested_plugin_ids(plugin_ids) + if not normalized_plugin_ids: + return ReloadPluginsResultPayload(success=True, requested_plugin_ids=[]) + candidates, duplicate_candidates = self._loader.discover_candidates(self._plugin_dirs) failed_plugins: Dict[str, str] = {} normalized_external_available = { @@ -682,28 +739,35 @@ class PluginRunner: if str(candidate_plugin_id or "").strip() and str(candidate_plugin_version or "").strip() } - if plugin_id in duplicate_candidates: - conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) - return ReloadPluginResultPayload( - success=False, - requested_plugin_id=plugin_id, - failed_plugins={plugin_id: f"检测到重复插件 ID: {conflict_paths}"}, - ) - loaded_plugin_ids = set(self._loader.list_plugins()) - plugin_is_loaded = plugin_id in loaded_plugin_ids - plugin_has_candidate = plugin_id in candidates + reload_root_ids: Set[str] = set() + for plugin_id in normalized_plugin_ids: + if plugin_id in duplicate_candidates: + conflict_paths = ", ".join(str(path) for path in duplicate_candidates[plugin_id]) + failed_plugins[plugin_id] = f"检测到重复插件 ID: {conflict_paths}" + continue - if not plugin_is_loaded and not plugin_has_candidate: - return ReloadPluginResultPayload( + plugin_is_loaded = plugin_id in loaded_plugin_ids + plugin_has_candidate = plugin_id in candidates + if not plugin_is_loaded and not plugin_has_candidate: + failed_plugins[plugin_id] = "插件不存在或未找到合法的 manifest/plugin.py" + continue + + reload_root_ids.add(plugin_id) + + if not reload_root_ids: + return ReloadPluginsResultPayload( success=False, - requested_plugin_id=plugin_id, - failed_plugins={plugin_id: "插件不存在或未找到合法的 manifest/plugin.py"}, + requested_plugin_ids=normalized_plugin_ids, + failed_plugins=failed_plugins, ) - target_plugin_ids: Set[str] = {plugin_id} - if plugin_is_loaded: - target_plugin_ids = self._collect_reverse_dependents(plugin_id) + target_plugin_ids: Set[str] = { + plugin_id for plugin_id in reload_root_ids if plugin_id not in loaded_plugin_ids + } + loaded_root_plugin_ids = reload_root_ids & loaded_plugin_ids + if loaded_root_plugin_ids: + target_plugin_ids.update(self._collect_reverse_dependents_for_roots(loaded_root_plugin_ids)) unload_order = self._build_unload_order(target_plugin_ids & loaded_plugin_ids) unloaded_plugins: List[str] = [] @@ -813,19 +877,19 @@ class PluginRunner: if not restored: rollback_failures[rollback_plugin_id] = "无法重新激活旧版本" - return ReloadPluginResultPayload( + return ReloadPluginsResultPayload( success=False, - requested_plugin_id=plugin_id, + requested_plugin_ids=normalized_plugin_ids, reloaded_plugins=[], unloaded_plugins=unloaded_plugins, failed_plugins=self._finalize_failed_reload_messages(failed_plugins, rollback_failures), ) - requested_plugin_success = plugin_id in reloaded_plugins + requested_plugin_success = all(plugin_id in reloaded_plugins for plugin_id in reload_root_ids) - return ReloadPluginResultPayload( - success=requested_plugin_success, - requested_plugin_id=plugin_id, + return ReloadPluginsResultPayload( + success=requested_plugin_success and not failed_plugins, + requested_plugin_ids=normalized_plugin_ids, reloaded_plugins=reloaded_plugins, unloaded_plugins=unloaded_plugins, failed_plugins=failed_plugins, @@ -1139,6 +1203,29 @@ class PluginRunner: ) return envelope.make_response(payload=result.model_dump()) + async def _handle_reload_plugins(self, envelope: Envelope) -> Envelope: + """处理批量插件重载请求。""" + + try: + payload = ReloadPluginsPayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + if self._reload_lock.locked(): + requested_plugin_ids = ", ".join(self._normalize_requested_plugin_ids(payload.plugin_ids)) or "" + return envelope.make_error_response( + ErrorCode.E_RELOAD_IN_PROGRESS.value, + f"插件 {requested_plugin_ids} 批量重载请求被拒绝:已有重载任务正在执行", + ) + + async with self._reload_lock: + result = await self._reload_plugins_by_ids( + list(payload.plugin_ids), + payload.reason, + external_available_plugins=dict(payload.external_available_plugins), + ) + return envelope.make_response(payload=result.model_dump()) + def request_capability(self) -> RPCClient: """获取 RPC 客户端(供 SDK 使用,发起能力调用)""" return self._rpc_client @@ -1153,6 +1240,7 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: 防止插件代码 import 主程序模块读取运行时数据。 """ import importlib.abc + from importlib.machinery import ModuleSpec import sysconfig # 保留: 标准库路径 + site-packages(含 SDK 和依赖) @@ -1195,6 +1283,20 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: # 安装 import 钩子,阻止插件导入主程序核心模块 # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包 + class _BlockedSrcModuleLoader(importlib.abc.Loader): + """阻止被 Runner 允许列表之外的主程序模块完成导入。""" + + def __init__(self, fullname: str) -> None: + self._fullname = fullname + + def create_module(self, spec: ModuleSpec) -> None: + del spec + return None + + def exec_module(self, module: Any) -> None: + del module + raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}") + class _PluginImportBlocker(importlib.abc.MetaPathFinder): """阻止 Runner 子进程导入主程序核心模块。 @@ -1203,14 +1305,15 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: """ _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") + __maibot_runner_plugin_import_blocker__ = True - def find_module(self, fullname: str, path: Any = None) -> Any: + def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None: """决定是否拦截指定模块导入。""" - return self if self._should_block(fullname) else None - - def load_module(self, fullname: str) -> None: - """阻止被拦截模块继续导入。""" - raise ImportError(f"Runner 子进程不允许导入主程序模块: {fullname}") + del path, target + if not self._should_block(fullname): + return None + # Python 3.13+/3.14 会优先走 find_spec,不再依赖 find_module。 + return ModuleSpec(fullname, _BlockedSrcModuleLoader(fullname), is_package=True) def _should_block(self, fullname: str) -> bool: """判断给定模块名是否应被阻止导入。""" @@ -1222,6 +1325,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES ) + sys.meta_path[:] = [ + finder + for finder in sys.meta_path + if not getattr(finder, "__maibot_runner_plugin_import_blocker__", False) + ] sys.meta_path.insert(0, _PluginImportBlocker())