feat: 增强 RPC 服务器连接处理,添加连接锁以防止并发连接问题

This commit is contained in:
DrSmoothl
2026-03-24 11:43:23 +08:00
parent 1b61e51554
commit f4a9afc452
3 changed files with 252 additions and 33 deletions

View File

@@ -1298,9 +1298,12 @@ class TestDependencyResolution:
assert "on_unload" in loader.failed_plugins["test.demo-plugin"]
def test_isolate_sys_path_preserves_plugin_dirs(self):
import builtins
from src.plugin_runtime.runner import runner_main
plugin_root = os.path.normpath("/tmp/maibot-plugin-root")
original_import = builtins.__import__
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
@@ -1312,14 +1315,17 @@ class TestDependencyResolution:
assert plugin_root in sys.path
finally:
builtins.__import__ = original_import
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
def test_isolate_sys_path_blocks_disallowed_src_imports(self):
import builtins
import importlib
from src.plugin_runtime.runner import runner_main
original_import = builtins.__import__
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
sys.modules.pop("src.forbidden_demo", None)
@@ -1330,10 +1336,89 @@ class TestDependencyResolution:
with pytest.raises(ImportError, match="不允许导入主程序模块"):
importlib.import_module("src.forbidden_demo")
finally:
builtins.__import__ = original_import
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
sys.modules.pop("src.forbidden_demo", None)
def test_isolate_sys_path_blocks_preloaded_runtime_modules(self):
import builtins
import importlib
from src.plugin_runtime.runner import runner_main
original_import = builtins.__import__
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
try:
runner_main._isolate_sys_path([])
with pytest.raises(ImportError, match="rpc_client"):
importlib.import_module("src.plugin_runtime.runner.rpc_client")
finally:
builtins.__import__ = original_import
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
def test_isolate_sys_path_keeps_legacy_logger_import_available(self):
import builtins
import importlib
from src.plugin_runtime.runner import runner_main
original_import = builtins.__import__
original_path = list(sys.path)
original_meta_path = list(sys.meta_path)
try:
runner_main._isolate_sys_path([])
logger_module = importlib.import_module("src.common.logger")
assert callable(logger_module.get_logger)
finally:
builtins.__import__ = original_import
sys.path[:] = original_path
sys.meta_path[:] = original_meta_path
@pytest.mark.asyncio
async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch):
from src.plugin_runtime.runner import runner_main
captured = {}
class FakeRunner:
def __init__(
self,
host_address: str,
session_token: str,
plugin_dirs: list[str],
external_available_plugins: dict[str, str] | None = None,
) -> None:
captured["host_address"] = host_address
captured["session_token"] = session_token
captured["plugin_dirs"] = plugin_dirs
captured["external_available_plugins"] = external_available_plugins or {}
async def run(self) -> None:
assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None
assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None
monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999")
monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token")
monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins")
monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}')
monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None)
monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None)
monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner)
await runner_main._async_main()
assert captured["host_address"] == "tcp://127.0.0.1:9999"
assert captured["session_token"] == "secret-token"
assert captured["plugin_dirs"] == ["/tmp/plugins"]
assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"}
# ─── Host-side ComponentRegistry 测试 ──────────────────────
@@ -2093,6 +2178,67 @@ class TestWorkflowExecutor:
class TestRPCServer:
"""RPC Server 代际保护测试"""
@pytest.mark.asyncio
async def test_reject_second_active_runner_connection(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
class DummyTransport:
async def start(self, handler):
return None
async def stop(self):
return None
def get_address(self):
return "dummy"
class FakeConnection:
def __init__(self, incoming_frames: list[bytes]):
self._incoming_frames = list(incoming_frames)
self.sent_frames: list[bytes] = []
self.is_closed = False
async def recv_frame(self):
return self._incoming_frames.pop(0)
async def send_frame(self, data):
self.sent_frames.append(data)
async def close(self):
self.is_closed = True
codec = MsgPackCodec()
server = RPCServer(transport=DummyTransport(), session_token="session-token")
active_conn = SimpleNamespace(is_closed=False)
server._connection = active_conn
hello = HelloPayload(
runner_id="runner-b",
sdk_version="1.0.0",
session_token="session-token",
)
envelope = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="runner.hello",
payload=hello.model_dump(),
)
incoming_conn = FakeConnection([codec.encode_envelope(envelope)])
await server._handle_connection(incoming_conn)
assert incoming_conn.is_closed is True
assert server._connection is active_conn
assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手"
assert len(incoming_conn.sent_frames) == 1
response = codec.decode_envelope(incoming_conn.sent_frames[0])
response_payload = HelloResponsePayload.model_validate(response.payload)
assert response_payload.accepted is False
assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手"
def test_ignore_stale_generation_response(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.envelope import Envelope, MessageType