From 0b0f47a444cef41c937ad212a2c103a612812c48 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 24 Mar 2026 19:04:05 +0800 Subject: [PATCH] feat: Enhance Hook System with HookHandler and Dispatcher - Introduced HookHandlerEntry to manage hook processing with attributes like hook name, mode, order, timeout, and error policy. - Implemented normalization methods for hook attributes to ensure valid configurations. - Updated ComponentRegistry to support retrieval of hook handlers based on hook names, with sorting by mode and order. - Refactored HookDispatcher to handle invocation of hooks, separating blocking and non-blocking handlers, and managing execution results. - Added support for registering hook specifications and invoking hooks across supervisors in PluginRuntimeManager. - Removed deprecated workflow step handling from PluginRunner, streamlining hook invocation responses. --- pytests/test_plugin_runtime.py | 703 +++++++++-------- src/plugin_runtime/host/component_registry.py | 183 ++++- src/plugin_runtime/host/hook_dispatcher.py | 738 +++++++++++++++--- src/plugin_runtime/host/supervisor.py | 26 +- src/plugin_runtime/integration.py | 54 +- src/plugin_runtime/runner/runner_main.py | 66 +- 6 files changed, 1247 insertions(+), 523 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index e3247f05..5c9f39b0 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -5,6 +5,7 @@ from pathlib import Path from types import SimpleNamespace +from typing import Any, Awaitable, Callable, Dict, List, Optional import asyncio import json @@ -1831,395 +1832,445 @@ class TestMaiMessages: assert msg.llm_response_content == "new response" -# ─── WorkflowExecutor 测试 ──────────────────────────────── +class _FakeHookSupervisor: + """用于 Hook 分发测试的简化 Supervisor。""" + + def __init__( + self, + group_name: str, + component_registry: Any, + handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]], + call_log: List[tuple[str, str]], + ) -> None: + """初始化测试用 Supervisor。 + + Args: + group_name: 运行时分组名称。 + component_registry: 组件注册表实例。 + handlers: 处理器映射,键为 `plugin_id.component_name`。 + call_log: 记录调用顺序的列表。 + """ + + self._group_name = group_name + self.component_registry = component_registry + self._handlers = handlers + self._call_log = call_log + + @property + def group_name(self) -> str: + """返回当前测试 Supervisor 的分组名称。""" + + return self._group_name + + async def invoke_plugin( + self, + method: str, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> SimpleNamespace: + """模拟调用插件组件。 + + Args: + method: RPC 方法名。 + plugin_id: 目标插件 ID。 + component_name: 目标组件名称。 + args: 调用参数。 + timeout_ms: 超时配置,测试中仅用于保持接口一致。 + + Returns: + SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。 + """ + + del method + del timeout_ms + + full_name = f"{plugin_id}.{component_name}" + handler = self._handlers[full_name] + self._call_log.append((plugin_id, component_name)) + result = handler(dict(args or {})) + if asyncio.iscoroutine(result): + result = await result + return SimpleNamespace(payload=result) -class TestWorkflowExecutor: - """Host-side Workflow 执行器测试(新 pipeline 模型)""" +# ─── HookDispatcher 测试 ──────────────────────────────── + + +class TestHookDispatcher: + """命名 Hook 分发器测试。""" + + @staticmethod + def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: + """导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。 + + Args: + monkeypatch: pytest 的 monkeypatch 工具。 + + Returns: + tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。 + """ + + monkeypatch.setattr(sys, "exit", lambda code=0: None) + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.host.hook_dispatcher import HookDispatcher + + return ComponentRegistry, HookDispatcher @pytest.mark.asyncio - async def test_empty_pipeline_completes(self): - """无任何 workflow_step 注册时,pipeline 全阶段跳过,状态 completed""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: + """未注册处理器时应直接返回原始参数。""" - reg = ComponentRegistry() - executor = WorkflowExecutor(reg) + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - async def mock_invoke(plugin_id, comp_name, args): - return {"hook_result": "continue"} + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, []) - result, final_msg, ctx = await executor.execute( - mock_invoke, - message={"plain_text": "test"}, - ) - assert result.status == "completed" - assert result.return_message == "workflow completed" - assert len(ctx.timings) == 6 # 6 stages + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") + + assert result.hook_name == "heart_fc.cycle_start" + assert result.kwargs == {"session_id": "s-1"} + assert result.aborted is False @pytest.mark.asyncio - async def test_blocking_hook_modifies_message(self): - """blocking hook 可以修改消息""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None: + """blocking 处理器可以修改参数。""" - reg = ComponentRegistry() - reg.register_component( + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( "upper", - "workflow_step", + "HOOK_HANDLER", "p1", { - "stage": "pre_process", - "priority": 10, - "blocking": True, + "hook": "heart_fc.cycle_start", + "mode": "blocking", + "order": "normal", }, ) - executor = WorkflowExecutor(reg) - - async def mock_invoke(plugin_id, comp_name, args): - msg = args.get("message", {}) - return { - "hook_result": "continue", - "modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()}, - } - - result, final_msg, ctx = await executor.execute( - mock_invoke, - message={"plain_text": "hello"}, + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor( + "builtin", + registry, + { + "p1.upper": lambda args: { + "success": True, + "action": "continue", + "modified_kwargs": { + "session_id": args["session_id"], + "text": str(args["text"]).upper(), + }, + } + }, + [], ) - assert result.status == "completed" - assert final_msg["plain_text"] == "HELLO" - assert len(ctx.modification_log) == 1 - assert ctx.modification_log[0].stage == "pre_process" + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello") + + assert result.kwargs["session_id"] == "s-1" + assert result.kwargs["text"] == "HELLO" + assert result.aborted is False @pytest.mark.asyncio - async def test_abort_stops_pipeline(self): - """HookResult.ABORT 立即终止 pipeline""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None: + """blocking 处理器的 abort 应阻止后续 blocking 处理器执行。""" - reg = ComponentRegistry() - reg.register_component( - "blocker", - "workflow_step", + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( + "stopper", + "HOOK_HANDLER", "p1", - { - "stage": "pre_process", - "priority": 10, - "blocking": True, - }, + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, ) - executor = WorkflowExecutor(reg) - - async def mock_invoke(plugin_id, comp_name, args): - return {"hook_result": "abort"} - - result, _, ctx = await executor.execute( - mock_invoke, - message={"plain_text": "test"}, - ) - assert result.status == "aborted" - assert result.stopped_at == "pre_process" - - @pytest.mark.asyncio - async def test_skip_stage(self): - """HookResult.SKIP_STAGE 跳过当前阶段剩余 hook""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - # high-priority hook 返回 skip_stage - reg.register_component( - "skipper", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 100, - "blocking": True, - }, - ) - # low-priority hook 不应被执行 - reg.register_component( - "checker", - "workflow_step", + registry.register_component( + "after_stop", + "HOOK_HANDLER", "p2", - { - "stage": "ingress", - "priority": 1, - "blocking": True, - }, + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, + ) + call_log: List[tuple[str, str]] = [] + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor( + "builtin", + registry, + { + "p1.stopper": lambda args: {"success": True, "action": "abort"}, + "p2.after_stop": lambda args: {"success": True, "action": "continue"}, + }, + call_log, ) - executor = WorkflowExecutor(reg) - call_log = [] + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1") - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - if comp_name == "skipper": - return {"hook_result": "skip_stage"} - return {"hook_result": "continue"} - - result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"}) - assert result.status == "completed" - # 只有 skipper 被调用,checker 被跳过 - assert call_log == ["skipper"] + assert result.aborted is True + assert result.stopped_by == "p1.stopper" + assert call_log == [("p1", "stopper")] @pytest.mark.asyncio - async def test_pre_filter(self): - """filter 条件不匹配时跳过 hook""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_observe_handler_runs_in_background_without_mutation( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """observe 处理器应后台执行且不能影响主流程参数。""" - reg = ComponentRegistry() - reg.register_component( - "only_dm", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 10, - "blocking": True, - "filter": {"chat_type": "direct"}, - }, - ) - executor = WorkflowExecutor(reg) + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - return {"hook_result": "continue"} - - # 不匹配 filter —— hook 不应被调用 - await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"}) - assert not call_log - - # 匹配 filter —— hook 应被调用 - await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"}) - assert call_log == ["only_dm"] - - @pytest.mark.asyncio - async def test_error_policy_skip(self): - """error_policy=skip 时跳过失败的 hook 继续执行""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - reg.register_component( - "failer", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 100, - "blocking": True, - "error_policy": "skip", - }, - ) - reg.register_component( - "ok_step", - "workflow_step", - "p2", - { - "stage": "ingress", - "priority": 1, - "blocking": True, - }, - ) - executor = WorkflowExecutor(reg) - - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - if comp_name == "failer": - raise RuntimeError("boom") - return {"hook_result": "continue"} - - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) - assert result.status == "completed" - assert "failer" in call_log - assert "ok_step" in call_log - assert any("boom" in e for e in ctx.errors) - - @pytest.mark.asyncio - async def test_error_policy_abort(self): - """error_policy=abort(默认)时 pipeline 失败""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - reg.register_component( - "failer", - "workflow_step", - "p1", - { - "stage": "ingress", - "priority": 10, - "blocking": True, - # error_policy defaults to "abort" - }, - ) - executor = WorkflowExecutor(reg) - - async def mock_invoke(plugin_id, comp_name, args): - raise RuntimeError("fatal") - - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"}) - assert result.status == "failed" - assert result.stopped_at == "ingress" - - @pytest.mark.asyncio - async def test_nonblocking_hooks_concurrent(self): - """non-blocking hook 并发执行,不修改消息""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - for i in range(3): - reg.register_component( - f"nb_{i}", - "workflow_step", - f"p{i}", - { - "stage": "post_process", - "priority": 0, - "blocking": False, - }, - ) - executor = WorkflowExecutor(reg) - - call_log = [] - - async def mock_invoke(plugin_id, comp_name, args): - call_log.append(comp_name) - return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}} - - result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) - # non-blocking 的 modified_message 被忽略 - assert final_msg["plain_text"] == "original" - # 给异步 task 时间完成 - await asyncio.sleep(0.1) - assert result.status == "completed" - - @pytest.mark.asyncio - async def test_nonblocking_tasks_are_retained_until_completion(self): - """execute 返回后,non-blocking task 仍应保持强引用直到执行完成。""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor - - reg = ComponentRegistry() - reg.register_component( + registry = ComponentRegistry() + registry.register_component( "observer", - "workflow_step", + "HOOK_HANDLER", "p1", - { - "stage": "post_process", - "priority": 0, - "blocking": False, - }, + {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, ) - executor = WorkflowExecutor(reg) - started = asyncio.Event() release = asyncio.Event() + call_log: List[tuple[str, str]] = [] + + async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]: + """模拟耗时观察型处理器。""" - async def mock_invoke(plugin_id, comp_name, args): started.set() await release.wait() - return {"hook_result": "continue"} + return { + "success": True, + "action": "abort", + "modified_kwargs": {"session_id": "changed"}, + "custom_result": args["session_id"], + } - result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor( + "builtin", + registry, + {"p1.observer": observe_handler}, + call_log, + ) + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") await asyncio.sleep(0) - assert result.status == "completed" - assert final_msg["plain_text"] == "original" + assert result.aborted is False + assert result.kwargs["session_id"] == "s-1" assert started.is_set() - assert len(executor._background_tasks) == 1 + assert len(dispatcher._background_tasks) == 1 release.set() await asyncio.sleep(0) await asyncio.sleep(0) - assert not executor._background_tasks + assert call_log == [("p1", "observer")] + assert not dispatcher._background_tasks @pytest.mark.asyncio - async def test_command_routing(self): - """PLAN 阶段内置命令路由""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None: + """全局排序应先看 order,再看内置/第三方来源。""" - reg = ComponentRegistry() - reg.register_component( - "help", - "command", - "p1", - { - "command_pattern": r"^/help", - }, + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + builtin_registry = ComponentRegistry() + third_registry = ComponentRegistry() + builtin_registry.register_component( + "builtin_early", + "HOOK_HANDLER", + "b1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, + ) + builtin_registry.register_component( + "builtin_normal", + "HOOK_HANDLER", + "b1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, + ) + third_registry.register_component( + "third_early", + "HOOK_HANDLER", + "t1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, + ) + third_registry.register_component( + "third_normal", + "HOOK_HANDLER", + "t1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"}, ) - executor = WorkflowExecutor(reg) - async def mock_invoke(plugin_id, comp_name, args): - if comp_name == "help": - return {"output": "帮助信息"} - return {"hook_result": "continue"} + call_log: List[tuple[str, str]] = [] + dispatcher = HookDispatcher() + builtin_supervisor = _FakeHookSupervisor( + "builtin", + builtin_registry, + { + "b1.builtin_early": lambda args: {"success": True, "action": "continue"}, + "b1.builtin_normal": lambda args: {"success": True, "action": "continue"}, + }, + call_log, + ) + third_supervisor = _FakeHookSupervisor( + "third_party", + third_registry, + { + "t1.third_early": lambda args: {"success": True, "action": "continue"}, + "t1.third_normal": lambda args: {"success": True, "action": "continue"}, + }, + call_log, + ) - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"}) - assert result.status == "completed" - assert ctx.matched_command == "p1.help" - cmd_result = ctx.get_stage_output("plan", "command_result") - assert cmd_result is not None - assert cmd_result["output"] == "帮助信息" + await dispatcher.invoke_hook( + "heart_fc.cycle_start", + [third_supervisor, builtin_supervisor], + cycle_id="c-1", + ) + + assert call_log == [ + ("b1", "builtin_early"), + ("t1", "third_early"), + ("b1", "builtin_normal"), + ("t1", "third_normal"), + ] @pytest.mark.asyncio - async def test_stage_outputs(self): - """stage_outputs 数据在阶段间传递""" - from src.plugin_runtime.host.component_registry import ComponentRegistry - from src.plugin_runtime.host.workflow_executor import WorkflowExecutor + async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None: + """error_policy=abort 时应中止本次 Hook 调用。""" - reg = ComponentRegistry() - # ingress 阶段写入数据 - reg.register_component( - "writer", - "workflow_step", + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( + "failer", + "HOOK_HANDLER", "p1", { - "stage": "ingress", - "priority": 10, - "blocking": True, + "hook": "heart_fc.cycle_start", + "mode": "blocking", + "order": "normal", + "error_policy": "abort", }, ) - # pre_process 阶段读取数据 - reg.register_component( - "reader", - "workflow_step", - "p2", + call_log: List[tuple[str, str]] = [] + + async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]: + """抛出异常以触发 abort 策略。""" + + del args + raise RuntimeError("boom") + + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log) + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") + + assert result.aborted is True + assert result.stopped_by == "p1.failer" + assert any("boom" in error for error in result.errors) + assert call_log == [("p1", "failer")] + + @pytest.mark.asyncio + async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None: + """处理器超时应被记录为错误并继续。""" + + ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch) + + registry = ComponentRegistry() + registry.register_component( + "slow", + "HOOK_HANDLER", + "p1", { - "stage": "pre_process", - "priority": 10, - "blocking": True, + "hook": "heart_fc.cycle_start", + "mode": "blocking", + "order": "normal", + "timeout_ms": 10, }, ) - executor = WorkflowExecutor(reg) + call_log: List[tuple[str, str]] = [] - async def mock_invoke(plugin_id, comp_name, args): - if comp_name == "writer": - return { - "hook_result": "continue", - "stage_output": {"parsed_intent": "greeting"}, - } - if comp_name == "reader": - # 验证 stage_outputs 被传递过来 - outputs = args.get("stage_outputs", {}) - ingress_data = outputs.get("ingress", {}) - assert ingress_data.get("parsed_intent") == "greeting" - return {"hook_result": "continue"} - return {"hook_result": "continue"} + async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]: + """模拟超时处理器。""" - result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"}) - assert result.status == "completed" - assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" + del args + await asyncio.sleep(0.05) + return {"success": True, "action": "continue"} + + dispatcher = HookDispatcher() + supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log) + + result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1") + + assert result.aborted is False + assert any("超时" in error for error in result.errors) + assert call_log == [("p1", "slow")] + + +class TestPluginRuntimeHookEntry: + """PluginRuntimeManager 命名 Hook 入口测试。""" + + @staticmethod + def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]: + """导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。 + + Args: + monkeypatch: pytest 的 monkeypatch 工具。 + + Returns: + tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。 + """ + + monkeypatch.setattr(sys, "exit", lambda code=0: None) + from src.plugin_runtime.host.component_registry import ComponentRegistry + from src.plugin_runtime.integration import PluginRuntimeManager + + return ComponentRegistry, PluginRuntimeManager + + @pytest.mark.asyncio + async def test_manager_invoke_hook_dispatches_across_supervisors( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。""" + + ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch) + + builtin_registry = ComponentRegistry() + builtin_registry.register_component( + "builtin_guard", + "HOOK_HANDLER", + "b1", + {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"}, + ) + third_registry = ComponentRegistry() + third_registry.register_component( + "observer", + "HOOK_HANDLER", + "t1", + {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"}, + ) + + call_log: List[tuple[str, str]] = [] + manager = PluginRuntimeManager() + manager._started = True + manager._builtin_supervisor = _FakeHookSupervisor( + "builtin", + builtin_registry, + {"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}}, + call_log, + ) + manager._third_party_supervisor = _FakeHookSupervisor( + "third_party", + third_registry, + {"t1.observer": lambda args: {"success": True, "action": "continue"}}, + call_log, + ) + + result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1") + + await asyncio.sleep(0) + assert manager.invoke_dispatcher is manager.hook_dispatcher + assert result.aborted is False + assert result.kwargs["session_id"] == "s-1" + assert ("b1", "builtin_guard") in call_log class TestRPCServer: diff --git a/src/plugin_runtime/host/component_registry.py b/src/plugin_runtime/host/component_registry.py index 97fdca30..8f995e2a 100644 --- a/src/plugin_runtime/host/component_registry.py +++ b/src/plugin_runtime/host/component_registry.py @@ -1,7 +1,7 @@ """Host-side ComponentRegistry 对齐旧系统 component_registry.py 的核心能力: -- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway) +- 按类型注册组件(action / command / tool / event_handler / hook_handler / message_gateway) - 命名空间 (plugin_id.component_name) - 命令正则匹配 - 组件启用/禁用 @@ -106,14 +106,129 @@ class EventHandlerEntry(ComponentEntry): class HookHandlerEntry(ComponentEntry): - """WorkflowHandler 组件条目""" + """HookHandler 组件条目。""" def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - self.stage: str = metadata.get("stage", "") - self.priority: int = metadata.get("priority", 0) - self.blocking: bool = metadata.get("blocking", False) + self.hook: str = self._normalize_hook_name(metadata.get("hook", "")) + self.mode: str = self._normalize_mode(metadata.get("mode", "blocking")) + self.order: str = self._normalize_order(metadata.get("order", "normal")) + self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0)) + self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip")) super().__init__(name, component_type, plugin_id, metadata) + @staticmethod + def _normalize_error_policy(raw_value: Any) -> str: + """规范化 Hook 异常处理策略。 + + Args: + raw_value: 原始异常处理策略值。 + + Returns: + str: 规范化后的异常处理策略。 + + Raises: + ValueError: 当异常处理策略不受支持时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + normalized_value = str(normalized_source or "").strip().lower() or "skip" + if normalized_value not in {"abort", "skip", "log"}: + raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}") + return normalized_value + + @staticmethod + def _normalize_hook_name(raw_value: Any) -> str: + """规范化命名 Hook 名称。 + + Args: + raw_value: 原始 Hook 名称。 + + Returns: + str: 去空白后的 Hook 名称。 + + Raises: + ValueError: 当 Hook 名称为空时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + if not (normalized_value := str(normalized_source or "").strip()): + raise ValueError("HookHandler 的 hook 名称不能为空") + return normalized_value + + @staticmethod + def _normalize_mode(raw_value: Any) -> str: + """规范化 Hook 处理模式。 + + Args: + raw_value: 原始模式值。 + + Returns: + str: 规范化后的模式。 + + Raises: + ValueError: 当模式不受支持时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + normalized_value = str(normalized_source or "").strip().lower() or "blocking" + if normalized_value not in {"blocking", "observe"}: + raise ValueError(f"HookHandler 模式不合法: {raw_value}") + return normalized_value + + @staticmethod + def _normalize_order(raw_value: Any) -> str: + """规范化 Hook 顺序槽位。 + + Args: + raw_value: 原始顺序值。 + + Returns: + str: 规范化后的顺序槽位。 + + Raises: + ValueError: 当顺序值不受支持时抛出。 + """ + + normalized_source = getattr(raw_value, "value", raw_value) + normalized_value = str(normalized_source or "").strip().lower() or "normal" + if normalized_value not in {"early", "normal", "late"}: + raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}") + return normalized_value + + @staticmethod + def _normalize_timeout_ms(raw_value: Any) -> int: + """规范化 Hook 超时配置。 + + Args: + raw_value: 原始超时值。 + + Returns: + int: 规范化后的超时毫秒数。 + + Raises: + ValueError: 当超时值为负数或无法转换为整数时抛出。 + """ + + try: + timeout_ms = int(raw_value or 0) + except (TypeError, ValueError) as exc: + raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc + if timeout_ms < 0: + raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}") + return timeout_ms + + @property + def is_blocking(self) -> bool: + """返回当前 Hook 是否为阻塞模式。""" + + return self.mode == "blocking" + + @property + def is_observe(self) -> bool: + """返回当前 Hook 是否为观察模式。""" + + return self.mode == "observe" + class MessageGatewayEntry(ComponentEntry): """MessageGateway 组件条目""" @@ -454,16 +569,17 @@ class ComponentRegistry: return handlers def get_hook_handlers( - self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None + self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None ) -> List[HookHandlerEntry]: - """获取特定 hook 阶段的所有步骤,按 priority 降序。 + """获取订阅指定命名 Hook 的全部处理器。 Args: - stage: hook 名称 - enabled_only: 是否仅返回启用的组件 - session_id: 可选的会话ID,若提供则考虑会话禁用状态 + hook_name: 目标 Hook 名称。 + enabled_only: 是否仅返回启用的组件。 + session_id: 可选的会话 ID,若提供则考虑会话禁用状态。 + Returns: - handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序 + List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表。 """ handlers: List[HookHandlerEntry] = [] for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values(): @@ -471,11 +587,37 @@ class ComponentRegistry: continue if not isinstance(comp, HookHandlerEntry): continue - if comp.stage == stage: + if comp.hook == hook_name: handlers.append(comp) - handlers.sort(key=lambda c: c.priority, reverse=True) + handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name)) return handlers + @staticmethod + def _get_hook_mode_rank(mode: str) -> int: + """返回 Hook 模式的排序权重。 + + Args: + mode: Hook 模式字符串。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"blocking": 0, "observe": 1}.get(mode, 99) + + @staticmethod + def _get_hook_order_rank(order: str) -> int: + """返回 Hook 顺序槽位的排序权重。 + + Args: + order: Hook 顺序槽位字符串。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"early": 0, "normal": 1, "late": 2}.get(order, 99) + def get_message_gateway( self, plugin_id: str, @@ -566,8 +708,13 @@ class ComponentRegistry: Returns: stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等 """ - stats: StatusDict = {"total": len(self._components)} # type: ignore - for comp_type, type_dict in self._by_type.items(): - stats[comp_type.value.lower()] = len(type_dict) - stats["plugins"] = len(self._by_plugin) - return stats + return StatusDict( + total=len(self._components), + action=len(self._by_type[ComponentTypes.ACTION]), + command=len(self._by_type[ComponentTypes.COMMAND]), + tool=len(self._by_type[ComponentTypes.TOOL]), + event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]), + hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]), + message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]), + plugins=len(self._by_plugin), + ) diff --git a/src/plugin_runtime/host/hook_dispatcher.py b/src/plugin_runtime/host/hook_dispatcher.py index d5e88448..0406c8f6 100644 --- a/src/plugin_runtime/host/hook_dispatcher.py +++ b/src/plugin_runtime/host/hook_dispatcher.py @@ -1,166 +1,678 @@ -""" -Hook Dispatch 系统 +"""命名 Hook 分发系统。 -插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。 -每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。 -在参数/返回值匹配的情况下允许修改参数/返回值。 +主程序可以在任意执行点触发一个命名 Hook,Host 会收集所有订阅该 Hook 的 +插件处理器,并按照固定的全局顺序调度执行。 -HookDispatcher 负责: -1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry) -2. 按 priority 排序,区分 blocking 和非 blocking 模式 -3. blocking 模式:依次同步调用,支持修改参数/提前终止 -4. 非 blocking 模式:异步调用,不阻塞主流程 -5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限 +排序规则如下: + +1. `blocking` 先于 `observe` +2. `early` 先于 `normal` 先于 `late` +3. 内置插件先于第三方插件 +4. `plugin_id` +5. `handler_name` + +其中: + +- `blocking` 处理器串行执行,可修改 `kwargs`,也可中止本次 Hook 调用。 +- `observe` 处理器后台并发执行,只允许旁路观察,不参与主流程控制。 """ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set + import asyncio -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING +import contextlib from src.common.logger import get_logger from src.config.config import global_config - if TYPE_CHECKING: + from .component_registry import HookHandlerEntry from .supervisor import PluginRunnerSupervisor - from .component_registry import ComponentRegistry, HookHandlerEntry logger = get_logger("plugin_runtime.host.hook_dispatcher") -@dataclass -class HookResult: - """单个 HookHandler 的执行结果""" +@dataclass(slots=True) +class HookSpec: + """命名 Hook 的静态规格定义。 + + Attributes: + name: Hook 的唯一名称。 + description: Hook 描述。 + default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。 + allow_blocking: 是否允许注册阻塞处理器。 + allow_observe: 是否允许注册观察处理器。 + allow_abort: 是否允许处理器中止当前 Hook 调用。 + allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。 + """ + + name: str + description: str = "" + default_timeout_ms: int = 0 + allow_blocking: bool = True + allow_observe: bool = True + allow_abort: bool = True + allow_kwargs_mutation: bool = True + + +@dataclass(slots=True) +class HookHandlerExecutionResult: + """单个 HookHandler 的执行结果。 + + Attributes: + handler_name: 完整处理器名称,格式通常为 `plugin_id.component_name`。 + plugin_id: 处理器所属插件 ID。 + success: 本次调用是否成功。 + action: 当前处理器要求的控制动作,仅支持 `continue` 或 `abort`。 + modified_kwargs: 处理器返回的修改后参数字典。 + custom_result: 处理器返回的附加结果。 + error_message: 失败时的错误描述。 + """ handler_name: str - success: bool = field(default=True) - continue_processing: bool = field(default=True) - modified_kwargs: Optional[Dict[str, Any]] = field(default=None) - custom_result: Any = field(default=None) + plugin_id: str + success: bool = True + action: str = "continue" + modified_kwargs: Optional[Dict[str, Any]] = None + custom_result: Any = None + error_message: str = "" + + +@dataclass(slots=True) +class HookDispatchResult: + """一次命名 Hook 调用的聚合结果。 + + Attributes: + hook_name: 本次调用的 Hook 名称。 + kwargs: 经阻塞处理器串行处理后的最终参数字典。 + aborted: 是否被某个处理器中止。 + stopped_by: 若被中止,记录触发中止的完整处理器名称。 + custom_results: 阻塞处理器返回的附加结果列表。 + errors: 本次调用中记录到的错误信息列表。 + """ + + hook_name: str + kwargs: Dict[str, Any] = field(default_factory=dict) + aborted: bool = False + stopped_by: Optional[str] = None + custom_results: List[Any] = field(default_factory=list) + errors: List[str] = field(default_factory=list) + + +@dataclass(slots=True) +class _HookInvocationTarget: + """内部使用的 Hook 调度目标。 + + Attributes: + supervisor: 负责该处理器的 Supervisor。 + entry: Hook 处理器条目。 + source_rank: 插件来源权重,内置插件为 `0`,第三方插件为 `1`。 + """ + + supervisor: "PluginRunnerSupervisor" + entry: "HookHandlerEntry" + source_rank: int class HookDispatcher: - """Host-side Hook 分发器 + """命名 Hook 分发器。""" - 由业务层调用 hook_dispatch(), - 内部通过 ComponentRegistry 查询 handler, - 再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 - """ - - def __init__(self, component_registry: "ComponentRegistry") -> None: - """初始化 HookDispatcher + def __init__( + self, + supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None, + ) -> None: + """初始化 Hook 分发器。 Args: - component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler + supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()` + 时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。 """ - self._component_registry: "ComponentRegistry" = component_registry - self._background_tasks: Set[asyncio.Task] = set() + + self._background_tasks: Set[asyncio.Task[Any]] = set() + self._hook_specs: Dict[str, HookSpec] = {} + self._supervisors_provider = supervisors_provider async def stop(self) -> None: - """停止 HookDispatcher,取消所有未完成的后台任务""" + """停止分发器并取消所有未完成的观察任务。""" + for task in self._background_tasks: task.cancel() await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear() - async def hook_dispatch( - self, - stage: str, - supervisor: "PluginRunnerSupervisor", - **kwargs: Any, - ) -> Dict[str, Any]: - """分发 hook 到所有对应 handler 的便捷方法。 - - 内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑, - 无需调用方手动构造 invoke_fn 闭包。 + def register_hook_spec(self, spec: HookSpec) -> None: + """注册单个命名 Hook 规格。 Args: - stage: hook 名称 - supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin - **kwargs: 关键字参数,会展开传递给 handler + spec: 需要注册的 Hook 规格。 + """ + + normalized_name = self._normalize_hook_name(spec.name) + self._hook_specs[normalized_name] = HookSpec( + name=normalized_name, + description=spec.description, + default_timeout_ms=max(int(spec.default_timeout_ms), 0), + allow_blocking=bool(spec.allow_blocking), + allow_observe=bool(spec.allow_observe), + allow_abort=bool(spec.allow_abort), + allow_kwargs_mutation=bool(spec.allow_kwargs_mutation), + ) + + def register_hook_specs(self, specs: Sequence[HookSpec]) -> None: + """批量注册命名 Hook 规格。 + + Args: + specs: 需要注册的 Hook 规格序列。 + """ + + for spec in specs: + self.register_hook_spec(spec) + + def get_hook_spec(self, hook_name: str) -> HookSpec: + """获取指定 Hook 的规格定义。 + + Args: + hook_name: Hook 名称。 Returns: - modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数 + HookSpec: 若未显式注册,则返回按系统默认值生成的运行时规格。 """ - handler_entries = self._component_registry.get_hook_handlers(stage) - if not handler_entries: - return kwargs - current_kwargs = kwargs.copy() - blocking_handlers: List["HookHandlerEntry"] = [] - non_blocking_handlers: List["HookHandlerEntry"] = [] + normalized_name = self._normalize_hook_name(hook_name) + if normalized_name in self._hook_specs: + return self._hook_specs[normalized_name] - # 分离 blocking 和非 blocking handler - for entry in handler_entries: - if entry.blocking: - blocking_handlers.append(entry) - else: - non_blocking_handlers.append(entry) + return HookSpec( + name=normalized_name, + default_timeout_ms=self._get_default_timeout_ms(), + ) - # 处理 blocking handlers(同步调用,支持修改参数/提前终止) - timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0 - for entry in blocking_handlers: - hook_args = {"stage": stage, **current_kwargs} - try: - # 应用超时控制 - result = await asyncio.wait_for( - self._invoke_handler(supervisor, entry, hook_args), - timeout=timeout, + async def invoke_hook( + self, + hook_name: str, + supervisors: Optional[Sequence["PluginRunnerSupervisor"]] = None, + **kwargs: Any, + ) -> HookDispatchResult: + """触发一次命名 Hook 调用。 + + Args: + hook_name: 本次触发的 Hook 名称。 + supervisors: 当前运行时中所有可参与分发的 Supervisor;留空时使用绑定的提供器。 + **kwargs: 传递给 Hook 处理器的关键字参数。 + + Returns: + HookDispatchResult: 聚合后的 Hook 调用结果。 + """ + + resolved_supervisors = list(supervisors) if supervisors is not None else list(self._resolve_supervisors()) + normalized_hook_name = self._normalize_hook_name(hook_name) + hook_spec = self.get_hook_spec(normalized_hook_name) + current_kwargs: Dict[str, Any] = dict(kwargs) + dispatch_result = HookDispatchResult(hook_name=normalized_hook_name, kwargs=dict(current_kwargs)) + invocation_targets = self._collect_invocation_targets(normalized_hook_name, resolved_supervisors) + + if not invocation_targets: + return dispatch_result + + for target in invocation_targets: + if target.entry.is_observe: + self._schedule_observe_handler( + hook_name=normalized_hook_name, + hook_spec=hook_spec, + target=target, + kwargs=current_kwargs, ) - except asyncio.TimeoutError: - logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过") - result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True) + continue - if result: - if result.modified_kwargs is not None: - current_kwargs = result.modified_kwargs - if not result.continue_processing: - logger.info(f"HookHandler {entry.full_name} 终止了后续处理") - break + if not hook_spec.allow_blocking: + error_message = ( + f"Hook {normalized_hook_name} 不允许 blocking 处理器," + f"已跳过 {target.entry.full_name}" + ) + logger.warning(error_message) + dispatch_result.errors.append(error_message) + continue - # 处理 non-blocking handlers(异步调用,不阻塞主流程) - for entry in non_blocking_handlers: - async_kwargs = current_kwargs.copy() - hook_args = {"stage": stage, **async_kwargs} - task = asyncio.create_task( - asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout) + execution_result = await self._invoke_handler( + hook_name=normalized_hook_name, + hook_spec=hook_spec, + target=target, + kwargs=current_kwargs, + ) + self._merge_blocking_result( + hook_spec=hook_spec, + target=target, + execution_result=execution_result, + dispatch_result=dispatch_result, ) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - return current_kwargs + current_kwargs = dict(dispatch_result.kwargs) + if dispatch_result.aborted: + break + + return dispatch_result + + def _resolve_supervisors(self) -> Sequence["PluginRunnerSupervisor"]: + """解析当前调用应使用的 Supervisor 列表。 + + Returns: + Sequence[PluginRunnerSupervisor]: 可参与本次 Hook 调度的 Supervisor 序列。 + + Raises: + ValueError: 当未传入 `supervisors` 且分发器也未绑定提供器时抛出。 + """ + + if self._supervisors_provider is None: + raise ValueError("当前 HookDispatcher 未绑定 supervisors_provider,请显式传入 supervisors") + return self._supervisors_provider() + + def _collect_invocation_targets( + self, + hook_name: str, + supervisors: Sequence["PluginRunnerSupervisor"], + ) -> List[_HookInvocationTarget]: + """收集并排序本次 Hook 调用的全部处理器目标。 + + Args: + hook_name: 目标 Hook 名称。 + supervisors: 当前参与调度的 Supervisor 序列。 + + Returns: + List[_HookInvocationTarget]: 已完成全局排序的处理器目标列表。 + """ + + invocation_targets: List[_HookInvocationTarget] = [] + for supervisor in supervisors: + source_rank = self._get_supervisor_source_rank(supervisor) + for entry in supervisor.component_registry.get_hook_handlers(hook_name): + invocation_targets.append( + _HookInvocationTarget( + supervisor=supervisor, + entry=entry, + source_rank=source_rank, + ) + ) + + invocation_targets.sort(key=self._build_sort_key) + return invocation_targets + + @staticmethod + def _build_sort_key(target: _HookInvocationTarget) -> tuple[int, int, int, str, str]: + """构造 Hook 处理器的全局排序键。 + + Args: + target: 待排序的处理器目标。 + + Returns: + tuple[int, int, int, str, str]: 全局排序键。 + """ + + return ( + HookDispatcher._get_mode_rank(target.entry.mode), + HookDispatcher._get_order_rank(target.entry.order), + target.source_rank, + target.entry.plugin_id, + target.entry.name, + ) + + @staticmethod + def _get_default_timeout_ms() -> int: + """读取系统级默认 Hook 超时。 + + Returns: + int: 默认超时毫秒数。 + """ + + timeout_seconds = float(global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0) + return max(int(timeout_seconds * 1000), 1) + + @staticmethod + def _get_mode_rank(mode: str) -> int: + """返回 Hook 模式的排序权重。 + + Args: + mode: Hook 模式。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"blocking": 0, "observe": 1}.get(mode, 99) + + @staticmethod + def _get_order_rank(order: str) -> int: + """返回 Hook 顺序槽位的排序权重。 + + Args: + order: Hook 顺序槽位。 + + Returns: + int: 越小表示越靠前。 + """ + + return {"early": 0, "normal": 1, "late": 2}.get(order, 99) + + @staticmethod + def _get_supervisor_source_rank(supervisor: "PluginRunnerSupervisor") -> int: + """返回 Supervisor 的来源排序权重。 + + Args: + supervisor: 目标 Supervisor。 + + Returns: + int: 内置插件返回 `0`,第三方插件返回 `1`。 + """ + + return 0 if supervisor.group_name == "builtin" else 1 + + @staticmethod + def _normalize_hook_name(hook_name: str) -> str: + """规范化命名 Hook 名称。 + + Args: + hook_name: 原始 Hook 名称。 + + Returns: + str: 规范化后的 Hook 名称。 + + Raises: + ValueError: 当 Hook 名称为空时抛出。 + """ + + normalized_name = str(hook_name or "").strip() + if not normalized_name: + raise ValueError("Hook 名称不能为空") + return normalized_name + + def _resolve_timeout_ms(self, hook_spec: HookSpec, target: _HookInvocationTarget) -> int: + """计算单个处理器的实际超时。 + + Args: + hook_spec: 当前 Hook 的规格定义。 + target: 当前执行目标。 + + Returns: + int: 最终生效的超时毫秒数。 + """ + + if target.entry.timeout_ms > 0: + return target.entry.timeout_ms + if hook_spec.default_timeout_ms > 0: + return hook_spec.default_timeout_ms + return self._get_default_timeout_ms() async def _invoke_handler( self, - supervisor: "PluginRunnerSupervisor", - handler_entry: "HookHandlerEntry", - args: Dict[str, Any], - ) -> Optional[HookResult]: - """调用单个 handler 并收集结果。 + hook_name: str, + hook_spec: HookSpec, + target: _HookInvocationTarget, + kwargs: Dict[str, Any], + ) -> HookHandlerExecutionResult: + """执行单个 Hook 处理器。 Args: - supervisor: PluginRunnerSupervisor 实例 - handler_entry: HookHandlerEntry 实例 - args: 传递给 handler 的参数字典 - stage: hook 名称 + hook_name: 当前 Hook 名称。 + hook_spec: 当前 Hook 规格。 + target: 当前执行目标。 + kwargs: 当前参数字典。 Returns: - Optional[HookResult]: 执行结果,如果执行失败则返回 None + HookHandlerExecutionResult: 处理器执行结果。 """ - try: - resp_envelope = await supervisor.invoke_plugin( - "plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args - ) - resp = resp_envelope.payload - result = HookResult( - handler_name=handler_entry.full_name, - success=resp.get("success", True), - continue_processing=resp.get("continue_processing", True), - modified_kwargs=resp.get("modified_kwargs"), - custom_result=resp.get("custom_result"), - ) - except Exception as e: - logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True) - result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True) - return result + timeout_ms = self._resolve_timeout_ms(hook_spec, target) + request_args: Dict[str, Any] = {"hook_name": hook_name, **dict(kwargs)} + + try: + response_envelope = await asyncio.wait_for( + target.supervisor.invoke_plugin( + "plugin.invoke_hook", + target.entry.plugin_id, + target.entry.name, + request_args, + timeout_ms=timeout_ms, + ), + timeout=max(timeout_ms / 1000.0, 0.001), + ) + except asyncio.TimeoutError: + error_message = ( + f"HookHandler {target.entry.full_name} 执行超时,已超过 {timeout_ms}ms" + ) + logger.error(error_message) + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + success=False, + error_message=error_message, + ) + except Exception as exc: + error_message = f"HookHandler {target.entry.full_name} 执行失败: {exc}" + logger.error(error_message, exc_info=True) + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + success=False, + error_message=error_message, + ) + + response_payload = response_envelope.payload + if not isinstance(response_payload, dict): + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + custom_result=response_payload, + ) + + return HookHandlerExecutionResult( + handler_name=target.entry.full_name, + plugin_id=target.entry.plugin_id, + success=bool(response_payload.get("success", True)), + action=self._normalize_action(response_payload.get("action", "continue")), + modified_kwargs=self._extract_modified_kwargs(response_payload.get("modified_kwargs")), + custom_result=response_payload.get("custom_result"), + error_message=str(response_payload.get("error_message", "") or ""), + ) + + @staticmethod + def _extract_modified_kwargs(raw_value: Any) -> Optional[Dict[str, Any]]: + """提取并校验处理器返回的 `modified_kwargs`。 + + Args: + raw_value: 原始返回值。 + + Returns: + Optional[Dict[str, Any]]: 合法时返回字典,否则返回 `None`。 + """ + + if raw_value is None: + return None + if isinstance(raw_value, dict): + return dict(raw_value) + logger.warning("HookHandler 返回的 modified_kwargs 不是字典,已忽略") + return None + + @staticmethod + def _normalize_action(raw_value: Any) -> str: + """规范化处理器动作返回值。 + + Args: + raw_value: 原始动作值。 + + Returns: + str: 规范化后的动作值,仅支持 `continue` 或 `abort`。 + """ + + normalized_value = str(raw_value or "").strip().lower() or "continue" + if normalized_value not in {"continue", "abort"}: + logger.warning(f"未知的 Hook action: {raw_value},已按 continue 处理") + return "continue" + return normalized_value + + def _merge_blocking_result( + self, + hook_spec: HookSpec, + target: _HookInvocationTarget, + execution_result: HookHandlerExecutionResult, + dispatch_result: HookDispatchResult, + ) -> None: + """合并阻塞处理器结果到聚合结果。 + + Args: + hook_spec: 当前 Hook 规格。 + target: 当前执行目标。 + execution_result: 当前处理器执行结果。 + dispatch_result: 当前聚合结果对象。 + """ + + if execution_result.custom_result is not None: + dispatch_result.custom_results.append(execution_result.custom_result) + + if not execution_result.success: + error_message = execution_result.error_message or f"HookHandler {target.entry.full_name} 执行失败" + dispatch_result.errors.append(error_message) + self._apply_error_policy(target, hook_spec, dispatch_result, error_message) + return + + if execution_result.modified_kwargs is not None: + if hook_spec.allow_kwargs_mutation: + dispatch_result.kwargs = dict(execution_result.modified_kwargs) + else: + error_message = ( + f"Hook {dispatch_result.hook_name} 不允许修改 kwargs," + f"已忽略 {target.entry.full_name} 的 modified_kwargs" + ) + logger.warning(error_message) + dispatch_result.errors.append(error_message) + + if execution_result.action == "abort": + if hook_spec.allow_abort: + dispatch_result.aborted = True + dispatch_result.stopped_by = target.entry.full_name + logger.info(f"HookHandler {target.entry.full_name} 中止了 Hook {dispatch_result.hook_name}") + else: + error_message = ( + f"Hook {dispatch_result.hook_name} 不允许 abort," + f"已忽略 {target.entry.full_name} 的 abort 请求" + ) + logger.warning(error_message) + dispatch_result.errors.append(error_message) + + def _apply_error_policy( + self, + target: _HookInvocationTarget, + hook_spec: HookSpec, + dispatch_result: HookDispatchResult, + error_message: str, + ) -> None: + """根据错误策略处理阻塞处理器失败。 + + Args: + target: 触发错误的处理器目标。 + hook_spec: 当前 Hook 规格。 + dispatch_result: 当前聚合结果对象。 + error_message: 需要记录的错误描述。 + """ + + if target.entry.error_policy != "abort": + return + if not hook_spec.allow_abort: + logger.warning( + "Hook %s 禁止 abort,已将 %s 的错误策略按 skip 处理", + dispatch_result.hook_name, + target.entry.full_name, + ) + return + + dispatch_result.aborted = True + dispatch_result.stopped_by = target.entry.full_name + logger.warning( + "HookHandler %s 因错误策略 abort 中止了 Hook %s: %s", + target.entry.full_name, + dispatch_result.hook_name, + error_message, + ) + + def _schedule_observe_handler( + self, + hook_name: str, + hook_spec: HookSpec, + target: _HookInvocationTarget, + kwargs: Dict[str, Any], + ) -> None: + """后台调度观察型处理器。 + + Args: + hook_name: 当前 Hook 名称。 + hook_spec: 当前 Hook 规格。 + target: 当前观察型处理器目标。 + kwargs: 调用参数快照。 + """ + + if not hook_spec.allow_observe: + logger.warning("Hook %s 不允许 observe 处理器,已跳过 %s", hook_name, target.entry.full_name) + return + + task = asyncio.create_task( + self._run_observe_handler( + hook_name=hook_name, + hook_spec=hook_spec, + target=target, + kwargs=dict(kwargs), + ) + ) + self._background_tasks.add(task) + task.add_done_callback(self._handle_background_task_done) + + async def _run_observe_handler( + self, + hook_name: str, + hook_spec: HookSpec, + target: _HookInvocationTarget, + kwargs: Dict[str, Any], + ) -> None: + """执行观察型处理器并吞掉控制流副作用。 + + Args: + hook_name: 当前 Hook 名称。 + hook_spec: 当前 Hook 规格。 + target: 当前观察型处理器目标。 + kwargs: 调用参数快照。 + """ + + execution_result = await self._invoke_handler( + hook_name=hook_name, + hook_spec=hook_spec, + target=target, + kwargs=kwargs, + ) + + if not execution_result.success: + logger.warning( + "观察型 HookHandler %s 执行失败: %s", + target.entry.full_name, + execution_result.error_message or "未知错误", + ) + return + + if execution_result.modified_kwargs is not None: + logger.warning( + "观察型 HookHandler %s 返回了 modified_kwargs,已忽略", target.entry.full_name + ) + if execution_result.action == "abort": + logger.warning( + "观察型 HookHandler %s 请求 abort,已忽略", target.entry.full_name + ) + + def _handle_background_task_done(self, task: asyncio.Task[Any]) -> None: + """处理观察任务完成回调。 + + Args: + task: 已完成的后台任务。 + """ + + self._background_tasks.discard(task) + with contextlib.suppress(asyncio.CancelledError): + exception = task.exception() + if exception is not None: + logger.error(f"观察型 Hook 后台任务执行失败: {exception}") diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 08638d16..c94fcb3f 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -49,7 +49,7 @@ from .api_registry import APIRegistry from .capability_service import CapabilityService from .component_registry import ComponentRegistry from .event_dispatcher import EventDispatcher -from .hook_dispatcher import HookDispatcher +from .hook_dispatcher import HookDispatchResult, HookDispatcher from .logger_bridge import RunnerLogBridge from .message_gateway import MessageGateway from .rpc_server import RPCServer @@ -80,6 +80,7 @@ class PluginRunnerSupervisor: def __init__( self, plugin_dirs: Optional[List[Path]] = None, + group_name: str = "third_party", socket_path: Optional[str] = None, health_check_interval_sec: Optional[float] = None, max_restart_attempts: Optional[int] = None, @@ -89,12 +90,14 @@ class PluginRunnerSupervisor: Args: plugin_dirs: 由当前 Runner 负责加载的插件目录列表。 + group_name: 当前 Supervisor 所属运行时分组名称。 socket_path: 自定义 IPC 地址;留空时由传输层自动生成。 health_check_interval_sec: 健康检查间隔,单位秒。 max_restart_attempts: 自动重启 Runner 的最大次数。 runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。 """ runtime_config = global_config.plugin_runtime + self._group_name: str = str(group_name or "third_party").strip() or "third_party" self._plugin_dirs: List[Path] = plugin_dirs or [] self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0 self._runner_spawn_timeout: float = ( @@ -108,7 +111,7 @@ class PluginRunnerSupervisor: self._api_registry = APIRegistry() self._component_registry = ComponentRegistry() self._event_dispatcher = EventDispatcher(self._component_registry) - self._hook_dispatcher = HookDispatcher(self._component_registry) + self._hook_dispatcher = HookDispatcher(lambda: [self]) self._message_gateway = MessageGateway(self._component_registry) self._log_bridge = RunnerLogBridge() @@ -133,6 +136,12 @@ class PluginRunnerSupervisor: """返回授权管理器。""" return self._authorization + @property + def group_name(self) -> str: + """返回当前 Supervisor 的运行时分组名称。""" + + return self._group_name + @property def capability_service(self) -> CapabilityService: """返回能力服务。""" @@ -243,17 +252,18 @@ class PluginRunnerSupervisor: """ return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) - async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]: - """分发 Hook 到已注册的 Hook 处理器。 + async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult: + """在当前 Supervisor 内触发一次命名 Hook 调用。 Args: - stage: Hook 阶段名称。 - **kwargs: 传递给 Hook 的关键字参数。 + hook_name: 本次触发的 Hook 名称。 + **kwargs: 传递给 Hook 处理器的关键字参数。 Returns: - Dict[str, Any]: 经 Hook 修改后的参数字典。 + HookDispatchResult: 聚合后的 Hook 调用结果。 """ - return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs) + + return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs) async def send_message_to_external( self, diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index c34f5ef5..264c8ed2 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -3,8 +3,9 @@ 提供 PluginRuntimeManager 单例,负责: 1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程) 2. 将 EventType 桥接到运行时的 event dispatch -3. 在运行时的 ComponentRegistry 中查找命令 -4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 +3. 触发跨 Supervisor 的命名 Hook 调用 +4. 在运行时的 ComponentRegistry 中查找命令 +5. 提供统一的能力实现注册接口,使插件可以调用主程序功能 """ from pathlib import Path @@ -24,6 +25,7 @@ from src.plugin_runtime.capabilities import ( RuntimeDataCapabilityMixin, ) from src.plugin_runtime.capabilities.registry import register_capability_impls +from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils from src.plugin_runtime.runner.manifest_validator import ManifestValidator @@ -72,6 +74,7 @@ class PluginRuntimeManager( self._manifest_validator: ManifestValidator = ManifestValidator() self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload self._config_reload_callback_registered: bool = False + self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors) async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: """接收 Platform IO 审核后的入站消息并送入主消息链。 @@ -182,6 +185,7 @@ class PluginRuntimeManager( if builtin_dirs: self._builtin_supervisor = PluginSupervisor( plugin_dirs=builtin_dirs, + group_name="builtin", socket_path=builtin_socket, ) self._register_capability_impls(self._builtin_supervisor) @@ -189,6 +193,7 @@ class PluginRuntimeManager( if third_party_dirs: self._third_party_supervisor = PluginSupervisor( plugin_dirs=third_party_dirs, + group_name="third_party", socket_path=third_party_socket, ) self._register_capability_impls(self._third_party_supervisor) @@ -235,6 +240,7 @@ class PluginRuntimeManager( await platform_io_manager.stop() except Exception as platform_io_exc: logger.warning(f"Platform IO 停止失败: {platform_io_exc}") + await self._hook_dispatcher.stop() self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -274,6 +280,7 @@ class PluginRuntimeManager( else: logger.info("插件运行时已停止") finally: + await self._hook_dispatcher.stop() self._started = False self._builtin_supervisor = None self._third_party_supervisor = None @@ -284,11 +291,41 @@ class PluginRuntimeManager( """返回插件运行时是否处于启动状态。""" return self._started + @property + def hook_dispatcher(self) -> HookDispatcher: + """返回跨 Supervisor 的命名 Hook 分发器。""" + + return self._hook_dispatcher + + @property + def invoke_dispatcher(self) -> HookDispatcher: + """返回命名 Hook 分发器的兼容别名。""" + + return self._hook_dispatcher + @property def supervisors(self) -> List["PluginSupervisor"]: """获取所有活跃的 Supervisor""" return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None] + def register_hook_spec(self, spec: HookSpec) -> None: + """注册单个命名 Hook 规格。 + + Args: + spec: 需要注册的 Hook 规格。 + """ + + self._hook_dispatcher.register_hook_spec(spec) + + def register_hook_specs(self, specs: Sequence[HookSpec]) -> None: + """批量注册命名 Hook 规格。 + + Args: + specs: 需要注册的 Hook 规格序列。 + """ + + self._hook_dispatcher.register_hook_specs(specs) + def _build_registered_dependency_map(self) -> Dict[str, Set[str]]: """根据当前已注册插件构建全局依赖图。""" @@ -588,6 +625,19 @@ class PluginRuntimeManager( return True, modified + async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult: + """触发一次跨 Supervisor 的命名 Hook 调用。 + + Args: + hook_name: 本次触发的 Hook 名称。 + **kwargs: 传递给 Hook 处理器的关键字参数。 + + Returns: + HookDispatchResult: 聚合后的 Hook 调用结果。 + """ + + return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs) + # ─── 命令查找 ────────────────────────────────────────────── def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]: diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index d1ebc064..9de5d977 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -330,7 +330,6 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) - self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step) self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) @@ -1053,73 +1052,28 @@ class PluginRunner: ) except Exception as exc: logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True) - return envelope.make_response(payload={"success": False, "continue_processing": True}) + return envelope.make_response( + payload={ + "success": False, + "action": "continue", + "error_message": str(exc), + } + ) if raw is None: - result = {"success": True, "continue_processing": True} + result = {"success": True, "action": "continue"} elif isinstance(raw, dict): result = { "success": True, - "continue_processing": raw.get("continue_processing", True), + "action": str(raw.get("action", "continue") or "continue").strip().lower() or "continue", "modified_kwargs": raw.get("modified_kwargs"), "custom_result": raw.get("custom_result"), } else: - result = {"success": True, "continue_processing": True, "custom_result": raw} + result = {"success": True, "action": "continue", "custom_result": raw} return envelope.make_response(payload=result) - async def _handle_workflow_step(self, envelope: Envelope) -> Envelope: - """处理 WorkflowStep 调用请求 - - 与通用 invoke 不同,会将返回值规范化为 - {hook_result, modified_message, stage_output} 格式。 - """ - try: - invoke = InvokePayload.model_validate(envelope.payload) - except Exception as e: - return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) - - plugin_id = envelope.plugin_id - meta = self._loader.get_plugin(plugin_id) - if meta is None: - return envelope.make_error_response( - ErrorCode.E_PLUGIN_NOT_FOUND.value, - f"插件 {plugin_id} 未加载", - ) - - component_name = invoke.component_name - handler_method = self._resolve_component_handler(meta, component_name) - - if handler_method is None or not callable(handler_method): - return envelope.make_error_response( - ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"插件 {plugin_id} 无组件: {component_name}", - ) - - try: - raw = ( - await handler_method(**invoke.args) - if inspect.iscoroutinefunction(handler_method) - else handler_method(**invoke.args) - ) - - # 规范化返回值 - if isinstance(raw, str): - result = {"hook_result": raw} - elif isinstance(raw, dict): - result = raw - result.setdefault("hook_result", "continue") - else: - result = {"hook_result": "continue"} - - resp_payload = InvokeResultPayload(success=True, result=result) - return envelope.make_response(payload=resp_payload.model_dump()) - except Exception as e: - logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True) - resp_payload = InvokeResultPayload(success=False, result=str(e)) - return envelope.make_response(payload=resp_payload.model_dump()) - async def _handle_health(self, envelope: Envelope) -> Envelope: """处理健康检查""" uptime_ms = int((time.monotonic() - self._start_time) * 1000)