"""插件运行时框架基础测试 验证协议层、传输层、RPC 通信链路的正确性。 """ # pyright: reportArgumentType=false, reportAttributeAccessIssue=false, reportCallIssue=false, reportIndexIssue=false, reportMissingImports=false, reportOptionalMemberAccess=false from pathlib import Path from types import SimpleNamespace from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence import asyncio import json import os import sys import pytest # 确保项目根目录在 sys.path 中 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) # SDK 包路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) def build_test_manifest( plugin_id: str, *, version: str = "1.0.0", name: str = "测试插件", description: str = "测试插件描述", dependencies: list[dict[str, str]] | None = None, capabilities: list[str] | None = None, host_min_version: str = "0.12.0", host_max_version: str = "1.0.0", sdk_min_version: str = "2.0.0", sdk_max_version: str = "2.99.99", ) -> dict[str, object]: """构造一个合法的 Manifest v2 测试样例。 Args: plugin_id: 插件 ID。 version: 插件版本。 name: 展示名称。 description: 插件描述。 dependencies: 依赖声明列表。 capabilities: 能力声明列表。 host_min_version: Host 最低支持版本。 host_max_version: Host 最高支持版本。 sdk_min_version: SDK 最低支持版本。 sdk_max_version: SDK 最高支持版本。 Returns: dict[str, object]: 可直接序列化为 ``_manifest.json`` 的字典。 """ return { "manifest_version": 2, "version": version, "name": name, "description": description, "author": { "name": "tester", "url": "https://example.com/tester", }, "license": "MIT", "urls": { "repository": f"https://example.com/{plugin_id}", }, "host_application": { "min_version": host_min_version, "max_version": host_max_version, }, "sdk": { "min_version": sdk_min_version, "max_version": sdk_max_version, }, "dependencies": dependencies or [], "capabilities": capabilities or [], "i18n": { "default_locale": "zh-CN", "supported_locales": ["zh-CN"], }, "id": plugin_id, } def build_test_manifest_model( plugin_id: str, *, version: str = "1.0.0", dependencies: list[dict[str, str]] | None = None, capabilities: list[str] | None = None, host_version: str = "1.0.0", sdk_version: str = "2.0.1", ) -> object: """构造一个已经通过校验的强类型 Manifest 测试对象。 Args: plugin_id: 插件 ID。 version: 插件版本。 dependencies: 依赖声明列表。 capabilities: 能力声明列表。 host_version: 当前测试使用的 Host 版本。 sdk_version: 当前测试使用的 SDK 版本。 Returns: object: ``PluginManifest`` 实例。 """ from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version=host_version, sdk_version=sdk_version) manifest = validator.parse_manifest( build_test_manifest( plugin_id, version=version, dependencies=dependencies, capabilities=capabilities, ) ) assert manifest is not None return manifest # ─── 协议层测试 ─────────────────────────────────────────── class TestProtocol: """协议层测试""" def test_envelope_create_and_serialize(self): """Envelope 创建与序列化""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType env = Envelope( request_id=1, message_type=MessageType.REQUEST, method="plugin.invoke_command", plugin_id="test_plugin", payload={"component_name": "greet", "args": {}}, ) assert env.request_id == 1 assert env.is_request() assert env.method == "plugin.invoke_command" # 测试 make_response resp = env.make_response(payload={"success": True}) assert resp.is_response() assert resp.request_id == 1 assert resp.payload["success"] is True def test_envelope_make_error_response(self): """错误响应生成""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType env = Envelope( request_id=42, message_type=MessageType.REQUEST, method="cap.request", ) err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限") assert err_resp.error is not None assert err_resp.error["code"] == "E_UNAUTHORIZED" assert err_resp.error["message"] == "没有权限" def test_msgpack_codec(self): """MsgPack 编解码""" from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.envelope import Envelope, MessageType codec = MsgPackCodec() env = Envelope( request_id=100, message_type=MessageType.REQUEST, method="test.method", payload={"key": "value", "number": 42}, ) # 编码 data = codec.encode_envelope(env) assert isinstance(data, bytes) # 解码 decoded = codec.decode_envelope(data) assert decoded.request_id == 100 assert decoded.method == "test.method" assert decoded.payload["key"] == "value" assert decoded.payload["number"] == 42 def test_json_codec(self): """JSON 编解码已移除,仅保留 MsgPack""" pass def test_request_id_generator(self): """请求 ID 生成器单调递增""" from src.plugin_runtime.protocol.envelope import RequestIdGenerator gen = RequestIdGenerator() ids = [gen.next() for _ in range(100)] assert ids == list(range(1, 101)) def test_error_codes(self): """错误码枚举""" from src.plugin_runtime.protocol.errors import ErrorCode, RPCError err = RPCError(ErrorCode.E_TIMEOUT, "请求超时") assert err.code == ErrorCode.E_TIMEOUT assert "E_TIMEOUT" in str(err) # 序列化/反序列化 d = err.to_dict() err2 = RPCError.from_dict(d) assert err2.code == ErrorCode.E_TIMEOUT # ─── 传输层测试 ─────────────────────────────────────────── class TestTransport: """传输层测试""" @pytest.mark.asyncio async def test_uds_connection_framing(self): """UDS 分帧协议测试""" from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient server = UDSTransportServer() received = asyncio.Event() received_data = [] async def handler(conn): data = await conn.recv_frame() received_data.append(data) await conn.send_frame(b"pong") received.set() await server.start(handler) address = server.get_address() client = UDSTransportClient(address) conn = await client.connect() await conn.send_frame(b"ping") # 等待服务端处理 await asyncio.wait_for(received.wait(), timeout=5.0) assert received_data[0] == b"ping" # 接收服务端回复 resp = await conn.recv_frame() assert resp == b"pong" await conn.close() await server.stop() @pytest.mark.asyncio async def test_tcp_connection_framing(self): """TCP 分帧协议测试""" from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient server = TCPTransportServer() received = asyncio.Event() received_data = [] async def handler(conn): data = await conn.recv_frame() received_data.append(data) await conn.send_frame(b"tcp_pong") received.set() await server.start(handler) address = server.get_address() host, port = address.split(":") client = TCPTransportClient(host, int(port)) conn = await client.connect() await conn.send_frame(b"tcp_ping") await asyncio.wait_for(received.wait(), timeout=5.0) assert received_data[0] == b"tcp_ping" resp = await conn.recv_frame() assert resp == b"tcp_pong" await conn.close() await server.stop() @pytest.mark.asyncio @pytest.mark.skipif(sys.platform != "win32", reason="Windows only") async def test_named_pipe_connection_framing(self): """Windows Named Pipe 分帧协议测试""" from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer server = NamedPipeTransportServer() received = asyncio.Event() received_data = [] async def handler(conn): data = await conn.recv_frame() received_data.append(data) await conn.send_frame(b"pipe_pong") received.set() await server.start(handler) client = NamedPipeTransportClient(server.get_address()) conn = await client.connect() await conn.send_frame(b"pipe_ping") await asyncio.wait_for(received.wait(), timeout=5.0) assert received_data[0] == b"pipe_ping" resp = await conn.recv_frame() assert resp == b"pipe_pong" await conn.close() await server.stop() @pytest.mark.asyncio async def test_transport_factory(self): """传输工厂测试""" from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client server = create_transport_server() assert server is not None # UDS 路径 client = create_transport_client("/tmp/test.sock") assert client is not None # Windows Named Pipe 地址 client = create_transport_client(r"\\.\pipe\maibot-test") assert client is not None # TCP 地址 client = create_transport_client("127.0.0.1:9999") assert client is not None # ─── Host 层测试 ────────────────────────────────────────── class TestHost: """Host 端基础设施测试""" def test_policy_engine(self): """策略引擎测试""" from src.plugin_runtime.host.policy_engine import PolicyEngine engine = PolicyEngine() # 注册插件 token = engine.register_plugin( plugin_id="test_plugin", generation=1, capabilities=["send.text", "db.query"], ) assert token.plugin_id == "test_plugin" assert "send.text" in token.capabilities # 能力检查 ok, _ = engine.check_capability("test_plugin", "send.text") assert ok ok, reason = engine.check_capability("test_plugin", "llm.generate") assert not ok assert "未获授权" in reason # 未注册插件 ok, reason = engine.check_capability("unknown", "send.text") assert not ok ok, reason = engine.check_capability("test_plugin", "send.text", generation=2) assert not ok assert "generation 不匹配" in reason def test_policy_engine_allows_parallel_generations(self): """同一插件在热重载期间应允许 active/staged 两代并行持有能力令牌。""" from src.plugin_runtime.host.policy_engine import PolicyEngine engine = PolicyEngine() engine.register_plugin("test_plugin", generation=1, capabilities=["send.text"]) engine.register_plugin("test_plugin", generation=2, capabilities=["send.text", "llm.generate"]) ok, _ = engine.check_capability("test_plugin", "send.text", generation=1) assert ok is True ok, _ = engine.check_capability("test_plugin", "llm.generate", generation=2) assert ok is True ok, reason = engine.check_capability("test_plugin", "llm.generate", generation=1) assert ok is False assert "未获授权" in reason def test_circuit_breaker_removed(self): """熔断器已移除,验证 supervisor 不依赖它""" pass def test_circuit_breaker_registry_removed(self): """熔断器注册表已移除""" pass # ─── SDK 测试 ───────────────────────────────────────────── class TestSDK: """SDK 框架测试""" def test_component_decorators(self): """组件装饰器测试""" from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler from maibot_sdk.types import ActivationType, EventType class TestPlugin(MaiBotPlugin): @Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"]) async def handle_greet(self, **kwargs): return True, "ok" @Command("echo", pattern=r"^/echo") async def handle_echo(self, **kwargs): return True, "echoed", 2 @Tool("search", parameters={"query": {"type": "string"}}) async def handle_search(self, **kwargs): return {"result": "found"} @EventHandler("on_start", event_type=EventType.ON_START) async def handle_start(self, **kwargs): return True, False, "started" plugin = TestPlugin() components = plugin.get_components() assert len(components) == 4 names = {c["name"] for c in components} assert "greet" in names assert "echo" in names assert "search" in names assert "on_start" in names types = {c["type"] for c in components} assert "action" in types assert "command" in types assert "tool" in types assert "event_handler" in types def test_plugin_context_not_initialized(self): """未初始化上下文时应报错""" from maibot_sdk import MaiBotPlugin plugin = MaiBotPlugin() with pytest.raises(RuntimeError, match="尚未初始化"): _ = plugin.ctx def test_plugin_context_injection(self): """上下文注入测试""" from maibot_sdk import MaiBotPlugin from maibot_sdk.context import PluginContext plugin = MaiBotPlugin() ctx = PluginContext(plugin_id="test") plugin._set_context(ctx) assert plugin.ctx.plugin_id == "test" assert plugin.ctx.send is not None assert plugin.ctx.db is not None assert plugin.ctx.llm is not None assert plugin.ctx.config is not None @pytest.mark.asyncio async def test_runner_injected_context_binds_plugin_identity(self): """Runner 注入的上下文应忽略调用方伪造的 plugin_id。""" from src.plugin_runtime.runner.runner_main import PluginRunner class DummyRPCClient: def __init__(self): self.calls = [] async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): self.calls.append( { "method": method, "plugin_id": plugin_id, "payload": payload, "timeout_ms": timeout_ms, } ) return SimpleNamespace(error=None, payload={"result": {"ok": True}}) class DummyPlugin: def _set_context(self, ctx): self.ctx = ctx runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) runner._rpc_client = DummyRPCClient() plugin = DummyPlugin() runner._inject_context("owner_plugin", plugin) plugin.ctx._plugin_id = "forged_plugin" result = await plugin.ctx.call_capability("send.text", text="hello", stream_id="stream-1") assert result == {"ok": True} assert runner._rpc_client.calls[0]["plugin_id"] == "owner_plugin" assert runner._rpc_client.calls[0]["method"] == "cap.request" @pytest.mark.asyncio async def test_runner_applies_initial_plugin_config(self, tmp_path): """Runner 应在 on_load 前为支持的插件实例注入 config.toml。""" from src.plugin_runtime.runner.runner_main import PluginRunner class DummyPlugin: def __init__(self): self.configs = [] def set_plugin_config(self, config): self.configs.append(config) plugin_dir = tmp_path / "demo_plugin" plugin_dir.mkdir() (plugin_dir / "config.toml").write_text("[section]\nvalue = 1\n", encoding="utf-8") runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() meta = SimpleNamespace(plugin_id="demo_plugin", plugin_dir=str(plugin_dir), instance=plugin) runner._apply_plugin_config(meta) assert plugin.configs == [{"section": {"value": 1}}] @pytest.mark.asyncio async def test_runner_config_update_refreshes_plugin_config_before_callback(self): """配置更新时应先刷新插件配置,再调用 on_config_update。""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType from src.plugin_runtime.runner.runner_main import PluginRunner class DummyPlugin: def __init__(self): self.configs = [] self.updates = [] def set_plugin_config(self, config): self.configs.append(config) async def on_config_update(self, scope, config, version): self.updates.append((scope, config, version, list(self.configs))) runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) envelope = Envelope( request_id=1, message_type=MessageType.REQUEST, method="plugin.config_updated", plugin_id="demo_plugin", payload={ "plugin_id": "demo_plugin", "config_scope": "self", "config_data": {"enabled": True}, "config_version": "v2", }, ) response = await runner._handle_config_updated(envelope) assert response.payload["acknowledged"] is True assert plugin.configs == [{"enabled": True}] assert plugin.updates == [("self", {"enabled": True}, "v2", [{"enabled": True}])] @pytest.mark.asyncio async def test_runner_global_config_update_does_not_override_plugin_config(self): """bot/model 广播不应覆盖插件自身配置缓存。""" from src.plugin_runtime.protocol.envelope import Envelope, MessageType from src.plugin_runtime.runner.runner_main import PluginRunner class DummyPlugin: def __init__(self): self.configs = [] self.updates = [] def set_plugin_config(self, config): self.configs.append(config) async def on_config_update(self, scope, config, version): self.updates.append((scope, config, version, list(self.configs))) runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin = DummyPlugin() runner._loader._loaded_plugins["demo_plugin"] = SimpleNamespace(instance=plugin) plugin.set_plugin_config({"plugin_enabled": True}) envelope = Envelope( request_id=1, message_type=MessageType.REQUEST, method="plugin.config_updated", plugin_id="demo_plugin", payload={ "plugin_id": "demo_plugin", "config_scope": "model", "config_data": {"models": []}, "config_version": "", }, ) response = await runner._handle_config_updated(envelope) assert response.payload["acknowledged"] is True assert plugin.configs == [{"plugin_enabled": True}] assert plugin.updates == [("model", {"models": []}, "", [{"plugin_enabled": True}])] @pytest.mark.asyncio async def test_runner_bootstraps_capabilities_before_on_load(self, monkeypatch): """on_load 期间的 capability 调用应在 bootstrap 后生效。""" from src.plugin_runtime.runner.runner_main import PluginRunner class DummyRPCClient: def __init__(self): self.calls = [] async def connect_and_handshake(self): return True def register_method(self, method, handler): return None async def send_request(self, method, plugin_id="", payload=None, timeout_ms=30000): self.calls.append( { "method": method, "plugin_id": plugin_id, "payload": payload, "timeout_ms": timeout_ms, } ) if method == "cap.call": bootstrap_methods = [call["method"] for call in self.calls[:-1]] assert "plugin.bootstrap" in bootstrap_methods return SimpleNamespace(error=None, payload={"success": True}) return SimpleNamespace(error=None, payload={"accepted": True}) async def disconnect(self): return None class DummyPlugin: def __init__(self, runner): self.runner = runner def _set_context(self, ctx): self.ctx = ctx def get_components(self): return [{"name": "handler", "type": "command", "metadata": {}}] async def on_load(self): result = await self.ctx.call_capability("send.text", text="hello", stream_id="stream-1") assert result is True self.runner._shutting_down = True runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) runner._rpc_client = DummyRPCClient() plugin = DummyPlugin(runner) meta = SimpleNamespace( plugin_id="demo_plugin", plugin_dir="/tmp/demo_plugin", instance=plugin, version="1.0.0", capabilities_required=["send.text"], ) monkeypatch.setattr(runner, "_install_log_handler", lambda: None) monkeypatch.setattr(runner, "_uninstall_log_handler", lambda: asyncio.sleep(0)) monkeypatch.setattr(runner._loader, "discover_and_load", lambda plugin_dirs: [meta]) await runner.run() methods = [call["method"] for call in runner._rpc_client.calls] assert methods == ["plugin.bootstrap", "plugin.register_components", "cap.call", "runner.ready"] @pytest.mark.asyncio async def test_runner_batch_reload_merges_overlapping_reverse_dependents(self, monkeypatch): """批量重载应只对重叠依赖闭包执行一次 unload/load。""" from src.plugin_runtime.runner.runner_main import PluginRunner runner = PluginRunner(host_address="dummy", session_token="token", plugin_dirs=[]) plugin_a_id = "test.plugin-a" plugin_b_id = "test.plugin-b" plugin_c_id = "test.plugin-c" def build_meta(plugin_id: str, dependencies: list[str]) -> SimpleNamespace: return SimpleNamespace( plugin_id=plugin_id, dependencies=dependencies, plugin_dir=f"/tmp/{plugin_id}", version="1.0.0", instance=SimpleNamespace(), ) loaded_metas = { plugin_a_id: build_meta(plugin_a_id, []), plugin_b_id: build_meta(plugin_b_id, [plugin_a_id]), plugin_c_id: build_meta(plugin_c_id, [plugin_b_id]), } reloaded_metas = { plugin_id: build_meta(plugin_id, list(meta.dependencies)) for plugin_id, meta in loaded_metas.items() } candidates = { plugin_a_id: ( "dir_plugin_a", build_test_manifest_model(plugin_a_id), "plugin_a/plugin.py", ), plugin_b_id: ( "dir_plugin_b", build_test_manifest_model( plugin_b_id, dependencies=[{"type": "plugin", "id": plugin_a_id, "version_spec": ">=1.0.0,<2.0.0"}], ), "plugin_b/plugin.py", ), plugin_c_id: ( "dir_plugin_c", build_test_manifest_model( plugin_c_id, dependencies=[{"type": "plugin", "id": plugin_b_id, "version_spec": ">=1.0.0,<2.0.0"}], ), "plugin_c/plugin.py", ), } unloaded_plugins: list[str] = [] activated_plugins: list[str] = [] monkeypatch.setattr(runner._loader, "discover_candidates", lambda plugin_dirs: (candidates, {})) monkeypatch.setattr(runner._loader, "list_plugins", lambda: sorted(loaded_metas.keys())) monkeypatch.setattr(runner._loader, "get_plugin", lambda plugin_id: loaded_metas.get(plugin_id)) monkeypatch.setattr( runner._loader, "remove_loaded_plugin", lambda plugin_id: loaded_metas.pop(plugin_id, None), ) monkeypatch.setattr(runner._loader, "purge_plugin_modules", lambda plugin_id, plugin_dir: []) monkeypatch.setattr( runner._loader, "resolve_dependencies", lambda reload_candidates, extra_available=None: (sorted(reload_candidates.keys()), {}), ) monkeypatch.setattr( runner._loader, "load_candidate", lambda plugin_id, candidate: reloaded_metas[plugin_id], ) async def fake_unload_plugin(meta, reason, purge_modules=False): del reason, purge_modules unloaded_plugins.append(meta.plugin_id) loaded_metas.pop(meta.plugin_id, None) async def fake_activate_plugin(meta): activated_plugins.append(meta.plugin_id) loaded_metas[meta.plugin_id] = meta return True monkeypatch.setattr(runner, "_unload_plugin", fake_unload_plugin) monkeypatch.setattr(runner, "_activate_plugin", fake_activate_plugin) result = await runner._reload_plugins_by_ids([plugin_a_id, plugin_b_id], reason="manual") assert result.success is True assert result.requested_plugin_ids == [plugin_a_id, plugin_b_id] assert unloaded_plugins == [plugin_c_id, plugin_b_id, plugin_a_id] assert activated_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] assert result.reloaded_plugins == [plugin_a_id, plugin_b_id, plugin_c_id] class TestPluginSdkUsage: """验证仓库内插件按新 SDK 归一化返回值工作。""" def test_runner_skips_signal_handler_registration_on_windows(self, monkeypatch): """Windows 下不应尝试注册 add_signal_handler。""" from src.plugin_runtime.runner import runner_main registered_signals = [] class DummyLoop: def add_signal_handler(self, sig, callback): registered_signals.append((sig, callback)) monkeypatch.setattr(runner_main.sys, "platform", "win32") runner_main._install_shutdown_signal_handlers(lambda: None, DummyLoop()) assert not registered_signals @pytest.mark.asyncio async def test_builtin_emoji_plugin_handles_normalized_results(self): from maibot_sdk.context import PluginContext from src.plugins.built_in.emoji_plugin.plugin import EmojiPlugin async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): assert method == "cap.request" assert payload is not None capability = payload["capability"] return { "emoji.get_random": { "success": True, "emojis": [{"base64": "img-1", "emotion": "happy"}], }, "message.get_recent": {"success": True, "messages": [{"id": 1}]}, "message.build_readable": {"success": True, "text": "最近消息"}, "llm.generate": {"success": True, "response": "happy", "reasoning": "", "model_name": "m"}, "send.emoji": {"success": True}, }[capability] plugin = EmojiPlugin() plugin._set_context(PluginContext(plugin_id="emoji", rpc_call=fake_rpc_call)) success, message = await plugin.handle_emoji(stream_id="stream-1", reasoning="测试", chat_id="chat-1") assert success is True assert "成功发送表情包" in message @pytest.mark.asyncio async def test_tts_plugin_uses_send_custom_bool_result(self): from maibot_sdk.context import PluginContext from src.plugins.built_in.tts_plugin.plugin import TTSPlugin async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): assert method == "cap.request" assert payload is not None assert payload["capability"] == "send.custom" return {"success": True} plugin = TTSPlugin() plugin._set_context(PluginContext(plugin_id="tts", rpc_call=fake_rpc_call)) success, message = await plugin.handle_tts_action( stream_id="stream-1", action_data={"voice_text": "你好!!!"}, ) assert success is True assert message == "TTS动作执行成功" @pytest.mark.asyncio async def test_hello_world_plugin_handles_random_emoji_list(self): from maibot_sdk.context import PluginContext from plugins.hello_world_plugin.plugin import HelloWorldPlugin async def fake_rpc_call(method: str, plugin_id: str = "", payload: dict | None = None): assert method == "cap.request" assert payload is not None capability = payload["capability"] return { "emoji.get_random": {"success": True, "emojis": [{"base64": "img-1"}, {"base64": "img-2"}]}, "send.forward": {"success": True}, }[capability] plugin = HelloWorldPlugin() plugin._set_context(PluginContext(plugin_id="hello", rpc_call=fake_rpc_call)) success, message, should_continue = await plugin.handle_random_emojis(stream_id="stream-1") assert success is True assert message == "已发送随机表情包" assert should_continue is True # ─── 端到端集成测试 ──────────────────────────────────────── class TestE2E: """端到端集成测试(Host + Runner 通信)""" @pytest.mark.asyncio async def test_handshake(self): """Host-Runner 握手流程测试""" from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server import secrets session_token = secrets.token_hex(16) codec = MsgPackCodec() handshake_done = asyncio.Event() server_result = {} async def server_handler(conn): # 接收握手 data = await conn.recv_frame() env = codec.decode_envelope(data) assert env.method == "runner.hello" hello = HelloPayload.model_validate(env.payload) assert hello.session_token == session_token # 发送响应 resp_payload = HelloResponsePayload( accepted=True, host_version="1.0", assigned_generation=1, ) resp = env.make_response(payload=resp_payload.model_dump()) await conn.send_frame(codec.encode_envelope(resp)) server_result["runner_id"] = hello.runner_id handshake_done.set() # 保持连接一会儿 await asyncio.sleep(1.0) server = create_transport_server() await server.start(server_handler) # 客户端握手 client = create_transport_client(server.get_address()) conn = await client.connect() hello = HelloPayload( runner_id="test-runner", sdk_version="1.0.0", session_token=session_token, ) env = Envelope( request_id=1, message_type=MessageType.REQUEST, method="runner.hello", payload=hello.model_dump(), ) await conn.send_frame(codec.encode_envelope(env)) resp_data = await conn.recv_frame() resp = codec.decode_envelope(resp_data) resp_payload = HelloResponsePayload.model_validate(resp.payload) assert resp_payload.accepted assert resp_payload.assigned_generation == 1 await asyncio.wait_for(handshake_done.wait(), timeout=5.0) assert server_result["runner_id"] == "test-runner" await conn.close() await server.stop() # ─── Manifest 校验测试 ───────────────────────────────────── class TestManifestValidator: """Manifest 校验器测试""" def test_valid_manifest(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") manifest = build_test_manifest("test.valid-plugin", capabilities=["send.text"]) assert validator.validate(manifest) is True assert len(validator.errors) == 0 assert validator.warnings == [] def test_missing_required_fields(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") manifest = {"manifest_version": 2} assert validator.validate(manifest) is False assert len(validator.errors) >= 6 assert any("缺少必需字段" in error for error in validator.errors) def test_unsupported_manifest_version(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") manifest = build_test_manifest("test.invalid-version") manifest["manifest_version"] = 999 assert validator.validate(manifest) is False assert any("manifest_version" in e for e in validator.errors) def test_host_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="0.8.5", sdk_version="2.0.1") manifest = build_test_manifest( "test.host-check", host_min_version="0.9.0", host_max_version="1.0.0", ) assert validator.validate(manifest) is False assert any("Host 版本不兼容" in e for e in validator.errors) def test_sdk_version_compatibility(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="1.0.0", sdk_version="1.9.9") manifest = build_test_manifest("test.sdk-check") assert validator.validate(manifest) is False assert any("SDK 版本不兼容" in e for e in validator.errors) def test_extra_fields_are_rejected(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") manifest = build_test_manifest("test.extra-field") manifest["unexpected"] = True assert validator.validate(manifest) is False assert any("存在未声明字段" in error for error in validator.errors) def test_python_package_conflict_rejects_manifest(self): from src.plugin_runtime.runner.manifest_validator import ManifestValidator validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") manifest = build_test_manifest( "test.numpy-conflict", dependencies=[ { "type": "python_package", "name": "numpy", "version_spec": ">=999.0.0", } ], ) assert validator.validate(manifest) is False assert any("Python 包依赖冲突" in error for error in validator.errors) class TestVersionComparator: """版本号比较器测试""" def test_normalize(self): from src.plugin_runtime.runner.manifest_validator import VersionComparator assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0" assert VersionComparator.normalize_version("1.2") == "1.2.0" assert VersionComparator.normalize_version("") == "0.0.0" def test_compare(self): from src.plugin_runtime.runner.manifest_validator import VersionComparator assert VersionComparator.compare("0.8.0", "0.8.0") == 0 assert VersionComparator.compare("0.8.0", "0.9.0") == -1 assert VersionComparator.compare("1.0.0", "0.9.0") == 1 def test_is_in_range(self): from src.plugin_runtime.runner.manifest_validator import VersionComparator ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0") assert ok ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0") assert not ok ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0") assert not ok # ─── 依赖解析测试 ────────────────────────────────────────── class TestDependencyResolution: """插件依赖解析测试""" def test_topological_sort(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { "test.core": ( "dir_core", build_test_manifest_model("test.core"), "plugin.py", ), "test.auth": ( "dir_auth", build_test_manifest_model( "test.auth", dependencies=[ {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, ], ), "plugin.py", ), "test.api": ( "dir_api", build_test_manifest_model( "test.api", dependencies=[ {"type": "plugin", "id": "test.core", "version_spec": ">=1.0.0,<2.0.0"}, {"type": "plugin", "id": "test.auth", "version_spec": ">=1.0.0,<2.0.0"}, ], ), "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) assert len(failed) == 0 assert order.index("test.core") < order.index("test.auth") assert order.index("test.auth") < order.index("test.api") def test_missing_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { "test.plugin-a": ( "dir_a", build_test_manifest_model( "test.plugin-a", dependencies=[ {"type": "plugin", "id": "test.nonexistent", "version_spec": ">=1.0.0,<2.0.0"}, ], ), "plugin.py", ), } order, failed = loader._resolve_dependencies(candidates) assert "test.plugin-a" in failed assert "依赖未满足" in failed["test.plugin-a"] def test_circular_dependency(self): from src.plugin_runtime.runner.plugin_loader import PluginLoader loader = PluginLoader() candidates = { "test.a": ( "dir_a", build_test_manifest_model( "test.a", dependencies=[ {"type": "plugin", "id": "test.b", "version_spec": ">=1.0.0,<2.0.0"}, ], ), "p.py", ), "test.b": ( "dir_b", build_test_manifest_model( "test.b", dependencies=[ {"type": "plugin", "id": "test.a", "version_spec": ">=1.0.0,<2.0.0"}, ], ), "p.py", ), } order, failed = loader._resolve_dependencies(candidates) assert len(failed) >= 1 # 至少一个循环插件被标记 def test_loader_supports_package_imports_inside_create_plugin(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader plugin_root = tmp_path / "plugins" plugin_root.mkdir() plugin_dir = plugin_root / "grok_search_plugin" plugin_dir.mkdir() (plugin_dir / "_manifest.json").write_text( json.dumps( build_test_manifest( "test.grok-search-plugin", name="grok_search_plugin", description="demo", ) ), encoding="utf-8", ) (plugin_dir / "__init__.py").write_text("VALUE = 1\n", encoding="utf-8") (plugin_dir / "services.py").write_text("def answer():\n return 42\n", encoding="utf-8") (plugin_dir / "plugin.py").write_text( "class DemoPlugin:\n" " pass\n\n" "def create_plugin():\n" " from grok_search_plugin.services import answer\n" " plugin = DemoPlugin()\n" " plugin.answer = answer\n" " return plugin\n", encoding="utf-8", ) loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) assert [meta.plugin_id for meta in loaded] == ["test.grok-search-plugin"] assert loader.failed_plugins == {} assert loaded[0].instance.answer() == 42 def test_loader_requires_sdk_plugin_to_override_on_config_update(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader plugin_root = tmp_path / "plugins" plugin_root.mkdir() plugin_dir = plugin_root / "demo_plugin" plugin_dir.mkdir() (plugin_dir / "_manifest.json").write_text( json.dumps( build_test_manifest( "test.demo-plugin", name="demo_plugin", description="demo", ) ), encoding="utf-8", ) (plugin_dir / "plugin.py").write_text( "from maibot_sdk import MaiBotPlugin\n\n" "class DemoPlugin(MaiBotPlugin):\n" " async def on_load(self):\n" " pass\n\n" " async def on_unload(self):\n" " pass\n\n" "def create_plugin():\n" " return DemoPlugin()\n", encoding="utf-8", ) loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) assert loaded == [] assert "test.demo-plugin" in loader.failed_plugins assert "on_config_update" in loader.failed_plugins["test.demo-plugin"] def test_loader_requires_sdk_plugin_to_override_on_load(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader plugin_root = tmp_path / "plugins" plugin_root.mkdir() plugin_dir = plugin_root / "demo_plugin" plugin_dir.mkdir() (plugin_dir / "_manifest.json").write_text( json.dumps( build_test_manifest( "test.demo-plugin", name="demo_plugin", description="demo", ) ), encoding="utf-8", ) (plugin_dir / "plugin.py").write_text( "from maibot_sdk import MaiBotPlugin\n\n" "class DemoPlugin(MaiBotPlugin):\n" " async def on_unload(self):\n" " pass\n\n" " async def on_config_update(self, scope, config_data, version):\n" " pass\n\n" "def create_plugin():\n" " return DemoPlugin()\n", encoding="utf-8", ) loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) assert loaded == [] assert "test.demo-plugin" in loader.failed_plugins assert "on_load" in loader.failed_plugins["test.demo-plugin"] def test_loader_requires_sdk_plugin_to_override_on_unload(self, tmp_path): from src.plugin_runtime.runner.plugin_loader import PluginLoader plugin_root = tmp_path / "plugins" plugin_root.mkdir() plugin_dir = plugin_root / "demo_plugin" plugin_dir.mkdir() (plugin_dir / "_manifest.json").write_text( json.dumps( build_test_manifest( "test.demo-plugin", name="demo_plugin", description="demo", ) ), encoding="utf-8", ) (plugin_dir / "plugin.py").write_text( "from maibot_sdk import MaiBotPlugin\n\n" "class DemoPlugin(MaiBotPlugin):\n" " async def on_load(self):\n" " pass\n\n" " async def on_config_update(self, scope, config_data, version):\n" " pass\n\n" "def create_plugin():\n" " return DemoPlugin()\n", encoding="utf-8", ) loader = PluginLoader() loaded = loader.discover_and_load([str(plugin_root)]) assert loaded == [] assert "test.demo-plugin" in loader.failed_plugins assert "on_unload" in loader.failed_plugins["test.demo-plugin"] @pytest.mark.asyncio async def test_async_main_removes_sensitive_runtime_env_vars(self, monkeypatch): from src.plugin_runtime.runner import runner_main captured = {} original_path = list(sys.path) class FakeRunner: def __init__( self, host_address: str, session_token: str, plugin_dirs: list[str], external_available_plugins: dict[str, str] | None = None, ) -> None: captured["host_address"] = host_address captured["session_token"] = session_token captured["plugin_dirs"] = plugin_dirs captured["external_available_plugins"] = external_available_plugins or {} async def run(self) -> None: assert os.environ.get(runner_main.ENV_IPC_ADDRESS) is None assert os.environ.get(runner_main.ENV_SESSION_TOKEN) is None monkeypatch.setenv(runner_main.ENV_IPC_ADDRESS, "tcp://127.0.0.1:9999") monkeypatch.setenv(runner_main.ENV_SESSION_TOKEN, "secret-token") monkeypatch.setenv(runner_main.ENV_PLUGIN_DIRS, "/tmp/plugins") monkeypatch.setenv(runner_main.ENV_EXTERNAL_PLUGIN_IDS, '{"demo.plugin":"1.0.0"}') monkeypatch.setattr(runner_main, "_install_shutdown_signal_handlers", lambda callback: None) monkeypatch.setattr(runner_main, "PluginRunner", FakeRunner) await runner_main._async_main() assert captured["host_address"] == "tcp://127.0.0.1:9999" assert captured["session_token"] == "secret-token" assert captured["plugin_dirs"] == ["/tmp/plugins"] assert captured["external_available_plugins"] == {"demo.plugin": "1.0.0"} assert sys.path == original_path # ─── Host-side ComponentRegistry 测试 ────────────────────── class TestComponentRegistry: """Host-side 组件注册表测试""" def test_register_and_query(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "greet", "action", "plugin_a", { "description": "打招呼", "activation_type": "keyword", "activation_keywords": ["hi"], }, ) reg.register_component( "help", "command", "plugin_a", { "command_pattern": r"^/help", }, ) reg.register_component( "search", "tool", "plugin_b", { "description": "搜索", }, ) stats = reg.get_stats() assert stats["total"] == 3 assert stats["action"] == 1 assert stats["command"] == 1 assert stats["tool"] == 1 def test_register_command_with_invalid_regex_only_warns(self, monkeypatch): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() warnings: list[str] = [] monkeypatch.setattr( "src.plugin_runtime.host.component_registry.logger.warning", lambda message: warnings.append(str(message)), ) success = reg.register_component( "broken", "command", "plugin_a", { "command_pattern": "[", }, ) assert success is True assert reg.get_component("plugin_a.broken") is not None assert warnings assert "plugin_a.broken" in warnings[0] def test_register_hook_handler_rejects_unknown_hook(self): from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry from src.plugin_runtime.host.hook_spec_registry import HookSpecRegistry reg = ComponentRegistry(hook_spec_registry=HookSpecRegistry()) with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"): reg.register_component( "broken_hook", "hook_handler", "plugin_a", { "hook": "chat.receive.unknown", "mode": "blocking", }, ) def test_register_plugin_components_is_atomic_when_hook_invalid(self): from src.plugin_runtime.host.component_registry import ComponentRegistrationError, ComponentRegistry from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry hook_spec_registry = HookSpecRegistry() hook_spec_registry.register_hook_spec(HookSpec(name="chat.receive.before_process")) reg = ComponentRegistry(hook_spec_registry=hook_spec_registry) reg.register_plugin_components( "plugin_a", [ {"name": "cmd_old", "component_type": "command", "metadata": {"command_pattern": r"^/old"}}, ], ) with pytest.raises(ComponentRegistrationError, match="未注册的 Hook"): reg.register_plugin_components( "plugin_a", [ { "name": "hook_ok", "component_type": "hook_handler", "metadata": {"hook": "chat.receive.before_process", "mode": "blocking"}, }, { "name": "hook_bad", "component_type": "hook_handler", "metadata": {"hook": "chat.receive.missing", "mode": "blocking"}, }, ], ) assert reg.get_component("plugin_a.cmd_old") is not None assert reg.get_component("plugin_a.hook_ok") is None def test_query_by_type(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component("a1", "action", "p1", {}) reg.register_component("a2", "action", "p2", {}) actions = reg.get_components_by_type("action") assert len(actions) == 2 def test_find_command_by_text(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "help", "command", "p1", { "command_pattern": r"^/help", }, ) reg.register_component( "echo", "command", "p1", { "command_pattern": r"^/echo\s", }, ) match = reg.find_command_by_text("/help me") assert match is not None comp, groups = match assert comp.name == "help" match = reg.find_command_by_text("/echo hello") assert match is not None comp, groups = match assert comp.name == "echo" match = reg.find_command_by_text("no match") assert match is None def test_enable_disable(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component("a1", "action", "p1", {}) reg.set_component_enabled("p1.a1", False) actions = reg.get_components_by_type("action", enabled_only=True) assert len(actions) == 0 actions = reg.get_components_by_type("action", enabled_only=False) assert len(actions) == 1 def test_remove_by_plugin(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component("a1", "action", "p1", {}) reg.register_component("c1", "command", "p1", {}) reg.register_component("a2", "action", "p2", {}) removed = reg.remove_components_by_plugin("p1") assert removed == 2 assert reg.get_stats()["total"] == 1 def test_reregister_same_plugin_replaces_component_set(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_plugin_components( "p1", [ {"name": "a1", "component_type": "action", "metadata": {}}, {"name": "a2", "component_type": "action", "metadata": {}}, ], ) reg.remove_components_by_plugin("p1") reg.register_plugin_components( "p1", [ {"name": "a1", "component_type": "action", "metadata": {}}, ], ) assert reg.get_component("p1.a1") is not None assert reg.get_component("p1.a2") is None def test_event_handlers_sorted_by_weight(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "h_low", "event_handler", "p1", { "event_type": "on_message", "weight": 10, }, ) reg.register_component( "h_high", "event_handler", "p2", { "event_type": "on_message", "weight": 100, }, ) handlers = reg.get_event_handlers("on_message") assert handlers[0].name == "h_high" assert handlers[1].name == "h_low" def test_tools_for_llm(self): from src.plugin_runtime.host.component_registry import ComponentRegistry reg = ComponentRegistry() reg.register_component( "search", "tool", "p1", { "description": "搜索工具", "parameters_raw": {"query": {"type": "string"}}, }, ) tools = reg.get_tools_for_llm() assert len(tools) == 1 assert tools[0]["name"] == "p1.search" assert tools[0]["parameters"]["query"]["type"] == "string" # ─── EventDispatcher 测试 ───────────────────────────────── class TestEventDispatcher: """Host-side 事件分发器测试""" @pytest.mark.asyncio async def test_dispatch_non_blocking(self): from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.event_dispatcher import EventDispatcher reg = ComponentRegistry() reg.register_component( "h1", "event_handler", "p1", { "event_type": "on_start", "weight": 0, "intercept_message": False, }, ) dispatcher = EventDispatcher(reg) call_log = [] async def mock_invoke(plugin_id, comp_name, args): call_log.append((plugin_id, comp_name)) return {"success": True, "continue_processing": True} should_continue, modified = await dispatcher.dispatch_event("on_start", mock_invoke) assert should_continue # 非阻塞分发是异步的,等一下让 task 完成 await asyncio.sleep(0.1) assert len(call_log) == 1 assert call_log[0] == ("p1", "h1") @pytest.mark.asyncio async def test_dispatch_intercepting(self): from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.event_dispatcher import EventDispatcher reg = ComponentRegistry() reg.register_component( "filter", "event_handler", "p1", { "event_type": "on_message_pre_process", "weight": 100, "intercept_message": True, }, ) dispatcher = EventDispatcher(reg) async def mock_invoke(plugin_id, comp_name, args): return { "success": True, "continue_processing": False, "modified_message": {"plain_text": "filtered"}, } should_continue, modified = await dispatcher.dispatch_event( "on_message_pre_process", mock_invoke, message={"plain_text": "hello"} ) assert not should_continue assert modified is not None assert modified["plain_text"] == "filtered" class TestEventBus: """核心事件总线与 IPC 桥接测试""" @pytest.mark.asyncio async def test_bridge_preserves_modified_message(self, monkeypatch): import types fake_message_data_model = types.ModuleType("src.common.data_models.message_data_model") fake_message_data_model.ReplyContentType = object fake_message_data_model.ReplyContent = object fake_message_data_model.ForwardNode = object fake_message_data_model.ReplySetModel = object monkeypatch.setitem(sys.modules, "src.common.data_models.message_data_model", fake_message_data_model) from src.core.event_bus import EventBus from src.core.types import EventType, MaiMessages from src.plugin_runtime import integration as integration_module bus = EventBus() async def noop_handler(message): return True, message bus.subscribe(EventType.ON_MESSAGE, noop_handler, name="noop", intercept=True) class FakeManager: is_running = True async def bridge_event(self, event_type_value, message_dict=None, extra_args=None): assert event_type_value == EventType.ON_MESSAGE.value return True, {"plain_text": "modified by ipc"} monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) original = MaiMessages(plain_text="original") continue_flag, modified = await bus.emit(EventType.ON_MESSAGE, original) assert continue_flag is True assert modified is not None assert modified.plain_text == "modified by ipc" assert original.plain_text == "original" # ─── MaiMessages 测试 ───────────────────────────────────── class TestMaiMessages: """统一消息模型测试""" def test_create_and_serialize(self): from maibot_sdk.messages import MaiMessages, MessageSegment msg = MaiMessages( message_segments=[MessageSegment(type="text", data={"text": "hello"})], plain_text="hello", stream_id="stream_1", ) d = msg.to_rpc_dict() assert d["plain_text"] == "hello" assert len(d["message_segments"]) == 1 msg2 = MaiMessages.from_rpc_dict(d) assert msg2.plain_text == "hello" def test_deepcopy(self): from maibot_sdk.messages import MaiMessages msg = MaiMessages(plain_text="original") msg2 = msg.deepcopy() msg2.plain_text = "modified" assert msg.plain_text == "original" def test_modify_flags(self): from maibot_sdk.messages import MaiMessages from maibot_sdk.types import ModifyFlag msg = MaiMessages(plain_text="hello") assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT) msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False) assert not msg.modify_prompt("new prompt") assert msg.llm_prompt is None assert msg.modify_response("new response") assert msg.llm_response_content == "new response" class _FakeHookSupervisor: """用于 Hook 分发测试的简化 Supervisor。""" def __init__( self, group_name: str, component_registry: Any, handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]], call_log: List[tuple[str, str]], ) -> None: """初始化测试用 Supervisor。 Args: group_name: 运行时分组名称。 component_registry: 组件注册表实例。 handlers: 处理器映射,键为 `plugin_id.component_name`。 call_log: 记录调用顺序的列表。 """ self._group_name = group_name self.component_registry = component_registry self._handlers = handlers self._call_log = call_log @property def group_name(self) -> str: """返回当前测试 Supervisor 的分组名称。""" return self._group_name async def invoke_plugin( self, method: str, plugin_id: str, component_name: str, args: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, ) -> SimpleNamespace: """模拟调用插件组件。 Args: method: RPC 方法名。 plugin_id: 目标插件 ID。 component_name: 目标组件名称。 args: 调用参数。 timeout_ms: 超时配置,测试中仅用于保持接口一致。 Returns: SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。 """ del method del timeout_ms full_name = f"{plugin_id}.{component_name}" handler = self._handlers[full_name] self._call_log.append((plugin_id, component_name)) result = handler(dict(args or {})) if asyncio.iscoroutine(result): result = await result return SimpleNamespace(payload=result) # ─── HookDispatcher 测试 ──────────────────────────────── class TestHookDispatcher: """命名 Hook 分发器测试。""" @staticmethod def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: """导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。 Args: monkeypatch: pytest 的 monkeypatch 工具。 Returns: tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。 """ monkeypatch.setattr(sys, "exit", lambda code=0: None) from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.host.hook_dispatcher import HookDispatcher return ComponentRegistry, HookDispatcher @pytest.mark.asyncio async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: """未注册处理器时应直接返回原始参数。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) dispatcher = HookDispatcher() supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, []) result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") assert result.hook_name == "heart_fc.cycle_start" assert result.kwargs == {"session_id": "s-1"} assert result.aborted is False @pytest.mark.asyncio async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: """blocking 处理器可以修改参数。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) registry = ComponentRegistry() registry.register_component( "upper", "HOOK_HANDLER", "p1", { "hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal", }, ) dispatcher = HookDispatcher() supervisor = _FakeHookSupervisor( "builtin", registry, { "p1.upper": lambda args: { "success": True, "action": "continue", "modified_kwargs": { "session_id": args["session_id"], "text": str(args["text"]).upper(), }, } }, [], ) result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello") assert result.kwargs["session_id"] == "s-1" assert result.kwargs["text"] == "HELLO" assert result.aborted is False @pytest.mark.asyncio async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None: """blocking 处理器的 abort 应阻止后续 blocking 处理器执行。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) registry = ComponentRegistry() registry.register_component( "stopper", "HOOK_HANDLER", "p1", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, ) registry.register_component( "after_stop", "HOOK_HANDLER", "p2", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, ) call_log: List[tuple[str, str]] = [] dispatcher = HookDispatcher() supervisor = _FakeHookSupervisor( "builtin", registry, { "p1.stopper": lambda args: {"success": True, "action": "abort"}, "p2.after_stop": lambda args: {"success": True, "action": "continue"}, }, call_log, ) result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1") assert result.aborted is True assert result.stopped_by == "p1.stopper" assert call_log == [("p1", "stopper")] @pytest.mark.asyncio async def test_observe_handler_runs_in_background_without_mutation( self, monkeypatch: pytest.MonkeyPatch, ) -> None: """observe 处理器应后台执行且不能影响主流程参数。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) registry = ComponentRegistry() registry.register_component( "observer", "HOOK_HANDLER", "p1", {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, ) started = asyncio.Event() release = asyncio.Event() call_log: List[tuple[str, str]] = [] async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]: """模拟耗时观察型处理器。""" started.set() await release.wait() return { "success": True, "action": "abort", "modified_kwargs": {"session_id": "changed"}, "custom_result": args["session_id"], } dispatcher = HookDispatcher() supervisor = _FakeHookSupervisor( "builtin", registry, {"p1.observer": observe_handler}, call_log, ) result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") await asyncio.sleep(0) assert result.aborted is False assert result.kwargs["session_id"] == "s-1" assert started.is_set() assert len(dispatcher._background_tasks) == 1 release.set() await asyncio.sleep(0) await asyncio.sleep(0) assert call_log == [("p1", "observer")] assert not dispatcher._background_tasks @pytest.mark.asyncio async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None: """全局排序应先看 order,再看内置/第三方来源。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) builtin_registry = ComponentRegistry() third_registry = ComponentRegistry() builtin_registry.register_component( "builtin_early", "HOOK_HANDLER", "b1", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, ) builtin_registry.register_component( "builtin_normal", "HOOK_HANDLER", "b1", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, ) third_registry.register_component( "third_early", "HOOK_HANDLER", "t1", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, ) third_registry.register_component( "third_normal", "HOOK_HANDLER", "t1", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, ) call_log: List[tuple[str, str]] = [] dispatcher = HookDispatcher() builtin_supervisor = _FakeHookSupervisor( "builtin", builtin_registry, { "b1.builtin_early": lambda args: {"success": True, "action": "continue"}, "b1.builtin_normal": lambda args: {"success": True, "action": "continue"}, }, call_log, ) third_supervisor = _FakeHookSupervisor( "third_party", third_registry, { "t1.third_early": lambda args: {"success": True, "action": "continue"}, "t1.third_normal": lambda args: {"success": True, "action": "continue"}, }, call_log, ) await dispatcher.invoke_hook( "heart_fc.cycle_start", [third_supervisor, builtin_supervisor], cycle_id="c-1", ) assert call_log == [ ("b1", "builtin_early"), ("t1", "third_early"), ("b1", "builtin_normal"), ("t1", "third_normal"), ] @pytest.mark.asyncio async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None: """error_policy=abort 时应中止本次 Hook 调用。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) registry = ComponentRegistry() registry.register_component( "failer", "HOOK_HANDLER", "p1", { "hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal", "error_policy": "abort", }, ) call_log: List[tuple[str, str]] = [] async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]: """抛出异常以触发 abort 策略。""" del args raise RuntimeError("boom") dispatcher = HookDispatcher() supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log) result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") assert result.aborted is True assert result.stopped_by == "p1.failer" assert any("boom" in error for error in result.errors) assert call_log == [("p1", "failer")] @pytest.mark.asyncio async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None: """处理器超时应被记录为错误并继续。""" ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) registry = ComponentRegistry() registry.register_component( "slow", "HOOK_HANDLER", "p1", { "hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal", "timeout_ms": 10, }, ) call_log: List[tuple[str, str]] = [] async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]: """模拟超时处理器。""" del args await asyncio.sleep(0.05) return {"success": True, "action": "continue"} dispatcher = HookDispatcher() supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log) result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") assert result.aborted is False assert any("超时" in error for error in result.errors) assert call_log == [("p1", "slow")] class TestPluginRuntimeHookEntry: """PluginRuntimeManager 命名 Hook 入口测试。""" @staticmethod def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: """导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。 Args: monkeypatch: pytest 的 monkeypatch 工具。 Returns: tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。 """ monkeypatch.setattr(sys, "exit", lambda code=0: None) from src.plugin_runtime.host.component_registry import ComponentRegistry from src.plugin_runtime.integration import PluginRuntimeManager return ComponentRegistry, PluginRuntimeManager @pytest.mark.asyncio async def test_manager_invoke_hook_dispatches_across_supervisors( self, monkeypatch: pytest.MonkeyPatch, ) -> None: """PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。""" ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch) builtin_registry = ComponentRegistry() builtin_registry.register_component( "builtin_guard", "HOOK_HANDLER", "b1", {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, ) third_registry = ComponentRegistry() third_registry.register_component( "observer", "HOOK_HANDLER", "t1", {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, ) call_log: List[tuple[str, str]] = [] manager = PluginRuntimeManager() manager._started = True manager._builtin_supervisor = _FakeHookSupervisor( "builtin", builtin_registry, {"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}}, call_log, ) manager._third_party_supervisor = _FakeHookSupervisor( "third_party", third_registry, {"t1.observer": lambda args: {"success": True, "action": "continue"}}, call_log, ) result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1") await asyncio.sleep(0) assert manager.invoke_dispatcher is manager.hook_dispatcher assert result.aborted is False assert result.kwargs["session_id"] == "s-1" assert ("b1", "builtin_guard") in call_log def test_manager_lists_builtin_hook_specs(self, monkeypatch: pytest.MonkeyPatch) -> None: """PluginRuntimeManager 应暴露内置 Hook 规格清单。""" _ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch) manager = PluginRuntimeManager() hook_names = {spec.name for spec in manager.list_hook_specs()} assert "chat.receive.before_process" in hook_names assert "send_service.before_send" in hook_names assert "maisaka.planner.after_response" in hook_names class TestRPCServer: """RPC Server 代际保护测试""" @pytest.mark.asyncio async def test_reject_second_active_runner_connection(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType class DummyTransport: async def start(self, handler): return None async def stop(self): return None def get_address(self): return "dummy" class FakeConnection: def __init__(self, incoming_frames: list[bytes]): self._incoming_frames = list(incoming_frames) self.sent_frames: list[bytes] = [] self.is_closed = False async def recv_frame(self): return self._incoming_frames.pop(0) async def send_frame(self, data): self.sent_frames.append(data) async def close(self): self.is_closed = True codec = MsgPackCodec() server = RPCServer(transport=DummyTransport(), session_token="session-token") active_conn = SimpleNamespace(is_closed=False) server._connection = active_conn hello = HelloPayload( runner_id="runner-b", sdk_version="1.0.0", session_token="session-token", ) envelope = Envelope( request_id=1, message_type=MessageType.REQUEST, method="runner.hello", payload=hello.model_dump(), ) incoming_conn = FakeConnection([codec.encode_envelope(envelope)]) await server._handle_connection(incoming_conn) assert incoming_conn.is_closed is True assert server._connection is active_conn assert server.last_handshake_rejection_reason == "已有活跃 Runner 连接,拒绝新的握手" assert len(incoming_conn.sent_frames) == 1 response = codec.decode_envelope(incoming_conn.sent_frames[0]) response_payload = HelloResponsePayload.model_validate(response.payload) assert response_payload.accepted is False assert response_payload.reason == "已有活跃 Runner 连接,拒绝新的握手" def test_ignore_stale_generation_response(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.envelope import Envelope, MessageType class DummyTransport: async def start(self, handler): return None async def stop(self): return None def get_address(self): return "dummy" server = RPCServer(transport=DummyTransport()) server._runner_generation = 2 loop = asyncio.new_event_loop() try: future = loop.create_future() server._pending_requests[1] = (future, 2) stale_response = Envelope( request_id=1, message_type=MessageType.RESPONSE, method="plugin.health", generation=1, payload={"healthy": True}, ) server._handle_response(stale_response) assert not future.done() assert 1 in server._pending_requests finally: loop.close() @pytest.mark.asyncio async def test_send_queue_backpressure_is_enforced(self): from src.plugin_runtime.host.rpc_server import RPCServer from src.plugin_runtime.protocol.errors import ErrorCode, RPCError class DummyTransport: async def start(self, handler): return None async def stop(self): return None def get_address(self): return "dummy" class BlockingConnection: def __init__(self): self.is_closed = False self.release = asyncio.Event() async def send_frame(self, data): await self.release.wait() async def close(self): self.is_closed = True server = RPCServer(transport=DummyTransport(), send_queue_size=1) await server.start() conn = BlockingConnection() server._connection = conn server._runner_generation = 1 first_send = asyncio.create_task(server.send_event("runner.log_batch")) await asyncio.sleep(0) second_send = asyncio.create_task(server.send_event("runner.log_batch")) await asyncio.sleep(0) with pytest.raises(RPCError) as exc_info: await server.send_event("runner.log_batch") assert exc_info.value.code == ErrorCode.E_BACKPRESSURE conn.release.set() await asyncio.gather(first_send, second_send) await server.stop() class TestRPCClient: """Runner RPCClient 后台任务生命周期测试""" @pytest.mark.asyncio async def test_background_tasks_retained_and_cancelled_on_disconnect(self): from src.plugin_runtime.runner.rpc_client import RPCClient client = RPCClient(host_address="dummy", session_token="token") release = asyncio.Event() async def pending_task(): await release.wait() task = asyncio.create_task(pending_task()) client._track_background_task(task) assert task in client._background_tasks await asyncio.sleep(0) assert task in client._background_tasks await client.disconnect() assert task.cancelled() is True assert not client._background_tasks class TestSupervisor: """Supervisor 生命周期边界测试""" @staticmethod def _build_register_payload(plugin_id: str = "plugin_a", component_names=None): from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload component_names = component_names or ["handler"] return RegisterComponentsPayload( plugin_id=plugin_id, plugin_version="1.0.0", components=[ ComponentDeclaration( name=name, component_type="event_handler", plugin_id=plugin_id, metadata={"event_type": "on_message"}, ) for name in component_names ], capabilities_required=["send.text"], ) @staticmethod def _make_process(pid: int): class FakeProcess: def __init__(self): self.pid = pid self.returncode = None self.stdout = None self.stderr = None self.terminated = False self.killed = False def terminate(self): self.terminated = True self.returncode = 0 def kill(self): self.killed = True self.returncode = -9 async def wait(self): return self.returncode return FakeProcess() @pytest.mark.asyncio async def test_reload_waits_for_target_generation(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor from src.plugin_runtime.protocol.envelope import HealthPayload supervisor = PluginSupervisor(plugin_dirs=[]) old_process = self._make_process(1) new_process = self._make_process(2) class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" self.committed = False self.staging_started = False def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staging_started = True self.staged_generation = 2 async def commit_staged_takeover(self): self.runner_generation = self.staged_generation self.staged_generation = 0 self.committed = True async def rollback_staged_takeover(self): self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): assert target_generation == 2 return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) supervisor._rpc_server = FakeRPCServer() supervisor._runner_process = old_process async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) supervisor._runner_ready_events[2] = asyncio.Event() supervisor._runner_ready_events[2].set() monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is True assert supervisor._runner_process is new_process assert supervisor._rpc_server.committed is True assert old_process.terminated is True @pytest.mark.asyncio async def test_reload_restores_runtime_state_on_failure(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor supervisor = PluginSupervisor(plugin_dirs=[]) old_process = self._make_process(1) new_process = self._make_process(2) old_reg = self._build_register_payload() supervisor._runner_process = old_process supervisor._registered_plugins[old_reg.plugin_id] = old_reg supervisor._rebuild_runtime_state() class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" self.rolled_back = False def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staged_generation = 2 async def commit_staged_takeover(self): self.runner_generation = self.staged_generation self.staged_generation = 0 async def rollback_staged_takeover(self): self.rolled_back = True self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): raise RuntimeError("new runner unhealthy") supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) supervisor._runner_ready_events[2] = asyncio.Event() supervisor._runner_ready_events[2].set() monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is False assert supervisor._runner_process is old_process assert supervisor._rpc_server.rolled_back is True assert old_reg.plugin_id in supervisor._registered_plugins assert supervisor.component_registry.get_component("plugin_a.handler") is not None @pytest.mark.asyncio async def test_reload_rebuilds_exact_component_set(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor from src.plugin_runtime.protocol.envelope import HealthPayload supervisor = PluginSupervisor(plugin_dirs=[]) old_process = self._make_process(1) new_process = self._make_process(2) old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"]) new_reg = self._build_register_payload("plugin_a", component_names=["handler"]) supervisor._runner_process = old_process supervisor._registered_plugins[old_reg.plugin_id] = old_reg supervisor._rebuild_runtime_state() class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staged_generation = 2 async def commit_staged_takeover(self): self.runner_generation = self.staged_generation self.staged_generation = 0 async def rollback_staged_takeover(self): self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg supervisor._runner_ready_payloads[2] = SimpleNamespace(loaded_plugins=["plugin_a"], failed_plugins=[]) supervisor._runner_ready_events[2] = asyncio.Event() supervisor._runner_ready_events[2].set() monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is True assert supervisor.component_registry.get_component("plugin_a.handler") is not None assert supervisor.component_registry.get_component("plugin_a.obsolete") is None @pytest.mark.asyncio async def test_reload_plugins_uses_batch_rpc_for_multiple_roots(self): from src.plugin_runtime.host.supervisor import PluginSupervisor from src.plugin_runtime.protocol.envelope import ReloadPluginsResultPayload supervisor = PluginSupervisor(plugin_dirs=[]) sent_requests: list[tuple[str, dict[str, object], int]] = [] class FakeRPCServer: async def send_request(self, method, payload, timeout_ms=5000, **kwargs): del kwargs sent_requests.append((method, payload, timeout_ms)) return SimpleNamespace( payload=ReloadPluginsResultPayload( success=True, requested_plugin_ids=["plugin_a", "plugin_b"], reloaded_plugins=["plugin_a", "plugin_b", "plugin_c"], unloaded_plugins=["plugin_c", "plugin_b", "plugin_a"], ).model_dump() ) supervisor._rpc_server = FakeRPCServer() reloaded = await supervisor.reload_plugins(["plugin_a", "plugin_b", "plugin_a"], reason="manual") assert reloaded is True assert len(sent_requests) == 1 method, payload, timeout_ms = sent_requests[0] assert method == "plugin.reload_batch" assert payload["plugin_ids"] == ["plugin_a", "plugin_b"] assert payload["reason"] == "manual" assert timeout_ms >= 10000 @pytest.mark.asyncio async def test_reload_rolls_back_when_runner_ready_not_received(self, monkeypatch): from src.plugin_runtime.host.supervisor import PluginSupervisor supervisor = PluginSupervisor(plugin_dirs=[], runner_spawn_timeout_sec=0.01) old_process = self._make_process(1) new_process = self._make_process(2) old_reg = self._build_register_payload() supervisor._runner_process = old_process supervisor._registered_plugins[old_reg.plugin_id] = old_reg supervisor._rebuild_runtime_state() class FakeRPCServer: def __init__(self): self.runner_generation = 1 self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" self.rolled_back = False def reset_session_token(self): self.session_token = "new-fake-token" return self.session_token def restore_session_token(self, token): self.session_token = token def begin_staged_takeover(self): self.staged_generation = 2 async def commit_staged_takeover(self): raise AssertionError("runner.ready 未到达前不应提交 staged takeover") async def rollback_staged_takeover(self): self.rolled_back = True self.staged_generation = 0 def has_generation(self, generation): return generation in {self.runner_generation, self.staged_generation} async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): raise AssertionError("runner.ready 未到达前不应执行健康检查") supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) reloaded = await supervisor.reload_plugins("test") assert reloaded is False assert supervisor._runner_process is old_process assert supervisor._rpc_server.rolled_back is True @pytest.mark.asyncio async def test_attach_stderr_drain_drains_stream(self): """_attach_stderr_drain 为 stderr 创建排空任务,读完后任务自动完成。""" from src.plugin_runtime.host.supervisor import PluginSupervisor supervisor = PluginSupervisor(plugin_dirs=[]) stderr = asyncio.StreamReader() stderr.feed_data(b"fatal startup error\n") stderr.feed_eof() # stdout=None 模拟新架构(不再捕获 stdout) process = SimpleNamespace(pid=99, stdout=None, stderr=stderr) supervisor._attach_stderr_drain(process) # 给 drain task 足够时间消费完数据 await asyncio.sleep(0.05) assert supervisor._stderr_drain_task is None or supervisor._stderr_drain_task.done() class TestIntegration: """运行时集成层启动/清理测试""" @pytest.mark.asyncio async def test_cap_database_get_with_filters_does_not_reference_unbound_key_value(self, monkeypatch): from src.plugin_runtime import integration as integration_module import src.common.database.database_model as real_db_models from src.services import database_service as real_database_service captured: dict[str, object] = {} class DummyModel: pass async def fake_db_get(model_class, filters=None, limit=None, order_by=None, single_result=False): captured["model_class"] = model_class captured["filters"] = filters captured["limit"] = limit captured["order_by"] = order_by captured["single_result"] = single_result return [{"id": 1}] monkeypatch.setattr(real_database_service, "db_get", fake_db_get) monkeypatch.setattr(real_db_models, "DemoTable", DummyModel, raising=False) result = await integration_module.PluginRuntimeManager._cap_database_get( "plugin_a", "database.get", { "table": "DemoTable", "filters": {"status": "active"}, "limit": 5, }, ) assert result == {"success": True, "result": [{"id": 1}]} assert captured["model_class"] is DummyModel assert captured["filters"] == {"status": "active"} assert captured["limit"] == 5 assert captured["single_result"] is False @pytest.mark.asyncio async def test_component_enable_rejects_ambiguous_short_name(self, monkeypatch): from src.plugin_runtime import integration as integration_module from src.plugin_runtime.host.component_registry import ComponentRegistry class FakeSupervisor: def __init__(self, plugin_id: str): self.component_registry = ComponentRegistry() self.component_registry.register_component( name="shared", component_type="tool", plugin_id=plugin_id, metadata={}, ) class FakeManager: def __init__(self): self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) manager = integration_module.PluginRuntimeManager() manager._builtin_supervisor = FakeSupervisor("plugin_a") manager._third_party_supervisor = FakeSupervisor("plugin_b") result = await manager._cap_component_enable( "plugin_a", "component.enable", {"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""}, ) assert result["success"] is False assert "组件名不唯一" in result["error"] @pytest.mark.asyncio async def test_component_disable_rejects_non_global_scope(self, monkeypatch): from src.plugin_runtime import integration as integration_module from src.plugin_runtime.host.component_registry import ComponentRegistry class FakeSupervisor: def __init__(self): self.component_registry = ComponentRegistry() self.component_registry.register_component( name="handler", component_type="tool", plugin_id="plugin_a", metadata={}, ) class FakeManager: def __init__(self): self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) manager = integration_module.PluginRuntimeManager() manager._builtin_supervisor = FakeSupervisor() result = await manager._cap_component_disable( "plugin_a", "component.disable", {"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"}, ) assert result["success"] is False assert "仅支持全局组件禁用" in result["error"] @pytest.mark.asyncio async def test_start_cleans_up_started_supervisors_on_failure(self, monkeypatch): from src.plugin_runtime import integration as integration_module instances = [] builtin_dir = Path("builtin") thirdparty_dir = Path("thirdparty") class FakeCapabilityService: def register_capability(self, name, impl): return None class FakeSupervisor: def __init__(self, plugin_dirs=None, socket_path=None): self._plugin_dirs = plugin_dirs or [] self.capability_service = FakeCapabilityService() self.external_plugin_versions = {} self.stopped = False instances.append(self) def set_external_available_plugins(self, plugin_versions): self.external_plugin_versions = dict(plugin_versions) def get_loaded_plugin_ids(self): return [] def get_loaded_plugin_versions(self): return {} async def start(self): if len(instances) == 2 and self is instances[1]: raise RuntimeError("boom") async def stop(self): self.stopped = True monkeypatch.setattr( integration_module.PluginRuntimeManager, "_get_builtin_plugin_dirs", staticmethod(lambda: [builtin_dir]) ) monkeypatch.setattr( integration_module.PluginRuntimeManager, "_get_third_party_plugin_dirs", staticmethod(lambda: [thirdparty_dir]) ) import src.plugin_runtime.host.supervisor as supervisor_module monkeypatch.setattr(supervisor_module, "PluginSupervisor", FakeSupervisor) manager = integration_module.PluginRuntimeManager() await manager.start() assert manager.is_running is False assert len(instances) == 2 assert instances[0].stopped is True @pytest.mark.asyncio async def test_handle_plugin_source_changes_restarts_supervisors_after_dependency_sync(self, monkeypatch, tmp_path): from src.config.file_watcher import FileChange from src.plugin_runtime import integration as integration_module import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" alpha_dir = builtin_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") monkeypatch.chdir(tmp_path) class FakeSupervisor: def __init__(self, plugin_dirs, registered_plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = registered_plugins self.config_updates = [] def get_loaded_plugin_ids(self): return sorted(self._registered_plugins.keys()) def get_loaded_plugin_versions(self): return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} async def notify_plugin_config_updated(self, plugin_id, config_data, config_version=""): self.config_updates.append((plugin_id, config_data, config_version)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], {"test.alpha": object()}) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], {"test.beta": object()}) dependency_sync_calls = [] restart_calls = [] async def fake_sync(plugin_dirs: Sequence[Path]) -> Any: """记录依赖同步调用。""" dependency_sync_calls.append(list(plugin_dirs)) return integration_module.DependencySyncState( blocked_changed_plugin_ids={"test.beta"}, environment_changed=False, ) async def fake_restart(reason: str) -> bool: """记录 Supervisor 重启调用。""" restart_calls.append(reason) return True monkeypatch.setattr(manager, "_sync_plugin_dependencies", fake_sync) monkeypatch.setattr(manager, "_restart_supervisors", fake_restart) changes = [ FileChange(change_type=1, path=beta_dir / "plugin.py"), ] await manager._handle_plugin_source_changes(changes) assert dependency_sync_calls == [[builtin_root, thirdparty_root]] assert restart_calls == ["file_watcher_blocklist_changed"] assert manager._builtin_supervisor.config_updates == [] assert manager._third_party_supervisor.config_updates == [] @pytest.mark.asyncio async def test_reload_plugins_globally_warns_and_skips_cross_supervisor_dependents(self, monkeypatch): from src.plugin_runtime import integration as integration_module class FakeRegistration: def __init__(self, dependencies): self.dependencies = dependencies class FakeSupervisor: def __init__(self, registrations): self._registered_plugins = registrations self.reload_calls = [] def get_loaded_plugin_ids(self): return sorted(self._registered_plugins.keys()) def get_loaded_plugin_versions(self): return {plugin_id: "1.0.0" for plugin_id in self._registered_plugins} async def reload_plugins(self, plugin_ids=None, reason="manual", external_available_plugins=None): self.reload_calls.append((plugin_ids, reason, dict(sorted((external_available_plugins or {}).items())))) return True builtin_supervisor = FakeSupervisor({"test.alpha": FakeRegistration([])}) third_party_supervisor = FakeSupervisor( { "test.beta": FakeRegistration(["test.alpha"]), "test.gamma": FakeRegistration(["test.beta"]), } ) manager = integration_module.PluginRuntimeManager() manager._builtin_supervisor = builtin_supervisor manager._third_party_supervisor = third_party_supervisor warning_messages = [] monkeypatch.setattr( integration_module.logger, "warning", lambda message: warning_messages.append(message), ) reloaded = await manager.reload_plugins_globally(["test.alpha"], reason="manual") assert reloaded is True assert builtin_supervisor.reload_calls == [ (["test.alpha"], "manual", {"test.beta": "1.0.0", "test.gamma": "1.0.0"}) ] assert third_party_supervisor.reload_calls == [] assert len(warning_messages) == 1 assert "test.beta, test.gamma" in warning_messages[0] assert "跨 Supervisor API 调用仍然可用" in warning_messages[0] @pytest.mark.asyncio async def test_handle_plugin_config_changes_only_notify_target_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" alpha_dir = builtin_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("enabled = true\n", encoding="utf-8") (beta_dir / "config.toml").write_text("enabled = false\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") monkeypatch.chdir(tmp_path) class FakeSupervisor: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} self.config_updates = [] async def inspect_plugin_config( self, plugin_id: str, config_data: Optional[Dict[str, Any]] = None, use_provided_config: bool = False, ) -> SimpleNamespace: """返回测试用的配置解析结果。""" del config_data, use_provided_config return SimpleNamespace(enabled=True, normalized_config={"enabled": True}, plugin_id=plugin_id) async def notify_plugin_config_updated( self, plugin_id, config_data, config_version="", config_scope="self", ): self.config_updates.append((plugin_id, config_data, config_version, config_scope)) return True manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) await manager._handle_plugin_config_changes( "test.alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) assert manager._builtin_supervisor.config_updates == [("test.alpha", {"enabled": True}, "", "self")] assert manager._third_party_supervisor.config_updates == [] @pytest.mark.asyncio async def test_handle_plugin_config_changes_loads_unloaded_enabled_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange import json thirdparty_root = tmp_path / "plugins" alpha_dir = thirdparty_root / "alpha" alpha_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("[plugin]\nenabled = true\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") monkeypatch.chdir(tmp_path) class FakeSupervisor: def __init__(self, plugin_dirs): self._plugin_dirs = plugin_dirs self._registered_plugins = {} async def inspect_plugin_config( self, plugin_id: str, config_data: Optional[Dict[str, Any]] = None, use_provided_config: bool = False, ) -> SimpleNamespace: """返回测试用的启用配置快照。""" del config_data, use_provided_config return SimpleNamespace(enabled=True, normalized_config={"plugin": {"enabled": True}}, plugin_id=plugin_id) manager = integration_module.PluginRuntimeManager() manager._started = True manager._third_party_supervisor = FakeSupervisor([thirdparty_root]) load_calls = [] async def fake_load_plugin_globally(plugin_id: str, reason: str = "manual") -> bool: """记录自动加载调用。""" load_calls.append((plugin_id, reason)) return True monkeypatch.setattr(manager, "load_plugin_globally", fake_load_plugin_globally) await manager._handle_plugin_config_changes( "test.alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) assert load_calls == [("test.alpha", "config_enabled")] @pytest.mark.asyncio async def test_handle_plugin_config_changes_unloads_loaded_disabled_plugin(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module from src.config.file_watcher import FileChange import json builtin_root = tmp_path / "src" / "plugins" / "built_in" alpha_dir = builtin_root / "alpha" alpha_dir.mkdir(parents=True) (alpha_dir / "config.toml").write_text("[plugin]\nenabled = false\n", encoding="utf-8") (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") monkeypatch.chdir(tmp_path) class FakeSupervisor: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} async def inspect_plugin_config( self, plugin_id: str, config_data: Optional[Dict[str, Any]] = None, use_provided_config: bool = False, ) -> SimpleNamespace: """返回测试用的禁用配置快照。""" del config_data, use_provided_config return SimpleNamespace( enabled=False, normalized_config={"plugin": {"enabled": False}}, plugin_id=plugin_id, ) manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) reload_calls = [] async def fake_reload_plugins_globally(plugin_ids: Sequence[str], reason: str = "manual") -> bool: """记录自动卸载调用。""" reload_calls.append((list(plugin_ids), reason)) return True monkeypatch.setattr(manager, "reload_plugins_globally", fake_reload_plugins_globally) await manager._handle_plugin_config_changes( "test.alpha", [FileChange(change_type=1, path=alpha_dir / "config.toml")], ) assert reload_calls == [(["test.alpha"], "config_disabled")] @pytest.mark.asyncio async def test_handle_main_config_reload_only_notifies_subscribers(self, monkeypatch): from src.plugin_runtime import integration as integration_module class FakeRegistration: def __init__(self, subscriptions): self.config_reload_subscriptions = subscriptions class FakeSupervisor: def __init__(self, registrations): self._registered_plugins = registrations self.config_updates = [] def get_config_reload_subscribers(self, scope): matched_plugins = [] for plugin_id, registration in self._registered_plugins.items(): if scope in registration.config_reload_subscriptions: matched_plugins.append(plugin_id) return matched_plugins async def notify_plugin_config_updated( self, plugin_id, config_data, config_version="", config_scope="self", ): self.config_updates.append((plugin_id, config_data, config_version, config_scope)) return True fake_global = SimpleNamespace(plugin_runtime=SimpleNamespace(enabled=True)) monkeypatch.setattr( integration_module.config_manager, "get_global_config", lambda: SimpleNamespace(model_dump=lambda: {"bot": {"name": "MaiBot"}}, plugin_runtime=fake_global.plugin_runtime), ) monkeypatch.setattr( integration_module.config_manager, "get_model_config", lambda: SimpleNamespace(model_dump=lambda: {"models": [{"name": "demo"}]}), ) manager = integration_module.PluginRuntimeManager() manager._started = True manager._builtin_supervisor = FakeSupervisor( { "test.alpha": FakeRegistration(["bot"]), "test.beta": FakeRegistration([]), } ) manager._third_party_supervisor = FakeSupervisor( { "test.gamma": FakeRegistration(["model"]), } ) await manager._handle_main_config_reload(["bot", "model"]) assert manager._builtin_supervisor.config_updates == [ ("test.alpha", {"bot": {"name": "MaiBot"}}, "", "bot") ] assert manager._third_party_supervisor.config_updates == [ ("test.gamma", {"models": [{"name": "demo"}]}, "", "model") ] def test_refresh_plugin_config_watch_subscriptions_registers_per_plugin(self, tmp_path): from src.plugin_runtime import integration as integration_module import json builtin_root = tmp_path / "src" / "plugins" / "built_in" thirdparty_root = tmp_path / "plugins" alpha_dir = builtin_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") class FakeWatcher: def __init__(self): self.subscriptions = [] self.unsubscribed = [] def subscribe(self, callback, *, paths=None, change_types=None): subscription_id = f"sub-{len(self.subscriptions) + 1}" self.subscriptions.append({"id": subscription_id, "callback": callback, "paths": tuple(paths or ())}) return subscription_id def unsubscribe(self, subscription_id): self.unsubscribed.append(subscription_id) return True class FakeSupervisor: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} manager = integration_module.PluginRuntimeManager() manager._plugin_file_watcher = FakeWatcher() manager._builtin_supervisor = FakeSupervisor([builtin_root], ["test.alpha"]) manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.beta"]) manager._refresh_plugin_config_watch_subscriptions() assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"} assert { subscription["paths"][0] for subscription in manager._plugin_file_watcher.subscriptions } == {alpha_dir / "config.toml", beta_dir / "config.toml"} def test_refresh_plugin_config_watch_subscriptions_includes_unloaded_plugins(self, tmp_path): from src.plugin_runtime import integration as integration_module import json thirdparty_root = tmp_path / "plugins" alpha_dir = thirdparty_root / "alpha" beta_dir = thirdparty_root / "beta" alpha_dir.mkdir(parents=True) beta_dir.mkdir(parents=True) (alpha_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (beta_dir / "plugin.py").write_text("def create_plugin():\n return object()\n", encoding="utf-8") (alpha_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.alpha")), encoding="utf-8") (beta_dir / "_manifest.json").write_text(json.dumps(build_test_manifest("test.beta")), encoding="utf-8") class FakeWatcher: def __init__(self): self.subscriptions = [] def subscribe( self, callback: Any, *, paths: Optional[Sequence[Path]] = None, change_types: Any = None, ) -> str: """记录新的监听订阅。""" del callback, change_types subscription_id = f"sub-{len(self.subscriptions) + 1}" self.subscriptions.append({"id": subscription_id, "paths": tuple(paths or ())}) return subscription_id def unsubscribe(self, subscription_id: str) -> bool: """兼容 watcher 取消订阅接口。""" del subscription_id return True class FakeSupervisor: def __init__(self, plugin_dirs, plugins): self._plugin_dirs = plugin_dirs self._registered_plugins = {plugin_id: object() for plugin_id in plugins} manager = integration_module.PluginRuntimeManager() manager._plugin_file_watcher = FakeWatcher() manager._third_party_supervisor = FakeSupervisor([thirdparty_root], ["test.alpha"]) manager._refresh_plugin_config_watch_subscriptions() assert set(manager._plugin_config_watcher_subscriptions.keys()) == {"test.alpha", "test.beta"} @pytest.mark.asyncio async def test_component_reload_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch): from src.plugin_runtime import integration as integration_module manager = integration_module.PluginRuntimeManager() monkeypatch.setattr(manager, "reload_plugins_globally", lambda plugin_ids, reason="manual": asyncio.sleep(0, False)) result = await manager._cap_component_reload_plugin( "plugin_a", "component.reload_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False assert result["error"] == "插件 alpha 热重载失败" @pytest.mark.asyncio async def test_component_load_plugin_returns_failure_when_reload_rolls_back(self, monkeypatch, tmp_path): from src.plugin_runtime import integration as integration_module manager = integration_module.PluginRuntimeManager() monkeypatch.setattr(manager, "load_plugin_globally", lambda plugin_id, reason="manual": asyncio.sleep(0, False)) result = await manager._cap_component_load_plugin( "plugin_a", "component.load_plugin", {"plugin_name": "alpha"}, ) assert result["success"] is False assert result["error"] == "插件 alpha 热重载失败"