feat: 增强 RPC 服务器连接处理,添加连接锁以防止并发连接问题
This commit is contained in:
@@ -1298,9 +1298,12 @@ class TestDependencyResolution:
|
||||
assert "on_unload" in loader.failed_plugins["test.demo-plugin"]
|
||||
|
||||
def test_isolate_sys_path_preserves_plugin_dirs(self):
|
||||
import builtins
|
||||
|
||||
from src.plugin_runtime.runner import runner_main
|
||||
|
||||
plugin_root = os.path.normpath("/tmp/maibot-plugin-root")
|
||||
original_import = builtins.__import__
|
||||
original_path = list(sys.path)
|
||||
original_meta_path = list(sys.meta_path)
|
||||
|
||||
@@ -1312,14 +1315,17 @@ class TestDependencyResolution:
|
||||
|
||||
assert plugin_root in sys.path
|
||||
finally:
|
||||
builtins.__import__ = original_import
|
||||
sys.path[:] = original_path
|
||||
sys.meta_path[:] = original_meta_path
|
||||
|
||||
def test_isolate_sys_path_blocks_disallowed_src_imports(self):
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from src.plugin_runtime.runner import runner_main
|
||||
|
||||
original_import = builtins.__import__
|
||||
original_path = list(sys.path)
|
||||
original_meta_path = list(sys.meta_path)
|
||||
sys.modules.pop("src.forbidden_demo", None)
|
||||
@@ -1330,10 +1336,89 @@ class TestDependencyResolution:
|
||||
with pytest.raises(ImportError, match="不允许导入主程序模块"):
|
||||
importlib.import_module("src.forbidden_demo")
|
||||
finally:
|
||||
builtins.__import__ = original_import
|
||||
sys.path[:] = original_path
|
||||
sys.meta_path[:] = original_meta_path
|
||||
sys.modules.pop("src.forbidden_demo", None)
|
||||
|
||||
def test_isolate_sys_path_blocks_preloaded_runtime_modules(self):
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from src.plugin_runtime.runner import runner_main
|
||||
|
||||
original_import = builtins.__import__
|
||||
original_path = list(sys.path)
|
||||
original_meta_path = list(sys.meta_path)
|
||||
|
||||
try:
|
||||
runner_main._isolate_sys_path([])
|
||||
|
||||
with pytest.raises(ImportError, match="rpc_client"):
|
||||
importlib.import_module("src.plugin_runtime.runner.rpc_client")
|
||||
finally:
|
||||
builtins.__import__ = original_import
|
||||
sys.path[:] = original_path
|
||||
sys.meta_path[:] = original_meta_path
|
||||
|
||||
def test_isolate_sys_path_keeps_legacy_logger_import_available(self):
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
from src.plugin_runtime.runner import runner_main
|
||||
|
||||
original_import = builtins.__import__
|
||||
original_path = list(sys.path)
|
||||
original_meta_path = list(sys.meta_path)
|
||||
|
||||
try:
|
||||
runner_main._isolate_sys_path([])
|
||||
|
||||
logger_module = importlib.import_module("src.common.logger")
|
||||
assert callable(logger_module.get_logger)
|
||||
finally:
|
||||
builtins.__import__ = original_import
|
||||
sys.path[:] = original_path
|
||||
sys.meta_path[:] = original_meta_path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch):
|
||||
from src.plugin_runtime.runner import runner_main
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeRunner:
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
plugin_dirs: list[str],
|
||||
external_available_plugins: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
captured["host_address"] = host_address
|
||||
captured["session_token"] = session_token
|
||||
captured["plugin_dirs"] = plugin_dirs
|
||||
captured["external_available_plugins"] = external_available_plugins or {}
|
||||
|
||||
async def run(self) -> None:
|
||||
assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None
|
||||
assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None
|
||||
|
||||
monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999")
|
||||
monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token")
|
||||
monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins")
|
||||
monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}')
|
||||
monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None)
|
||||
monkeypatch.setattr(runner_main, "_isolate_sys_path", lambda plugin_dirs: None)
|
||||
monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner)
|
||||
|
||||
await runner_main._async_main()
|
||||
|
||||
assert captured["host_address"] == "tcp://127.0.0.1:9999"
|
||||
assert captured["session_token"] == "secret-token"
|
||||
assert captured["plugin_dirs"] == ["/tmp/plugins"]
|
||||
assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"}
|
||||
|
||||
|
||||
# ─── Host-side ComponentRegistry 测试 ──────────────────────
|
||||
|
||||
@@ -2093,6 +2178,67 @@ class TestWorkflowExecutor:
|
||||
class TestRPCServer:
|
||||
"""RPC Server 代际保护测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_second_active_runner_connection(self):
|
||||
from src.plugin_runtime.host.rpc_server import RPCServer
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||||
|
||||
class DummyTransport:
|
||||
async def start(self, handler):
|
||||
return None
|
||||
|
||||
async def stop(self):
|
||||
return None
|
||||
|
||||
def get_address(self):
|
||||
return "dummy"
|
||||
|
||||
class FakeConnection:
|
||||
def __init__(self, incoming_frames: list[bytes]):
|
||||
self._incoming_frames = list(incoming_frames)
|
||||
self.sent_frames: list[bytes] = []
|
||||
self.is_closed = False
|
||||
|
||||
async def recv_frame(self):
|
||||
return self._incoming_frames.pop(0)
|
||||
|
||||
async def send_frame(self, data):
|
||||
self.sent_frames.append(data)
|
||||
|
||||
async def close(self):
|
||||
self.is_closed = True
|
||||
|
||||
codec = MsgPackCodec()
|
||||
server = RPCServer(transport=DummyTransport(), session_token="session-token")
|
||||
active_conn = SimpleNamespace(is_closed=False)
|
||||
server._connection = active_conn
|
||||
|
||||
hello = HelloPayload(
|
||||
runner_id="runner-b",
|
||||
sdk_version="1.0.0",
|
||||
session_token="session-token",
|
||||
)
|
||||
envelope = Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="runner.hello",
|
||||
payload=hello.model_dump(),
|
||||
)
|
||||
incoming_conn = FakeConnection([codec.encode_envelope(envelope)])
|
||||
|
||||
await server._handle_connection(incoming_conn)
|
||||
|
||||
assert incoming_conn.is_closed is True
|
||||
assert server._connection is active_conn
|
||||
assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手"
|
||||
assert len(incoming_conn.sent_frames) == 1
|
||||
|
||||
response = codec.decode_envelope(incoming_conn.sent_frames[0])
|
||||
response_payload = HelloResponsePayload.model_validate(response.payload)
|
||||
assert response_payload.accepted is False
|
||||
assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手"
|
||||
|
||||
def test_ignore_stale_generation_response(self):
|
||||
from src.plugin_runtime.host.rpc_server import RPCServer
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
Reference in New Issue
Block a user