feat: 增强插件能力检查,支持 generation 校验并添加清理功能

This commit is contained in:
DrSmoothl
2026-03-12 21:22:23 +08:00
parent df39fa7584
commit d0b56abdab
8 changed files with 466 additions and 51 deletions

View File

@@ -3,9 +3,11 @@
验证协议层、传输层、RPC 通信链路的正确性。
"""
from types import SimpleNamespace
import asyncio
import sys
import os
import sys
import pytest
@@ -227,6 +229,10 @@ class TestHost:
ok, reason = engine.check_capability("unknown", "send.text")
assert not ok
ok, reason = engine.check_capability("test_plugin", "send.text", generation=2)
assert not ok
assert "generation 不匹配" in reason
def test_circuit_breaker_removed(self):
"""熔断器已移除,验证 supervisor 不依赖它"""
pass
@@ -702,6 +708,49 @@ class TestEventDispatcher:
assert modified["plain_text"] == "filtered"
class TestEventBus:
"""核心事件总线与 IPC 桥接测试"""
@pytest.mark.asyncio
async def test_bridge_preserves_modified_message(self, monkeypatch):
import types
fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model")
fake_message_data_model.ReplyContentType = object
fake_message_data_model.ReplyContent = object
fake_message_data_model.ForwardNode = object
fake_message_data_model.ReplySetModel = object
monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model)
from src.core.event_bus import EventBus
from src.core.types import EventType, MaiMessages
from src.plugin_runtime import integration as integration_module
bus = EventBus()
async def noop_handler(message):
return True, message
bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True)
class FakeManager:
is_running = True
async def bridge_event(self, event_type_value, message_dict=None, extra_args=None):
assert event_type_value == EventType.ON_MESSAGE.value
return True, {"plain_text": "modified by ipc"}
monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager())
original = MaiMessages(plain_text="original")
continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original)
assert continue_flag is True
assert modified is not None
assert modified.plain_text == "modified by ipc"
assert original.plain_text == "original"
# ─── MaiMessages 测试 ─────────────────────────────────────
class TestMaiMessages:
@@ -1050,3 +1099,223 @@ class TestWorkflowExecutor:
)
assert result.status == "completed"
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
class TestRPCServer:
"""RPC Server 代际保护测试"""
def test_ignore_stale_generation_response(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
class DummyTransport:
async def start(self, handler):
return None
async def stop(self):
return None
def get_address(self):
return "dummy"
server = RPCServer(transport=DummyTransport())
server._runner_generation = 2
loop = asyncio.new_event_loop()
try:
future = loop.create_future()
server._pending_requests[1] = future
stale_response = Envelope(
request_id=1,
message_type=MessageType.RESPONSE,
method="plugin.health",
generation=1,
payload={"healthy": True},
)
server._handle_response(stale_response)
assert not future.done()
assert 1 in server._pending_requests
finally:
loop.close()
class TestSupervisor:
"""Supervisor 生命周期边界测试"""
@staticmethod
def _build_register_payload(plugin_id: str = "plugin_a"):
from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload
return RegisterComponentsPayload(
plugin_id=plugin_id,
plugin_version="1.0.0",
components=[
ComponentDeclaration(
name="handler",
component_type="event_handler",
plugin_id=plugin_id,
metadata={"event_type": "on_message"},
)
],
capabilities_required=["send.text"],
)
@staticmethod
def _make_process(pid: int):
class FakeProcess:
def __init__(self):
self.pid = pid
self.returncode = None
self.stdout = None
self.stderr = None
self.terminated = False
self.killed = False
def terminate(self):
self.terminated = True
self.returncode = 0
def kill(self):
self.killed = True
self.returncode = -9
async def wait(self):
return self.returncode
return FakeProcess()
@pytest.mark.asyncio
async def test_reload_waits_for_target_generation(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.protocol.envelope import HealthPayload
supervisor = PluginSupervisor(plugin_dirs=[])
old_process = self._make_process(1)
new_process = self._make_process(2)
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.is_connected = True
async def send_request(self, method, timeout_ms=5000, **kwargs):
assert self.runner_generation == 2
return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump())
supervisor._rpc_server = FakeRPCServer()
supervisor._runner_process = old_process
async def fake_spawn_runner():
supervisor._runner_process = new_process
async def advance_generation():
await asyncio.sleep(0.01)
supervisor._rpc_server.runner_generation = 2
asyncio.create_task(advance_generation())
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
await supervisor.reload_plugins("test")
assert supervisor._runner_process is new_process
assert old_process.terminated is True
@pytest.mark.asyncio
async def test_reload_restores_runtime_state_on_failure(self, monkeypatch):
from src.plugin_runtime.host.supervisor import PluginSupervisor
supervisor = PluginSupervisor(plugin_dirs=[])
old_process = self._make_process(1)
new_process = self._make_process(2)
old_reg = self._build_register_payload()
supervisor._runner_process = old_process
supervisor._registered_plugins[old_reg.plugin_id] = old_reg
supervisor._rebuild_runtime_state()
class FakeRPCServer:
def __init__(self):
self.runner_generation = 1
self.is_connected = True
async def send_request(self, method, timeout_ms=5000, **kwargs):
raise RuntimeError("new runner unhealthy")
supervisor._rpc_server = FakeRPCServer()
async def fake_spawn_runner():
supervisor._runner_process = new_process
supervisor._rpc_server.runner_generation = 2
monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner)
await supervisor.reload_plugins("test")
assert supervisor._runner_process is old_process
assert old_reg.plugin_id in supervisor._registered_plugins
assert supervisor.component_registry.get_component("plugin_a.handler") is not None
@pytest.mark.asyncio
async def test_attach_runner_output_tasks_drains_streams(self):
from src.plugin_runtime.host.supervisor import PluginSupervisor
supervisor = PluginSupervisor(plugin_dirs=[])
stdout = asyncio.StreamReader()
stdout.feed_data(b"hello stdout\n")
stdout.feed_eof()
stderr = asyncio.StreamReader()
stderr.feed_data(b"hello stderr\n")
stderr.feed_eof()
process = SimpleNamespace(pid=99, stdout=stdout, stderr=stderr)
supervisor._attach_runner_output_tasks(process)
await asyncio.sleep(0.05)
assert not supervisor._runner_output_tasks
class TestIntegration:
"""运行时集成层启动/清理测试"""
@pytest.mark.asyncio
async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch):
from src.plugin_runtime import integration as integration_module
instances = []
class FakeCapabilityService:
def register_capability(self, name, impl):
return None
class FakeSupervisor:
def __init__(self, plugin_dirs=None, socket_path=None):
self.plugin_dirs = plugin_dirs or []
self.capability_service = FakeCapabilityService()
self.stopped = False
instances.append(self)
async def start(self):
if len(instances) == 2 and self is instances[1]:
raise RuntimeError("boom")
async def stop(self):
self.stopped = True
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]))
monkeypatch.setattr(integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]))
import src.plugin_runtime.host.supervisor as supervisor_module
monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor)
manager = integration_module.PluginRuntimeManager()
await manager.start()
assert manager.is_running is False
assert len(instances) == 2
assert instances[0].stopped is True