feat: Enhance Hook System with HookHandler and Dispatcher

- Introduced HookHandlerEntry to manage hook processing with attributes like hook name, mode, order, timeout, and error policy.
- Implemented normalization methods for hook attributes to ensure valid configurations.
- Updated ComponentRegistry to support retrieval of hook handlers based on hook names, with sorting by mode and order.
- Refactored HookDispatcher to handle invocation of hooks, separating blocking and non-blocking handlers, and managing execution results.
- Added support for registering hook specifications and invoking hooks across supervisors in PluginRuntimeManager.
- Removed deprecated workflow step handling from PluginRunner, streamlining hook invocation responses.
This commit is contained in:
DrSmoothl
2026-03-24 19:04:05 +08:00
parent 865e4916e3
commit 0b0f47a444
6 changed files with 1247 additions and 523 deletions

View File

@@ -5,6 +5,7 @@
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Awaitable, Callable, Dict, List, Optional
import asyncio
import json
@@ -1831,395 +1832,445 @@ class TestMaiMessages:
assert msg.llm_response_content == "new response"
# ─── WorkflowExecutor 测试 ────────────────────────────────
class _FakeHookSupervisor:
"""用于 Hook 分发测试的简化 Supervisor。"""
def __init__(
self,
group_name: str,
component_registry: Any,
handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]],
call_log: List[tuple[str, str]],
) -> None:
"""初始化测试用 Supervisor。
Args:
group_name: 运行时分组名称。
component_registry: 组件注册表实例。
handlers: 处理器映射,键为 `plugin_id.component_name`。
call_log: 记录调用顺序的列表。
"""
self._group_name = group_name
self.component_registry = component_registry
self._handlers = handlers
self._call_log = call_log
@property
def group_name(self) -> str:
"""返回当前测试 Supervisor 的分组名称。"""
return self._group_name
async def invoke_plugin(
self,
method: str,
plugin_id: str,
component_name: str,
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> SimpleNamespace:
"""模拟调用插件组件。
Args:
method: RPC 方法名。
plugin_id: 目标插件 ID。
component_name: 目标组件名称。
args: 调用参数。
timeout_ms: 超时配置,测试中仅用于保持接口一致。
Returns:
SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。
"""
del method
del timeout_ms
full_name = f"{plugin_id}.{component_name}"
handler = self._handlers[full_name]
self._call_log.append((plugin_id, component_name))
result = handler(dict(args or {}))
if asyncio.iscoroutine(result):
result = await result
return SimpleNamespace(payload=result)
class TestWorkflowExecutor:
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
# ─── HookDispatcher 测试 ────────────────────────────────
class TestHookDispatcher:
"""命名 Hook 分发器测试。"""
@staticmethod
def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
"""导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。
Args:
monkeypatch: pytest 的 monkeypatch 工具。
Returns:
tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。
"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.hook_dispatcher import HookDispatcher
return ComponentRegistry, HookDispatcher
@pytest.mark.asyncio
async def test_empty_pipeline_completes(self):
"""无任何 workflow_step 注册时pipeline 全阶段跳过,状态 completed"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""未注册处理器时应直接返回原始参数。"""
reg = ComponentRegistry()
executor = WorkflowExecutor(reg)
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "continue"}
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, [])
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "completed"
assert result.return_message == "workflow completed"
assert len(ctx.timings) == 6 # 6 stages
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.hook_name == "heart_fc.cycle_start"
assert result.kwargs == {"session_id": "s-1"}
assert result.aborted is False
@pytest.mark.asyncio
async def test_blocking_hook_modifies_message(self):
"""blocking hook 可以修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""blocking 处理器可以修改参数。"""
reg = ComponentRegistry()
reg.register_component(
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"upper",
"workflow_step",
"HOOK_HANDLER",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
"hook": "heart_fc.cycle_start",
"mode": "blocking",
"order": "normal",
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
msg = args.get("message", {})
return {
"hook_result": "continue",
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()},
}
result, final_msg, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "hello"},
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{
"p1.upper": lambda args: {
"success": True,
"action": "continue",
"modified_kwargs": {
"session_id": args["session_id"],
"text": str(args["text"]).upper(),
},
}
},
[],
)
assert result.status == "completed"
assert final_msg["plain_text"] == "HELLO"
assert len(ctx.modification_log) == 1
assert ctx.modification_log[0].stage == "pre_process"
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello")
assert result.kwargs["session_id"] == "s-1"
assert result.kwargs["text"] == "HELLO"
assert result.aborted is False
@pytest.mark.asyncio
async def test_abort_stops_pipeline(self):
"""HookResult.ABORT 立即终止 pipeline"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""blocking 处理器的 abort 应阻止后续 blocking 处理器执行。"""
reg = ComponentRegistry()
reg.register_component(
"blocker",
"workflow_step",
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"stopper",
"HOOK_HANDLER",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
return {"hook_result": "abort"}
result, _, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "aborted"
assert result.stopped_at == "pre_process"
@pytest.mark.asyncio
async def test_skip_stage(self):
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# high-priority hook 返回 skip_stage
reg.register_component(
"skipper",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
},
)
# low-priority hook 不应被执行
reg.register_component(
"checker",
"workflow_step",
registry.register_component(
"after_stop",
"HOOK_HANDLER",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
call_log: List[tuple[str, str]] = []
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{
"p1.stopper": lambda args: {"success": True, "action": "abort"},
"p2.after_stop": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
executor = WorkflowExecutor(reg)
call_log = []
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1")
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "skipper":
return {"hook_result": "skip_stage"}
return {"hook_result": "continue"}
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
# 只有 skipper 被调用checker 被跳过
assert call_log == ["skipper"]
assert result.aborted is True
assert result.stopped_by == "p1.stopper"
assert call_log == [("p1", "stopper")]
@pytest.mark.asyncio
async def test_pre_filter(self):
"""filter 条件不匹配时跳过 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_observe_handler_runs_in_background_without_mutation(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""observe 处理器应后台执行且不能影响主流程参数。"""
reg = ComponentRegistry()
reg.register_component(
"only_dm",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
"filter": {"chat_type": "direct"},
},
)
executor = WorkflowExecutor(reg)
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue"}
# 不匹配 filter —— hook 不应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
assert not call_log
# 匹配 filter —— hook 应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
assert call_log == ["only_dm"]
@pytest.mark.asyncio
async def test_error_policy_skip(self):
"""error_policy=skip 时跳过失败的 hook 继续执行"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
"error_policy": "skip",
},
)
reg.register_component(
"ok_step",
"workflow_step",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "failer":
raise RuntimeError("boom")
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
assert "failer" in call_log
assert "ok_step" in call_log
assert any("boom" in e for e in ctx.errors)
@pytest.mark.asyncio
async def test_error_policy_abort(self):
"""error_policy=abort默认时 pipeline 失败"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
# error_policy defaults to "abort"
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
raise RuntimeError("fatal")
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "failed"
assert result.stopped_at == "ingress"
@pytest.mark.asyncio
async def test_nonblocking_hooks_concurrent(self):
"""non-blocking hook 并发执行,不修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
for i in range(3):
reg.register_component(
f"nb_{i}",
"workflow_step",
f"p{i}",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
# non-blocking 的 modified_message 被忽略
assert final_msg["plain_text"] == "original"
# 给异步 task 时间完成
await asyncio.sleep(0.1)
assert result.status == "completed"
@pytest.mark.asyncio
async def test_nonblocking_tasks_are_retained_until_completion(self):
"""execute 返回后non-blocking task 仍应保持强引用直到执行完成。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
registry = ComponentRegistry()
registry.register_component(
"observer",
"workflow_step",
"HOOK_HANDLER",
"p1",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
)
executor = WorkflowExecutor(reg)
started = asyncio.Event()
release = asyncio.Event()
call_log: List[tuple[str, str]] = []
async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""模拟耗时观察型处理器。"""
async def mock_invoke(plugin_id, comp_name, args):
started.set()
await release.wait()
return {"hook_result": "continue"}
return {
"success": True,
"action": "abort",
"modified_kwargs": {"session_id": "changed"},
"custom_result": args["session_id"],
}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{"p1.observer": observe_handler},
call_log,
)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
await asyncio.sleep(0)
assert result.status == "completed"
assert final_msg["plain_text"] == "original"
assert result.aborted is False
assert result.kwargs["session_id"] == "s-1"
assert started.is_set()
assert len(executor._background_tasks) == 1
assert len(dispatcher._background_tasks) == 1
release.set()
await asyncio.sleep(0)
await asyncio.sleep(0)
assert not executor._background_tasks
assert call_log == [("p1", "observer")]
assert not dispatcher._background_tasks
@pytest.mark.asyncio
async def test_command_routing(self):
"""PLAN 阶段内置命令路由"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""全局排序应先看 order再看内置/第三方来源。"""
reg = ComponentRegistry()
reg.register_component(
"help",
"command",
"p1",
{
"command_pattern": r"^/help",
},
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
builtin_registry = ComponentRegistry()
third_registry = ComponentRegistry()
builtin_registry.register_component(
"builtin_early",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
builtin_registry.register_component(
"builtin_normal",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
third_registry.register_component(
"third_early",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
third_registry.register_component(
"third_normal",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "help":
return {"output": "帮助信息"}
return {"hook_result": "continue"}
call_log: List[tuple[str, str]] = []
dispatcher = HookDispatcher()
builtin_supervisor = _FakeHookSupervisor(
"builtin",
builtin_registry,
{
"b1.builtin_early": lambda args: {"success": True, "action": "continue"},
"b1.builtin_normal": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
third_supervisor = _FakeHookSupervisor(
"third_party",
third_registry,
{
"t1.third_early": lambda args: {"success": True, "action": "continue"},
"t1.third_normal": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"})
assert result.status == "completed"
assert ctx.matched_command == "p1.help"
cmd_result = ctx.get_stage_output("plan", "command_result")
assert cmd_result is not None
assert cmd_result["output"] == "帮助信息"
await dispatcher.invoke_hook(
"heart_fc.cycle_start",
[third_supervisor, builtin_supervisor],
cycle_id="c-1",
)
assert call_log == [
("b1", "builtin_early"),
("t1", "third_early"),
("b1", "builtin_normal"),
("t1", "third_normal"),
]
@pytest.mark.asyncio
async def test_stage_outputs(self):
"""stage_outputs 数据在阶段间传递"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""error_policy=abort 时应中止本次 Hook 调用。"""
reg = ComponentRegistry()
# ingress 阶段写入数据
reg.register_component(
"writer",
"workflow_step",
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"failer",
"HOOK_HANDLER",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
"hook": "heart_fc.cycle_start",
"mode": "blocking",
"order": "normal",
"error_policy": "abort",
},
)
# pre_process 阶段读取数据
reg.register_component(
"reader",
"workflow_step",
"p2",
call_log: List[tuple[str, str]] = []
async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""抛出异常以触发 abort 策略。"""
del args
raise RuntimeError("boom")
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.aborted is True
assert result.stopped_by == "p1.failer"
assert any("boom" in error for error in result.errors)
assert call_log == [("p1", "failer")]
@pytest.mark.asyncio
async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""处理器超时应被记录为错误并继续。"""
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"slow",
"HOOK_HANDLER",
"p1",
{
"stage": "pre_process",
"priority": 10,
"blocking": True,
"hook": "heart_fc.cycle_start",
"mode": "blocking",
"order": "normal",
"timeout_ms": 10,
},
)
executor = WorkflowExecutor(reg)
call_log: List[tuple[str, str]] = []
async def mock_invoke(plugin_id, comp_name, args):
if comp_name == "writer":
return {
"hook_result": "continue",
"stage_output": {"parsed_intent": "greeting"},
}
if comp_name == "reader":
# 验证 stage_outputs 被传递过来
outputs = args.get("stage_outputs", {})
ingress_data = outputs.get("ingress", {})
assert ingress_data.get("parsed_intent") == "greeting"
return {"hook_result": "continue"}
return {"hook_result": "continue"}
async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""模拟超时处理器。"""
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"})
assert result.status == "completed"
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting"
del args
await asyncio.sleep(0.05)
return {"success": True, "action": "continue"}
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.aborted is False
assert any("超时" in error for error in result.errors)
assert call_log == [("p1", "slow")]
class TestPluginRuntimeHookEntry:
"""PluginRuntimeManager 命名 Hook 入口测试。"""
@staticmethod
def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
"""导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。
Args:
monkeypatch: pytest 的 monkeypatch 工具。
Returns:
tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。
"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.integration import PluginRuntimeManager
return ComponentRegistry, PluginRuntimeManager
@pytest.mark.asyncio
async def test_manager_invoke_hook_dispatches_across_supervisors(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。"""
ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
builtin_registry = ComponentRegistry()
builtin_registry.register_component(
"builtin_guard",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
third_registry = ComponentRegistry()
third_registry.register_component(
"observer",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
)
call_log: List[tuple[str, str]] = []
manager = PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = _FakeHookSupervisor(
"builtin",
builtin_registry,
{"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}},
call_log,
)
manager._third_party_supervisor = _FakeHookSupervisor(
"third_party",
third_registry,
{"t1.observer": lambda args: {"success": True, "action": "continue"}},
call_log,
)
result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1")
await asyncio.sleep(0)
assert manager.invoke_dispatcher is manager.hook_dispatcher
assert result.aborted is False
assert result.kwargs["session_id"] == "s-1"
assert ("b1", "builtin_guard") in call_log
class TestRPCServer:

View File

@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_handler / message_gateway
- 按类型注册组件action / command / tool / event_handler / hook_handler / message_gateway
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -106,14 +106,129 @@ class EventHandlerEntry(ComponentEntry):
class HookHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目"""
"""HookHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "")
self.priority: int = metadata.get("priority", 0)
self.blocking: bool = metadata.get("blocking", False)
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
self.order: str = self._normalize_order(metadata.get("order", "normal"))
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
super().__init__(name, component_type, plugin_id, metadata)
@staticmethod
def _normalize_error_policy(raw_value: Any) -> str:
"""规范化 Hook 异常处理策略。
Args:
raw_value: 原始异常处理策略值。
Returns:
str: 规范化后的异常处理策略。
Raises:
ValueError: 当异常处理策略不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "skip"
if normalized_value not in {"abort", "skip", "log"}:
raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_hook_name(raw_value: Any) -> str:
"""规范化命名 Hook 名称。
Args:
raw_value: 原始 Hook 名称。
Returns:
str: 去空白后的 Hook 名称。
Raises:
ValueError: 当 Hook 名称为空时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
if not (normalized_value := str(normalized_source or "").strip()):
raise ValueError("HookHandler 的 hook 名称不能为空")
return normalized_value
@staticmethod
def _normalize_mode(raw_value: Any) -> str:
"""规范化 Hook 处理模式。
Args:
raw_value: 原始模式值。
Returns:
str: 规范化后的模式。
Raises:
ValueError: 当模式不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "blocking"
if normalized_value not in {"blocking", "observe"}:
raise ValueError(f"HookHandler 模式不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_order(raw_value: Any) -> str:
"""规范化 Hook 顺序槽位。
Args:
raw_value: 原始顺序值。
Returns:
str: 规范化后的顺序槽位。
Raises:
ValueError: 当顺序值不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "normal"
if normalized_value not in {"early", "normal", "late"}:
raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_timeout_ms(raw_value: Any) -> int:
"""规范化 Hook 超时配置。
Args:
raw_value: 原始超时值。
Returns:
int: 规范化后的超时毫秒数。
Raises:
ValueError: 当超时值为负数或无法转换为整数时抛出。
"""
try:
timeout_ms = int(raw_value or 0)
except (TypeError, ValueError) as exc:
raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc
if timeout_ms < 0:
raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}")
return timeout_ms
@property
def is_blocking(self) -> bool:
"""返回当前 Hook 是否为阻塞模式。"""
return self.mode == "blocking"
@property
def is_observe(self) -> bool:
"""返回当前 Hook 是否为观察模式。"""
return self.mode == "observe"
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
@@ -454,16 +569,17 @@ class ComponentRegistry:
return handlers
def get_hook_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[HookHandlerEntry]:
"""获取特定 hook 阶段的所有步骤,按 priority 降序
"""获取订阅指定命名 Hook 的全部处理器
Args:
stage: hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID若提供则考虑会话禁用状态
hook_name: 目标 Hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话 ID若提供则考虑会话禁用状态
Returns:
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表
"""
handlers: List[HookHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
@@ -471,11 +587,37 @@ class ComponentRegistry:
continue
if not isinstance(comp, HookHandlerEntry):
continue
if comp.stage == stage:
if comp.hook == hook_name:
handlers.append(comp)
handlers.sort(key=lambda c: c.priority, reverse=True)
handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name))
return handlers
@staticmethod
def _get_hook_mode_rank(mode: str) -> int:
"""返回 Hook 模式的排序权重。
Args:
mode: Hook 模式字符串。
Returns:
int: 越小表示越靠前。
"""
return {"blocking": 0, "observe": 1}.get(mode, 99)
@staticmethod
def _get_hook_order_rank(order: str) -> int:
"""返回 Hook 顺序槽位的排序权重。
Args:
order: Hook 顺序槽位字符串。
Returns:
int: 越小表示越靠前。
"""
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
def get_message_gateway(
self,
plugin_id: str,
@@ -566,8 +708,13 @@ class ComponentRegistry:
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
stats[comp_type.value.lower()] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats
return StatusDict(
total=len(self._components),
action=len(self._by_type[ComponentTypes.ACTION]),
command=len(self._by_type[ComponentTypes.COMMAND]),
tool=len(self._by_type[ComponentTypes.TOOL]),
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
plugins=len(self._by_plugin),
)

View File

@@ -1,166 +1,678 @@
"""
Hook Dispatch 系统
"""命名 Hook 分发系统。
插件可以注册自己的Hook当特定函数被调用时Hook Dispatch系统会将调用转发给插件的Hook处理函数。
每个Hook的参数随Hook点位确定因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数
在参数/返回值匹配的情况下允许修改参数/返回值。
主程序可以在任意执行点触发一个命名 HookHost 会收集所有订阅该 Hook 的
插件处理器,并按照固定的全局顺序调度执行
HookDispatcher 负责
1. 按 stage 查询已注册的 hook_handler通过 ComponentRegistry
2. 按 priority 排序,区分 blocking 和非 blocking 模式
3. blocking 模式:依次同步调用,支持修改参数/提前终止
4. 非 blocking 模式:异步调用,不阻塞主流程
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
排序规则如下
1. `blocking` 先于 `observe`
2. `early` 先于 `normal` 先于 `late`
3. 内置插件先于第三方插件
4. `plugin_id`
5. `handler_name`
其中:
- `blocking` 处理器串行执行,可修改 `kwargs`,也可中止本次 Hook 调用。
- `observe` 处理器后台并发执行,只允许旁路观察,不参与主流程控制。
"""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
import contextlib
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from .component_registry import HookHandlerEntry
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, HookHandlerEntry
logger = get_logger("plugin_runtime.host.hook_dispatcher")
@dataclass
class HookResult:
"""单个 HookHandler 的执行结果"""
@dataclass(slots=True)
class HookSpec:
"""命名 Hook 的静态规格定义。
Attributes:
name: Hook 的唯一名称。
description: Hook 描述。
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
allow_blocking: 是否允许注册阻塞处理器。
allow_observe: 是否允许注册观察处理器。
allow_abort: 是否允许处理器中止当前 Hook 调用。
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
"""
name: str
description: str = ""
default_timeout_ms: int = 0
allow_blocking: bool = True
allow_observe: bool = True
allow_abort: bool = True
allow_kwargs_mutation: bool = True
@dataclass(slots=True)
class HookHandlerExecutionResult:
"""单个 HookHandler 的执行结果。
Attributes:
handler_name: 完整处理器名称,格式通常为 `plugin_id.component_name`。
plugin_id: 处理器所属插件 ID。
success: 本次调用是否成功。
action: 当前处理器要求的控制动作,仅支持 `continue` 或 `abort`。
modified_kwargs: 处理器返回的修改后参数字典。
custom_result: 处理器返回的附加结果。
error_message: 失败时的错误描述。
"""
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
custom_result: Any = field(default=None)
plugin_id: str
success: bool = True
action: str = "continue"
modified_kwargs: Optional[Dict[str, Any]] = None
custom_result: Any = None
error_message: str = ""
@dataclass(slots=True)
class HookDispatchResult:
"""一次命名 Hook 调用的聚合结果。
Attributes:
hook_name: 本次调用的 Hook 名称。
kwargs: 经阻塞处理器串行处理后的最终参数字典。
aborted: 是否被某个处理器中止。
stopped_by: 若被中止,记录触发中止的完整处理器名称。
custom_results: 阻塞处理器返回的附加结果列表。
errors: 本次调用中记录到的错误信息列表。
"""
hook_name: str
kwargs: Dict[str, Any] = field(default_factory=dict)
aborted: bool = False
stopped_by: Optional[str] = None
custom_results: List[Any] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
@dataclass(slots=True)
class _HookInvocationTarget:
"""内部使用的 Hook 调度目标。
Attributes:
supervisor: 负责该处理器的 Supervisor。
entry: Hook 处理器条目。
source_rank: 插件来源权重,内置插件为 `0`,第三方插件为 `1`。
"""
supervisor: "PluginRunnerSupervisor"
entry: "HookHandlerEntry"
source_rank: int
class HookDispatcher:
"""Host-side Hook 分发器
"""命名 Hook 分发器"""
由业务层调用 hook_dispatch()
内部通过 ComponentRegistry 查询 handler
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, component_registry: "ComponentRegistry") -> None:
"""初始化 HookDispatcher
def __init__(
self,
supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
) -> None:
"""初始化 Hook 分发器。
Args:
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
"""
self._component_registry: "ComponentRegistry" = component_registry
self._background_tasks: Set[asyncio.Task] = set()
self._background_tasks: Set[asyncio.Task[Any]] = set()
self._hook_specs: Dict[str, HookSpec] = {}
self._supervisors_provider = supervisors_provider
async def stop(self) -> None:
"""停止 HookDispatcher取消所有未完成的后台任务"""
"""停止分发器并取消所有未完成的观察任务"""
for task in self._background_tasks:
task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
async def hook_dispatch(
self,
stage: str,
supervisor: "PluginRunnerSupervisor",
**kwargs: Any,
) -> Dict[str, Any]:
"""分发 hook 到所有对应 handler 的便捷方法。
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
def register_hook_spec(self, spec: HookSpec) -> None:
"""注册单个命名 Hook 规格。
Args:
stage: hook 名称
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
**kwargs: 关键字参数,会展开传递给 handler
spec: 需要注册的 Hook 规格。
"""
normalized_name = self._normalize_hook_name(spec.name)
self._hook_specs[normalized_name] = HookSpec(
name=normalized_name,
description=spec.description,
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
allow_blocking=bool(spec.allow_blocking),
allow_observe=bool(spec.allow_observe),
allow_abort=bool(spec.allow_abort),
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
"""批量注册命名 Hook 规格。
Args:
specs: 需要注册的 Hook 规格序列。
"""
for spec in specs:
self.register_hook_spec(spec)
def get_hook_spec(self, hook_name: str) -> HookSpec:
"""获取指定 Hook 的规格定义。
Args:
hook_name: Hook 名称。
Returns:
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
HookSpec: 若未显式注册,则返回按系统默认值生成的运行时规格。
"""
handler_entries = self._component_registry.get_hook_handlers(stage)
if not handler_entries:
return kwargs
current_kwargs = kwargs.copy()
blocking_handlers: List["HookHandlerEntry"] = []
non_blocking_handlers: List["HookHandlerEntry"] = []
normalized_name = self._normalize_hook_name(hook_name)
if normalized_name in self._hook_specs:
return self._hook_specs[normalized_name]
# 分离 blocking 和非 blocking handler
for entry in handler_entries:
if entry.blocking:
blocking_handlers.append(entry)
else:
non_blocking_handlers.append(entry)
return HookSpec(
name=normalized_name,
default_timeout_ms=self._get_default_timeout_ms(),
)
# 处理 blocking handlers同步调用支持修改参数/提前终止)
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
for entry in blocking_handlers:
hook_args = {"stage": stage, **current_kwargs}
try:
# 应用超时控制
result = await asyncio.wait_for(
self._invoke_handler(supervisor, entry, hook_args),
timeout=timeout,
async def invoke_hook(
self,
hook_name: str,
supervisors: Optional[Sequence["PluginRunnerSupervisor"]] = None,
**kwargs: Any,
) -> HookDispatchResult:
"""触发一次命名 Hook 调用。
Args:
hook_name: 本次触发的 Hook 名称。
supervisors: 当前运行时中所有可参与分发的 Supervisor留空时使用绑定的提供器。
**kwargs: 传递给 Hook 处理器的关键字参数。
Returns:
HookDispatchResult: 聚合后的 Hook 调用结果。
"""
resolved_supervisors = list(supervisors) if supervisors is not None else list(self._resolve_supervisors())
normalized_hook_name = self._normalize_hook_name(hook_name)
hook_spec = self.get_hook_spec(normalized_hook_name)
current_kwargs: Dict[str, Any] = dict(kwargs)
dispatch_result = HookDispatchResult(hook_name=normalized_hook_name, kwargs=dict(current_kwargs))
invocation_targets = self._collect_invocation_targets(normalized_hook_name, resolved_supervisors)
if not invocation_targets:
return dispatch_result
for target in invocation_targets:
if target.entry.is_observe:
self._schedule_observe_handler(
hook_name=normalized_hook_name,
hook_spec=hook_spec,
target=target,
kwargs=current_kwargs,
)
except asyncio.TimeoutError:
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
continue
if result:
if result.modified_kwargs is not None:
current_kwargs = result.modified_kwargs
if not result.continue_processing:
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
break
if not hook_spec.allow_blocking:
error_message = (
f"Hook {normalized_hook_name} 不允许 blocking 处理器,"
f"已跳过 {target.entry.full_name}"
)
logger.warning(error_message)
dispatch_result.errors.append(error_message)
continue
# 处理 non-blocking handlers异步调用不阻塞主流程
for entry in non_blocking_handlers:
async_kwargs = current_kwargs.copy()
hook_args = {"stage": stage, **async_kwargs}
task = asyncio.create_task(
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
execution_result = await self._invoke_handler(
hook_name=normalized_hook_name,
hook_spec=hook_spec,
target=target,
kwargs=current_kwargs,
)
self._merge_blocking_result(
hook_spec=hook_spec,
target=target,
execution_result=execution_result,
dispatch_result=dispatch_result,
)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return current_kwargs
current_kwargs = dict(dispatch_result.kwargs)
if dispatch_result.aborted:
break
return dispatch_result
def _resolve_supervisors(self) -> Sequence["PluginRunnerSupervisor"]:
"""解析当前调用应使用的 Supervisor 列表。
Returns:
Sequence[PluginRunnerSupervisor]: 可参与本次 Hook 调度的 Supervisor 序列。
Raises:
ValueError: 当未传入 `supervisors` 且分发器也未绑定提供器时抛出。
"""
if self._supervisors_provider is None:
raise ValueError("当前 HookDispatcher 未绑定 supervisors_provider请显式传入 supervisors")
return self._supervisors_provider()
def _collect_invocation_targets(
self,
hook_name: str,
supervisors: Sequence["PluginRunnerSupervisor"],
) -> List[_HookInvocationTarget]:
"""收集并排序本次 Hook 调用的全部处理器目标。
Args:
hook_name: 目标 Hook 名称。
supervisors: 当前参与调度的 Supervisor 序列。
Returns:
List[_HookInvocationTarget]: 已完成全局排序的处理器目标列表。
"""
invocation_targets: List[_HookInvocationTarget] = []
for supervisor in supervisors:
source_rank = self._get_supervisor_source_rank(supervisor)
for entry in supervisor.component_registry.get_hook_handlers(hook_name):
invocation_targets.append(
_HookInvocationTarget(
supervisor=supervisor,
entry=entry,
source_rank=source_rank,
)
)
invocation_targets.sort(key=self._build_sort_key)
return invocation_targets
@staticmethod
def _build_sort_key(target: _HookInvocationTarget) -> tuple[int, int, int, str, str]:
"""构造 Hook 处理器的全局排序键。
Args:
target: 待排序的处理器目标。
Returns:
tuple[int, int, int, str, str]: 全局排序键。
"""
return (
HookDispatcher._get_mode_rank(target.entry.mode),
HookDispatcher._get_order_rank(target.entry.order),
target.source_rank,
target.entry.plugin_id,
target.entry.name,
)
@staticmethod
def _get_default_timeout_ms() -> int:
"""读取系统级默认 Hook 超时。
Returns:
int: 默认超时毫秒数。
"""
timeout_seconds = float(global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0)
return max(int(timeout_seconds * 1000), 1)
@staticmethod
def _get_mode_rank(mode: str) -> int:
"""返回 Hook 模式的排序权重。
Args:
mode: Hook 模式。
Returns:
int: 越小表示越靠前。
"""
return {"blocking": 0, "observe": 1}.get(mode, 99)
@staticmethod
def _get_order_rank(order: str) -> int:
"""返回 Hook 顺序槽位的排序权重。
Args:
order: Hook 顺序槽位。
Returns:
int: 越小表示越靠前。
"""
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
@staticmethod
def _get_supervisor_source_rank(supervisor: "PluginRunnerSupervisor") -> int:
"""返回 Supervisor 的来源排序权重。
Args:
supervisor: 目标 Supervisor。
Returns:
int: 内置插件返回 `0`,第三方插件返回 `1`。
"""
return 0 if supervisor.group_name == "builtin" else 1
@staticmethod
def _normalize_hook_name(hook_name: str) -> str:
"""规范化命名 Hook 名称。
Args:
hook_name: 原始 Hook 名称。
Returns:
str: 规范化后的 Hook 名称。
Raises:
ValueError: 当 Hook 名称为空时抛出。
"""
normalized_name = str(hook_name or "").strip()
if not normalized_name:
raise ValueError("Hook 名称不能为空")
return normalized_name
def _resolve_timeout_ms(self, hook_spec: HookSpec, target: _HookInvocationTarget) -> int:
"""计算单个处理器的实际超时。
Args:
hook_spec: 当前 Hook 的规格定义。
target: 当前执行目标。
Returns:
int: 最终生效的超时毫秒数。
"""
if target.entry.timeout_ms > 0:
return target.entry.timeout_ms
if hook_spec.default_timeout_ms > 0:
return hook_spec.default_timeout_ms
return self._get_default_timeout_ms()
async def _invoke_handler(
self,
supervisor: "PluginRunnerSupervisor",
handler_entry: "HookHandlerEntry",
args: Dict[str, Any],
) -> Optional[HookResult]:
"""调用单个 handler 并收集结果。
hook_name: str,
hook_spec: HookSpec,
target: _HookInvocationTarget,
kwargs: Dict[str, Any],
) -> HookHandlerExecutionResult:
"""执行单个 Hook 处理器。
Args:
supervisor: PluginRunnerSupervisor 实例
handler_entry: HookHandlerEntry 实例
args: 传递给 handler 的参数字典
stage: hook 名称
hook_name: 当前 Hook 名称。
hook_spec: 当前 Hook 规格。
target: 当前执行目标。
kwargs: 当前参数字典。
Returns:
Optional[HookResult]: 执行结果,如果执行失败则返回 None
HookHandlerExecutionResult: 处理器执行结果
"""
try:
resp_envelope = await supervisor.invoke_plugin(
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = HookResult(
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_kwargs=resp.get("modified_kwargs"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
return result
timeout_ms = self._resolve_timeout_ms(hook_spec, target)
request_args: Dict[str, Any] = {"hook_name": hook_name, **dict(kwargs)}
try:
response_envelope = await asyncio.wait_for(
target.supervisor.invoke_plugin(
"plugin.invoke_hook",
target.entry.plugin_id,
target.entry.name,
request_args,
timeout_ms=timeout_ms,
),
timeout=max(timeout_ms / 1000.0, 0.001),
)
except asyncio.TimeoutError:
error_message = (
f"HookHandler {target.entry.full_name} 执行超时,已超过 {timeout_ms}ms"
)
logger.error(error_message)
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
success=False,
error_message=error_message,
)
except Exception as exc:
error_message = f"HookHandler {target.entry.full_name} 执行失败: {exc}"
logger.error(error_message, exc_info=True)
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
success=False,
error_message=error_message,
)
response_payload = response_envelope.payload
if not isinstance(response_payload, dict):
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
custom_result=response_payload,
)
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
success=bool(response_payload.get("success", True)),
action=self._normalize_action(response_payload.get("action", "continue")),
modified_kwargs=self._extract_modified_kwargs(response_payload.get("modified_kwargs")),
custom_result=response_payload.get("custom_result"),
error_message=str(response_payload.get("error_message", "") or ""),
)
@staticmethod
def _extract_modified_kwargs(raw_value: Any) -> Optional[Dict[str, Any]]:
"""提取并校验处理器返回的 `modified_kwargs`。
Args:
raw_value: 原始返回值。
Returns:
Optional[Dict[str, Any]]: 合法时返回字典,否则返回 `None`。
"""
if raw_value is None:
return None
if isinstance(raw_value, dict):
return dict(raw_value)
logger.warning("HookHandler 返回的 modified_kwargs 不是字典,已忽略")
return None
@staticmethod
def _normalize_action(raw_value: Any) -> str:
"""规范化处理器动作返回值。
Args:
raw_value: 原始动作值。
Returns:
str: 规范化后的动作值,仅支持 `continue` 或 `abort`。
"""
normalized_value = str(raw_value or "").strip().lower() or "continue"
if normalized_value not in {"continue", "abort"}:
logger.warning(f"未知的 Hook action: {raw_value},已按 continue 处理")
return "continue"
return normalized_value
def _merge_blocking_result(
self,
hook_spec: HookSpec,
target: _HookInvocationTarget,
execution_result: HookHandlerExecutionResult,
dispatch_result: HookDispatchResult,
) -> None:
"""合并阻塞处理器结果到聚合结果。
Args:
hook_spec: 当前 Hook 规格。
target: 当前执行目标。
execution_result: 当前处理器执行结果。
dispatch_result: 当前聚合结果对象。
"""
if execution_result.custom_result is not None:
dispatch_result.custom_results.append(execution_result.custom_result)
if not execution_result.success:
error_message = execution_result.error_message or f"HookHandler {target.entry.full_name} 执行失败"
dispatch_result.errors.append(error_message)
self._apply_error_policy(target, hook_spec, dispatch_result, error_message)
return
if execution_result.modified_kwargs is not None:
if hook_spec.allow_kwargs_mutation:
dispatch_result.kwargs = dict(execution_result.modified_kwargs)
else:
error_message = (
f"Hook {dispatch_result.hook_name} 不允许修改 kwargs"
f"已忽略 {target.entry.full_name} 的 modified_kwargs"
)
logger.warning(error_message)
dispatch_result.errors.append(error_message)
if execution_result.action == "abort":
if hook_spec.allow_abort:
dispatch_result.aborted = True
dispatch_result.stopped_by = target.entry.full_name
logger.info(f"HookHandler {target.entry.full_name} 中止了 Hook {dispatch_result.hook_name}")
else:
error_message = (
f"Hook {dispatch_result.hook_name} 不允许 abort"
f"已忽略 {target.entry.full_name} 的 abort 请求"
)
logger.warning(error_message)
dispatch_result.errors.append(error_message)
def _apply_error_policy(
self,
target: _HookInvocationTarget,
hook_spec: HookSpec,
dispatch_result: HookDispatchResult,
error_message: str,
) -> None:
"""根据错误策略处理阻塞处理器失败。
Args:
target: 触发错误的处理器目标。
hook_spec: 当前 Hook 规格。
dispatch_result: 当前聚合结果对象。
error_message: 需要记录的错误描述。
"""
if target.entry.error_policy != "abort":
return
if not hook_spec.allow_abort:
logger.warning(
"Hook %s 禁止 abort已将 %s 的错误策略按 skip 处理",
dispatch_result.hook_name,
target.entry.full_name,
)
return
dispatch_result.aborted = True
dispatch_result.stopped_by = target.entry.full_name
logger.warning(
"HookHandler %s 因错误策略 abort 中止了 Hook %s: %s",
target.entry.full_name,
dispatch_result.hook_name,
error_message,
)
def _schedule_observe_handler(
self,
hook_name: str,
hook_spec: HookSpec,
target: _HookInvocationTarget,
kwargs: Dict[str, Any],
) -> None:
"""后台调度观察型处理器。
Args:
hook_name: 当前 Hook 名称。
hook_spec: 当前 Hook 规格。
target: 当前观察型处理器目标。
kwargs: 调用参数快照。
"""
if not hook_spec.allow_observe:
logger.warning("Hook %s 不允许 observe 处理器,已跳过 %s", hook_name, target.entry.full_name)
return
task = asyncio.create_task(
self._run_observe_handler(
hook_name=hook_name,
hook_spec=hook_spec,
target=target,
kwargs=dict(kwargs),
)
)
self._background_tasks.add(task)
task.add_done_callback(self._handle_background_task_done)
async def _run_observe_handler(
self,
hook_name: str,
hook_spec: HookSpec,
target: _HookInvocationTarget,
kwargs: Dict[str, Any],
) -> None:
"""执行观察型处理器并吞掉控制流副作用。
Args:
hook_name: 当前 Hook 名称。
hook_spec: 当前 Hook 规格。
target: 当前观察型处理器目标。
kwargs: 调用参数快照。
"""
execution_result = await self._invoke_handler(
hook_name=hook_name,
hook_spec=hook_spec,
target=target,
kwargs=kwargs,
)
if not execution_result.success:
logger.warning(
"观察型 HookHandler %s 执行失败: %s",
target.entry.full_name,
execution_result.error_message or "未知错误",
)
return
if execution_result.modified_kwargs is not None:
logger.warning(
"观察型 HookHandler %s 返回了 modified_kwargs已忽略", target.entry.full_name
)
if execution_result.action == "abort":
logger.warning(
"观察型 HookHandler %s 请求 abort已忽略", target.entry.full_name
)
def _handle_background_task_done(self, task: asyncio.Task[Any]) -> None:
"""处理观察任务完成回调。
Args:
task: 已完成的后台任务。
"""
self._background_tasks.discard(task)
with contextlib.suppress(asyncio.CancelledError):
exception = task.exception()
if exception is not None:
logger.error(f"观察型 Hook 后台任务执行失败: {exception}")

View File

@@ -49,7 +49,7 @@ from .api_registry import APIRegistry
from .capability_service import CapabilityService
from .component_registry import ComponentRegistry
from .event_dispatcher import EventDispatcher
from .hook_dispatcher import HookDispatcher
from .hook_dispatcher import HookDispatchResult, HookDispatcher
from .logger_bridge import RunnerLogBridge
from .message_gateway import MessageGateway
from .rpc_server import RPCServer
@@ -80,6 +80,7 @@ class PluginRunnerSupervisor:
def __init__(
self,
plugin_dirs: Optional[List[Path]] = None,
group_name: str = "third_party",
socket_path: Optional[str] = None,
health_check_interval_sec: Optional[float] = None,
max_restart_attempts: Optional[int] = None,
@@ -89,12 +90,14 @@ class PluginRunnerSupervisor:
Args:
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
group_name: 当前 Supervisor 所属运行时分组名称。
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
health_check_interval_sec: 健康检查间隔,单位秒。
max_restart_attempts: 自动重启 Runner 的最大次数。
runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。
"""
runtime_config = global_config.plugin_runtime
self._group_name: str = str(group_name or "third_party").strip() or "third_party"
self._plugin_dirs: List[Path] = plugin_dirs or []
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
self._runner_spawn_timeout: float = (
@@ -108,7 +111,7 @@ class PluginRunnerSupervisor:
self._api_registry = APIRegistry()
self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(lambda: [self])
self._message_gateway = MessageGateway(self._component_registry)
self._log_bridge = RunnerLogBridge()
@@ -133,6 +136,12 @@ class PluginRunnerSupervisor:
"""返回授权管理器。"""
return self._authorization
@property
def group_name(self) -> str:
"""返回当前 Supervisor 的运行时分组名称。"""
return self._group_name
@property
def capability_service(self) -> CapabilityService:
"""返回能力服务。"""
@@ -243,17 +252,18 @@ class PluginRunnerSupervisor:
"""
return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]:
"""分发 Hook 到已注册的 Hook 处理器
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult:
"""在当前 Supervisor 内触发一次命名 Hook 调用
Args:
stage: Hook 阶段名称。
**kwargs: 传递给 Hook 的关键字参数。
hook_name: 本次触发的 Hook 名称。
**kwargs: 传递给 Hook 处理器的关键字参数。
Returns:
Dict[str, Any]: 经 Hook 修改后的参数字典
HookDispatchResult: 聚合后的 Hook 调用结果
"""
return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs)
async def send_message_to_external(
self,

View File

@@ -3,8 +3,9 @@
提供 PluginRuntimeManager 单例,负责:
1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程)
2. 将 EventType 桥接到运行时的 event dispatch
3. 在运行时的 ComponentRegistry 中查找命令
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
3. 触发跨 Supervisor 的命名 Hook 调用
4. 在运行时的 ComponentRegistry 中查找命令
5. 提供统一的能力实现注册接口,使插件可以调用主程序功能
"""
from pathlib import Path
@@ -24,6 +25,7 @@ from src.plugin_runtime.capabilities import (
RuntimeDataCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
@@ -72,6 +74,7 @@ class PluginRuntimeManager(
self._manifest_validator: ManifestValidator = ManifestValidator()
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
@@ -182,6 +185,7 @@ class PluginRuntimeManager(
if builtin_dirs:
self._builtin_supervisor = PluginSupervisor(
plugin_dirs=builtin_dirs,
group_name="builtin",
socket_path=builtin_socket,
)
self._register_capability_impls(self._builtin_supervisor)
@@ -189,6 +193,7 @@ class PluginRuntimeManager(
if third_party_dirs:
self._third_party_supervisor = PluginSupervisor(
plugin_dirs=third_party_dirs,
group_name="third_party",
socket_path=third_party_socket,
)
self._register_capability_impls(self._third_party_supervisor)
@@ -235,6 +240,7 @@ class PluginRuntimeManager(
await platform_io_manager.stop()
except Exception as platform_io_exc:
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
await self._hook_dispatcher.stop()
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
@@ -274,6 +280,7 @@ class PluginRuntimeManager(
else:
logger.info("插件运行时已停止")
finally:
await self._hook_dispatcher.stop()
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
@@ -284,11 +291,41 @@ class PluginRuntimeManager(
"""返回插件运行时是否处于启动状态。"""
return self._started
@property
def hook_dispatcher(self) -> HookDispatcher:
"""返回跨 Supervisor 的命名 Hook 分发器。"""
return self._hook_dispatcher
@property
def invoke_dispatcher(self) -> HookDispatcher:
"""返回命名 Hook 分发器的兼容别名。"""
return self._hook_dispatcher
@property
def supervisors(self) -> List["PluginSupervisor"]:
"""获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
def register_hook_spec(self, spec: HookSpec) -> None:
"""注册单个命名 Hook 规格。
Args:
spec: 需要注册的 Hook 规格。
"""
self._hook_dispatcher.register_hook_spec(spec)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
"""批量注册命名 Hook 规格。
Args:
specs: 需要注册的 Hook 规格序列。
"""
self._hook_dispatcher.register_hook_specs(specs)
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。"""
@@ -588,6 +625,19 @@ class PluginRuntimeManager(
return True, modified
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult:
"""触发一次跨 Supervisor 的命名 Hook 调用。
Args:
hook_name: 本次触发的 Hook 名称。
**kwargs: 传递给 Hook 处理器的关键字参数。
Returns:
HookDispatchResult: 聚合后的 Hook 调用结果。
"""
return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs)
# ─── 命令查找 ──────────────────────────────────────────────
def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]:

View File

@@ -330,7 +330,6 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
@@ -1053,73 +1052,28 @@ class PluginRunner:
)
except Exception as exc:
logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
return envelope.make_response(payload={"success": False, "continue_processing": True})
return envelope.make_response(
payload={
"success": False,
"action": "continue",
"error_message": str(exc),
}
)
if raw is None:
result = {"success": True, "continue_processing": True}
result = {"success": True, "action": "continue"}
elif isinstance(raw, dict):
result = {
"success": True,
"continue_processing": raw.get("continue_processing", True),
"action": str(raw.get("action", "continue") or "continue").strip().lower() or "continue",
"modified_kwargs": raw.get("modified_kwargs"),
"custom_result": raw.get("custom_result"),
}
else:
result = {"success": True, "continue_processing": True, "custom_result": raw}
result = {"success": True, "action": "continue", "custom_result": raw}
return envelope.make_response(payload=result)
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
与通用 invoke 不同,会将返回值规范化为
{hook_result, modified_message, stage_output} 格式。
"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
component_name = invoke.component_name
handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
# 规范化返回值
if isinstance(raw, str):
result = {"hook_result": raw}
elif isinstance(raw, dict):
result = raw
result.setdefault("hook_result", "continue")
else:
result = {"hook_result": "continue"}
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_health(self, envelope: Envelope) -> Envelope:
"""处理健康检查"""
uptime_ms = int((time.monotonic() - self._start_time) * 1000)