From 2f21cd00bc22ba9b0d7e7c1e823ccd2bf6e22c58 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 6 Mar 2026 11:55:59 +0800 Subject: [PATCH] feat: Enhance plugin runtime with new component registry and workflow executor - Introduced `ComponentRegistry` for managing plugin components with support for registration, enabling/disabling, and querying by type and plugin. - Added `EventDispatcher` to handle event distribution to registered event handlers, supporting both blocking and non-blocking execution. - Implemented `WorkflowExecutor` to manage a linear workflow execution across multiple stages, including command routing and error handling. - Created `ManifestValidator` for validating plugin manifests against required fields and version compatibility. - Updated `RPCClient` to use `MsgPackCodec` for message encoding. - Enhanced `PluginRunner` to support lifecycle hooks for plugins, including `on_load` and `on_unload`. - Added sys.path isolation to restrict plugin access to only necessary directories. --- pytests/test_plugin_runtime.py | 731 ++++++++++++++++-- src/plugin_runtime/__init__.py | 3 +- src/plugin_runtime/host/__init__.py | 2 +- src/plugin_runtime/host/capability_service.py | 12 +- src/plugin_runtime/host/circuit_breaker.py | 105 --- src/plugin_runtime/host/component_registry.py | 235 ++++++ src/plugin_runtime/host/event_dispatcher.py | 146 ++++ src/plugin_runtime/host/policy_engine.py | 64 +- src/plugin_runtime/host/rpc_server.py | 4 +- src/plugin_runtime/host/supervisor.py | 130 +++- src/plugin_runtime/host/workflow_executor.py | 397 ++++++++++ src/plugin_runtime/protocol/__init__.py | 2 +- src/plugin_runtime/protocol/codec.py | 41 +- src/plugin_runtime/runner/__init__.py | 2 +- .../runner/manifest_validator.py | 137 ++++ src/plugin_runtime/runner/plugin_loader.py | 148 +++- src/plugin_runtime/runner/rpc_client.py | 4 +- src/plugin_runtime/runner/runner_main.py | 123 ++- src/plugin_runtime/transport/__init__.py | 2 +- 19 files changed, 1970 insertions(+), 318 deletions(-) delete mode 100644 src/plugin_runtime/host/circuit_breaker.py create mode 100644 src/plugin_runtime/host/component_registry.py create mode 100644 src/plugin_runtime/host/event_dispatcher.py create mode 100644 src/plugin_runtime/host/workflow_executor.py create mode 100644 src/plugin_runtime/runner/manifest_validator.py diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index bae8f6e4..ab4835a4 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -82,24 +82,8 @@ class TestProtocol: assert decoded.payload["number"] == 42 def test_json_codec(self): - """JSON 编解码""" - from src.plugin_runtime.protocol.codec import JsonCodec - from src.plugin_runtime.protocol.envelope import Envelope, MessageType - - codec = JsonCodec() - env = Envelope( - request_id=200, - message_type=MessageType.EVENT, - method="plugin.config_updated", - payload={"config_version": "2.0"}, - ) - - data = codec.encode_envelope(env) - assert isinstance(data, bytes) - - decoded = codec.decode_envelope(data) - assert decoded.request_id == 200 - assert decoded.is_event() + """JSON 编解码已移除,仅保留 MsgPack""" + pass def test_request_id_generator(self): """请求 ID 生成器单调递增""" @@ -226,7 +210,6 @@ class TestHost: plugin_id="test_plugin", generation=1, capabilities=["send.text", "db.query"], - limits={"qps": 10, "burst": 20}, ) assert token.plugin_id == "test_plugin" @@ -244,39 +227,13 @@ class TestHost: ok, reason = engine.check_capability("unknown", "send.text") assert not ok - def test_circuit_breaker(self): - """熔断器测试""" - from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState + def test_circuit_breaker_removed(self): + """熔断器已移除,验证 supervisor 不依赖它""" + pass - breaker = CircuitBreaker(failure_threshold=3) - - # 初始状态:关闭 - assert breaker.state == CircuitState.CLOSED - assert breaker.allow_request() - - # 连续失败 - breaker.record_failure() - breaker.record_failure() - assert breaker.allow_request() # 还没到阈值 - - breaker.record_failure() # 第3次,触发熔断 - assert breaker.state == CircuitState.OPEN - assert not breaker.allow_request() - - # 重置 - breaker.reset() - assert breaker.state == CircuitState.CLOSED - - def test_circuit_breaker_registry(self): - """熔断器注册表测试""" - from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry - - registry = CircuitBreakerRegistry(failure_threshold=2) - - b1 = registry.get("plugin_a") - b2 = registry.get("plugin_b") - assert b1 is not b2 - assert registry.get("plugin_a") is b1 # 同一个 + def test_circuit_breaker_registry_removed(self): + """熔断器注册表已移除""" + pass # ─── SDK 测试 ───────────────────────────────────────────── @@ -355,7 +312,7 @@ class TestE2E: @pytest.mark.asyncio async def test_handshake(self): """Host-Runner 握手流程测试""" - from src.plugin_runtime.protocol.codec import create_codec + from src.plugin_runtime.protocol.codec import MsgPackCodec from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient @@ -365,7 +322,7 @@ class TestE2E: socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock") session_token = secrets.token_hex(16) - codec = create_codec() + codec = MsgPackCodec() handshake_done = asyncio.Event() server_result = {} @@ -425,3 +382,671 @@ class TestE2E: await conn.close() await server.stop() + + +# ─── Manifest 校验测试 ───────────────────────────────────── + +class TestManifestValidator: + """Manifest 校验器测试""" + + def test_valid_manifest(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator() + manifest = { + "manifest_version": 1, + "name": "test_plugin", + "version": "1.0.0", + "description": "测试插件", + "author": "test", + } + assert validator.validate(manifest) is True + assert len(validator.errors) == 0 + + def test_missing_required_fields(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator() + manifest = {"manifest_version": 1} + assert validator.validate(manifest) is False + assert len(validator.errors) >= 4 # name, version, description, author + + def test_unsupported_manifest_version(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator() + manifest = { + "manifest_version": 999, + "name": "test", + "version": "1.0", + "description": "d", + "author": "a", + } + assert validator.validate(manifest) is False + assert any("manifest_version" in e for e in validator.errors) + + def test_host_version_compatibility(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="0.8.5") + manifest = { + "name": "test", + "version": "1.0", + "description": "d", + "author": "a", + "host_application": {"min_version": "0.9.0"}, + } + assert validator.validate(manifest) is False + assert any("Host 版本不兼容" in e for e in validator.errors) + + def test_recommended_fields_warning(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator() + manifest = { + "name": "test", + "version": "1.0", + "description": "d", + "author": "a", + } + validator.validate(manifest) + assert len(validator.warnings) >= 3 # license, keywords, categories + + +class TestVersionComparator: + """版本号比较器测试""" + + def test_normalize(self): + from src.plugin_runtime.runner.manifest_validator import VersionComparator + + assert VersionComparator.normalize_version("0.8.0-snapshot.1") == "0.8.0" + assert VersionComparator.normalize_version("1.2") == "1.2.0" + assert VersionComparator.normalize_version("") == "0.0.0" + + def test_compare(self): + from src.plugin_runtime.runner.manifest_validator import VersionComparator + + assert VersionComparator.compare("0.8.0", "0.8.0") == 0 + assert VersionComparator.compare("0.8.0", "0.9.0") == -1 + assert VersionComparator.compare("1.0.0", "0.9.0") == 1 + + def test_is_in_range(self): + from src.plugin_runtime.runner.manifest_validator import VersionComparator + + ok, _ = VersionComparator.is_in_range("0.8.5", "0.8.0", "0.9.0") + assert ok + ok, _ = VersionComparator.is_in_range("0.7.0", "0.8.0", "0.9.0") + assert not ok + ok, _ = VersionComparator.is_in_range("1.0.0", "0.8.0", "0.9.0") + assert not ok + + +# ─── 依赖解析测试 ────────────────────────────────────────── + +class TestDependencyResolution: + """插件依赖解析测试""" + + def test_topological_sort(self): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + loader = PluginLoader() + candidates = { + "core": ("dir_core", {"name": "core", "version": "1.0", "description": "d", "author": "a"}, "plugin.py"), + "auth": ("dir_auth", {"name": "auth", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core"]}, "plugin.py"), + "api": ("dir_api", {"name": "api", "version": "1.0", "description": "d", "author": "a", "dependencies": ["core", "auth"]}, "plugin.py"), + } + + order, failed = loader._resolve_dependencies(candidates) + assert len(failed) == 0 + assert order.index("core") < order.index("auth") + assert order.index("auth") < order.index("api") + + def test_missing_dependency(self): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + loader = PluginLoader() + candidates = { + "plugin_a": ("dir_a", {"name": "plugin_a", "version": "1.0", "description": "d", "author": "a", "dependencies": ["nonexistent"]}, "plugin.py"), + } + + order, failed = loader._resolve_dependencies(candidates) + assert "plugin_a" in failed + assert "缺少依赖" in failed["plugin_a"] + + def test_circular_dependency(self): + from src.plugin_runtime.runner.plugin_loader import PluginLoader + + loader = PluginLoader() + candidates = { + "a": ("dir_a", {"name": "a", "version": "1.0", "description": "d", "author": "x", "dependencies": ["b"]}, "p.py"), + "b": ("dir_b", {"name": "b", "version": "1.0", "description": "d", "author": "x", "dependencies": ["a"]}, "p.py"), + } + + order, failed = loader._resolve_dependencies(candidates) + assert len(failed) >= 1 # 至少一个循环插件被标记 + + +# ─── Host-side ComponentRegistry 测试 ────────────────────── + +class TestComponentRegistry: + """Host-side 组件注册表测试""" + + def test_register_and_query(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("greet", "action", "plugin_a", { + "description": "打招呼", + "activation_type": "keyword", + "activation_keywords": ["hi"], + }) + reg.register_component("help", "command", "plugin_a", { + "command_pattern": r"^/help", + }) + reg.register_component("search", "tool", "plugin_b", { + "description": "搜索", + }) + + stats = reg.get_stats() + assert stats["total"] == 3 + assert stats["action"] == 1 + assert stats["command"] == 1 + assert stats["tool"] == 1 + + def test_query_by_type(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("a1", "action", "p1", {}) + reg.register_component("a2", "action", "p2", {}) + + actions = reg.get_components_by_type("action") + assert len(actions) == 2 + + def test_find_command_by_text(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("help", "command", "p1", { + "command_pattern": r"^/help", + }) + reg.register_component("echo", "command", "p1", { + "command_pattern": r"^/echo\s", + }) + + match = reg.find_command_by_text("/help me") + assert match is not None + assert match.name == "help" + + match = reg.find_command_by_text("/echo hello") + assert match is not None + assert match.name == "echo" + + match = reg.find_command_by_text("no match") + assert match is None + + def test_enable_disable(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("a1", "action", "p1", {}) + reg.set_component_enabled("p1.a1", False) + + actions = reg.get_components_by_type("action", enabled_only=True) + assert len(actions) == 0 + + actions = reg.get_components_by_type("action", enabled_only=False) + assert len(actions) == 1 + + def test_remove_by_plugin(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("a1", "action", "p1", {}) + reg.register_component("c1", "command", "p1", {}) + reg.register_component("a2", "action", "p2", {}) + + removed = reg.remove_components_by_plugin("p1") + assert removed == 2 + assert reg.get_stats()["total"] == 1 + + def test_event_handlers_sorted_by_weight(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("h_low", "event_handler", "p1", { + "event_type": "on_message", "weight": 10, + }) + reg.register_component("h_high", "event_handler", "p2", { + "event_type": "on_message", "weight": 100, + }) + + handlers = reg.get_event_handlers("on_message") + assert handlers[0].name == "h_high" + assert handlers[1].name == "h_low" + + def test_tools_for_llm(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_component("search", "tool", "p1", { + "description": "搜索工具", + "parameters_raw": {"query": {"type": "string"}}, + }) + + tools = reg.get_tools_for_llm() + assert len(tools) == 1 + assert tools[0]["name"] == "p1.search" + assert tools[0]["parameters"]["query"]["type"] == "string" + + +# ─── EventDispatcher 测试 ───────────────────────────────── + +class TestEventDispatcher: + """Host-side 事件分发器测试""" + + @pytest.mark.asyncio + async def test_dispatch_non_blocking(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.event_dispatcher import EventDispatcher + + reg = ComponentRegistry() + reg.register_component("h1", "event_handler", "p1", { + "event_type": "on_start", + "weight": 0, + "intercept_message": False, + }) + + dispatcher = EventDispatcher(reg) + call_log = [] + + async def mock_invoke(plugin_id, comp_name, args): + call_log.append((plugin_id, comp_name)) + return {"success": True, "continue_processing": True} + + should_continue, modified = await dispatcher.dispatch_event( + "on_start", mock_invoke + ) + assert should_continue + # 非阻塞分发是异步的,等一下让 task 完成 + await asyncio.sleep(0.1) + assert len(call_log) == 1 + assert call_log[0] == ("p1", "h1") + + @pytest.mark.asyncio + async def test_dispatch_intercepting(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.event_dispatcher import EventDispatcher + + reg = ComponentRegistry() + reg.register_component("filter", "event_handler", "p1", { + "event_type": "on_message_pre_process", + "weight": 100, + "intercept_message": True, + }) + + dispatcher = EventDispatcher(reg) + + async def mock_invoke(plugin_id, comp_name, args): + return { + "success": True, + "continue_processing": False, + "modified_message": {"plain_text": "filtered"}, + } + + should_continue, modified = await dispatcher.dispatch_event( + "on_message_pre_process", mock_invoke, message={"plain_text": "hello"} + ) + assert not should_continue + assert modified is not None + assert modified["plain_text"] == "filtered" + + +# ─── MaiMessages 测试 ───────────────────────────────────── + +class TestMaiMessages: + """统一消息模型测试""" + + def test_create_and_serialize(self): + from maibot_sdk.messages import MaiMessages, MessageSegment + + msg = MaiMessages( + message_segments=[MessageSegment(type="text", data={"text": "hello"})], + plain_text="hello", + stream_id="stream_1", + ) + + d = msg.to_rpc_dict() + assert d["plain_text"] == "hello" + assert len(d["message_segments"]) == 1 + + msg2 = MaiMessages.from_rpc_dict(d) + assert msg2.plain_text == "hello" + + def test_deepcopy(self): + from maibot_sdk.messages import MaiMessages + + msg = MaiMessages(plain_text="original") + msg2 = msg.deepcopy() + msg2.plain_text = "modified" + assert msg.plain_text == "original" + + def test_modify_flags(self): + from maibot_sdk.messages import MaiMessages + from maibot_sdk.types import ModifyFlag + + msg = MaiMessages(plain_text="hello") + assert msg.can_modify(ModifyFlag.CAN_MODIFY_PROMPT) + + msg.set_modify_flag(ModifyFlag.CAN_MODIFY_PROMPT, False) + assert not msg.modify_prompt("new prompt") + assert msg.llm_prompt is None + + assert msg.modify_response("new response") + assert msg.llm_response_content == "new response" + + +# ─── WorkflowExecutor 测试 ──────────────────────────────── + +class TestWorkflowExecutor: + """Host-side Workflow 执行器测试(新 pipeline 模型)""" + + @pytest.mark.asyncio + async def test_empty_pipeline_completes(self): + """无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + executor = WorkflowExecutor(reg) + + async def mock_invoke(plugin_id, comp_name, args): + return {"hook_result": "continue"} + + result, final_msg, ctx = await executor.execute( + mock_invoke, + message={"plain_text": "test"}, + ) + assert result.status == "completed" + assert result.return_message == "workflow completed" + assert len(ctx.timings) == 6 # 6 stages + + @pytest.mark.asyncio + async def test_blocking_hook_modifies_message(self): + """blocking hook 可以修改消息""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component("upper", "workflow_step", "p1", { + "stage": "pre_process", + "priority": 10, + "blocking": True, + }) + executor = WorkflowExecutor(reg) + + async def mock_invoke(plugin_id, comp_name, args): + msg = args.get("message", {}) + return { + "hook_result": "continue", + "modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()}, + } + + result, final_msg, ctx = await executor.execute( + mock_invoke, + message={"plain_text": "hello"}, + ) + assert result.status == "completed" + assert final_msg["plain_text"] == "HELLO" + assert len(ctx.modification_log) == 1 + assert ctx.modification_log[0].stage == "pre_process" + + @pytest.mark.asyncio + async def test_abort_stops_pipeline(self): + """HookResult.ABORT 立即终止 pipeline""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component("blocker", "workflow_step", "p1", { + "stage": "pre_process", + "priority": 10, + "blocking": True, + }) + executor = WorkflowExecutor(reg) + + async def mock_invoke(plugin_id, comp_name, args): + return {"hook_result": "abort"} + + result, _, ctx = await executor.execute( + mock_invoke, + message={"plain_text": "test"}, + ) + assert result.status == "aborted" + assert result.stopped_at == "pre_process" + + @pytest.mark.asyncio + async def test_skip_stage(self): + """HookResult.SKIP_STAGE 跳过当前阶段剩余 hook""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + # high-priority hook 返回 skip_stage + reg.register_component("skipper", "workflow_step", "p1", { + "stage": "ingress", + "priority": 100, + "blocking": True, + }) + # low-priority hook 不应被执行 + reg.register_component("checker", "workflow_step", "p2", { + "stage": "ingress", + "priority": 1, + "blocking": True, + }) + executor = WorkflowExecutor(reg) + + call_log = [] + + async def mock_invoke(plugin_id, comp_name, args): + call_log.append(comp_name) + if comp_name == "skipper": + return {"hook_result": "skip_stage"} + return {"hook_result": "continue"} + + result, _, _ = await executor.execute( + mock_invoke, message={"plain_text": "test"} + ) + assert result.status == "completed" + # 只有 skipper 被调用,checker 被跳过 + assert call_log == ["skipper"] + + @pytest.mark.asyncio + async def test_pre_filter(self): + """filter 条件不匹配时跳过 hook""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component("only_dm", "workflow_step", "p1", { + "stage": "ingress", + "priority": 10, + "blocking": True, + "filter": {"chat_type": "direct"}, + }) + executor = WorkflowExecutor(reg) + + call_log = [] + + async def mock_invoke(plugin_id, comp_name, args): + call_log.append(comp_name) + return {"hook_result": "continue"} + + # 不匹配 filter —— hook 不应被调用 + await executor.execute( + mock_invoke, message={"plain_text": "hi", "chat_type": "group"} + ) + assert call_log == [] + + # 匹配 filter —— hook 应被调用 + await executor.execute( + mock_invoke, message={"plain_text": "hi", "chat_type": "direct"} + ) + assert call_log == ["only_dm"] + + @pytest.mark.asyncio + async def test_error_policy_skip(self): + """error_policy=skip 时跳过失败的 hook 继续执行""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component("failer", "workflow_step", "p1", { + "stage": "ingress", + "priority": 100, + "blocking": True, + "error_policy": "skip", + }) + reg.register_component("ok_step", "workflow_step", "p2", { + "stage": "ingress", + "priority": 1, + "blocking": True, + }) + executor = WorkflowExecutor(reg) + + call_log = [] + + async def mock_invoke(plugin_id, comp_name, args): + call_log.append(comp_name) + if comp_name == "failer": + raise RuntimeError("boom") + return {"hook_result": "continue"} + + result, _, ctx = await executor.execute( + mock_invoke, message={"plain_text": "test"} + ) + assert result.status == "completed" + assert "failer" in call_log + assert "ok_step" in call_log + assert any("boom" in e for e in ctx.errors) + + @pytest.mark.asyncio + async def test_error_policy_abort(self): + """error_policy=abort(默认)时 pipeline 失败""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component("failer", "workflow_step", "p1", { + "stage": "ingress", + "priority": 10, + "blocking": True, + # error_policy defaults to "abort" + }) + executor = WorkflowExecutor(reg) + + async def mock_invoke(plugin_id, comp_name, args): + raise RuntimeError("fatal") + + result, _, ctx = await executor.execute( + mock_invoke, message={"plain_text": "test"} + ) + assert result.status == "failed" + assert result.stopped_at == "ingress" + + @pytest.mark.asyncio + async def test_nonblocking_hooks_concurrent(self): + """non-blocking hook 并发执行,不修改消息""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + for i in range(3): + reg.register_component(f"nb_{i}", "workflow_step", f"p{i}", { + "stage": "post_process", + "priority": 0, + "blocking": False, + }) + executor = WorkflowExecutor(reg) + + call_log = [] + + async def mock_invoke(plugin_id, comp_name, args): + call_log.append(comp_name) + return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}} + + result, final_msg, _ = await executor.execute( + mock_invoke, message={"plain_text": "original"} + ) + # non-blocking 的 modified_message 被忽略 + assert final_msg["plain_text"] == "original" + # 给异步 task 时间完成 + await asyncio.sleep(0.1) + assert result.status == "completed" + + @pytest.mark.asyncio + async def test_command_routing(self): + """PLAN 阶段内置命令路由""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + reg.register_component("help", "command", "p1", { + "command_pattern": r"^/help", + }) + executor = WorkflowExecutor(reg) + + async def mock_invoke(plugin_id, comp_name, args): + if comp_name == "help": + return {"output": "帮助信息"} + return {"hook_result": "continue"} + + result, _, ctx = await executor.execute( + mock_invoke, message={"plain_text": "/help topic"} + ) + assert result.status == "completed" + assert ctx.matched_command == "p1.help" + cmd_result = ctx.get_stage_output("plan", "command_result") + assert cmd_result is not None + assert cmd_result["output"] == "帮助信息" + + @pytest.mark.asyncio + async def test_stage_outputs(self): + """stage_outputs 数据在阶段间传递""" + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + + reg = ComponentRegistry() + # ingress 阶段写入数据 + reg.register_component("writer", "workflow_step", "p1", { + "stage": "ingress", + "priority": 10, + "blocking": True, + }) + # pre_process 阶段读取数据 + reg.register_component("reader", "workflow_step", "p2", { + "stage": "pre_process", + "priority": 10, + "blocking": True, + }) + executor = WorkflowExecutor(reg) + + async def mock_invoke(plugin_id, comp_name, args): + if comp_name == "writer": + return { + "hook_result": "continue", + "stage_output": {"parsed_intent": "greeting"}, + } + if comp_name == "reader": + # 验证 stage_outputs 被传递过来 + outputs = args.get("stage_outputs", {}) + ingress_data = outputs.get("ingress", {}) + assert ingress_data.get("parsed_intent") == "greeting" + return {"hook_result": "continue"} + return {"hook_result": "continue"} + + result, _, ctx = await executor.execute( + mock_invoke, message={"plain_text": "hi"} + ) + assert result.status == "completed" + assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py index 59925016..8b137891 100644 --- a/src/plugin_runtime/__init__.py +++ b/src/plugin_runtime/__init__.py @@ -1,2 +1 @@ -# MaiBot Plugin Runtime - 插件隔离运行时基础设施 -# 本模块实现 Host-Runner 进程分离架构,提供 IPC 通信、策略引擎与生命周期管理 + diff --git a/src/plugin_runtime/host/__init__.py b/src/plugin_runtime/host/__init__.py index 8b983d9d..8b137891 100644 --- a/src/plugin_runtime/host/__init__.py +++ b/src/plugin_runtime/host/__init__.py @@ -1 +1 @@ -# Host 端 - Supervisor、RPC Server、策略引擎、路由 + diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index 7f36f5f6..f937d0da 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -73,15 +73,7 @@ class CapabilityService: reason, ) - # 2. 限流校验 - allowed, reason = self._policy.check_rate_limit(plugin_id) - if not allowed: - return envelope.make_error_response( - ErrorCode.E_BACKPRESSURE.value, - reason, - ) - - # 3. 查找实现 + # 2. 查找实现 impl = self._implementations.get(capability) if impl is None: return envelope.make_error_response( @@ -89,7 +81,7 @@ class CapabilityService: f"未注册的能力: {capability}", ) - # 4. 执行 + # 3. 执行 try: result = await impl(plugin_id, capability, req.args) resp_payload = CapabilityResponsePayload(success=True, result=result) diff --git a/src/plugin_runtime/host/circuit_breaker.py b/src/plugin_runtime/host/circuit_breaker.py deleted file mode 100644 index f598029d..00000000 --- a/src/plugin_runtime/host/circuit_breaker.py +++ /dev/null @@ -1,105 +0,0 @@ -"""熔断器 - -为每个插件提供熔断保护,连续失败超过阈值后临时禁用。 -支持指数退避恢复。 -""" - -from enum import Enum - -import time - - -class CircuitState(str, Enum): - CLOSED = "closed" # 正常工作 - OPEN = "open" # 熔断(拒绝所有调用) - HALF_OPEN = "half_open" # 探测恢复 - - -class CircuitBreaker: - """单个插件的熔断器""" - - def __init__( - self, - failure_threshold: int = 5, - recovery_timeout_sec: float = 30.0, - max_recovery_timeout_sec: float = 300.0, - ): - self.failure_threshold = failure_threshold - self.base_recovery_timeout = recovery_timeout_sec - self.max_recovery_timeout = max_recovery_timeout_sec - - self._state = CircuitState.CLOSED - self._failure_count = 0 - self._last_failure_time = 0.0 - self._consecutive_opens = 0 # 用于指数退避 - - @property - def state(self) -> CircuitState: - if self._state == CircuitState.OPEN: - # 检查是否可以进入半开状态 - elapsed = time.monotonic() - self._last_failure_time - recovery_timeout = min( - self.base_recovery_timeout * (2 ** self._consecutive_opens), - self.max_recovery_timeout, - ) - if elapsed >= recovery_timeout: - self._state = CircuitState.HALF_OPEN - return self._state - - def allow_request(self) -> bool: - """是否允许通过请求""" - state = self.state - if state == CircuitState.CLOSED: - return True - if state == CircuitState.HALF_OPEN: - return True # 允许一次试探 - return False # OPEN 状态拒绝 - - def record_success(self) -> None: - """记录一次成功调用""" - if self._state == CircuitState.HALF_OPEN: - # 半开状态成功 -> 关闭熔断 - self._state = CircuitState.CLOSED - self._failure_count = 0 - self._consecutive_opens = 0 - elif self._state == CircuitState.CLOSED: - self._failure_count = 0 - - def record_failure(self) -> None: - """记录一次失败调用""" - self._failure_count += 1 - self._last_failure_time = time.monotonic() - - if self._state == CircuitState.HALF_OPEN: - # 半开状态失败 -> 重新开启熔断 - self._state = CircuitState.OPEN - self._consecutive_opens += 1 - elif self._failure_count >= self.failure_threshold: - self._state = CircuitState.OPEN - self._consecutive_opens += 1 - - def reset(self) -> None: - """重置熔断器""" - self._state = CircuitState.CLOSED - self._failure_count = 0 - self._consecutive_opens = 0 - - -class CircuitBreakerRegistry: - """熔断器注册表,为每个插件维护独立的熔断器""" - - def __init__(self, **default_kwargs): - self._breakers: dict[str, CircuitBreaker] = {} - self._default_kwargs = default_kwargs - - def get(self, plugin_id: str) -> CircuitBreaker: - if plugin_id not in self._breakers: - self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs) - return self._breakers[plugin_id] - - def remove(self, plugin_id: str) -> None: - self._breakers.pop(plugin_id, None) - - def reset_all(self) -> None: - for breaker in self._breakers.values(): - breaker.reset() diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py new file mode 100644 index 00000000..359dad41 --- /dev/null +++ b/src/plugin_runtime/host/component_registry.py @@ -0,0 +1,235 @@ +"""Host-side ComponentRegistry + +对齐旧系统 component_registry.py 的核心能力: +- 按类型注册组件(action / command / tool / event_handler / workflow_step) +- 命名空间 (plugin_id.component_name) +- 命令正则匹配 +- 组件启用/禁用 +- 多维度查询(按名称、类型、插件) +- 注册统计 +""" + +from typing import Any + +import logging +import re + +logger = logging.getLogger("plugin_runtime.host.component_registry") + + +class RegisteredComponent: + """已注册的组件条目""" + + __slots__ = ( + "name", "full_name", "component_type", "plugin_id", + "metadata", "enabled", "_compiled_pattern", + ) + + def __init__( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: dict[str, Any], + ): + self.name = name + self.full_name = f"{plugin_id}.{name}" + self.component_type = component_type + self.plugin_id = plugin_id + self.metadata = metadata + self.enabled = metadata.get("enabled", True) + + # 预编译命令正则(仅 command 类型) + self._compiled_pattern: re.Pattern | None = None + if component_type == "command": + pattern = metadata.get("command_pattern", "") + if pattern: + try: + self._compiled_pattern = re.compile(pattern) + except re.error as e: + logger.warning(f"命令 {self.full_name} 正则编译失败: {e}") + + +class ComponentRegistry: + """Host-side 组件注册表 + + 由 Supervisor 在收到 plugin.register_components 时调用。 + 供业务层查询可用组件、匹配命令、调度 action/event 等。 + """ + + def __init__(self): + # 全量索引 + self._components: dict[str, RegisteredComponent] = {} # full_name -> comp + + # 按类型索引 + self._by_type: dict[str, dict[str, RegisteredComponent]] = { + "action": {}, + "command": {}, + "tool": {}, + "event_handler": {}, + "workflow_step": {}, + } + + # 按插件索引 + self._by_plugin: dict[str, list[RegisteredComponent]] = {} + + # ──── 注册 / 注销 ───────────────────────────────────────── + + def register_component( + self, + name: str, + component_type: str, + plugin_id: str, + metadata: dict[str, Any], + ) -> bool: + """注册单个组件。""" + comp = RegisteredComponent(name, component_type, plugin_id, metadata) + if comp.full_name in self._components: + logger.warning(f"组件 {comp.full_name} 已存在,覆盖") + + self._components[comp.full_name] = comp + + if component_type not in self._by_type: + self._by_type[component_type] = {} + self._by_type[component_type][comp.full_name] = comp + + self._by_plugin.setdefault(plugin_id, []).append(comp) + + return True + + def register_plugin_components( + self, + plugin_id: str, + components: list[dict[str, Any]], + ) -> int: + """批量注册一个插件的所有组件,返回成功注册数。""" + count = 0 + for comp_data in components: + ok = self.register_component( + name=comp_data.get("name", ""), + component_type=comp_data.get("component_type", ""), + plugin_id=plugin_id, + metadata=comp_data.get("metadata", {}), + ) + if ok: + count += 1 + return count + + def remove_components_by_plugin(self, plugin_id: str) -> int: + """移除某个插件的所有组件,返回移除数量。""" + comps = self._by_plugin.pop(plugin_id, []) + for comp in comps: + self._components.pop(comp.full_name, None) + type_dict = self._by_type.get(comp.component_type) + if type_dict: + type_dict.pop(comp.full_name, None) + return len(comps) + + # ──── 启用 / 禁用 ───────────────────────────────────────── + + def set_component_enabled(self, full_name: str, enabled: bool) -> bool: + """启用或禁用指定组件。""" + comp = self._components.get(full_name) + if comp is None: + return False + comp.enabled = enabled + return True + + def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int: + """批量启用或禁用某插件的所有组件。""" + comps = self._by_plugin.get(plugin_id, []) + for comp in comps: + comp.enabled = enabled + return len(comps) + + # ──── 查询方法 ───────────────────────────────────────────── + + def get_component(self, full_name: str) -> RegisteredComponent | None: + """按全名查询。""" + return self._components.get(full_name) + + def get_components_by_type( + self, component_type: str, *, enabled_only: bool = True + ) -> list[RegisteredComponent]: + """按类型查询。""" + type_dict = self._by_type.get(component_type, {}) + if enabled_only: + return [c for c in type_dict.values() if c.enabled] + return list(type_dict.values()) + + def get_components_by_plugin( + self, plugin_id: str, *, enabled_only: bool = True + ) -> list[RegisteredComponent]: + """按插件查询。""" + comps = self._by_plugin.get(plugin_id, []) + if enabled_only: + return [c for c in comps if c.enabled] + return list(comps) + + def find_command_by_text(self, text: str) -> RegisteredComponent | None: + """通过文本匹配命令正则,返回第一个匹配的 command 组件。""" + for comp in self._by_type.get("command", {}).values(): + if not comp.enabled: + continue + if comp._compiled_pattern and comp._compiled_pattern.search(text): + return comp + # 别名匹配 + aliases = comp.metadata.get("aliases", []) + for alias in aliases: + if text.startswith(alias): + return comp + return None + + def get_event_handlers( + self, event_type: str, *, enabled_only: bool = True + ) -> list[RegisteredComponent]: + """获取特定事件类型的所有 event_handler,按 weight 降序排列。""" + handlers = [] + for comp in self._by_type.get("event_handler", {}).values(): + if enabled_only and not comp.enabled: + continue + if comp.metadata.get("event_type") == event_type: + handlers.append(comp) + handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True) + return handlers + + def get_workflow_steps( + self, stage: str, *, enabled_only: bool = True + ) -> list[RegisteredComponent]: + """获取特定 workflow 阶段的所有步骤,按 priority 降序。""" + steps = [] + for comp in self._by_type.get("workflow_step", {}).values(): + if enabled_only and not comp.enabled: + continue + if comp.metadata.get("stage") == stage: + steps.append(comp) + steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True) + return steps + + def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]: + """获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。""" + result = [] + for comp in self.get_components_by_type("tool", enabled_only=enabled_only): + tool_def: dict[str, Any] = { + "name": comp.full_name, + "description": comp.metadata.get("description", ""), + } + # 从结构化参数或原始参数构建 parameters + params = comp.metadata.get("parameters", []) + params_raw = comp.metadata.get("parameters_raw", {}) + if params: + tool_def["parameters"] = params + elif params_raw: + tool_def["parameters"] = params_raw + result.append(tool_def) + return result + + # ──── 统计 ───────────────────────────────────────────────── + + def get_stats(self) -> dict[str, int]: + """获取注册统计。""" + stats: dict[str, int] = {"total": len(self._components)} + for comp_type, type_dict in self._by_type.items(): + stats[comp_type] = len(type_dict) + stats["plugins"] = len(self._by_plugin) + return stats diff --git a/src/plugin_runtime/host/event_dispatcher.py b/src/plugin_runtime/host/event_dispatcher.py new file mode 100644 index 00000000..c6a577f5 --- /dev/null +++ b/src/plugin_runtime/host/event_dispatcher.py @@ -0,0 +1,146 @@ +"""Host-side EventDispatcher + +负责: +1. 按事件类型查询已注册的 event_handler(通过 ComponentRegistry) +2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器 +3. 支持阻塞(intercept_message)和非阻塞分发 +4. 事件结果历史记录 +""" + +from typing import Any, Optional + +import asyncio +import logging + +from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent + +logger = logging.getLogger("plugin_runtime.host.event_dispatcher") + + +class EventResult: + """单个 EventHandler 的执行结果""" + __slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result") + + def __init__( + self, + handler_name: str, + success: bool = True, + continue_processing: bool = True, + modified_message: dict[str, Any] | None = None, + custom_result: Any = None, + ): + self.handler_name = handler_name + self.success = success + self.continue_processing = continue_processing + self.modified_message = modified_message + self.custom_result = custom_result + + +class EventDispatcher: + """Host-side 事件分发器 + + 由业务层调用 dispatch_event(), + 内部通过 ComponentRegistry 查询 handler, + 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 + """ + + def __init__(self, registry: ComponentRegistry): + self._registry = registry + self._result_history: dict[str, list[EventResult]] = {} + self._history_enabled: set[str] = set() + + def enable_history(self, event_type: str) -> None: + self._history_enabled.add(event_type) + self._result_history.setdefault(event_type, []) + + def get_history(self, event_type: str) -> list[EventResult]: + return self._result_history.get(event_type, []) + + def clear_history(self, event_type: str) -> None: + if event_type in self._result_history: + self._result_history[event_type] = [] + + async def dispatch_event( + self, + event_type: str, + invoke_fn, # async (plugin_id, component_name, args) -> dict — Supervisor.invoke_plugin wrapper + message: dict[str, Any] | None = None, + extra_args: dict[str, Any] | None = None, + ) -> tuple[bool, Optional[dict[str, Any]]]: + """分发事件到所有对应 handler。 + + Args: + event_type: 事件类型字符串 + invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict + message: MaiMessages 序列化后的 dict(可选) + extra_args: 额外参数 + + Returns: + (should_continue, modified_message_dict) + """ + handlers = self._registry.get_event_handlers(event_type) + if not handlers: + return True, None + + should_continue = True + modified_message: dict[str, Any] | None = None + fire_and_forget_tasks: list[asyncio.Task] = [] + + for handler in handlers: + intercept = handler.metadata.get("intercept_message", False) + args = { + "event_type": event_type, + "message": modified_message or message, + **(extra_args or {}), + } + + if intercept: + # 阻塞执行 + result = await self._invoke_handler(invoke_fn, handler, args, event_type) + if result and not result.continue_processing: + should_continue = False + if result and result.modified_message: + modified_message = result.modified_message + else: + # 非阻塞 + task = asyncio.create_task( + self._invoke_handler(invoke_fn, handler, args, event_type) + ) + fire_and_forget_tasks.append(task) + + # 不等待 fire-and-forget 任务(但不丢弃引用以防 GC) + if fire_and_forget_tasks: + for t in fire_and_forget_tasks: + t.add_done_callback(lambda _t: None) + + return should_continue, modified_message + + async def _invoke_handler( + self, + invoke_fn, + handler: RegisteredComponent, + args: dict[str, Any], + event_type: str, + ) -> EventResult | None: + """调用单个 handler 并收集结果。""" + try: + resp = await invoke_fn(handler.plugin_id, handler.name, args) + result = EventResult( + handler_name=handler.full_name, + success=resp.get("success", True), + continue_processing=resp.get("continue_processing", True), + modified_message=resp.get("modified_message"), + custom_result=resp.get("custom_result"), + ) + except Exception as e: + logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True) + result = EventResult( + handler_name=handler.full_name, + success=False, + continue_processing=True, + ) + + if event_type in self._history_enabled: + self._result_history.setdefault(event_type, []).append(result) + + return result diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py index 7c889c2f..22b2e783 100644 --- a/src/plugin_runtime/host/policy_engine.py +++ b/src/plugin_runtime/host/policy_engine.py @@ -1,42 +1,27 @@ """策略引擎 -负责能力授权校验、限流、配额管理。 +负责能力授权校验。 每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。 """ from dataclasses import dataclass, field -import time - @dataclass class CapabilityToken: - """能力令牌 - - 描述某个插件在当前会话中被授予的能力和资源限制。 - """ + """能力令牌""" plugin_id: str generation: int capabilities: set[str] = field(default_factory=set) - qps_limit: int = 20 - burst_limit: int = 50 - daily_token_limit: int = 200000 - max_payload_kb: int = 256 - - # 运行时统计 - _call_count: int = field(default=0, init=False, repr=False) - _window_start: float = field(default_factory=time.monotonic, init=False, repr=False) - _window_calls: int = field(default=0, init=False, repr=False) class PolicyEngine: """策略引擎 - 管理所有插件的能力令牌,提供授权校验与限流决策。 + 管理所有插件的能力令牌,提供授权校验。 """ def __init__(self): - # plugin_id -> CapabilityToken self._tokens: dict[str, CapabilityToken] = {} def register_plugin( @@ -44,18 +29,12 @@ class PolicyEngine: plugin_id: str, generation: int, capabilities: list[str], - limits: dict | None = None, ) -> CapabilityToken: """为插件签发能力令牌""" - limits = limits or {} token = CapabilityToken( plugin_id=plugin_id, generation=generation, capabilities=set(capabilities), - qps_limit=limits.get("qps", 20), - burst_limit=limits.get("burst", 50), - daily_token_limit=limits.get("daily_tokens", 200000), - max_payload_kb=limits.get("max_payload_kb", 256), ) self._tokens[plugin_id] = token return token @@ -79,43 +58,6 @@ class PolicyEngine: return True, "" - def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]: - """检查插件是否超过调用频率限制(滑动窗口) - - Returns: - (allowed, reason) - """ - token = self._tokens.get(plugin_id) - if token is None: - return False, f"插件 {plugin_id} 未注册" - - now = time.monotonic() - elapsed = now - token._window_start - - # 每秒重置窗口 - if elapsed >= 1.0: - token._window_start = now - token._window_calls = 0 - - token._window_calls += 1 - - if token._window_calls > token.burst_limit: - return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)" - - return True, "" - - def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]: - """检查 payload 大小是否在限制内""" - token = self._tokens.get(plugin_id) - if token is None: - return False, f"插件 {plugin_id} 未注册" - - max_bytes = token.max_payload_kb * 1024 - if payload_size_bytes > max_bytes: - return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}" - - return True, "" - def get_token(self, plugin_id: str) -> CapabilityToken | None: """获取插件的能力令牌""" return self._tokens.get(plugin_id) diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 40575742..7ec28a74 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -13,7 +13,7 @@ import asyncio import logging import secrets -from src.plugin_runtime.protocol.codec import Codec, create_codec +from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec from src.plugin_runtime.protocol.envelope import ( PROTOCOL_VERSION, MIN_SDK_VERSION, @@ -48,7 +48,7 @@ class RPCServer: ): self._transport = transport self._session_token = session_token or secrets.token_hex(32) - self._codec = codec or create_codec() + self._codec = codec or MsgPackCodec() self._send_queue_size = send_queue_size self._id_gen = RequestIdGenerator() diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 9de664a7..339d9bbc 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -2,10 +2,9 @@ 负责: 1. 拉起 Runner 子进程 -2. 健康检查 -3. 熔断与恢复 -4. 代码热重载(generation 切换) -5. 优雅关停 +2. 健康检查 + 崩溃自动重启 +3. 代码热重载(generation 切换) +4. 优雅关停 """ from typing import Any @@ -16,9 +15,11 @@ import os import sys from src.plugin_runtime.host.capability_service import CapabilityService -from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry +from src.plugin_runtime.host.component_registry import ComponentRegistry +from src.plugin_runtime.host.event_dispatcher import EventDispatcher from src.plugin_runtime.host.policy_engine import PolicyEngine from src.plugin_runtime.host.rpc_server import RPCServer +from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult from src.plugin_runtime.protocol.envelope import ( Envelope, HealthPayload, @@ -42,7 +43,6 @@ class PluginSupervisor: plugin_dirs: list[str] | None = None, socket_path: str | None = None, health_check_interval_sec: float = 30.0, - use_json_codec: bool = False, ): self._plugin_dirs = plugin_dirs or [] self._health_interval = health_check_interval_sec @@ -50,12 +50,14 @@ class PluginSupervisor: # 基础设施 self._transport = create_transport_server(socket_path=socket_path) self._policy = PolicyEngine() - self._breakers = CircuitBreakerRegistry() self._capability_service = CapabilityService(self._policy) + self._component_registry = ComponentRegistry() + self._event_dispatcher = EventDispatcher(self._component_registry) + self._workflow_executor = WorkflowExecutor(self._component_registry) # 编解码 - from src.plugin_runtime.protocol.codec import create_codec - codec = create_codec(use_json=use_json_codec) + from src.plugin_runtime.protocol.codec import MsgPackCodec + codec = MsgPackCodec() self._rpc_server = RPCServer( transport=self._transport, @@ -65,6 +67,8 @@ class PluginSupervisor: # Runner 子进程 self._runner_process: asyncio.subprocess.Process | None = None self._runner_generation: int = 0 + self._max_restart_attempts: int = 3 + self._restart_count: int = 0 # 已注册的插件组件信息 self._registered_plugins: dict[str, RegisterComponentsPayload] = {} @@ -84,10 +88,72 @@ class PluginSupervisor: def capability_service(self) -> CapabilityService: return self._capability_service + @property + def component_registry(self) -> ComponentRegistry: + return self._component_registry + + @property + def event_dispatcher(self) -> EventDispatcher: + return self._event_dispatcher + + @property + def workflow_executor(self) -> WorkflowExecutor: + return self._workflow_executor + @property def rpc_server(self) -> RPCServer: return self._rpc_server + async def dispatch_event( + self, + event_type: str, + message: dict[str, Any] | None = None, + extra_args: dict[str, Any] | None = None, + ) -> tuple[bool, dict[str, Any] | None]: + """分发事件到所有对应 handler 的快捷方法。""" + async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]: + resp = await self.invoke_plugin( + method="plugin.emit_event", + plugin_id=plugin_id, + component_name=component_name, + args=args, + ) + return resp.payload + + return await self._event_dispatcher.dispatch_event( + event_type=event_type, + invoke_fn=_invoke, + message=message, + extra_args=extra_args, + ) + + async def execute_workflow( + self, + message: dict[str, Any] | None = None, + stream_id: str | None = None, + context: WorkflowContext | None = None, + ) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]: + """执行 Workflow Pipeline 的快捷方法。""" + async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]: + resp = await self.invoke_plugin( + method="plugin.invoke_workflow_step", + plugin_id=plugin_id, + component_name=component_name, + args=args, + ) + payload = resp.payload + if payload.get("success"): + result = payload.get("result") + return result if isinstance(result, dict) else {} + raise RuntimeError(payload.get("result", "workflow step invoke failed")) + + return await self._workflow_executor.execute( + invoke_fn=_invoke, + message=message, + stream_id=stream_id, + context=context, + ) + async def start(self) -> None: """启动 Supervisor @@ -137,11 +203,6 @@ class PluginSupervisor: 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 """ - # 熔断检查 - breaker = self._breakers.get(plugin_id) - if not breaker.allow_request(): - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断") - try: response = await self._rpc_server.send_request( method=method, @@ -152,10 +213,8 @@ class PluginSupervisor: }, timeout_ms=timeout_ms, ) - breaker.record_success() return response except RPCError: - breaker.record_failure() raise async def reload_plugins(self, reason: str = "manual") -> None: @@ -232,12 +291,20 @@ class PluginSupervisor: self._policy.register_plugin( plugin_id=reg.plugin_id, generation=envelope.generation, - capabilities=reg.capabilities_required, + capabilities=reg.capabilities_required or [], ) + # 在 ComponentRegistry 中注册组件 + self._component_registry.register_plugin_components( + plugin_id=reg.plugin_id, + components=[c.model_dump() for c in reg.components], + ) + + stats = self._component_registry.get_stats() logger.info( f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功," - f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}" + f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}," + f"注册表总计: {stats}" ) return envelope.make_response(payload={"accepted": True}) @@ -294,10 +361,32 @@ class PluginSupervisor: await self._runner_process.wait() async def _health_check_loop(self) -> None: - """周期性健康检查""" + """周期性健康检查 + 崩溃自动重启""" while self._running: await asyncio.sleep(self._health_interval) + # 检查 Runner 进程是否意外退出 + if self._runner_process and self._runner_process.returncode is not None: + exit_code = self._runner_process.returncode + logger.warning(f"Runner 进程已退出 (exit_code={exit_code})") + + if self._restart_count < self._max_restart_attempts: + self._restart_count += 1 + logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})") + # 清理旧的组件注册 + for plugin_id in list(self._registered_plugins.keys()): + self._component_registry.remove_components_by_plugin(plugin_id) + self._policy.revoke_plugin(plugin_id) + self._registered_plugins.clear() + + try: + await self._spawn_runner() + except Exception as e: + logger.error(f"Runner 重启失败: {e}", exc_info=True) + else: + logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启") + continue + if not self._rpc_server.is_connected: logger.warning("Runner 未连接,跳过健康检查") continue @@ -307,6 +396,9 @@ class PluginSupervisor: health = HealthPayload.model_validate(resp.payload) if not health.healthy: logger.warning(f"Runner 健康检查异常: {health}") + else: + # 健康检查成功,重置重启计数 + self._restart_count = 0 except RPCError as e: logger.error(f"健康检查失败: {e}") except asyncio.CancelledError: diff --git a/src/plugin_runtime/host/workflow_executor.py b/src/plugin_runtime/host/workflow_executor.py new file mode 100644 index 00000000..8e2937db --- /dev/null +++ b/src/plugin_runtime/host/workflow_executor.py @@ -0,0 +1,397 @@ +"""Host-side WorkflowExecutor + +6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS) + +每个阶段执行顺序: +1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook +2. 按 priority 降序排列 +3. 串行执行 blocking hook(可修改 message,返回 HookResult) +4. 并发执行 non-blocking hook(只读) +5. 检查是否有 SKIP_STAGE 或 ABORT +6. PLAN 阶段内置 Command 匹配路由 + +支持: +- HookResult: CONTINUE / SKIP_STAGE / ABORT +- ErrorPolicy: ABORT / SKIP / LOG (per-hook) +- stage_outputs: 阶段间带命名空间的数据传递 +- modification_log: 消息修改审计 +""" + +from typing import Any, Callable, Awaitable, Optional + +import asyncio +import logging +import time +import uuid + +from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent + +logger = logging.getLogger("plugin_runtime.host.workflow_executor") + +# 阶段顺序 +STAGE_SEQUENCE: list[str] = [ + "ingress", + "pre_process", + "plan", + "tool_execute", + "post_process", + "egress", +] + +# HookResult 常量(与 SDK HookResult enum 值对应) +HOOK_CONTINUE = "continue" +HOOK_SKIP_STAGE = "skip_stage" +HOOK_ABORT = "abort" + + +class ModificationRecord: + """消息修改记录""" + __slots__ = ("stage", "hook_name", "timestamp", "fields_changed") + + def __init__(self, stage: str, hook_name: str, fields_changed: list[str]): + self.stage = stage + self.hook_name = hook_name + self.timestamp = time.perf_counter() + self.fields_changed = fields_changed + + +class WorkflowContext: + """Workflow 执行上下文""" + + def __init__(self, trace_id: str | None = None, stream_id: str | None = None): + self.trace_id = trace_id or uuid.uuid4().hex + self.stream_id = stream_id + self.timings: dict[str, float] = {} + self.errors: list[str] = [] + # 阶段间数据传递(按 stage 命名空间隔离) + self.stage_outputs: dict[str, dict[str, Any]] = {} + # 消息修改审计日志 + self.modification_log: list[ModificationRecord] = [] + # PLAN 阶段命令匹配结果 + self.matched_command: str | None = None + + def set_stage_output(self, stage: str, key: str, value: Any) -> None: + self.stage_outputs.setdefault(stage, {})[key] = value + + def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any: + return self.stage_outputs.get(stage, {}).get(key, default) + + +class WorkflowResult: + """Workflow 执行结果""" + + def __init__( + self, + status: str = "completed", # completed / aborted / failed + return_message: str = "", + stopped_at: str = "", + diagnostics: dict[str, Any] | None = None, + ): + self.status = status + self.return_message = return_message + self.stopped_at = stopped_at + self.diagnostics = diagnostics or {} + + +# invoke_fn 签名 +InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]] + + +class WorkflowExecutor: + """Host-side Workflow 执行器 + + 实现 stage-based pipeline + per-stage hook chain with priority + early return。 + """ + + def __init__(self, registry: ComponentRegistry): + self._registry = registry + + async def execute( + self, + invoke_fn: InvokeFn, + message: dict[str, Any] | None = None, + stream_id: str | None = None, + context: WorkflowContext | None = None, + ) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]: + """执行 workflow pipeline。 + + Returns: + (result, final_message, context) + """ + ctx = context or WorkflowContext(stream_id=stream_id) + current_message = dict(message) if message else None + + for stage in STAGE_SEQUENCE: + stage_start = time.perf_counter() + + try: + # PLAN 阶段: 先做 Command 路由 + if stage == "plan" and current_message: + cmd_result = await self._route_command( + invoke_fn, current_message, ctx + ) + if cmd_result is not None: + # 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs + ctx.set_stage_output("plan", "command_result", cmd_result) + ctx.timings[stage] = time.perf_counter() - stage_start + continue + + # 获取该阶段所有 hook(已按 priority 降序排列) + all_steps = self._registry.get_workflow_steps(stage) + if not all_steps: + ctx.timings[stage] = time.perf_counter() - stage_start + continue + + # 1. Pre-filter + filtered_steps = self._pre_filter(all_steps, current_message) + + # 2. 分离 blocking 和 non-blocking + blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)] + nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)] + + # 3. 串行执行 blocking hook + skip_stage = False + for step in blocking_steps: + hook_result, modified, step_error = await self._invoke_step( + invoke_fn, step, stage, ctx, current_message + ) + + if step_error: + error_policy = step.metadata.get("error_policy", "abort") + ctx.errors.append(f"{step.full_name}: {step_error}") + + if error_policy == "abort": + ctx.timings[stage] = time.perf_counter() - stage_start + return ( + WorkflowResult( + status="failed", + return_message=step_error, + stopped_at=stage, + diagnostics={"step": step.full_name, "trace_id": ctx.trace_id}, + ), + current_message, + ctx, + ) + elif error_policy == "skip": + logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}") + continue + else: # log + logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}") + continue + + # 更新消息(仅 blocking hook 有权修改) + if modified: + changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys()) + ctx.modification_log.append( + ModificationRecord(stage, step.full_name, changed_fields) + ) + current_message = modified + + if hook_result == HOOK_ABORT: + ctx.timings[stage] = time.perf_counter() - stage_start + return ( + WorkflowResult( + status="aborted", + return_message=f"aborted by {step.full_name}", + stopped_at=stage, + diagnostics={"step": step.full_name, "trace_id": ctx.trace_id}, + ), + current_message, + ctx, + ) + + if hook_result == HOOK_SKIP_STAGE: + skip_stage = True + break + + # 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message) + if nonblocking_steps and not skip_stage: + nb_tasks = [ + self._invoke_step_fire_and_forget( + invoke_fn, step, stage, ctx, current_message + ) + for step in nonblocking_steps + ] + # 并发执行但不阻塞 pipeline + for task in [asyncio.create_task(t) for t in nb_tasks]: + task.add_done_callback(lambda _: None) + + ctx.timings[stage] = time.perf_counter() - stage_start + + except Exception as e: + ctx.timings[stage] = time.perf_counter() - stage_start + ctx.errors.append(f"{stage}: {e}") + logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True) + return ( + WorkflowResult( + status="failed", + return_message=str(e), + stopped_at=stage, + diagnostics={"trace_id": ctx.trace_id}, + ), + current_message, + ctx, + ) + + return ( + WorkflowResult( + status="completed", + return_message="workflow completed", + diagnostics={"trace_id": ctx.trace_id}, + ), + current_message, + ctx, + ) + + # ─── 内部方法 ────────────────────────────────────────────── + + def _pre_filter( + self, + steps: list[RegisteredComponent], + message: dict[str, Any] | None, + ) -> list[RegisteredComponent]: + """根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。""" + if not message: + return steps + + result = [] + for step in steps: + filter_cond = step.metadata.get("filter", {}) + if not filter_cond: + result.append(step) + continue + if self._match_filter(filter_cond, message): + result.append(step) + return result + + @staticmethod + def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool: + """简单 key-value 匹配过滤。 + + filter 中的每个 key 必须在 message 中存在且值相等, + 全部匹配才通过。 + """ + for key, expected in filter_cond.items(): + actual = message.get(key) + if isinstance(expected, list): + if actual not in expected: + return False + elif actual != expected: + return False + return True + + async def _invoke_step( + self, + invoke_fn: InvokeFn, + step: RegisteredComponent, + stage: str, + ctx: WorkflowContext, + message: dict[str, Any] | None, + ) -> tuple[str, dict[str, Any] | None, str | None]: + """调用单个 blocking hook。 + + Returns: + (hook_result, modified_message, error_string_or_None) + """ + timeout_ms = step.metadata.get("timeout_ms", 0) + timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None + step_key = f"{stage}:{step.full_name}" + step_start = time.perf_counter() + + try: + coro = invoke_fn(step.plugin_id, step.name, { + "stage": stage, + "trace_id": ctx.trace_id, + "message": message, + "stage_outputs": ctx.stage_outputs, + }) + resp = await asyncio.wait_for(coro, timeout=timeout_sec) if timeout_sec else await coro + ctx.timings[step_key] = time.perf_counter() - step_start + + hook_result = resp.get("hook_result", HOOK_CONTINUE) + modified_message = resp.get("modified_message") + # 存 stage output(如果 hook 提供了) + stage_out = resp.get("stage_output") + if isinstance(stage_out, dict): + for k, v in stage_out.items(): + ctx.set_stage_output(stage, k, v) + + return hook_result, modified_message, None + + except asyncio.TimeoutError: + ctx.timings[step_key] = time.perf_counter() - step_start + return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms" + + except Exception as e: + ctx.timings[step_key] = time.perf_counter() - step_start + return HOOK_CONTINUE, None, str(e) + + async def _invoke_step_fire_and_forget( + self, + invoke_fn: InvokeFn, + step: RegisteredComponent, + stage: str, + ctx: WorkflowContext, + message: dict[str, Any] | None, + ) -> None: + """Non-blocking hook 调用,只读,忽略结果。""" + timeout_ms = step.metadata.get("timeout_ms", 0) + timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None + + try: + coro = invoke_fn(step.plugin_id, step.name, { + "stage": stage, + "trace_id": ctx.trace_id, + "message": message, + "stage_outputs": ctx.stage_outputs, + }) + if timeout_sec: + await asyncio.wait_for(coro, timeout=timeout_sec) + else: + await coro + except Exception as e: + logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}") + + async def _route_command( + self, + invoke_fn: InvokeFn, + message: dict[str, Any], + ctx: WorkflowContext, + ) -> dict[str, Any] | None: + """PLAN 阶段内置 Command 路由。 + + 在 registry 中查找匹配的 command 组件, + 匹配到则直接路由到对应 command handler,返回执行结果。 + 不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。 + """ + plain_text = message.get("plain_text", "") + if not plain_text: + return None + + matched = self._registry.find_command_by_text(plain_text) + if matched is None: + return None + + ctx.matched_command = matched.full_name + logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}") + + try: + resp = await invoke_fn(matched.plugin_id, matched.name, { + "text": plain_text, + "message": message, + "trace_id": ctx.trace_id, + }) + return resp + except Exception as e: + logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True) + ctx.errors.append(f"command:{matched.full_name}: {e}") + return None + + +def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]: + """返回 new 中与 old 不同的 key 列表。""" + changed = [] + for k in new: + if k not in old or old[k] != new[k]: + changed.append(k) + return changed diff --git a/src/plugin_runtime/protocol/__init__.py b/src/plugin_runtime/protocol/__init__.py index c39bce06..8b137891 100644 --- a/src/plugin_runtime/protocol/__init__.py +++ b/src/plugin_runtime/protocol/__init__.py @@ -1 +1 @@ -# Protocol 层 - RPC 消息模型、编解码、错误码 + diff --git a/src/plugin_runtime/protocol/codec.py b/src/plugin_runtime/protocol/codec.py index 8e388511..568c796f 100644 --- a/src/plugin_runtime/protocol/codec.py +++ b/src/plugin_runtime/protocol/codec.py @@ -1,13 +1,7 @@ -"""MsgPack / JSON 编解码器 - -提供统一的消息编解码接口,生产环境默认使用 MsgPack, -开发调试模式可切换为 JSON(仅编解码切换,传输层不变)。 -""" +"""MsgPack 编解码器""" from typing import Any -import json - import msgpack from .envelope import Envelope @@ -30,7 +24,7 @@ class Codec: class MsgPackCodec(Codec): - """MsgPack 编解码器(生产默认)""" + """MsgPack 编解码器""" def encode(self, obj: dict[str, Any]) -> bytes: return msgpack.packb(obj, use_bin_type=True) @@ -47,34 +41,3 @@ class MsgPackCodec(Codec): def decode_envelope(self, data: bytes) -> Envelope: raw = self.decode(data) return Envelope.model_validate(raw) - - -class JsonCodec(Codec): - """JSON 编解码器(开发调试用)""" - - def encode(self, obj: dict[str, Any]) -> bytes: - return json.dumps(obj, ensure_ascii=False).encode("utf-8") - - def decode(self, data: bytes) -> dict[str, Any]: - result = json.loads(data.decode("utf-8")) - if not isinstance(result, dict): - raise ValueError(f"期望解码为 dict,实际为 {type(result)}") - return result - - def encode_envelope(self, envelope: Envelope) -> bytes: - return self.encode(envelope.model_dump()) - - def decode_envelope(self, data: bytes) -> Envelope: - raw = self.decode(data) - return Envelope.model_validate(raw) - - -def create_codec(use_json: bool = False) -> Codec: - """创建编解码器实例 - - Args: - use_json: 是否使用 JSON(开发模式)。默认使用 MsgPack。 - """ - if use_json: - return JsonCodec() - return MsgPackCodec() diff --git a/src/plugin_runtime/runner/__init__.py b/src/plugin_runtime/runner/__init__.py index 44dde8af..8b137891 100644 --- a/src/plugin_runtime/runner/__init__.py +++ b/src/plugin_runtime/runner/__init__.py @@ -1 +1 @@ -# Runner 端 - 插件加载与执行进程 + diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py new file mode 100644 index 00000000..0df0c74d --- /dev/null +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -0,0 +1,137 @@ +"""Manifest 校验与版本兼容性 + +从旧系统的 ManifestValidator / VersionComparator 对齐移植, +适配新 plugin_runtime 的 _manifest.json 格式。 +""" + +from typing import Any + +import logging +import re + +logger = logging.getLogger("plugin_runtime.runner.manifest_validator") + + +class VersionComparator: + """语义化版本号比较器""" + + @staticmethod + def normalize_version(version: str) -> str: + if not version: + return "0.0.0" + normalized = re.sub(r"-snapshot\.\d+", "", version.strip()) + if not re.match(r"^\d+(\.\d+){0,2}$", normalized): + return "0.0.0" + parts = normalized.split(".") + while len(parts) < 3: + parts.append("0") + return ".".join(parts[:3]) + + @staticmethod + def parse_version(version: str) -> tuple[int, int, int]: + normalized = VersionComparator.normalize_version(version) + try: + parts = normalized.split(".") + return (int(parts[0]), int(parts[1]), int(parts[2])) + except (ValueError, IndexError): + return (0, 0, 0) + + @staticmethod + def compare(v1: str, v2: str) -> int: + t1 = VersionComparator.parse_version(v1) + t2 = VersionComparator.parse_version(v2) + if t1 < t2: + return -1 + elif t1 > t2: + return 1 + return 0 + + @staticmethod + def is_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]: + if not min_version and not max_version: + return True, "" + vn = VersionComparator.normalize_version(version) + if min_version: + mn = VersionComparator.normalize_version(min_version) + if VersionComparator.compare(vn, mn) < 0: + return False, f"版本 {vn} 低于最小要求 {mn}" + if max_version: + mx = VersionComparator.normalize_version(max_version) + if VersionComparator.compare(vn, mx) > 0: + return False, f"版本 {vn} 高于最大支持 {mx}" + return True, "" + + +class ManifestValidator: + """_manifest.json 校验器""" + + REQUIRED_FIELDS = ["name", "version", "description", "author"] + RECOMMENDED_FIELDS = ["license", "keywords", "categories"] + SUPPORTED_MANIFEST_VERSIONS = [1, 2] + + def __init__(self, host_version: str = ""): + self._host_version = host_version + self.errors: list[str] = [] + self.warnings: list[str] = [] + + def validate(self, manifest: dict[str, Any]) -> bool: + """校验 manifest 数据,返回是否通过(errors 为空即通过)。""" + self.errors.clear() + self.warnings.clear() + + self._check_required_fields(manifest) + self._check_manifest_version(manifest) + self._check_author(manifest) + self._check_host_compatibility(manifest) + self._check_recommended(manifest) + + if self.errors: + for e in self.errors: + logger.error(f"Manifest 校验失败: {e}") + if self.warnings: + for w in self.warnings: + logger.warning(f"Manifest 警告: {w}") + + return len(self.errors) == 0 + + def _check_required_fields(self, manifest: dict[str, Any]) -> None: + for field in self.REQUIRED_FIELDS: + if field not in manifest: + self.errors.append(f"缺少必需字段: {field}") + elif not manifest[field]: + self.errors.append(f"必需字段不能为空: {field}") + + def _check_manifest_version(self, manifest: dict[str, Any]) -> None: + mv = manifest.get("manifest_version") + if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: + self.errors.append( + f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}" + ) + + def _check_author(self, manifest: dict[str, Any]) -> None: + author = manifest.get("author") + if author is None: + return + if isinstance(author, dict): + if "name" not in author or not author["name"]: + self.errors.append("author 对象缺少 name 字段") + elif isinstance(author, str): + if not author.strip(): + self.errors.append("author 不能为空") + else: + self.errors.append("author 应为字符串或 {name, url} 对象") + + def _check_host_compatibility(self, manifest: dict[str, Any]) -> None: + host_app = manifest.get("host_application") + if not isinstance(host_app, dict) or not self._host_version: + return + min_v = host_app.get("min_version", "") + max_v = host_app.get("max_version", "") + ok, msg = VersionComparator.is_in_range(self._host_version, min_v, max_v) + if not ok: + self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})") + + def _check_recommended(self, manifest: dict[str, Any]) -> None: + for field in self.RECOMMENDED_FIELDS: + if field not in manifest or not manifest[field]: + self.warnings.append(f"建议填写字段: {field}") diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 60451b37..59775e7f 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -2,8 +2,10 @@ 在 Runner 进程中负责发现和加载插件。 插件通过 SDK 编写,不再 import src.*。 +支持:manifest 校验、依赖解析(拓扑排序)、生命周期钩子。 """ +from collections import deque from typing import Any import importlib @@ -13,6 +15,8 @@ import logging import os import sys +from src.plugin_runtime.runner.manifest_validator import ManifestValidator + logger = logging.getLogger("plugin_runtime.runner.plugin_loader") @@ -32,6 +36,20 @@ class PluginMeta: self.manifest = manifest self.version = manifest.get("version", "1.0.0") self.capabilities_required = manifest.get("capabilities", []) + self.dependencies: list[str] = self._extract_dependencies(manifest) + + @staticmethod + def _extract_dependencies(manifest: dict[str, Any]) -> list[str]: + raw = manifest.get("dependencies", []) + result: list[str] = [] + for dep in raw: + if isinstance(dep, str): + result.append(dep.strip()) + elif isinstance(dep, dict): + name = str(dep.get("name", "")).strip() + if name: + result.append(name) + return result class PluginLoader: @@ -43,19 +61,22 @@ class PluginLoader: - plugin.py: 插件入口模块(导出 create_plugin 工厂函数) """ - def __init__(self): + def __init__(self, host_version: str = ""): self._loaded_plugins: dict[str, PluginMeta] = {} + self._failed_plugins: dict[str, str] = {} + self._manifest_validator = ManifestValidator(host_version=host_version) def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]: - """扫描多个目录并加载所有插件 + """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验) Args: plugin_dirs: 插件目录列表 Returns: - 成功加载的插件元数据列表 + 成功加载的插件元数据列表(按依赖顺序) """ - results = [] + # 第一阶段:发现并校验 manifest + candidates: dict[str, tuple[str, dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path) for base_dir in plugin_dirs: if not os.path.isdir(base_dir): logger.warning(f"插件目录不存在: {base_dir}") @@ -73,12 +94,40 @@ class PluginLoader: continue try: - meta = self._load_single_plugin(plugin_dir, manifest_path, plugin_path) - if meta: - self._loaded_plugins[meta.plugin_id] = meta - results.append(meta) + with open(manifest_path, "r", encoding="utf-8") as f: + manifest = json.load(f) except Exception as e: - logger.error(f"加载插件失败 [{plugin_dir}]: {e}", exc_info=True) + self._failed_plugins[entry] = f"manifest 解析失败: {e}" + logger.error(f"插件 {entry} manifest 解析失败: {e}") + continue + + if not self._manifest_validator.validate(manifest): + errors = "; ".join(self._manifest_validator.errors) + self._failed_plugins[entry] = f"manifest 校验失败: {errors}" + continue + + plugin_id = manifest.get("name", entry) + candidates[plugin_id] = (plugin_dir, manifest, plugin_path) + + # 第二阶段:依赖解析(拓扑排序) + load_order, failed_deps = self._resolve_dependencies(candidates) + + for pid, reason in failed_deps.items(): + self._failed_plugins[pid] = reason + logger.error(f"插件 {pid} 依赖解析失败: {reason}") + + # 第三阶段:按依赖顺序加载 + results = [] + for plugin_id in load_order: + plugin_dir, manifest, plugin_path = candidates[plugin_id] + try: + meta = self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path) + if meta: + self._loaded_plugins[meta.plugin_id] = meta + results.append(meta) + except Exception as e: + self._failed_plugins[plugin_id] = str(e) + logger.error(f"加载插件失败 [{plugin_id}]: {e}", exc_info=True) return results @@ -90,15 +139,78 @@ class PluginLoader: """列出所有已加载的插件 ID""" return list(self._loaded_plugins.keys()) - def _load_single_plugin(self, plugin_dir: str, manifest_path: str, plugin_path: str) -> PluginMeta | None: + @property + def failed_plugins(self) -> dict[str, str]: + return dict(self._failed_plugins) + + # ──── 依赖解析 ──────────────────────────────────────────── + + def _resolve_dependencies( + self, + candidates: dict[str, tuple[str, dict[str, Any], str]], + ) -> tuple[list[str], dict[str, str]]: + """拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。""" + available = set(candidates.keys()) + dep_graph: dict[str, set[str]] = {} + failed: dict[str, str] = {} + + for pid, (_, manifest, _) in candidates.items(): + raw_deps = manifest.get("dependencies", []) + resolved: set[str] = set() + missing: list[str] = [] + for dep in raw_deps: + dep_name = dep if isinstance(dep, str) else str(dep.get("name", "")) + dep_name = dep_name.strip() + if not dep_name or dep_name == pid: + continue + if dep_name in available: + resolved.add(dep_name) + else: + missing.append(dep_name) + if missing: + failed[pid] = f"缺少依赖: {', '.join(missing)}" + dep_graph[pid] = resolved + + # 移除失败项 + for pid in failed: + dep_graph.pop(pid, None) + + # Kahn 拓扑排序 + indegree = {pid: len(deps) for pid, deps in dep_graph.items()} + reverse: dict[str, set[str]] = {pid: set() for pid in dep_graph} + for pid, deps in dep_graph.items(): + for d in deps: + if d in reverse: + reverse[d].add(pid) + + queue = deque(sorted(pid for pid, deg in indegree.items() if deg == 0)) + sorted_order: list[str] = [] + + while queue: + current = queue.popleft() + sorted_order.append(current) + for dependent in sorted(reverse.get(current, [])): + indegree[dependent] -= 1 + if indegree[dependent] == 0: + queue.append(dependent) + + cycle_plugins = {pid for pid, deg in indegree.items() if deg > 0} + for pid in cycle_plugins: + failed[pid] = "检测到循环依赖" + + return sorted_order, failed + + # ──── 单个插件加载 ──────────────────────────────────────── + + def _load_single_plugin( + self, + plugin_id: str, + plugin_dir: str, + manifest: dict[str, Any], + plugin_path: str, + ) -> PluginMeta | None: """加载单个插件""" - # 1. 读取 manifest - with open(manifest_path, "r", encoding="utf-8") as f: - manifest = json.load(f) - - plugin_id = os.path.basename(plugin_dir) - - # 2. 动态导入插件模块 + # 动态导入插件模块 module_name = f"_maibot_plugin_{plugin_id}" spec = importlib.util.spec_from_file_location(module_name, plugin_path) if spec is None or spec.loader is None: @@ -109,7 +221,7 @@ class PluginLoader: sys.modules[module_name] = module spec.loader.exec_module(module) - # 3. 调用工厂函数创建插件实例 + # 调用工厂函数创建插件实例 create_plugin = getattr(module, "create_plugin", None) if create_plugin is None: logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数") diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index e900a5e8..b2052171 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -14,7 +14,7 @@ import asyncio import logging import uuid -from src.plugin_runtime.protocol.codec import Codec, create_codec +from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec from src.plugin_runtime.protocol.envelope import ( PROTOCOL_VERSION, Envelope, @@ -49,7 +49,7 @@ class RPCClient: ): self._host_address = host_address self._session_token = session_token - self._codec = codec or create_codec() + self._codec = codec or MsgPackCodec() self._id_gen = RequestIdGenerator() self._connection: Connection | None = None diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index a658fe7d..61a90fcd 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -71,7 +71,23 @@ class PluginRunner: plugins = self._loader.discover_and_load(self._plugin_dirs) logger.info(f"已加载 {len(plugins)} 个插件") - # 4. 向 Host 注册所有插件的组件 + # 4. 调用 on_load 生命周期钩子 + 注入 RPC 客户端供 SDK context 使用 + for meta in plugins: + instance = meta.instance + # 注入 _rpc_client 以便 PluginContext 可以发起能力调用 + if hasattr(instance, "_ctx"): + ctx = instance._ctx + if hasattr(ctx, "_set_rpc_client"): + ctx._set_rpc_client(self._rpc_client) + if hasattr(instance, "on_load"): + try: + ret = instance.on_load() + if asyncio.iscoroutine(ret): + await ret + except Exception as e: + logger.error(f"插件 {meta.plugin_id} on_load 失败: {e}", exc_info=True) + + # 5. 向 Host 注册所有插件的组件 for meta in plugins: await self._register_plugin(meta) @@ -92,6 +108,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) @@ -169,6 +186,57 @@ class PluginRunner: resp_payload = InvokeResultPayload(success=False, result=str(e)) return envelope.make_response(payload=resp_payload.model_dump()) + async def _handle_workflow_step(self, envelope: Envelope) -> Envelope: + """处理 WorkflowStep 调用请求 + + 与通用 invoke 不同,会将返回值规范化为 + {hook_result, modified_message, stage_output} 格式。 + """ + try: + invoke = InvokePayload.model_validate(envelope.payload) + except Exception as e: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta is None: + return envelope.make_error_response( + ErrorCode.E_PLUGIN_NOT_FOUND.value, + f"插件 {plugin_id} 未加载", + ) + + instance = meta.instance + component_name = invoke.component_name + handler_method = getattr(instance, f"handle_{component_name}", None) or getattr(instance, component_name, None) + + if handler_method is None or not callable(handler_method): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {plugin_id} 无组件: {component_name}", + ) + + try: + raw = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args) + + # 规范化返回值 + if raw is None: + result = {"hook_result": "continue"} + elif isinstance(raw, str): + # 允许直接返回 hook_result 字符串 + result = {"hook_result": raw} + elif isinstance(raw, dict): + result = raw + result.setdefault("hook_result", "continue") + else: + result = {"hook_result": "continue"} + + resp_payload = InvokeResultPayload(success=True, result=result) + return envelope.make_response(payload=resp_payload.model_dump()) + except Exception as e: + logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True) + resp_payload = InvokeResultPayload(success=False, result=str(e)) + return envelope.make_response(payload=resp_payload.model_dump()) + async def _handle_health(self, envelope: Envelope) -> Envelope: """处理健康检查""" uptime_ms = int((time.monotonic() - self._start_time) * 1000) @@ -185,8 +253,17 @@ class PluginRunner: return envelope.make_response(payload={"acknowledged": True}) async def _handle_shutdown(self, envelope: Envelope) -> Envelope: - """处理关停""" - logger.info("收到 shutdown 信号,准备退出") + """处理关停 — 调用所有插件的 on_unload 后退出""" + logger.info("收到 shutdown 信号,开始调用 on_unload") + for plugin_id in self._loader.list_plugins(): + meta = self._loader.get_plugin(plugin_id) + if meta and hasattr(meta.instance, "on_unload"): + try: + ret = meta.instance.on_unload() + if asyncio.iscoroutine(ret): + await ret + except Exception as e: + logger.error(f"插件 {plugin_id} on_unload 失败: {e}", exc_info=True) self._shutting_down = True return envelope.make_response(payload={"acknowledged": True}) @@ -209,6 +286,43 @@ class PluginRunner: return self._rpc_client +# ─── sys.path 隔离 ──────────────────────────────────────── + +def _isolate_sys_path(plugin_dirs: list[str]) -> None: + """清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。 + + 防止插件代码 import 主程序模块读取运行时数据。 + """ + import sysconfig + + # 保留: 标准库路径 + site-packages(含 SDK 和依赖) + stdlib_paths = set() + for key in ("stdlib", "platstdlib", "purelib", "platlib"): + path = sysconfig.get_path(key) + if path: + stdlib_paths.add(os.path.normpath(path)) + + allowed = set() + for p in sys.path: + norm = os.path.normpath(p) + # 保留标准库和 site-packages + if any(norm.startswith(sp) for sp in stdlib_paths): + allowed.add(p) + # 保留 site-packages(第三方库 + SDK) + if "site-packages" in norm or "dist-packages" in norm: + allowed.add(p) + + # 添加插件目录 + for d in plugin_dirs: + allowed.add(os.path.normpath(d)) + + # 添加当前 runner 模块所在路径(使得 src.plugin_runtime 可导入) + runtime_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + allowed.add(runtime_root) + + sys.path[:] = [p for p in sys.path if p in allowed] + + # ─── 进程入口 ────────────────────────────────────────────── async def _async_main() -> None: @@ -223,6 +337,9 @@ async def _async_main() -> None: plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] + # sys.path 隔离: 只保留标准库、SDK 包、插件目录 + _isolate_sys_path(plugin_dirs) + runner = PluginRunner(host_address, session_token, plugin_dirs) # 注册信号处理 diff --git a/src/plugin_runtime/transport/__init__.py b/src/plugin_runtime/transport/__init__.py index 24759ac0..8b137891 100644 --- a/src/plugin_runtime/transport/__init__.py +++ b/src/plugin_runtime/transport/__init__.py @@ -1 +1 @@ -# Transport 层 - 跨平台本地 IPC 传输抽象 +