diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index f3e1e7ce..29227658 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -1298,9 +1298,12 @@ class TestDependencyResolution: assert "on_unload" in loader.failed_plugins["test.demo-plugin"] def test_isolate_sys_path_preserves_plugin_dirs(self): + import builtins + from src.plugin_runtime.runner import runner_main plugin_root = os.path.normpath("/tmp/maibot-plugin-root") + original_import = builtins.__import__ original_path = list(sys.path) original_meta_path = list(sys.meta_path) @@ -1312,14 +1315,17 @@ class TestDependencyResolution: assert plugin_root in sys.path finally: + builtins.__import__ = original_import sys.path[:] = original_path sys.meta_path[:] = original_meta_path def test_isolate_sys_path_blocks_disallowed_src_imports(self): + import builtins import importlib from src.plugin_runtime.runner import runner_main + original_import = builtins.__import__ original_path = list(sys.path) original_meta_path = list(sys.meta_path) sys.modules.pop("src.forbidden_demo", None) @@ -1330,10 +1336,89 @@ class TestDependencyResolution: with pytest.raises(ImportError, match="不允许导入主程序模块"): importlib.import_module("src.forbidden_demo") finally: + builtins.__import__ = original_import sys.path[:] = original_path sys.meta_path[:] = original_meta_path sys.modules.pop("src.forbidden_demo", None) + def test_isolate_sys_path_blocks_preloaded_runtime_modules(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + + with pytest.raises(ImportError, match="rpc_client"): + importlib.import_module("src.plugin_runtime.runner.rpc_client") + finally: + builtins.__import__ = original_import + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + def test_isolate_sys_path_keeps_legacy_logger_import_available(self): + import builtins + import importlib + + from src.plugin_runtime.runner import runner_main + + original_import = builtins.__import__ + original_path = list(sys.path) + original_meta_path = list(sys.meta_path) + + try: + runner_main._isolate_sys_path([]) + + logger_module = importlib.import_module("src.common.logger") + assert callable(logger_module.get_logger) + finally: + builtins.__import__ = original_import + sys.path[:] = original_path + sys.meta_path[:] = original_meta_path + + @pytest.mark.asyncio + async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch): + from src.plugin_runtime.runner import runner_main + + captured = {} + + class FakeRunner: + def __init__( + self, + host_address: str, + session_token: str, + plugin_dirs: list[str], + external_available_plugins: dict[str, str] | None = None, + ) -> None: + captured["host_address"] = host_address + captured["session_token"] = session_token + captured["plugin_dirs"] = plugin_dirs + captured["external_available_plugins"] = external_available_plugins or {} + + async def run(self) -> None: + assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None + assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None + + monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999") + monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token") + monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins") + monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}') + monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None) + monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None) + monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner) + + await runner_main._async_main() + + assert captured["host_address"] == "tcp://127.0.0.1:9999" + assert captured["session_token"] == "secret-token" + assert captured["plugin_dirs"] == ["/tmp/plugins"] + assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"} + # ─── Host-side ComponentRegistry 测试 ────────────────────── @@ -2093,6 +2178,67 @@ class TestWorkflowExecutor: class TestRPCServer: """RPC Server 代际保护测试""" + @pytest.mark.asyncio + async def test_reject_second_active_runner_connection(self): + from src.plugin_runtime.host.rpc_server import RPCServer + from src.plugin_runtime.protocol.codec import MsgPackCodec + from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType + + class DummyTransport: + async def start(self, handler): + return None + + async def stop(self): + return None + + def get_address(self): + return "dummy" + + class FakeConnection: + def __init__(self, incoming_frames: list[bytes]): + self._incoming_frames = list(incoming_frames) + self.sent_frames: list[bytes] = [] + self.is_closed = False + + async def recv_frame(self): + return self._incoming_frames.pop(0) + + async def send_frame(self, data): + self.sent_frames.append(data) + + async def close(self): + self.is_closed = True + + codec = MsgPackCodec() + server = RPCServer(transport=DummyTransport(), session_token="session-token") + active_conn = SimpleNamespace(is_closed=False) + server._connection = active_conn + + hello = HelloPayload( + runner_id="runner-b", + sdk_version="1.0.0", + session_token="session-token", + ) + envelope = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="runner.hello", + payload=hello.model_dump(), + ) + incoming_conn = FakeConnection([codec.encode_envelope(envelope)]) + + await server._handle_connection(incoming_conn) + + assert incoming_conn.is_closed is True + assert server._connection is active_conn + assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手" + assert len(incoming_conn.sent_frames) == 1 + + response = codec.decode_envelope(incoming_conn.sent_frames[0]) + response_payload = HelloResponsePayload.model_validate(response.payload) + assert response_payload.accepted is False + assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手" + def test_ignore_stale_generation_response(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.envelope import Envelope, MessageType diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 2c422775..eb6768c2 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -70,6 +70,7 @@ class RPCServer: self._running: bool = False self._tasks: List[asyncio.Task[None]] = [] self._last_handshake_rejection_reason: str = "" + self._connection_lock: asyncio.Lock = asyncio.Lock() @property def session_token(self) -> str: @@ -216,27 +217,33 @@ class RPCServer: async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") - self.clear_handshake_state() - # 第一条消息必须是 runner.hello 握手 try: - success = await self._handle_handshake(conn) - if not success: - await conn.close() - return + async with self._connection_lock: + self.clear_handshake_state() + success = await self._handle_handshake(conn) + if not success: + await conn.close() + return + logger.info("Runner staged 握手成功") + self._connection = conn except Exception as e: logger.error(f"握手失败: {e}") await conn.close() return - logger.info("Runner staged 握手成功") - self._connection = conn + # 启动消息接收循环 try: await self._recv_loop(conn) except Exception as e: logger.error(f"连接异常断开: {e}") finally: - self._connection = None - self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") + should_fail_pending_requests = False + async with self._connection_lock: + if self._connection is conn: + self._connection = None + should_fail_pending_requests = True + if should_fail_pending_requests: + self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") async def _handle_handshake(self, conn: Connection) -> bool: """处理 runner.hello 握手""" @@ -264,6 +271,15 @@ class RPCServer: await conn.send_frame(self._codec.encode_envelope(resp)) return False + # 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。 + if self.is_connected: + logger.warning("拒绝新的 Runner 连接:已有活跃连接") + self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手" + resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason) + resp = envelope.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(self._codec.encode_envelope(resp)) + return False + # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): logger.error(f"SDK 版本不兼容: {hello.sdk_version}") diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index e66d2fab..e4e47c68 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -1237,11 +1237,14 @@ class PluginRunner: def _isolate_sys_path(plugin_dirs: List[str]) -> None: """清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。 - 防止插件代码 import 主程序模块读取运行时数据。 + 同时移除插件可直接访问的主程序内部模块缓存,避免通过 ``sys.modules`` + 或常规导入绕过 SDK / capability 边界。 """ + import builtins import importlib.abc from importlib.machinery import ModuleSpec import sysconfig + from types import ModuleType # 保留: 标准库路径 + site-packages(含 SDK 和依赖) stdlib_paths = set() @@ -1271,18 +1274,68 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: for d in plugin_dir_paths: allowed.add(d) - # 添加项目根目录(使得 src.plugin_runtime / src.common 可导入) - runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) - allowed.add(runtime_root) - preserved_paths = [p for p in sys.path if p in allowed] - for extra_path in [*plugin_dir_paths, runtime_root]: + for extra_path in plugin_dir_paths: if extra_path not in preserved_paths: preserved_paths.append(extra_path) sys.path[:] = preserved_paths - # 安装 import 钩子,阻止插件导入主程序核心模块 - # 仅允许 src.plugin_runtime 和 src.common,拒绝其他 src.* 子包 + # 仅为旧版插件兼容层保留极小的 src.* 可见面: + # - src.plugin_system.*: 通过 maibot_sdk.compat 导入钩子重定向 + # - src.common.logger: 仓库内仍有少量旧插件沿用该日志入口 + allowed_src_exact_modules = frozenset( + { + "src", + "src.common", + "src.common.logger", + "src.common.logger_color_and_mapping", + } + ) + allowed_src_prefixes = ("src.plugin_system",) + + def _is_allowed_src_module(fullname: str) -> bool: + """判断给定 src.* 模块是否在 Runner 允许列表中。""" + if fullname in allowed_src_exact_modules: + return True + return any(fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in allowed_src_prefixes) + + def _format_block_message(fullname: str) -> str: + """构造统一的拒绝导入错误信息。""" + return ( + f"Runner 子进程不允许导入主程序模块: {fullname}。" + "请改用 maibot_sdk 或 src.plugin_system 兼容层提供的接口。" + ) + + def _detach_module_from_parent(fullname: str, module: ModuleType) -> None: + """从父模块上移除已清理模块的属性引用。""" + parent_name, _, child_name = fullname.rpartition(".") + if not parent_name or not child_name: + return + + parent_module = sys.modules.get(parent_name) + if parent_module is None: + return + if getattr(parent_module, child_name, None) is module: + with contextlib.suppress(AttributeError): + delattr(parent_module, child_name) + + # 清理主程序内部模块缓存,避免插件经由 sys.modules 直接拿到高权限对象。 + existing_src_modules = sorted( + ( + (module_name, module) + for module_name, module in list(sys.modules.items()) + if module_name == "src" or module_name.startswith("src.") + ), + key=lambda item: item[0].count("."), + reverse=True, + ) + for module_name, module in existing_src_modules: + if _is_allowed_src_module(module_name): + continue + _detach_module_from_parent(module_name, module) + sys.modules.pop(module_name, None) + + # 安装 import 钩子,阻止再次导入被清理掉的主程序内部模块。 class _BlockedSrcModuleLoader(importlib.abc.Loader): """阻止被 Runner 允许列表之外的主程序模块完成导入。""" @@ -1295,16 +1348,11 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: def exec_module(self, module: Any) -> None: del module - raise ImportError(f"Runner 子进程不允许导入主程序模块: {self._fullname}") + raise ImportError(_format_block_message(self._fullname)) class _PluginImportBlocker(importlib.abc.MetaPathFinder): - """阻止 Runner 子进程导入主程序核心模块。 + """阻止 Runner 子进程重新导入主程序内部 src.* 模块。""" - 只放行 src.plugin_runtime 和 src.common, - 拒绝 src.chat_module / src.services 等主程序内部包。 - """ - - _ALLOWED_SRC_PREFIXES = ("src.plugin_runtime", "src.common") __maibot_runner_plugin_import_blocker__ = True def find_spec(self, fullname: str, path: Any = None, target: Any = None) -> ModuleSpec | None: @@ -1317,13 +1365,9 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: def _should_block(self, fullname: str) -> bool: """判断给定模块名是否应被阻止导入。""" - # 放行非 src.* 的导入、以及 "src" 本身 - if not fullname.startswith("src.") or fullname == "src": + if not fullname.startswith("src"): return False - # 放行白名单前缀 - return not any( - fullname == prefix or fullname.startswith(f"{prefix}.") for prefix in self._ALLOWED_SRC_PREFIXES - ) + return not _is_allowed_src_module(fullname) sys.meta_path[:] = [ finder @@ -1332,15 +1376,28 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None: ] sys.meta_path.insert(0, _PluginImportBlocker()) + # ``import`` 语句在模块已存在于 sys.modules 时不会再经过 finder, + # 因此还需要在入口处补一层兜底。 + original_import = getattr(builtins, "__maibot_runner_original_import__", builtins.__import__) + builtins.__maibot_runner_original_import__ = original_import + + def _guarded_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: + if level == 0 and name.startswith("src") and not _is_allowed_src_module(name): + raise ImportError(_format_block_message(name)) + return original_import(name, globals, locals, fromlist, level) + + _guarded_import.__maibot_runner_plugin_import_guard__ = True + builtins.__import__ = _guarded_import + # ─── 进程入口 ────────────────────────────────────────────── async def _async_main() -> None: """异步主入口""" - host_address = os.environ.get(ENV_IPC_ADDRESS, "") + host_address = os.environ.pop(ENV_IPC_ADDRESS, "") external_plugin_ids_raw = os.environ.get(ENV_EXTERNAL_PLUGIN_IDS, "") - session_token = os.environ.get(ENV_SESSION_TOKEN, "") + session_token = os.environ.pop(ENV_SESSION_TOKEN, "") plugin_dirs_str = os.environ.get(ENV_PLUGIN_DIRS, "") if not host_address or not session_token: