"""插件运行时框架基础测试 验证协议层、传输层、RPC 通信链路的正确性。 """ from types import SimpleNamespace import asyncio import json import os import sys import pytest # 确保项目根目录在 sys.path 中 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) # SDK 包路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) # ─── 协议层测试 ─────────────────────────────────────────── class TestProtocol: """协议层测试""" def test_envelope_create_and_serialize(self): """Envelope 创建与序列化""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType env = Envelope( request_id=1, message_type=MessageType.REQUEST, method="plugin.invoke_command", plugin_id="test_plugin", payload={"component_name": "greet", "args": {}}, ) assert env.request_id == 1 assert env.is_request() assert env.method == "plugin.invoke_command" # 测试 make_response resp = env.make_response(payload={"success": True}) assert resp.is_response() assert resp.request_id == 1 assert resp.payload["success"] is True def test_envelope_make_error_response(self): """错误响应生成""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType env = Envelope( request_id=42, message_type=MessageType.REQUEST, method="cap.request", ) err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限") assert err_resp.error is not None assert err_resp.error["code"] == "E_UNAUTHORIZED" assert err_resp.error["message"] == "没有权限" def test_msgpack_codec(self): """MsgPack 编解码""" from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.envelope import Envelope, MessageType codec = MsgPackCodec() env = Envelope( request_id=100, message_type=MessageType.REQUEST, method="test.method", payload={"key": "value", "number": 42}, ) # 编码 data = codec.encode_envelope(env) assert isinstance(data, bytes) # 解码 decoded = codec.decode_envelope(data) assert decoded.request_id == 100 assert decoded.method == "test.method" assert decoded.payload["key"] == "value" assert decoded.payload["number"] == 42 def test_json_codec(self): """JSON 编解码已移除,仅保留 MsgPack""" pass def test_request_id_generator(self): """请求 ID 生成器单调递增""" from src.plugin_runtime.protocol.envelope import RequestIdGenerator gen = RequestIdGenerator() ids = [gen.next() for _ in range(100)] assert ids == list(range(1, 101)) def test_error_codes(self): """错误码枚举""" from src.plugin_runtime.protocol.errors import ErrorCode, RPCError err = RPCError(ErrorCode.E_TIMEOUT, "请求超时") assert err.code == ErrorCode.E_TIMEOUT assert "E_TIMEOUT" in str(err) # 序列化/反序列化 d = err.to_dict() err2 = RPCError.from_dict(d) assert err2.code == ErrorCode.E_TIMEOUT # ─── 传输层测试 ─────────────────────────────────────────── class TestTransport: """传输层测试""" @pytest.mark.asyncio async def test_uds_connection_framing(self): """UDS 分帧协议测试""" from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient server = UDSTransportServer() received = asyncio.Event() received_data = [] async def handler(conn): data = await conn.recv_frame() received_data.append(data) await conn.send_frame(b"pong") received.set() await server.start(handler) address = server.get_address() client = UDSTransportClient(address) conn = await client.connect() await conn.send_frame(b"ping") # 等待服务端处理 await asyncio.wait_for(received.wait(), timeout=5.0) assert received_data[0] == b"ping" # 接收服务端回复 resp = await conn.recv_frame() assert resp == b"pong" await conn.close() await server.stop() @pytest.mark.asyncio async def test_tcp_connection_framing(self): """TCP 分帧协议测试""" from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient server = TCPTransportServer() received = asyncio.Event() received_data = [] async def handler(conn): data = await conn.recv_frame() received_data.append(data) await conn.send_frame(b"tcp_pong") received.set() await server.start(handler) address = server.get_address() host, port = address.split(":") client = TCPTransportClient(host, int(port)) conn = await client.connect() await conn.send_frame(b"tcp_ping") await asyncio.wait_for(received.wait(), timeout=5.0) assert received_data[0] == b"tcp_ping" resp = await conn.recv_frame() assert resp == b"tcp_pong" await conn.close() await server.stop() @pytest.mark.asyncio @pytest.mark.skipif(sys.platform != "win32", reason="Windows only") async def test_named_pipe_connection_framing(self): """Windows Named Pipe 分帧协议测试""" from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer server = NamedPipeTransportServer() received = asyncio.Event() received_data = [] async def handler(conn): data = await conn.recv_frame() received_data.append(data) await conn.send_frame(b"pipe_pong") received.set() await server.start(handler) client = NamedPipeTransportClient(server.get_address()) conn = await client.connect() await conn.send_frame(b"pipe_ping") await asyncio.wait_for(received.wait(), timeout=5.0) assert received_data[0] == b"pipe_ping" resp = await conn.recv_frame() assert resp == b"pipe_pong" await conn.close() await server.stop() @pytest.mark.asyncio async def test_transport_factory(self): """传输工厂测试""" from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client server = create_transport_server() assert server is not None # UDS 路径 client = create_transport_client("/tmp/test.sock") assert client is not None # Windows Named Pipe 地址 client = create_transport_client(r"\\.\pipe\maibot-test") assert client is not None # TCP 地址 client = create_transport_client("127.0.0.1:9999") assert client is not None # ─── Host 层测试 ────────────────────────────────────────── class TestHost: """Host 端基础设施测试""" def test_policy_engine(self): """策略引擎测试""" from src.plugin_runtime.host.policy_engine import PolicyEngine engine = PolicyEngine() # 注册插件 token = engine.register_plugin( plugin_id="test_plugin", generation=1, capabilities=["send.text", "db.query"], ) assert token.plugin_id == "test_plugin" assert "send.text" in token.capabilities # 能力检查 ok, _ = engine.check_capability("test_plugin", "send.text") assert ok ok, reason = engine.check_capability("test_plugin", "llm.generate") assert not ok assert "未获授权" in reason # 未注册插件 ok, reason = engine.check_capability("unknown", "send.text") assert not ok ok, reason = engine.check_capability("test_plugin", "send.text", generation=2) assert not ok assert "generation 不匹配" in reason def test_policy_engine_allows_parallel_generations(self): """同一插件在热重载期间应允许 active/staged 两代并行持有能力令牌。""" from src.plugin_runtime.host.policy_engine import PolicyEngine engine = PolicyEngine() engine.register_plugin("test_plugin", generation=1, capabilities=["send.text"]) engine.register_plugin("test_plugin", generation=2, capabilities=["send.text", "llm.generate"]) ok, _ = engine.check_capability("test_plugin", "send.text", generation=1) assert ok is True ok, _ = engine.check_capability("test_plugin", "llm.generate", generation=2) assert ok is True ok, reason = engine.check_capability("test_plugin", "llm.generate", generation=1) assert ok is False assert "未获授权" in reason def test_circuit_breaker_removed(self): """熔断器已移除,验证 supervisor 不依赖它""" pass def test_circuit_breaker_registry_removed(self): """熔断器注册表已移除""" pass # ─── SDK 测试 ───────────────────────────────────────────── class TestSDK: """SDK 框架测试""" def test_component_decorators(self): """组件装饰器测试""" from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler from maibot_sdk.types import ActivationType, EventType class TestPlugin(MaiBotPlugin): @Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"]) async def handle_greet(self, **kwargs): return True, "ok" @Command("echo", pattern=r"^/echo") async def handle_echo(self, **kwargs): return True, "echoed", 2 @Tool("search", parameters={"query": {"type": "string"}}) async def handle_search(self, **kwargs): return {"result": "found"} @EventHandler("on_start", event_type=EventType.ON_START) async def handle_start(self, **kwargs): return True, False, "started" plugin = TestPlugin() components = plugin.get_components() assert len(components) == 4 names = {c["name"] for c in components} assert "greet" in names assert "echo" in names assert "search" in names assert "on_start" in names types = {c["type"] for c in components} assert "action" in types assert "command" in types assert "tool" in types assert "event_handler" in types def test_plugin_context_not_initialized(self): """未初始化上下文时应报错""" from maibot_sdk import MaiBotPlugin plugin = MaiBotPlugin() with pytest.raises(RuntimeError, match="尚未初始化"): _ = plugin.ctx def test_plugin_context_injection(self): """上下文注入测试""" from maibot_sdk import MaiBotPlugin from maibot_sdk.context import PluginContext plugin = MaiBotPlugin() ctx = PluginContext(plugin_id="test") plugin._set_context(ctx) assert plugin.ctx.plugin_id == "test" assert plugin.ctx.send is not None assert plugin.ctx.db is not None assert plugin.ctx.llm is not None assert plugin.ctx.config is not None @pytest.mark.asyncio async def test_runner_injected_context_binds_plugin_identity(self): """Runner 注入的上下文应忽略调用方伪造的 plugin_id。""" from src.plugin_runtime.runner.runner_main import PluginRunner class DummyRPCClient: def __init__(self): self.calls = [] async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): self.calls.append( { "method": method, "plugin_id": plugin_id, "payload": payload, "timeout_ms": timeout_ms, } ) return SimpleNamespace(error=None, payload={"result": {"ok": True}}) class DummyPlugin: def _set_context(self, ctx): self.ctx = ctx runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) runner._rpc_client = DummyRPCClient() plugin = DummyPlugin() runner._inject_context("owner_plugin", plugin) plugin.ctx._plugin_id = "forged_plugin" result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1") assert result == {"ok": True} assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin" assert runner._rpc_client.calls[0]["method"] == "cap.request" @pytest.mark.asyncio async def test_runner_applies_initial_plugin_config(self, tmp_path): """Runner 应在 on_load 前为支持的插件实例注入 config.toml。""" from src.plugin_runtime.runner.runner_main import PluginRunner class DummyPlugin: def __init__(self): self.configs = [] def set_plugin_config(self, config): self.configs.append(config) plugin_dir = tmp_path / "demo_plugin" plugin_dir.mkdir() (plugin_dir / "config.toml").write_text("[section]\nvalue = 1\n", encoding="utf-8") runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() meta = SimpleNamespace(plugin_id="demo_plugin", plugin_dir=str(plugin_dir), instance=plugin) runner._apply_plugin_config(meta) assert plugin.configs == [{"section": {"value": 1}}] @pytest.mark.asyncio async def test_runner_config_update_refreshes_plugin_config_before_callback(self): """配置更新时应先刷新插件配置,再调用 on_config_update。""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType from src.plugin_runtime.runner.runner_main import PluginRunner class DummyPlugin: def __init__(self): self.configs = [] self.updates = [] def set_plugin_config(self, config): self.configs.append(config) async def on_config_update(self, config, version): self.updates.append((config, version, list(self.configs))) runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) envelope = Envelope( request_id=1, message_type=MessageType.REQUEST, method="plugin.config_updated", plugin_id="demo_plugin", payload={"config_data": {"enabled": True}, "config_version": "v2"}, ) response = await runner._handle_config_updated(envelope) assert response.payload["acknowledged"] is True assert plugin.configs == [{"enabled": True}] assert plugin.updates == [({"enabled": True}, "v2", [{"enabled": True}])] @pytest.mark.asyncio async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch): """on_load 期间的 capability 调用应在 bootstrap 后生效。""" from src.plugin_runtime.runner.runner_main import PluginRunner class DummyRPCClient: def __init__(self): self.calls = [] async def connect_and_handshake(self): return True def register_method(self, method, handler): return None async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): self.calls.append( { "method": method, "plugin_id": plugin_id, "payload": payload, "timeout_ms": timeout_ms, } ) if method == "cap.call": bootstrap_methods = [call["method"] for call in self.calls[:-1]] assert "plugin.bootstrap" in bootstrap_methods return SimpleNamespace(error=None, payload={"success": True}) return SimpleNamespace(error=None, payload={"accepted": True}) async def disconnect(self): return None class DummyPlugin: def __init__(self, runner): self.runner = runner def _set_context(self, ctx): self.ctx = ctx def get_components(self): return [{"name": "handler", "type": "command", "metadata": {}}] async def on_load(self): result = await self.ctx.call_capability("send.text", text="hello", stream_id="stream-1") assert result is True self.runner._shutting_down = True runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) runner._rpc_client = DummyRPCClient() plugin = DummyPlugin(runner) meta = SimpleNamespace( plugin_id="demo_plugin", plugin_dir="/tmp/demo_plugin", instance=plugin, version="1.0.0", capabilities_required=["send.text"], ) monkeypatch.setattr(runner, "_install_log_handler", lambda: None) monkeypatch.setattr(runner, "_uninstall_log_handler", lambda: asyncio.sleep(0)) monkeypatch.setattr(runner._loader, "discover_and_load", lambda plugin_dirs: [meta]) await runner.run() methods = [call["method"] for call in runner._rpc_client.calls] assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] class TestPluginSdkUsage: """验证仓库内插件按新 SDK 归一化返回值工作。""" def test_runner_skips_signal_handler_registration_on_windows(self, monkeypatch): """Windows 下不应尝试注册 add_signal_handler。""" from src.plugin_runtime.runner import runner_main registered_signals = [] class DummyLoop: def add_signal_handler(self, sig, callback): registered_signals.append((sig, callback)) monkeypatch.setattr(runner_main.sys, "platform", "win32") runner_main._install_shutdown_signal_handlers(lambda: None, DummyLoop()) assert not registered_signals @pytest.mark.asyncio async def test_builtin_emoji_plugin_handles_normalized_results(self): from maibot_sdk.context import PluginContext from src.plugins.built_in.emoji_plugin.plugin import EmojiPlugin async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): assert method == "cap.request" assert payload is not None capability = payload["capability"] return { "emoji.get_random": { "success": True, "emojis": [{"base64": "img-1", "emotion": "happy"}], }, "message.get_recent": {"success": True, "messages": [{"id": 1}]}, "message.build_readable": {"success": True, "text": "最近消息"}, "llm.generate": {"success": True, "response": "happy", "reasoning": "", "model_name": "m"}, "send.emoji": {"success": True}, }[capability] plugin = EmojiPlugin() plugin._set_context(PluginContext(plugin_id="emoji", rpc_call=fake_rpc_call)) success, message = await plugin.handle_emoji(stream_id="stream-1", reasoning="测试", chat_id="chat-1") assert success is True assert "成功发送表情包" in message @pytest.mark.asyncio async def test_tts_plugin_uses_send_custom_bool_result(self): from maibot_sdk.context import PluginContext from src.plugins.built_in.tts_plugin.plugin import TTSPlugin async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): assert method == "cap.request" assert payload is not None assert payload["capability"] == "send.custom" return {"success": True} plugin = TTSPlugin() plugin._set_context(PluginContext(plugin_id="tts", rpc_call=fake_rpc_call)) success, message = await plugin.handle_tts_action( stream_id="stream-1", action_data={"voice_text": "你好!!!"}, ) assert success is True assert message == "TTS动作执行成功" @pytest.mark.asyncio async def test_hello_world_plugin_handles_random_emoji_list(self): from maibot_sdk.context import PluginContext from plugins.hello_world_plugin.plugin import HelloWorldPlugin async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): assert method == "cap.request" assert payload is not None capability = payload["capability"] return { "emoji.get_random": {"success": True, "emojis": [{"base64": "img-1"}, {"base64": "img-2"}]}, "send.forward": {"success": True}, }[capability] plugin = HelloWorldPlugin() plugin._set_context(PluginContext(plugin_id="hello", rpc_call=fake_rpc_call)) success, message, should_continue = await plugin.handle_random_emojis(stream_id="stream-1") assert success is True assert message == "已发送随机表情包" assert should_continue is True # ─── 端到端集成测试 ──────────────────────────────────────── class TestE2E: """端到端集成测试(Host + Runner 通信)""" @pytest.mark.asyncio async def test_handshake(self): """Host-Runner 握手流程测试""" from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server import secrets session_token = secrets.token_hex(16) codec = MsgPackCodec() handshake_done = asyncio.Event() server_result = {} async def server_handler(conn): # 接收握手 data = await conn.recv_frame() env = codec.decode_envelope(data) assert env.method == "runner.hello" hello = HelloPayload.model_validate(env.payload) assert hello.session_token == session_token # 发送响应 resp_payload = HelloResponsePayload( accepted=True, host_version="1.0", assigned_generation=1, ) resp = env.make_response(payload=resp_payload.model_dump()) await conn.send_frame(codec.encode_envelope(resp)) server_result["runner_id"] = hello.runner_id handshake_done.set() # 保持连接一会儿 await asyncio.sleep(1.0) server = create_transport_server() await server.start(server_handler) # 客户端握手 client = create_transport_client(server.get_address()) conn = await client.connect() hello = HelloPayload( runner_id="test-runner", sdk_version="1.0.0", session_token=session_token, ) env = Envelope( request_id=1, message_type=MessageType.REQUEST, method="runner.hello", payload=hello.model_dump(), ) await conn.send_frame(codec.encode_envelope(env)) resp_data = await conn.recv_frame() resp = codec.decode_envelope(resp_data) resp_payload = HelloResponsePayload.model_validate(resp.payload) assert resp_payload.accepted assert resp_payload.assigned_generation == 1 await asyncio.wait_for(handshake_done.wait(), timeout=5.0) assert server_result["runner_id"] == "test-runner" await conn.close() await server.stop() # ─── Manifest 校验测试 ───────────────────────────────────── class TestManifestValidator: """Manifest 校验器测试""" def test_valid_manifest(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator() manifest = { "manifest_version": 1, "name": "test_plugin", "version": "1.0.0", "description": "测试插件", "author": "test", } assert validator.validate(manifest) is True assert len(validator.errors) == 0 def test_missing_required_fields(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator() manifest = {"manifest_version": 1} assert validator.validate(manifest) is False assert len(validator.errors) >= 4 # name, version, description, author def test_unsupported_manifest_version(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator() manifest = { "manifest_version": 999, "name": "test", "version": "1.0", "description": "d", "author": "a", } assert validator.validate(manifest) is False assert any("manifest_version" in e for e in validator.errors) def test_host_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="0.8.5") manifest = { "name": "test", "version": "1.0", "description": "d", "author": "a", "host_application": {"min_version": "0.9.0"}, } assert validator.validate(manifest) is False assert any("Host 版本不兼容" in e for e in validator.errors) def test_recommended_fields_warning(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator() manifest = { "name": "test", "version": "1.0", "description": "d", "author": "a", } validator.validate(manifest) assert len(validator.warnings) >= 3 # license, keywords, categories class TestVersionComparator: """版本号比较器测试""" def test_normalize(self): from src.plugin_runtime.runner.manifest_validator import VersionComparator assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0" assert VersionComparator.normalize_version("1.2") == "1.2.0" assert VersionComparator.normalize_version("") == "0.0.0" def test_compare(self): from src.plugin_runtime.runner.manifest_validator import VersionComparator assert VersionComparator.compare("0.8.0", "0.8.0") == 0 assert VersionComparator.compare("0.8.0", "0.9.0") == -1 assert VersionComparator.compare("1.0.0", "0.9.0") == 1 def test_is_in_range(self): from src.plugin_runtime.runner.manifest_validator import VersionComparator ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0") assert ok ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0") assert not ok ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0") assert not ok # ─── 依赖解析测试 ────────────────────────────────────────── class TestDependencyResolution: """插件依赖解析测试""" def test_topological_sort(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"), "auth": ( "dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py", ), "api": ( "dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) assert len(failed) == 0 assert order.index("core") < order.index("auth") assert order.index("auth") < order.index("api") def test_missing_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { "plugin_a": ( "dir_a", { "name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"], }, "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) assert "plugin_a" in failed assert "缺少依赖" in failed["plugin_a"] def test_circular_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { "a": ( "dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py", ), "b": ( "dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py", ), } order, failed = loader._resolve_dependencies(candidates) assert len(failed) >= 1 # 至少一个循环插件被标记 def test_loader_supports_package_imports_inside_create_plugin(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader plugin_root = tmp_path / "plugins" plugin_root.mkdir() plugin_dir = plugin_root / "grok_search_plugin" plugin_dir.mkdir() (plugin_dir / "_manifest.json").write_text( json.dumps( { "name": "grok_search_plugin", "version": "1.0.0", "description": "demo", "author": "tester", } ), encoding="utf-8", ) (plugin_dir / "__init__.py").write_text("VALUE = 1\n", encoding="utf-8") (plugin_dir / "services.py").write_text("def answer():\n return 42\n", encoding="utf-8") (plugin_dir / "plugin.py").write_text( "class DemoPlugin:\n" " pass\n\n" "def create_plugin():\n" " from grok_search_plugin.services import answer\n" " plugin = DemoPlugin()\n" " plugin.answer = answer\n" " return plugin\n", encoding="utf-8", ) loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) assert [meta.plugin_id for meta in loaded] == ["grok_search_plugin"] assert loader.failed_plugins == {} assert loaded[0].instance.answer() == 42 def test_isolate_sys_path_preserves_plugin_dirs(self): from src.plugin_runtime.runner import runner_main plugin_root = os.path.normpath("/tmp/maibot-plugin-root") original_path = list(sys.path) original_meta_path = list(sys.meta_path) try: if plugin_root in sys.path: sys.path.remove(plugin_root) runner_main._isolate_sys_path([plugin_root]) assert plugin_root in sys.path finally: sys.path[:] = original_path sys.meta_path[:] = original_meta_path # ─── Host-side ComponentRegistry 测试 ────────────────────── class TestComponentRegistry: """Host-side 组件注册表测试""" def test_register_and_query(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "greet", "action", "plugin_a", { "description": "打招呼", "activation_type": "keyword", "activation_keywords": ["hi"], }, ) reg.register_component( "help", "command", "plugin_a", { "command_pattern": r"^/help", }, ) reg.register_component( "search", "tool", "plugin_b", { "description": "搜索", }, ) stats = reg.get_stats() assert stats["total"] == 3 assert stats["action"] == 1 assert stats["command"] == 1 assert stats["tool"] == 1 def test_query_by_type(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component("a1", "action", "p1", {}) reg.register_component("a2", "action", "p2", {}) actions = reg.get_components_by_type("action") assert len(actions) == 2 def test_find_command_by_text(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "help", "command", "p1", { "command_pattern": r"^/help", }, ) reg.register_component( "echo", "command", "p1", { "command_pattern": r"^/echo\s", }, ) match = reg.find_command_by_text("/help me") assert match is not None comp, groups = match assert comp.name == "help" match = reg.find_command_by_text("/echo hello") assert match is not None comp, groups = match assert comp.name == "echo" match = reg.find_command_by_text("no match") assert match is None def test_enable_disable(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component("a1", "action", "p1", {}) reg.set_component_enabled("p1.a1", False) actions = reg.get_components_by_type("action", enabled_only=True) assert len(actions) == 0 actions = reg.get_components_by_type("action", enabled_only=False) assert len(actions) == 1 def test_remove_by_plugin(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component("a1", "action", "p1", {}) reg.register_component("c1", "command", "p1", {}) reg.register_component("a2", "action", "p2", {}) removed = reg.remove_components_by_plugin("p1") assert removed == 2 assert reg.get_stats()["total"] == 1 def test_reregister_same_plugin_replaces_component_set(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_plugin_components( "p1", [ {"name": "a1", "component_type": "action", "metadata": {}}, {"name": "a2", "component_type": "action", "metadata": {}}, ], ) reg.remove_components_by_plugin("p1") reg.register_plugin_components( "p1", [ {"name": "a1", "component_type": "action", "metadata": {}}, ], ) assert reg.get_component("p1.a1") is not None assert reg.get_component("p1.a2") is None def test_event_handlers_sorted_by_weight(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "h_low", "event_handler", "p1", { "event_type": "on_message", "weight": 10, }, ) reg.register_component( "h_high", "event_handler", "p2", { "event_type": "on_message", "weight": 100, }, ) handlers = reg.get_event_handlers("on_message") assert handlers[0].name == "h_high" assert handlers[1].name == "h_low" def test_tools_for_llm(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "search", "tool", "p1", { "description": "搜索工具", "parameters_raw": {"query": {"type": "string"}}, }, ) tools = reg.get_tools_for_llm() assert len(tools) == 1 assert tools[0]["name"] == "p1.search" assert tools[0]["parameters"]["query"]["type"] == "string" # ─── EventDispatcher 测试 ───────────────────────────────── class TestEventDispatcher: """Host-side 事件分发器测试""" @pytest.mark.asyncio async def test_dispatch_non_blocking(self): from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.event_dispatcher import EventDispatcher reg = ComponentRegistry() reg.register_component( "h1", "event_handler", "p1", { "event_type": "on_start", "weight": 0, "intercept_message": False, }, ) dispatcher = EventDispatcher(reg) call_log = [] async def mock_invoke(plugin_id, comp_name, args): call_log.append((plugin_id, comp_name)) return {"success": True, "continue_processing": True} should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke) assert should_continue # 非阻塞分发是异步的,等一下让 task 完成 await asyncio.sleep(0.1) assert len(call_log) == 1 assert call_log[0] == ("p1", "h1") @pytest.mark.asyncio async def test_dispatch_intercepting(self): from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.event_dispatcher import EventDispatcher reg = ComponentRegistry() reg.register_component( "filter", "event_handler", "p1", { "event_type": "on_message_pre_process", "weight": 100, "intercept_message": True, }, ) dispatcher = EventDispatcher(reg) async def mock_invoke(plugin_id, comp_name, args): return { "success": True, "continue_processing": False, "modified_message": {"plain_text": "filtered"}, } should_continue, modified = await dispatcher.dispatch_event( "on_message_pre_process", mock_invoke, message={"plain_text": "hello"} ) assert not should_continue assert modified is not None assert modified["plain_text"] == "filtered" class TestEventBus: """核心事件总线与 IPC 桥接测试""" @pytest.mark.asyncio async def test_bridge_preserves_modified_message(self, monkeypatch): import types fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model") fake_message_data_model.ReplyContentType = object fake_message_data_model.ReplyContent = object fake_message_data_model.ForwardNode = object fake_message_data_model.ReplySetModel = object monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model) from src.core.event_bus import EventBus from src.core.types import EventType, MaiMessages from src.plugin_runtime import integration as integration_module bus = EventBus() async def noop_handler(message): return True, message bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True) class FakeManager: is_running = True async def bridge_event(self, event_type_value, message_dict=None, extra_args=None): assert event_type_value == EventType.ON_MESSAGE.value return True, {"plain_text": "modified by ipc"} monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) original = MaiMessages(plain_text="original") continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original) assert continue_flag is True assert modified is not None assert modified.plain_text == "modified by ipc" assert original.plain_text == "original" # ─── MaiMessages 测试 ───────────────────────────────────── class TestMaiMessages: """统一消息模型测试""" def test_create_and_serialize(self): from maibot_sdk.messages import MaiMessages, MessageSegment msg = MaiMessages( message_segments=[MessageSegment(type="text", data={"text": "hello"})], plain_text="hello", stream_id="stream_1", ) d = msg.to_rpc_dict() assert d["plain_text"] == "hello" assert len(d["message_segments"]) == 1 msg2 = MaiMessages.from_rpc_dict(d) assert msg2.plain_text == "hello" def test_deepcopy(self): from maibot_sdk.messages import MaiMessages msg = MaiMessages(plain_text="original") msg2 = msg.deepcopy() msg2.plain_text = "modified" assert msg.plain_text == "original" def test_modify_flags(self): from maibot_sdk.messages import MaiMessages from maibot_sdk.types import ModifyFlag msg = MaiMessages(plain_text="hello") assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT) msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False) assert not msg.modify_prompt("new prompt") assert msg.llm_prompt is None assert msg.modify_response("new response") assert msg.llm_response_content == "new response" # ─── WorkflowExecutor 测试 ──────────────────────────────── class TestWorkflowExecutor: """Host-side Workflow 执行器测试(新 pipeline 模型)""" @pytest.mark.asyncio async def test_empty_pipeline_completes(self): """无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): return {"hook_result": "continue"} result, final_msg, ctx = await executor.execute( mock_invoke, message={"plain_text": "test"}, ) assert result.status == "completed" assert result.return_message == "workflow completed" assert len(ctx.timings) == 6 # 6 stages @pytest.mark.asyncio async def test_blocking_hook_modifies_message(self): """blocking hook 可以修改消息""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "upper", "workflow_step", "p1", { "stage": "pre_process", "priority": 10, "blocking": True, }, ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): msg = args.get("message", {}) return { "hook_result": "continue", "modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()}, } result, final_msg, ctx = await executor.execute( mock_invoke, message={"plain_text": "hello"}, ) assert result.status == "completed" assert final_msg["plain_text"] == "HELLO" assert len(ctx.modification_log) == 1 assert ctx.modification_log[0].stage == "pre_process" @pytest.mark.asyncio async def test_abort_stops_pipeline(self): """HookResult.ABORT 立即终止 pipeline""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "blocker", "workflow_step", "p1", { "stage": "pre_process", "priority": 10, "blocking": True, }, ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): return {"hook_result": "abort"} result, _, ctx = await executor.execute( mock_invoke, message={"plain_text": "test"}, ) assert result.status == "aborted" assert result.stopped_at == "pre_process" @pytest.mark.asyncio async def test_skip_stage(self): """HookResult.SKIP_STAGE 跳过当前阶段剩余 hook""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() # high-priority hook 返回 skip_stage reg.register_component( "skipper", "workflow_step", "p1", { "stage": "ingress", "priority": 100, "blocking": True, }, ) # low-priority hook 不应被执行 reg.register_component( "checker", "workflow_step", "p2", { "stage": "ingress", "priority": 1, "blocking": True, }, ) executor = WorkflowExecutor(reg) call_log = [] async def mock_invoke(plugin_id, comp_name, args): call_log.append(comp_name) if comp_name == "skipper": return {"hook_result": "skip_stage"} return {"hook_result": "continue"} result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"}) assert result.status == "completed" # 只有 skipper 被调用,checker 被跳过 assert call_log == ["skipper"] @pytest.mark.asyncio async def test_pre_filter(self): """filter 条件不匹配时跳过 hook""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "only_dm", "workflow_step", "p1", { "stage": "ingress", "priority": 10, "blocking": True, "filter": {"chat_type": "direct"}, }, ) executor = WorkflowExecutor(reg) call_log = [] async def mock_invoke(plugin_id, comp_name, args): call_log.append(comp_name) return {"hook_result": "continue"} # 不匹配 filter —— hook 不应被调用 await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"}) assert not call_log # 匹配 filter —— hook 应被调用 await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}) assert call_log == ["only_dm"] @pytest.mark.asyncio async def test_error_policy_skip(self): """error_policy=skip 时跳过失败的 hook 继续执行""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "failer", "workflow_step", "p1", { "stage": "ingress", "priority": 100, "blocking": True, "error_policy": "skip", }, ) reg.register_component( "ok_step", "workflow_step", "p2", { "stage": "ingress", "priority": 1, "blocking": True, }, ) executor = WorkflowExecutor(reg) call_log = [] async def mock_invoke(plugin_id, comp_name, args): call_log.append(comp_name) if comp_name == "failer": raise RuntimeError("boom") return {"hook_result": "continue"} result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) assert result.status == "completed" assert "failer" in call_log assert "ok_step" in call_log assert any("boom" in e for e in ctx.errors) @pytest.mark.asyncio async def test_error_policy_abort(self): """error_policy=abort(默认)时 pipeline 失败""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "failer", "workflow_step", "p1", { "stage": "ingress", "priority": 10, "blocking": True, # error_policy defaults to "abort" }, ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): raise RuntimeError("fatal") result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) assert result.status == "failed" assert result.stopped_at == "ingress" @pytest.mark.asyncio async def test_nonblocking_hooks_concurrent(self): """non-blocking hook 并发执行,不修改消息""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() for i in range(3): reg.register_component( f"nb_{i}", "workflow_step", f"p{i}", { "stage": "post_process", "priority": 0, "blocking": False, }, ) executor = WorkflowExecutor(reg) call_log = [] async def mock_invoke(plugin_id, comp_name, args): call_log.append(comp_name) return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}} result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) # non-blocking 的 modified_message 被忽略 assert final_msg["plain_text"] == "original" # 给异步 task 时间完成 await asyncio.sleep(0.1) assert result.status == "completed" @pytest.mark.asyncio async def test_nonblocking_tasks_are_retained_until_completion(self): """execute 返回后,non-blocking task 仍应保持强引用直到执行完成。""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "observer", "workflow_step", "p1", { "stage": "post_process", "priority": 0, "blocking": False, }, ) executor = WorkflowExecutor(reg) started = asyncio.Event() release = asyncio.Event() async def mock_invoke(plugin_id, comp_name, args): started.set() await release.wait() return {"hook_result": "continue"} result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) await asyncio.sleep(0) assert result.status == "completed" assert final_msg["plain_text"] == "original" assert started.is_set() assert len(executor._background_tasks) == 1 release.set() await asyncio.sleep(0) await asyncio.sleep(0) assert not executor._background_tasks @pytest.mark.asyncio async def test_command_routing(self): """PLAN 阶段内置命令路由""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() reg.register_component( "help", "command", "p1", { "command_pattern": r"^/help", }, ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): if comp_name == "help": return {"output": "帮助信息"} return {"hook_result": "continue"} result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"}) assert result.status == "completed" assert ctx.matched_command == "p1.help" cmd_result = ctx.get_stage_output("plan", "command_result") assert cmd_result is not None assert cmd_result["output"] == "帮助信息" @pytest.mark.asyncio async def test_stage_outputs(self): """stage_outputs 数据在阶段间传递""" from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.workflow_executor import WorkflowExecutor reg = ComponentRegistry() # ingress 阶段写入数据 reg.register_component( "writer", "workflow_step", "p1", { "stage": "ingress", "priority": 10, "blocking": True, }, ) # pre_process 阶段读取数据 reg.register_component( "reader", "workflow_step", "p2", { "stage": "pre_process", "priority": 10, "blocking": True, }, ) executor = WorkflowExecutor(reg) async def mock_invoke(plugin_id, comp_name, args): if comp_name == "writer": return { "hook_result": "continue", "stage_output": {"parsed_intent": "greeting"}, } if comp_name == "reader": # 验证 stage_outputs 被传递过来 outputs = args.get("stage_outputs", {}) ingress_data = outputs.get("ingress", {}) assert ingress_data.get("parsed_intent") == "greeting" return {"hook_result": "continue"} return {"hook_result": "continue"} result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"}) assert result.status == "completed" assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" class TestRPCServer: """RPC Server 代际保护测试""" def test_ignore_stale_generation_response(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.envelope import Envelope, MessageType class DummyTransport: async def start(self, handler): return None async def stop(self): return None def get_address(self): return "dummy" server = RPCServer(transport=DummyTransport()) server._runner_generation = 2 loop = asyncio.new_event_loop() try: future = loop.create_future() server._pending_requests[1] = (future, 2) stale_response = Envelope( request_id=1, message_type=MessageType.RESPONSE, method="plugin.health", generation=1, payload={"healthy": True}, ) server._handle_response(stale_response) assert not future.done() assert 1 in server._pending_requests finally: loop.close() @pytest.mark.asyncio async def test_send_queue_backpressure_is_enforced(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.errors import ErrorCode, RPCError class DummyTransport: async def start(self, handler): return None async def stop(self): return None def get_address(self): return "dummy" class BlockingConnection: def __init__(self): self.is_closed = False self.release = asyncio.Event() async def send_frame(self, data): await self.release.wait() async def close(self): self.is_closed = True server = RPCServer(transport=DummyTransport(), send_queue_size=1) await server.start() conn = BlockingConnection() server._connection = conn server._runner_generation = 1 first_send = asyncio.create_task(server.send_event("runner.log_batch")) await asyncio.sleep(0) second_send = asyncio.create_task(server.send_event("runner.log_batch")) await asyncio.sleep(0) with pytest.raises(RPCError) as exc_info: await server.send_event("runner.log_batch") assert exc_info.value.code == ErrorCode.E_BACKPRESSURE conn.release.set() await asyncio.gather(first_send, second_send) await server.stop() class TestRPCClient: """Runner RPCClient 后台任务生命周期测试""" @pytest.mark.asyncio async def test_background_tasks_retained_and_cancelled_on_disconnect(self): from src.plugin_runtime.runner.rpc_client import RPCClient client = RPCClient(host_address="dummy", session_token="token") release = asyncio.Event() async def pending_task(): await release.wait() task = asyncio.create_task(pending_task()) client._track_background_task(task) assert task in client._background_tasks await asyncio.sleep(0) assert task in client._background_tasks await client.disconnect() assert task.cancelled() is True assert not client._background_tasks class TestSupervisor: """Supervisor 生命周期边界测试""" @staticmethod def _build_register_payload(plugin_id: str = "plugin_a", component_names=None): from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload component_names = component_names or ["handler"] return RegisterComponentsPayload( plugin_id=plugin_id, plugin_version="1.0.0", components=[ ComponentDeclaration( name=name, component_type="event_handler", plugin_id=plugin_id, metadata={"event_type": "on_message"}, ) for name in component_names ], capabilities_required=["send.text"], ) @staticmethod def _make_process(pid: int): class FakeProcess: def __init__(self): self.pid = pid self.returncode = None self.stdout = None self.stderr = None self.terminated = False self.killed = False def terminate(self): self.terminated = True self.returncode = 0 def kill(self): self.killed = True self.returncode = -9 async def wait(self): return self.returncode return FakeProcess() @pytest.mark.asyncio async def test_reload_waits_for_target_generation(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor from src.plugin_runtime.protocol.envelope import HealthPayload supervisor = PluginSupervisor(plugin_dirs=[]) old_process = self._make_process(1) new_process = self._make_process(2) class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" self.committed = False self.staging_started = False def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staging_started = True self.staged_generation = 2 async def commit_staged_takeover(self): self.runner_generation = self.staged_generation self.staged_generation = 0 self.committed = True async def rollback_staged_takeover(self): self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): assert target_generation == 2 return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) supervisor._rpc_server = FakeRPCServer() supervisor._runner_process = old_process async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) supervisor._runner_ready_events[2] = asyncio.Event() supervisor._runner_ready_events[2].set() monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is True assert supervisor._runner_process is new_process assert supervisor._rpc_server.committed is True assert old_process.terminated is True @pytest.mark.asyncio async def test_reload_restores_runtime_state_on_failure(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor supervisor = PluginSupervisor(plugin_dirs=[]) old_process = self._make_process(1) new_process = self._make_process(2) old_reg = self._build_register_payload() supervisor._runner_process = old_process supervisor._registered_plugins[old_reg.plugin_id] = old_reg supervisor._rebuild_runtime_state() class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" self.rolled_back = False def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staged_generation = 2 async def commit_staged_takeover(self): self.runner_generation = self.staged_generation self.staged_generation = 0 async def rollback_staged_takeover(self): self.rolled_back = True self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): raise RuntimeError("new runner unhealthy") supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) supervisor._runner_ready_events[2] = asyncio.Event() supervisor._runner_ready_events[2].set() monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is False assert supervisor._runner_process is old_process assert supervisor._rpc_server.rolled_back is True assert old_reg.plugin_id in supervisor._registered_plugins assert supervisor.component_registry.get_component("plugin_a.handler") is not None @pytest.mark.asyncio async def test_reload_rebuilds_exact_component_set(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor from src.plugin_runtime.protocol.envelope import HealthPayload supervisor = PluginSupervisor(plugin_dirs=[]) old_process = self._make_process(1) new_process = self._make_process(2) old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"]) new_reg = self._build_register_payload("plugin_a", component_names=["handler"]) supervisor._runner_process = old_process supervisor._registered_plugins[old_reg.plugin_id] = old_reg supervisor._rebuild_runtime_state() class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staged_generation = 2 async def commit_staged_takeover(self): self.runner_generation = self.staged_generation self.staged_generation = 0 async def rollback_staged_takeover(self): self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) supervisor._runner_ready_events[2] = asyncio.Event() supervisor._runner_ready_events[2].set() monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is True assert supervisor.component_registry.get_component("plugin_a.handler") is not None assert supervisor.component_registry.get_component("plugin_a.obsolete") is None @pytest.mark.asyncio async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor supervisor = PluginSupervisor(plugin_dirs=[], runner_spawn_timeout_sec=0.01) old_process = self._make_process(1) new_process = self._make_process(2) old_reg = self._build_register_payload() supervisor._runner_process = old_process supervisor._registered_plugins[old_reg.plugin_id] = old_reg supervisor._rebuild_runtime_state() class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" self.rolled_back = False def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staged_generation = 2 async def commit_staged_takeover(self): raise AssertionError("runner.ready 未到达前不应提交 staged takeover") async def rollback_staged_takeover(self): self.rolled_back = True self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): raise AssertionError("runner.ready 未到达前不应执行健康检查") supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is False assert supervisor._runner_process is old_process assert supervisor._rpc_server.rolled_back is True @pytest.mark.asyncio async def test_attach_stderr_drain_drains_stream(self): """_attach_stderr_drain 为 stderr 创建排空任务,读完后任务自动完成。""" from src.plugin_runtime.host.supervisor import PluginSupervisor supervisor = PluginSupervisor(plugin_dirs=[]) stderr = asyncio.StreamReader() stderr.feed_data(b"fatal startup error\n") stderr.feed_eof() # stdout=None 模拟新架构(不再捕获 stdout) process = SimpleNamespace(pid=99, stdout=None, stderr=stderr) supervisor._attach_stderr_drain(process) # 给 drain task 足够时间消费完数据 await asyncio.sleep(0.05) assert supervisor._stderr_drain_task is None or supervisor._stderr_drain_task.done() class TestIntegration: """运行时集成层启动/清理测试""" @pytest.mark.asyncio async def test_cap_database_get_with_filters_does_not_reference_unbound_key_value(self, monkeypatch): from src.plugin_runtime import integration as integration_module import src.common.database.database_model as real_db_models from src.services import database_service as real_database_service captured: dict[str, object] = {} class DummyModel: pass async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False): captured["model_class"] = model_class captured["filters"] = filters captured["limit"] = limit captured["order_by"] = order_by captured["single_result"] = single_result return [{"id": 1}] monkeypatch.setattr(real_database_service, "db_get", fake_db_get) monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False) result = await integration_module.PluginRuntimeManager._cap_database_get( "plugin_a", "database.get", { "table": "DemoTable", "filters": {"status": "active"}, "limit": 5, }, ) assert result == {"success": True, "result": [{"id": 1}]} assert captured["model_class"] is DummyModel assert captured["filters"] == {"status": "active"} assert captured["limit"] == 5 assert captured["single_result"] is False @pytest.mark.asyncio async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch): from src.plugin_runtime import integration as integration_module from src.plugin_runtime.host.component_registry import ComponentRegistry class FakeSupervisor: def __init__(self, plugin_id: str): self.component_registry = ComponentRegistry() self.component_registry.register_component( name="shared", component_type="tool", plugin_id=plugin_id, metadata={}, ) class FakeManager: def __init__(self): self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) result = await integration_module.PluginRuntimeManager._cap_component_enable( "plugin_a", "component.enable", {"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""}, ) assert result["success"] is False assert "组件名不唯一" in result["error"] @pytest.mark.asyncio async def test_component_disable_rejects_non_global_scope(self, monkeypatch): from src.plugin_runtime import integration as integration_module from src.plugin_runtime.host.component_registry import ComponentRegistry class FakeSupervisor: def __init__(self): self.component_registry = ComponentRegistry() self.component_registry.register_component( name="handler", component_type="tool", plugin_id="plugin_a", metadata={}, ) class FakeManager: def __init__(self): self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) result = await integration_module.PluginRuntimeManager._cap_component_disable( "plugin_a", "component.disable", {"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"}, ) assert result["success"] is False assert "仅支持全局组件禁用" in result["error"] @pytest.mark.asyncio async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch): from src.plugin_runtime import integration as integration_module instances = [] class FakeCapabilityService: def register_capability(self, name, impl): return None class FakeSupervisor: def __init__(self, plugin_dirs=None, socket_path=None): self.plugin_dirs = plugin_dirs or [] self.capability_service = FakeCapabilityService() self.stopped = False instances.append(self) async def start(self): if len(instances) == 2 and self is instances[1]: raise RuntimeError("boom") async def stop(self): self.stopped = True monkeypatch.setattr( integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: ["builtin"]) ) monkeypatch.setattr( integration_module.PluginRuntimeManager, "_get_thirdparty_plugin_dirs", staticmethod(lambda: ["thirdparty"]) ) import src.plugin_runtime.host.supervisor as supervisor_module monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor) manager = integration_module.PluginRuntimeManager() await manager.start() assert manager.is_running is False assert len(instances) == 2 assert instances[0].stopped is True @pytest.mark.asyncio async def test_handle_plugin_source_changes_only_reload_matching_supervisor(self, monkeypatch, tmp_path): from src.config.file_watcher import FileChange from src.plugin_runtime import integration as integration_module import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" alpha_dir = builtin_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") monkeypatch.chdir(tmp_path) class FakeSupervisor: def __init__(self, plugin_dirs, registered_plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = registered_plugins self.reload_reasons = [] self.config_updates = [] async def reload_plugins(self, plugin_ids=None, reason="manual"): self.reload_reasons.append((plugin_ids, reason)) async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], {"alpha": object()}) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"beta": object()}) changes = [ FileChange(change_type=1, path=beta_dir / "plugin.py"), ] refresh_calls = [] def fake_refresh() -> None: refresh_calls.append(True) manager._refresh_plugin_config_watch_subscriptions = fake_refresh await manager._handle_plugin_source_changes(changes) assert manager._builtin_supervisor.reload_reasons == [] assert manager._third_party_supervisor.reload_reasons == [(["beta"], "file_watcher")] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] assert refresh_calls == [True] @pytest.mark.asyncio async def test_handle_plugin_config_changes_only_reload_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" alpha_dir = builtin_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") monkeypatch.chdir(tmp_path) class FakeSupervisor: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} self.reload_calls = [] async def reload_plugin(self, plugin_id, reason="manual"): self.reload_calls.append((plugin_id, reason)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) refresh_calls = [] def fake_refresh() -> None: refresh_calls.append(True) manager._refresh_plugin_config_watch_subscriptions = fake_refresh await manager._handle_plugin_config_changes( "alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) assert manager._builtin_supervisor.reload_calls == [("alpha", "config_file_changed")] assert manager._third_party_supervisor.reload_calls == [] assert refresh_calls == [True] def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): from src.plugin_runtime import integration as integration_module import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" alpha_dir = builtin_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps({"name": "alpha"}), encoding="utf-8") (beta_dir / "_manifest.json").write_text(json.dumps({"name": "beta"}), encoding="utf-8") class FakeWatcher: def __init__(self): self.subscriptions = [] self.unsubscribed = [] def subscribe(self, callback, *, paths=None, change_types=None): subscription_id = f"sub-{len(self.subscriptions) + 1}" self.subscriptions.append({"id": subscription_id, "callback": callback, "paths": tuple(paths or ())}) return subscription_id def unsubscribe(self, subscription_id): self.unsubscribed.append(subscription_id) return True class FakeSupervisor: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} manager = integration_module.PluginRuntimeManager() manager._plugin_file_watcher = FakeWatcher() manager._builtin_supervisor = FakeSupervisor([builtin_root], ["alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["beta"]) manager._refresh_plugin_config_watch_subscriptions() assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"alpha", "beta"} assert { subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions } == {alpha_dir / "config.toml", beta_dir / "config.toml"} @pytest.mark.asyncio async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch): from src.plugin_runtime import integration as integration_module class FakeSupervisor: def __init__(self): self._registered_plugins = {"alpha": object()} async def reload_plugins(self, reason="manual"): return False class FakeManager: def __init__(self): self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) result = await integration_module.PluginRuntimeManager._cap_component_reload_plugin( "plugin_a", "component.reload_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False assert "已回滚" in result["error"] @pytest.mark.asyncio async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module plugin_root = tmp_path / "plugins" plugin_root.mkdir() (plugin_root / "alpha").mkdir() class FakeSupervisor: def __init__(self): self._registered_plugins = {} self._plugin_dirs = [str(plugin_root)] async def reload_plugins(self, reason="manual"): return False class FakeManager: def __init__(self): self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) result = await integration_module.PluginRuntimeManager._cap_component_load_plugin( "plugin_a", "component.load_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False assert "已回滚" in result["error"]