feat: 增强插件能力检查,支持 generation 校验并添加清理功能
This commit is contained in:
@@ -3,9 +3,11 @@
|
||||
验证协议层、传输层、RPC 通信链路的正确性。
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -227,6 +229,10 @@ class TestHost:
|
||||
ok, reason = engine.check_capability("unknown", "send.text")
|
||||
assert not ok
|
||||
|
||||
ok, reason = engine.check_capability("test_plugin", "send.text", generation=2)
|
||||
assert not ok
|
||||
assert "generation 不匹配" in reason
|
||||
|
||||
def test_circuit_breaker_removed(self):
|
||||
"""熔断器已移除,验证 supervisor 不依赖它"""
|
||||
pass
|
||||
@@ -702,6 +708,49 @@ class TestEventDispatcher:
|
||||
assert modified["plain_text"] == "filtered"
|
||||
|
||||
|
||||
class TestEventBus:
|
||||
"""核心事件总线与 IPC 桥接测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bridge_preserves_modified_message(self, monkeypatch):
|
||||
import types
|
||||
|
||||
fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model")
|
||||
fake_message_data_model.ReplyContentType = object
|
||||
fake_message_data_model.ReplyContent = object
|
||||
fake_message_data_model.ForwardNode = object
|
||||
fake_message_data_model.ReplySetModel = object
|
||||
monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model)
|
||||
|
||||
from src.core.event_bus import EventBus
|
||||
from src.core.types import EventType, MaiMessages
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
async def noop_handler(message):
|
||||
return True, message
|
||||
|
||||
bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True)
|
||||
|
||||
class FakeManager:
|
||||
is_running = True
|
||||
|
||||
async def bridge_event(self, event_type_value, message_dict=None, extra_args=None):
|
||||
assert event_type_value == EventType.ON_MESSAGE.value
|
||||
return True, {"plain_text": "modified by ipc"}
|
||||
|
||||
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
|
||||
|
||||
original = MaiMessages(plain_text="original")
|
||||
continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original)
|
||||
|
||||
assert continue_flag is True
|
||||
assert modified is not None
|
||||
assert modified.plain_text == "modified by ipc"
|
||||
assert original.plain_text == "original"
|
||||
|
||||
|
||||
# ─── MaiMessages 测试 ─────────────────────────────────────
|
||||
|
||||
class TestMaiMessages:
|
||||
@@ -1050,3 +1099,223 @@ class TestWorkflowExecutor:
|
||||
)
|
||||
assert result.status == "completed"
|
||||
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
|
||||
|
||||
|
||||
class TestRPCServer:
|
||||
"""RPC Server 代际保护测试"""
|
||||
|
||||
def test_ignore_stale_generation_response(self):
|
||||
from src.plugin_runtime.host.rpc_server import RPCServer
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
class DummyTransport:
|
||||
async def start(self, handler):
|
||||
return None
|
||||
|
||||
async def stop(self):
|
||||
return None
|
||||
|
||||
def get_address(self):
|
||||
return "dummy"
|
||||
|
||||
server = RPCServer(transport=DummyTransport())
|
||||
server._runner_generation = 2
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
future = loop.create_future()
|
||||
server._pending_requests[1] = future
|
||||
|
||||
stale_response = Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.RESPONSE,
|
||||
method="plugin.health",
|
||||
generation=1,
|
||||
payload={"healthy": True},
|
||||
)
|
||||
server._handle_response(stale_response)
|
||||
|
||||
assert not future.done()
|
||||
assert 1 in server._pending_requests
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
class TestSupervisor:
|
||||
"""Supervisor 生命周期边界测试"""
|
||||
|
||||
@staticmethod
|
||||
def _build_register_payload(plugin_id: str = "plugin_a"):
|
||||
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
|
||||
|
||||
return RegisterComponentsPayload(
|
||||
plugin_id=plugin_id,
|
||||
plugin_version="1.0.0",
|
||||
components=[
|
||||
ComponentDeclaration(
|
||||
name="handler",
|
||||
component_type="event_handler",
|
||||
plugin_id=plugin_id,
|
||||
metadata={"event_type": "on_message"},
|
||||
)
|
||||
],
|
||||
capabilities_required=["send.text"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _make_process(pid: int):
|
||||
class FakeProcess:
|
||||
def __init__(self):
|
||||
self.pid = pid
|
||||
self.returncode = None
|
||||
self.stdout = None
|
||||
self.stderr = None
|
||||
self.terminated = False
|
||||
self.killed = False
|
||||
|
||||
def terminate(self):
|
||||
self.terminated = True
|
||||
self.returncode = 0
|
||||
|
||||
def kill(self):
|
||||
self.killed = True
|
||||
self.returncode = -9
|
||||
|
||||
async def wait(self):
|
||||
return self.returncode
|
||||
|
||||
return FakeProcess()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_waits_for_target_generation(self, monkeypatch):
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
from src.plugin_runtime.protocol.envelope import HealthPayload
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
old_process = self._make_process(1)
|
||||
new_process = self._make_process(2)
|
||||
|
||||
class FakeRPCServer:
|
||||
def __init__(self):
|
||||
self.runner_generation = 1
|
||||
self.is_connected = True
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
||||
assert self.runner_generation == 2
|
||||
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
|
||||
|
||||
supervisor._rpc_server = FakeRPCServer()
|
||||
supervisor._runner_process = old_process
|
||||
|
||||
async def fake_spawn_runner():
|
||||
supervisor._runner_process = new_process
|
||||
|
||||
async def advance_generation():
|
||||
await asyncio.sleep(0.01)
|
||||
supervisor._rpc_server.runner_generation = 2
|
||||
|
||||
asyncio.create_task(advance_generation())
|
||||
|
||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||
|
||||
await supervisor.reload_plugins("test")
|
||||
|
||||
assert supervisor._runner_process is new_process
|
||||
assert old_process.terminated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_restores_runtime_state_on_failure(self, monkeypatch):
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
old_process = self._make_process(1)
|
||||
new_process = self._make_process(2)
|
||||
old_reg = self._build_register_payload()
|
||||
|
||||
supervisor._runner_process = old_process
|
||||
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
|
||||
supervisor._rebuild_runtime_state()
|
||||
|
||||
class FakeRPCServer:
|
||||
def __init__(self):
|
||||
self.runner_generation = 1
|
||||
self.is_connected = True
|
||||
|
||||
async def send_request(self, method, timeout_ms=5000, **kwargs):
|
||||
raise RuntimeError("new runner unhealthy")
|
||||
|
||||
supervisor._rpc_server = FakeRPCServer()
|
||||
|
||||
async def fake_spawn_runner():
|
||||
supervisor._runner_process = new_process
|
||||
supervisor._rpc_server.runner_generation = 2
|
||||
|
||||
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
|
||||
|
||||
await supervisor.reload_plugins("test")
|
||||
|
||||
assert supervisor._runner_process is old_process
|
||||
assert old_reg.plugin_id in supervisor._registered_plugins
|
||||
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_runner_output_tasks_drains_streams(self):
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
|
||||
stdout = asyncio.StreamReader()
|
||||
stdout.feed_data(b"hello stdout\n")
|
||||
stdout.feed_eof()
|
||||
|
||||
stderr = asyncio.StreamReader()
|
||||
stderr.feed_data(b"hello stderr\n")
|
||||
stderr.feed_eof()
|
||||
|
||||
process = SimpleNamespace(pid=99, stdout=stdout, stderr=stderr)
|
||||
supervisor._attach_runner_output_tasks(process)
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert not supervisor._runner_output_tasks
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""运行时集成层启动/清理测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch):
|
||||
from src.plugin_runtime import integration as integration_module
|
||||
|
||||
instances = []
|
||||
|
||||
class FakeCapabilityService:
|
||||
def register_capability(self, name, impl):
|
||||
return None
|
||||
|
||||
class FakeSupervisor:
|
||||
def __init__(self, plugin_dirs=None, socket_path=None):
|
||||
self.plugin_dirs = plugin_dirs or []
|
||||
self.capability_service = FakeCapabilityService()
|
||||
self.stopped = False
|
||||
instances.append(self)
|
||||
|
||||
async def start(self):
|
||||
if len(instances) == 2 and self is instances[1]:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
async def stop(self):
|
||||
self.stopped = True
|
||||
|
||||
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]))
|
||||
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]))
|
||||
|
||||
import src.plugin_runtime.host.supervisor as supervisor_module
|
||||
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
|
||||
|
||||
manager = integration_module.PluginRuntimeManager()
|
||||
await manager.start()
|
||||
|
||||
assert manager.is_running is False
|
||||
assert len(instances) == 2
|
||||
assert instances[0].stopped is True
|
||||
|
||||
Reference in New Issue
Block a user