Merge branch 'r-dev' of https://github.com/Mai-with-u/MaiBot into r-dev

This commit is contained in:
SengokuCola
2026-03-24 20:58:16 +08:00
6 changed files with 1247 additions and 523 deletions

View File

@@ -5,6 +5,7 @@
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Awaitable, Callable, Dict, List, Optional
import asyncio import asyncio
import json import json
@@ -1831,395 +1832,445 @@ class TestMaiMessages:
assert msg.llm_response_content == "new response" assert msg.llm_response_content == "new response"
# ─── WorkflowExecutor 测试 ──────────────────────────────── class _FakeHookSupervisor:
"""用于 Hook 分发测试的简化 Supervisor。"""
def __init__(
self,
group_name: str,
component_registry: Any,
handlers: Dict[str, Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]] | Dict[str, Any]]],
call_log: List[tuple[str, str]],
) -> None:
"""初始化测试用 Supervisor。
Args:
group_name: 运行时分组名称。
component_registry: 组件注册表实例。
handlers: 处理器映射,键为 `plugin_id.component_name`。
call_log: 记录调用顺序的列表。
"""
self._group_name = group_name
self.component_registry = component_registry
self._handlers = handlers
self._call_log = call_log
@property
def group_name(self) -> str:
"""返回当前测试 Supervisor 的分组名称。"""
return self._group_name
async def invoke_plugin(
self,
method: str,
plugin_id: str,
component_name: str,
args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> SimpleNamespace:
"""模拟调用插件组件。
Args:
method: RPC 方法名。
plugin_id: 目标插件 ID。
component_name: 目标组件名称。
args: 调用参数。
timeout_ms: 超时配置,测试中仅用于保持接口一致。
Returns:
SimpleNamespace: 仅包含 `payload` 字段的简化响应对象。
"""
del method
del timeout_ms
full_name = f"{plugin_id}.{component_name}"
handler = self._handlers[full_name]
self._call_log.append((plugin_id, component_name))
result = handler(dict(args or {}))
if asyncio.iscoroutine(result):
result = await result
return SimpleNamespace(payload=result)
class TestWorkflowExecutor: # ─── HookDispatcher 测试 ────────────────────────────────
"""Host-side Workflow 执行器测试(新 pipeline 模型)"""
class TestHookDispatcher:
"""命名 Hook 分发器测试。"""
@staticmethod
def _import_dispatcher_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
"""导入 Hook 分发相关模块,并屏蔽配置初始化触发的退出。
Args:
monkeypatch: pytest 的 monkeypatch 工具。
Returns:
tuple[Any, Any]: `ComponentRegistry` 与 `HookDispatcher` 类型。
"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.hook_dispatcher import HookDispatcher
return ComponentRegistry, HookDispatcher
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_pipeline_completes(self): async def test_empty_hook_returns_original_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""无任何 workflow_step 注册时pipeline 全阶段跳过,状态 completed""" """未注册处理器时应直接返回原始参数。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): dispatcher = HookDispatcher()
return {"hook_result": "continue"} supervisor = _FakeHookSupervisor("builtin", ComponentRegistry(), {}, [])
result, final_msg, ctx = await executor.execute( result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
mock_invoke,
message={"plain_text": "test"}, assert result.hook_name == "heart_fc.cycle_start"
) assert result.kwargs == {"session_id": "s-1"}
assert result.status == "completed" assert result.aborted is False
assert result.return_message == "workflow completed"
assert len(ctx.timings) == 6 # 6 stages
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_blocking_hook_modifies_message(self): async def test_blocking_hook_modifies_kwargs(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""blocking hook 可以修改消息""" """blocking 处理器可以修改参数。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
reg.register_component(
registry = ComponentRegistry()
registry.register_component(
"upper", "upper",
"workflow_step", "HOOK_HANDLER",
"p1", "p1",
{ {
"stage": "pre_process", "hook": "heart_fc.cycle_start",
"priority": 10, "mode": "blocking",
"blocking": True, "order": "normal",
}, },
) )
executor = WorkflowExecutor(reg) dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
async def mock_invoke(plugin_id, comp_name, args): "builtin",
msg = args.get("message", {}) registry,
return { {
"hook_result": "continue", "p1.upper": lambda args: {
"modified_message": {**msg, "plain_text": msg.get("plain_text", "").upper()}, "success": True,
} "action": "continue",
"modified_kwargs": {
result, final_msg, ctx = await executor.execute( "session_id": args["session_id"],
mock_invoke, "text": str(args["text"]).upper(),
message={"plain_text": "hello"}, },
}
},
[],
) )
assert result.status == "completed"
assert final_msg["plain_text"] == "HELLO" result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1", text="hello")
assert len(ctx.modification_log) == 1
assert ctx.modification_log[0].stage == "pre_process" assert result.kwargs["session_id"] == "s-1"
assert result.kwargs["text"] == "HELLO"
assert result.aborted is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_abort_stops_pipeline(self): async def test_abort_stops_following_blocking_handlers(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""HookResult.ABORT 立即终止 pipeline""" """blocking 处理器的 abort 应阻止后续 blocking 处理器执行。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
reg.register_component(
"blocker", registry = ComponentRegistry()
"workflow_step", registry.register_component(
"stopper",
"HOOK_HANDLER",
"p1", "p1",
{ {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
"stage": "pre_process",
"priority": 10,
"blocking": True,
},
) )
executor = WorkflowExecutor(reg) registry.register_component(
"after_stop",
async def mock_invoke(plugin_id, comp_name, args): "HOOK_HANDLER",
return {"hook_result": "abort"}
result, _, ctx = await executor.execute(
mock_invoke,
message={"plain_text": "test"},
)
assert result.status == "aborted"
assert result.stopped_at == "pre_process"
@pytest.mark.asyncio
async def test_skip_stage(self):
"""HookResult.SKIP_STAGE 跳过当前阶段剩余 hook"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
# high-priority hook 返回 skip_stage
reg.register_component(
"skipper",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
},
)
# low-priority hook 不应被执行
reg.register_component(
"checker",
"workflow_step",
"p2", "p2",
{ {"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
"stage": "ingress", )
"priority": 1, call_log: List[tuple[str, str]] = []
"blocking": True, dispatcher = HookDispatcher()
}, supervisor = _FakeHookSupervisor(
"builtin",
registry,
{
"p1.stopper": lambda args: {"success": True, "action": "abort"},
"p2.after_stop": lambda args: {"success": True, "action": "continue"},
},
call_log,
) )
executor = WorkflowExecutor(reg)
call_log = [] result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], cycle_id="c-1")
async def mock_invoke(plugin_id, comp_name, args): assert result.aborted is True
call_log.append(comp_name) assert result.stopped_by == "p1.stopper"
if comp_name == "skipper": assert call_log == [("p1", "stopper")]
return {"hook_result": "skip_stage"}
return {"hook_result": "continue"}
result, _, _ = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
# 只有 skipper 被调用checker 被跳过
assert call_log == ["skipper"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pre_filter(self): async def test_observe_handler_runs_in_background_without_mutation(
"""filter 条件不匹配时跳过 hook""" self,
from src.plugin_runtime.host.component_registry import ComponentRegistry monkeypatch: pytest.MonkeyPatch,
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor ) -> None:
"""observe 处理器应后台执行且不能影响主流程参数。"""
reg = ComponentRegistry() ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
reg.register_component(
"only_dm",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
"filter": {"chat_type": "direct"},
},
)
executor = WorkflowExecutor(reg)
call_log = [] registry = ComponentRegistry()
registry.register_component(
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue"}
# 不匹配 filter —— hook 不应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "group"})
assert not call_log
# 匹配 filter —— hook 应被调用
await executor.execute(mock_invoke, message={"plain_text": "hi", "chat_type": "direct"})
assert call_log == ["only_dm"]
@pytest.mark.asyncio
async def test_error_policy_skip(self):
"""error_policy=skip 时跳过失败的 hook 继续执行"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 100,
"blocking": True,
"error_policy": "skip",
},
)
reg.register_component(
"ok_step",
"workflow_step",
"p2",
{
"stage": "ingress",
"priority": 1,
"blocking": True,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
if comp_name == "failer":
raise RuntimeError("boom")
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "completed"
assert "failer" in call_log
assert "ok_step" in call_log
assert any("boom" in e for e in ctx.errors)
@pytest.mark.asyncio
async def test_error_policy_abort(self):
"""error_policy=abort默认时 pipeline 失败"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"failer",
"workflow_step",
"p1",
{
"stage": "ingress",
"priority": 10,
"blocking": True,
# error_policy defaults to "abort"
},
)
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args):
raise RuntimeError("fatal")
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "test"})
assert result.status == "failed"
assert result.stopped_at == "ingress"
@pytest.mark.asyncio
async def test_nonblocking_hooks_concurrent(self):
"""non-blocking hook 并发执行,不修改消息"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
for i in range(3):
reg.register_component(
f"nb_{i}",
"workflow_step",
f"p{i}",
{
"stage": "post_process",
"priority": 0,
"blocking": False,
},
)
executor = WorkflowExecutor(reg)
call_log = []
async def mock_invoke(plugin_id, comp_name, args):
call_log.append(comp_name)
return {"hook_result": "continue", "modified_message": {"plain_text": "ignored"}}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"})
# non-blocking 的 modified_message 被忽略
assert final_msg["plain_text"] == "original"
# 给异步 task 时间完成
await asyncio.sleep(0.1)
assert result.status == "completed"
@pytest.mark.asyncio
async def test_nonblocking_tasks_are_retained_until_completion(self):
"""execute 返回后non-blocking task 仍应保持强引用直到执行完成。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry()
reg.register_component(
"observer", "observer",
"workflow_step", "HOOK_HANDLER",
"p1", "p1",
{ {"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
"stage": "post_process",
"priority": 0,
"blocking": False,
},
) )
executor = WorkflowExecutor(reg)
started = asyncio.Event() started = asyncio.Event()
release = asyncio.Event() release = asyncio.Event()
call_log: List[tuple[str, str]] = []
async def observe_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"""模拟耗时观察型处理器。"""
async def mock_invoke(plugin_id, comp_name, args):
started.set() started.set()
await release.wait() await release.wait()
return {"hook_result": "continue"} return {
"success": True,
"action": "abort",
"modified_kwargs": {"session_id": "changed"},
"custom_result": args["session_id"],
}
result, final_msg, _ = await executor.execute(mock_invoke, message={"plain_text": "original"}) dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor(
"builtin",
registry,
{"p1.observer": observe_handler},
call_log,
)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
await asyncio.sleep(0) await asyncio.sleep(0)
assert result.status == "completed" assert result.aborted is False
assert final_msg["plain_text"] == "original" assert result.kwargs["session_id"] == "s-1"
assert started.is_set() assert started.is_set()
assert len(executor._background_tasks) == 1 assert len(dispatcher._background_tasks) == 1
release.set() release.set()
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
assert not executor._background_tasks assert call_log == [("p1", "observer")]
assert not dispatcher._background_tasks
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_command_routing(self): async def test_global_order_prefers_order_slot_then_source(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""PLAN 阶段内置命令路由""" """全局排序应先看 order再看内置/第三方来源。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
reg.register_component(
"help", builtin_registry = ComponentRegistry()
"command", third_registry = ComponentRegistry()
"p1", builtin_registry.register_component(
{ "builtin_early",
"command_pattern": r"^/help", "HOOK_HANDLER",
}, "b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
builtin_registry.register_component(
"builtin_normal",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
)
third_registry.register_component(
"third_early",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
third_registry.register_component(
"third_normal",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "normal"},
) )
executor = WorkflowExecutor(reg)
async def mock_invoke(plugin_id, comp_name, args): call_log: List[tuple[str, str]] = []
if comp_name == "help": dispatcher = HookDispatcher()
return {"output": "帮助信息"} builtin_supervisor = _FakeHookSupervisor(
return {"hook_result": "continue"} "builtin",
builtin_registry,
{
"b1.builtin_early": lambda args: {"success": True, "action": "continue"},
"b1.builtin_normal": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
third_supervisor = _FakeHookSupervisor(
"third_party",
third_registry,
{
"t1.third_early": lambda args: {"success": True, "action": "continue"},
"t1.third_normal": lambda args: {"success": True, "action": "continue"},
},
call_log,
)
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "/help topic"}) await dispatcher.invoke_hook(
assert result.status == "completed" "heart_fc.cycle_start",
assert ctx.matched_command == "p1.help" [third_supervisor, builtin_supervisor],
cmd_result = ctx.get_stage_output("plan", "command_result") cycle_id="c-1",
assert cmd_result is not None )
assert cmd_result["output"] == "帮助信息"
assert call_log == [
("b1", "builtin_early"),
("t1", "third_early"),
("b1", "builtin_normal"),
("t1", "third_normal"),
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stage_outputs(self): async def test_error_policy_abort_stops_dispatch(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""stage_outputs 数据在阶段间传递""" """error_policy=abort 时应中止本次 Hook 调用。"""
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor
reg = ComponentRegistry() ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
# ingress 阶段写入数据
reg.register_component( registry = ComponentRegistry()
"writer", registry.register_component(
"workflow_step", "failer",
"HOOK_HANDLER",
"p1", "p1",
{ {
"stage": "ingress", "hook": "heart_fc.cycle_start",
"priority": 10, "mode": "blocking",
"blocking": True, "order": "normal",
"error_policy": "abort",
}, },
) )
# pre_process 阶段读取数据 call_log: List[tuple[str, str]] = []
reg.register_component(
"reader", async def fail_handler(args: Dict[str, Any]) -> Dict[str, Any]:
"workflow_step", """抛出异常以触发 abort 策略。"""
"p2",
del args
raise RuntimeError("boom")
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.failer": fail_handler}, call_log)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.aborted is True
assert result.stopped_by == "p1.failer"
assert any("boom" in error for error in result.errors)
assert call_log == [("p1", "failer")]
@pytest.mark.asyncio
async def test_timeout_respects_handler_timeout_ms(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""处理器超时应被记录为错误并继续。"""
ComponentRegistry, HookDispatcher = self._import_dispatcher_modules(monkeypatch)
registry = ComponentRegistry()
registry.register_component(
"slow",
"HOOK_HANDLER",
"p1",
{ {
"stage": "pre_process", "hook": "heart_fc.cycle_start",
"priority": 10, "mode": "blocking",
"blocking": True, "order": "normal",
"timeout_ms": 10,
}, },
) )
executor = WorkflowExecutor(reg) call_log: List[tuple[str, str]] = []
async def mock_invoke(plugin_id, comp_name, args): async def slow_handler(args: Dict[str, Any]) -> Dict[str, Any]:
if comp_name == "writer": """模拟超时处理器。"""
return {
"hook_result": "continue",
"stage_output": {"parsed_intent": "greeting"},
}
if comp_name == "reader":
# 验证 stage_outputs 被传递过来
outputs = args.get("stage_outputs", {})
ingress_data = outputs.get("ingress", {})
assert ingress_data.get("parsed_intent") == "greeting"
return {"hook_result": "continue"}
return {"hook_result": "continue"}
result, _, ctx = await executor.execute(mock_invoke, message={"plain_text": "hi"}) del args
assert result.status == "completed" await asyncio.sleep(0.05)
assert ctx.get_stage_output("ingress", "parsed_intent") == "greeting" return {"success": True, "action": "continue"}
dispatcher = HookDispatcher()
supervisor = _FakeHookSupervisor("builtin", registry, {"p1.slow": slow_handler}, call_log)
result = await dispatcher.invoke_hook("heart_fc.cycle_start", [supervisor], session_id="s-1")
assert result.aborted is False
assert any("超时" in error for error in result.errors)
assert call_log == [("p1", "slow")]
class TestPluginRuntimeHookEntry:
"""PluginRuntimeManager 命名 Hook 入口测试。"""
@staticmethod
def _import_manager_modules(monkeypatch: pytest.MonkeyPatch) -> tuple[Any, Any]:
"""导入运行时管理器相关模块,并屏蔽配置初始化触发的退出。
Args:
monkeypatch: pytest 的 monkeypatch 工具。
Returns:
tuple[Any, Any]: `ComponentRegistry` 与 `PluginRuntimeManager` 类型。
"""
monkeypatch.setattr(sys, "exit", lambda code=0: None)
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.integration import PluginRuntimeManager
return ComponentRegistry, PluginRuntimeManager
@pytest.mark.asyncio
async def test_manager_invoke_hook_dispatches_across_supervisors(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""PluginRuntimeManager.invoke_hook() 应调用全局 Hook 分发器。"""
ComponentRegistry, PluginRuntimeManager = self._import_manager_modules(monkeypatch)
builtin_registry = ComponentRegistry()
builtin_registry.register_component(
"builtin_guard",
"HOOK_HANDLER",
"b1",
{"hook": "heart_fc.cycle_start", "mode": "blocking", "order": "early"},
)
third_registry = ComponentRegistry()
third_registry.register_component(
"observer",
"HOOK_HANDLER",
"t1",
{"hook": "heart_fc.cycle_start", "mode": "observe", "order": "normal"},
)
call_log: List[tuple[str, str]] = []
manager = PluginRuntimeManager()
manager._started = True
manager._builtin_supervisor = _FakeHookSupervisor(
"builtin",
builtin_registry,
{"b1.builtin_guard": lambda args: {"success": True, "action": "continue"}},
call_log,
)
manager._third_party_supervisor = _FakeHookSupervisor(
"third_party",
third_registry,
{"t1.observer": lambda args: {"success": True, "action": "continue"}},
call_log,
)
result = await manager.invoke_dispatcher.invoke_hook("heart_fc.cycle_start", session_id="s-1")
await asyncio.sleep(0)
assert manager.invoke_dispatcher is manager.hook_dispatcher
assert result.aborted is False
assert result.kwargs["session_id"] == "s-1"
assert ("b1", "builtin_guard") in call_log
class TestRPCServer: class TestRPCServer:

View File

@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry """Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力: 对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_handler / message_gateway - 按类型注册组件action / command / tool / event_handler / hook_handler / message_gateway
- 命名空间 (plugin_id.component_name) - 命名空间 (plugin_id.component_name)
- 命令正则匹配 - 命令正则匹配
- 组件启用/禁用 - 组件启用/禁用
@@ -106,14 +106,129 @@ class EventHandlerEntry(ComponentEntry):
class HookHandlerEntry(ComponentEntry): class HookHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目""" """HookHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None: def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "") self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
self.priority: int = metadata.get("priority", 0) self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
self.blocking: bool = metadata.get("blocking", False) self.order: str = self._normalize_order(metadata.get("order", "normal"))
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
super().__init__(name, component_type, plugin_id, metadata) super().__init__(name, component_type, plugin_id, metadata)
@staticmethod
def _normalize_error_policy(raw_value: Any) -> str:
"""规范化 Hook 异常处理策略。
Args:
raw_value: 原始异常处理策略值。
Returns:
str: 规范化后的异常处理策略。
Raises:
ValueError: 当异常处理策略不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "skip"
if normalized_value not in {"abort", "skip", "log"}:
raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_hook_name(raw_value: Any) -> str:
"""规范化命名 Hook 名称。
Args:
raw_value: 原始 Hook 名称。
Returns:
str: 去空白后的 Hook 名称。
Raises:
ValueError: 当 Hook 名称为空时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
if not (normalized_value := str(normalized_source or "").strip()):
raise ValueError("HookHandler 的 hook 名称不能为空")
return normalized_value
@staticmethod
def _normalize_mode(raw_value: Any) -> str:
"""规范化 Hook 处理模式。
Args:
raw_value: 原始模式值。
Returns:
str: 规范化后的模式。
Raises:
ValueError: 当模式不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "blocking"
if normalized_value not in {"blocking", "observe"}:
raise ValueError(f"HookHandler 模式不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_order(raw_value: Any) -> str:
"""规范化 Hook 顺序槽位。
Args:
raw_value: 原始顺序值。
Returns:
str: 规范化后的顺序槽位。
Raises:
ValueError: 当顺序值不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "normal"
if normalized_value not in {"early", "normal", "late"}:
raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_timeout_ms(raw_value: Any) -> int:
"""规范化 Hook 超时配置。
Args:
raw_value: 原始超时值。
Returns:
int: 规范化后的超时毫秒数。
Raises:
ValueError: 当超时值为负数或无法转换为整数时抛出。
"""
try:
timeout_ms = int(raw_value or 0)
except (TypeError, ValueError) as exc:
raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc
if timeout_ms < 0:
raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}")
return timeout_ms
@property
def is_blocking(self) -> bool:
"""返回当前 Hook 是否为阻塞模式。"""
return self.mode == "blocking"
@property
def is_observe(self) -> bool:
"""返回当前 Hook 是否为观察模式。"""
return self.mode == "observe"
class MessageGatewayEntry(ComponentEntry): class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目""" """MessageGateway 组件条目"""
@@ -454,16 +569,17 @@ class ComponentRegistry:
return handlers return handlers
def get_hook_handlers( def get_hook_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[HookHandlerEntry]: ) -> List[HookHandlerEntry]:
"""获取特定 hook 阶段的所有步骤,按 priority 降序 """获取订阅指定命名 Hook 的全部处理器
Args: Args:
stage: hook 名称 hook_name: 目标 Hook 名称
enabled_only: 是否仅返回启用的组件 enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID若提供则考虑会话禁用状态 session_id: 可选的会话 ID若提供则考虑会话禁用状态
Returns: Returns:
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序 List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表
""" """
handlers: List[HookHandlerEntry] = [] handlers: List[HookHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values(): for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
@@ -471,11 +587,37 @@ class ComponentRegistry:
continue continue
if not isinstance(comp, HookHandlerEntry): if not isinstance(comp, HookHandlerEntry):
continue continue
if comp.stage == stage: if comp.hook == hook_name:
handlers.append(comp) handlers.append(comp)
handlers.sort(key=lambda c: c.priority, reverse=True) handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name))
return handlers return handlers
@staticmethod
def _get_hook_mode_rank(mode: str) -> int:
"""返回 Hook 模式的排序权重。
Args:
mode: Hook 模式字符串。
Returns:
int: 越小表示越靠前。
"""
return {"blocking": 0, "observe": 1}.get(mode, 99)
@staticmethod
def _get_hook_order_rank(order: str) -> int:
"""返回 Hook 顺序槽位的排序权重。
Args:
order: Hook 顺序槽位字符串。
Returns:
int: 越小表示越靠前。
"""
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
def get_message_gateway( def get_message_gateway(
self, self,
plugin_id: str, plugin_id: str,
@@ -566,8 +708,13 @@ class ComponentRegistry:
Returns: Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等 stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
""" """
stats: StatusDict = {"total": len(self._components)} # type: ignore return StatusDict(
for comp_type, type_dict in self._by_type.items(): total=len(self._components),
stats[comp_type.value.lower()] = len(type_dict) action=len(self._by_type[ComponentTypes.ACTION]),
stats["plugins"] = len(self._by_plugin) command=len(self._by_type[ComponentTypes.COMMAND]),
return stats tool=len(self._by_type[ComponentTypes.TOOL]),
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
plugins=len(self._by_plugin),
)

View File

@@ -1,166 +1,678 @@
""" """命名 Hook 分发系统。
Hook Dispatch 系统
插件可以注册自己的Hook当特定函数被调用时Hook Dispatch系统会将调用转发给插件的Hook处理函数。 主程序可以在任意执行点触发一个命名 HookHost 会收集所有订阅该 Hook 的
每个Hook的参数随Hook点位确定因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数 插件处理器,并按照固定的全局顺序调度执行
在参数/返回值匹配的情况下允许修改参数/返回值。
HookDispatcher 负责 排序规则如下
1. 按 stage 查询已注册的 hook_handler通过 ComponentRegistry
2. 按 priority 排序,区分 blocking 和非 blocking 模式 1. `blocking` 先于 `observe`
3. blocking 模式:依次同步调用,支持修改参数/提前终止 2. `early` 先于 `normal` 先于 `late`
4. 非 blocking 模式:异步调用,不阻塞主流程 3. 内置插件先于第三方插件
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限 4. `plugin_id`
5. `handler_name`
其中:
- `blocking` 处理器串行执行,可修改 `kwargs`,也可中止本次 Hook 调用。
- `observe` 处理器后台并发执行,只允许旁路观察,不参与主流程控制。
""" """
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set
import asyncio import asyncio
from dataclasses import dataclass, field import contextlib
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
if TYPE_CHECKING: if TYPE_CHECKING:
from .component_registry import HookHandlerEntry
from .supervisor import PluginRunnerSupervisor from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, HookHandlerEntry
logger = get_logger("plugin_runtime.host.hook_dispatcher") logger = get_logger("plugin_runtime.host.hook_dispatcher")
@dataclass @dataclass(slots=True)
class HookResult: class HookSpec:
"""单个 HookHandler 的执行结果""" """命名 Hook 的静态规格定义。
Attributes:
name: Hook 的唯一名称。
description: Hook 描述。
default_timeout_ms: 默认超时毫秒数;为 `0` 时退回系统默认值。
allow_blocking: 是否允许注册阻塞处理器。
allow_observe: 是否允许注册观察处理器。
allow_abort: 是否允许处理器中止当前 Hook 调用。
allow_kwargs_mutation: 是否允许阻塞处理器修改 `kwargs`。
"""
name: str
description: str = ""
default_timeout_ms: int = 0
allow_blocking: bool = True
allow_observe: bool = True
allow_abort: bool = True
allow_kwargs_mutation: bool = True
@dataclass(slots=True)
class HookHandlerExecutionResult:
"""单个 HookHandler 的执行结果。
Attributes:
handler_name: 完整处理器名称,格式通常为 `plugin_id.component_name`。
plugin_id: 处理器所属插件 ID。
success: 本次调用是否成功。
action: 当前处理器要求的控制动作,仅支持 `continue` 或 `abort`。
modified_kwargs: 处理器返回的修改后参数字典。
custom_result: 处理器返回的附加结果。
error_message: 失败时的错误描述。
"""
handler_name: str handler_name: str
success: bool = field(default=True) plugin_id: str
continue_processing: bool = field(default=True) success: bool = True
modified_kwargs: Optional[Dict[str, Any]] = field(default=None) action: str = "continue"
custom_result: Any = field(default=None) modified_kwargs: Optional[Dict[str, Any]] = None
custom_result: Any = None
error_message: str = ""
@dataclass(slots=True)
class HookDispatchResult:
"""一次命名 Hook 调用的聚合结果。
Attributes:
hook_name: 本次调用的 Hook 名称。
kwargs: 经阻塞处理器串行处理后的最终参数字典。
aborted: 是否被某个处理器中止。
stopped_by: 若被中止,记录触发中止的完整处理器名称。
custom_results: 阻塞处理器返回的附加结果列表。
errors: 本次调用中记录到的错误信息列表。
"""
hook_name: str
kwargs: Dict[str, Any] = field(default_factory=dict)
aborted: bool = False
stopped_by: Optional[str] = None
custom_results: List[Any] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
@dataclass(slots=True)
class _HookInvocationTarget:
"""内部使用的 Hook 调度目标。
Attributes:
supervisor: 负责该处理器的 Supervisor。
entry: Hook 处理器条目。
source_rank: 插件来源权重,内置插件为 `0`,第三方插件为 `1`。
"""
supervisor: "PluginRunnerSupervisor"
entry: "HookHandlerEntry"
source_rank: int
class HookDispatcher: class HookDispatcher:
"""Host-side Hook 分发器 """命名 Hook 分发器"""
由业务层调用 hook_dispatch() def __init__(
内部通过 ComponentRegistry 查询 handler self,
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。 supervisors_provider: Optional[Callable[[], Sequence["PluginRunnerSupervisor"]]] = None,
""" ) -> None:
"""初始化 Hook 分发器。
def __init__(self, component_registry: "ComponentRegistry") -> None:
"""初始化 HookDispatcher
Args: Args:
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler supervisors_provider: 可选的 Supervisor 提供器。若调用 `invoke_hook()`
时未显式传入 `supervisors`,则使用该回调获取目标 Supervisor 列表。
""" """
self._component_registry: "ComponentRegistry" = component_registry
self._background_tasks: Set[asyncio.Task] = set() self._background_tasks: Set[asyncio.Task[Any]] = set()
self._hook_specs: Dict[str, HookSpec] = {}
self._supervisors_provider = supervisors_provider
async def stop(self) -> None: async def stop(self) -> None:
"""停止 HookDispatcher取消所有未完成的后台任务""" """停止分发器并取消所有未完成的观察任务"""
for task in self._background_tasks: for task in self._background_tasks:
task.cancel() task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True) await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear() self._background_tasks.clear()
async def hook_dispatch( def register_hook_spec(self, spec: HookSpec) -> None:
self, """注册单个命名 Hook 规格。
stage: str,
supervisor: "PluginRunnerSupervisor",
**kwargs: Any,
) -> Dict[str, Any]:
"""分发 hook 到所有对应 handler 的便捷方法。
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
Args: Args:
stage: hook 名称 spec: 需要注册的 Hook 规格。
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin """
**kwargs: 关键字参数,会展开传递给 handler
normalized_name = self._normalize_hook_name(spec.name)
self._hook_specs[normalized_name] = HookSpec(
name=normalized_name,
description=spec.description,
default_timeout_ms=max(int(spec.default_timeout_ms), 0),
allow_blocking=bool(spec.allow_blocking),
allow_observe=bool(spec.allow_observe),
allow_abort=bool(spec.allow_abort),
allow_kwargs_mutation=bool(spec.allow_kwargs_mutation),
)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
"""批量注册命名 Hook 规格。
Args:
specs: 需要注册的 Hook 规格序列。
"""
for spec in specs:
self.register_hook_spec(spec)
def get_hook_spec(self, hook_name: str) -> HookSpec:
"""获取指定 Hook 的规格定义。
Args:
hook_name: Hook 名称。
Returns: Returns:
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数 HookSpec: 若未显式注册,则返回按系统默认值生成的运行时规格。
""" """
handler_entries = self._component_registry.get_hook_handlers(stage)
if not handler_entries:
return kwargs
current_kwargs = kwargs.copy() normalized_name = self._normalize_hook_name(hook_name)
blocking_handlers: List["HookHandlerEntry"] = [] if normalized_name in self._hook_specs:
non_blocking_handlers: List["HookHandlerEntry"] = [] return self._hook_specs[normalized_name]
# 分离 blocking 和非 blocking handler return HookSpec(
for entry in handler_entries: name=normalized_name,
if entry.blocking: default_timeout_ms=self._get_default_timeout_ms(),
blocking_handlers.append(entry) )
else:
non_blocking_handlers.append(entry)
# 处理 blocking handlers同步调用支持修改参数/提前终止) async def invoke_hook(
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0 self,
for entry in blocking_handlers: hook_name: str,
hook_args = {"stage": stage, **current_kwargs} supervisors: Optional[Sequence["PluginRunnerSupervisor"]] = None,
try: **kwargs: Any,
# 应用超时控制 ) -> HookDispatchResult:
result = await asyncio.wait_for( """触发一次命名 Hook 调用。
self._invoke_handler(supervisor, entry, hook_args),
timeout=timeout, Args:
hook_name: 本次触发的 Hook 名称。
supervisors: 当前运行时中所有可参与分发的 Supervisor留空时使用绑定的提供器。
**kwargs: 传递给 Hook 处理器的关键字参数。
Returns:
HookDispatchResult: 聚合后的 Hook 调用结果。
"""
resolved_supervisors = list(supervisors) if supervisors is not None else list(self._resolve_supervisors())
normalized_hook_name = self._normalize_hook_name(hook_name)
hook_spec = self.get_hook_spec(normalized_hook_name)
current_kwargs: Dict[str, Any] = dict(kwargs)
dispatch_result = HookDispatchResult(hook_name=normalized_hook_name, kwargs=dict(current_kwargs))
invocation_targets = self._collect_invocation_targets(normalized_hook_name, resolved_supervisors)
if not invocation_targets:
return dispatch_result
for target in invocation_targets:
if target.entry.is_observe:
self._schedule_observe_handler(
hook_name=normalized_hook_name,
hook_spec=hook_spec,
target=target,
kwargs=current_kwargs,
) )
except asyncio.TimeoutError: continue
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
if result: if not hook_spec.allow_blocking:
if result.modified_kwargs is not None: error_message = (
current_kwargs = result.modified_kwargs f"Hook {normalized_hook_name} 不允许 blocking 处理器,"
if not result.continue_processing: f"已跳过 {target.entry.full_name}"
logger.info(f"HookHandler {entry.full_name} 终止了后续处理") )
break logger.warning(error_message)
dispatch_result.errors.append(error_message)
continue
# 处理 non-blocking handlers异步调用不阻塞主流程 execution_result = await self._invoke_handler(
for entry in non_blocking_handlers: hook_name=normalized_hook_name,
async_kwargs = current_kwargs.copy() hook_spec=hook_spec,
hook_args = {"stage": stage, **async_kwargs} target=target,
task = asyncio.create_task( kwargs=current_kwargs,
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout) )
self._merge_blocking_result(
hook_spec=hook_spec,
target=target,
execution_result=execution_result,
dispatch_result=dispatch_result,
) )
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return current_kwargs current_kwargs = dict(dispatch_result.kwargs)
if dispatch_result.aborted:
break
return dispatch_result
def _resolve_supervisors(self) -> Sequence["PluginRunnerSupervisor"]:
"""解析当前调用应使用的 Supervisor 列表。
Returns:
Sequence[PluginRunnerSupervisor]: 可参与本次 Hook 调度的 Supervisor 序列。
Raises:
ValueError: 当未传入 `supervisors` 且分发器也未绑定提供器时抛出。
"""
if self._supervisors_provider is None:
raise ValueError("当前 HookDispatcher 未绑定 supervisors_provider请显式传入 supervisors")
return self._supervisors_provider()
def _collect_invocation_targets(
self,
hook_name: str,
supervisors: Sequence["PluginRunnerSupervisor"],
) -> List[_HookInvocationTarget]:
"""收集并排序本次 Hook 调用的全部处理器目标。
Args:
hook_name: 目标 Hook 名称。
supervisors: 当前参与调度的 Supervisor 序列。
Returns:
List[_HookInvocationTarget]: 已完成全局排序的处理器目标列表。
"""
invocation_targets: List[_HookInvocationTarget] = []
for supervisor in supervisors:
source_rank = self._get_supervisor_source_rank(supervisor)
for entry in supervisor.component_registry.get_hook_handlers(hook_name):
invocation_targets.append(
_HookInvocationTarget(
supervisor=supervisor,
entry=entry,
source_rank=source_rank,
)
)
invocation_targets.sort(key=self._build_sort_key)
return invocation_targets
@staticmethod
def _build_sort_key(target: _HookInvocationTarget) -> tuple[int, int, int, str, str]:
"""构造 Hook 处理器的全局排序键。
Args:
target: 待排序的处理器目标。
Returns:
tuple[int, int, int, str, str]: 全局排序键。
"""
return (
HookDispatcher._get_mode_rank(target.entry.mode),
HookDispatcher._get_order_rank(target.entry.order),
target.source_rank,
target.entry.plugin_id,
target.entry.name,
)
@staticmethod
def _get_default_timeout_ms() -> int:
"""读取系统级默认 Hook 超时。
Returns:
int: 默认超时毫秒数。
"""
timeout_seconds = float(global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0)
return max(int(timeout_seconds * 1000), 1)
@staticmethod
def _get_mode_rank(mode: str) -> int:
"""返回 Hook 模式的排序权重。
Args:
mode: Hook 模式。
Returns:
int: 越小表示越靠前。
"""
return {"blocking": 0, "observe": 1}.get(mode, 99)
@staticmethod
def _get_order_rank(order: str) -> int:
"""返回 Hook 顺序槽位的排序权重。
Args:
order: Hook 顺序槽位。
Returns:
int: 越小表示越靠前。
"""
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
@staticmethod
def _get_supervisor_source_rank(supervisor: "PluginRunnerSupervisor") -> int:
"""返回 Supervisor 的来源排序权重。
Args:
supervisor: 目标 Supervisor。
Returns:
int: 内置插件返回 `0`,第三方插件返回 `1`。
"""
return 0 if supervisor.group_name == "builtin" else 1
@staticmethod
def _normalize_hook_name(hook_name: str) -> str:
"""规范化命名 Hook 名称。
Args:
hook_name: 原始 Hook 名称。
Returns:
str: 规范化后的 Hook 名称。
Raises:
ValueError: 当 Hook 名称为空时抛出。
"""
normalized_name = str(hook_name or "").strip()
if not normalized_name:
raise ValueError("Hook 名称不能为空")
return normalized_name
def _resolve_timeout_ms(self, hook_spec: HookSpec, target: _HookInvocationTarget) -> int:
"""计算单个处理器的实际超时。
Args:
hook_spec: 当前 Hook 的规格定义。
target: 当前执行目标。
Returns:
int: 最终生效的超时毫秒数。
"""
if target.entry.timeout_ms > 0:
return target.entry.timeout_ms
if hook_spec.default_timeout_ms > 0:
return hook_spec.default_timeout_ms
return self._get_default_timeout_ms()
async def _invoke_handler( async def _invoke_handler(
self, self,
supervisor: "PluginRunnerSupervisor", hook_name: str,
handler_entry: "HookHandlerEntry", hook_spec: HookSpec,
args: Dict[str, Any], target: _HookInvocationTarget,
) -> Optional[HookResult]: kwargs: Dict[str, Any],
"""调用单个 handler 并收集结果。 ) -> HookHandlerExecutionResult:
"""执行单个 Hook 处理器。
Args: Args:
supervisor: PluginRunnerSupervisor 实例 hook_name: 当前 Hook 名称。
handler_entry: HookHandlerEntry 实例 hook_spec: 当前 Hook 规格。
args: 传递给 handler 的参数字典 target: 当前执行目标。
stage: hook 名称 kwargs: 当前参数字典。
Returns: Returns:
Optional[HookResult]: 执行结果,如果执行失败则返回 None HookHandlerExecutionResult: 处理器执行结果
""" """
try:
resp_envelope = await supervisor.invoke_plugin(
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = HookResult(
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_kwargs=resp.get("modified_kwargs"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
return result timeout_ms = self._resolve_timeout_ms(hook_spec, target)
request_args: Dict[str, Any] = {"hook_name": hook_name, **dict(kwargs)}
try:
response_envelope = await asyncio.wait_for(
target.supervisor.invoke_plugin(
"plugin.invoke_hook",
target.entry.plugin_id,
target.entry.name,
request_args,
timeout_ms=timeout_ms,
),
timeout=max(timeout_ms / 1000.0, 0.001),
)
except asyncio.TimeoutError:
error_message = (
f"HookHandler {target.entry.full_name} 执行超时,已超过 {timeout_ms}ms"
)
logger.error(error_message)
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
success=False,
error_message=error_message,
)
except Exception as exc:
error_message = f"HookHandler {target.entry.full_name} 执行失败: {exc}"
logger.error(error_message, exc_info=True)
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
success=False,
error_message=error_message,
)
response_payload = response_envelope.payload
if not isinstance(response_payload, dict):
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
custom_result=response_payload,
)
return HookHandlerExecutionResult(
handler_name=target.entry.full_name,
plugin_id=target.entry.plugin_id,
success=bool(response_payload.get("success", True)),
action=self._normalize_action(response_payload.get("action", "continue")),
modified_kwargs=self._extract_modified_kwargs(response_payload.get("modified_kwargs")),
custom_result=response_payload.get("custom_result"),
error_message=str(response_payload.get("error_message", "") or ""),
)
@staticmethod
def _extract_modified_kwargs(raw_value: Any) -> Optional[Dict[str, Any]]:
"""提取并校验处理器返回的 `modified_kwargs`。
Args:
raw_value: 原始返回值。
Returns:
Optional[Dict[str, Any]]: 合法时返回字典,否则返回 `None`。
"""
if raw_value is None:
return None
if isinstance(raw_value, dict):
return dict(raw_value)
logger.warning("HookHandler 返回的 modified_kwargs 不是字典,已忽略")
return None
@staticmethod
def _normalize_action(raw_value: Any) -> str:
"""规范化处理器动作返回值。
Args:
raw_value: 原始动作值。
Returns:
str: 规范化后的动作值,仅支持 `continue` 或 `abort`。
"""
normalized_value = str(raw_value or "").strip().lower() or "continue"
if normalized_value not in {"continue", "abort"}:
logger.warning(f"未知的 Hook action: {raw_value},已按 continue 处理")
return "continue"
return normalized_value
def _merge_blocking_result(
self,
hook_spec: HookSpec,
target: _HookInvocationTarget,
execution_result: HookHandlerExecutionResult,
dispatch_result: HookDispatchResult,
) -> None:
"""合并阻塞处理器结果到聚合结果。
Args:
hook_spec: 当前 Hook 规格。
target: 当前执行目标。
execution_result: 当前处理器执行结果。
dispatch_result: 当前聚合结果对象。
"""
if execution_result.custom_result is not None:
dispatch_result.custom_results.append(execution_result.custom_result)
if not execution_result.success:
error_message = execution_result.error_message or f"HookHandler {target.entry.full_name} 执行失败"
dispatch_result.errors.append(error_message)
self._apply_error_policy(target, hook_spec, dispatch_result, error_message)
return
if execution_result.modified_kwargs is not None:
if hook_spec.allow_kwargs_mutation:
dispatch_result.kwargs = dict(execution_result.modified_kwargs)
else:
error_message = (
f"Hook {dispatch_result.hook_name} 不允许修改 kwargs"
f"已忽略 {target.entry.full_name} 的 modified_kwargs"
)
logger.warning(error_message)
dispatch_result.errors.append(error_message)
if execution_result.action == "abort":
if hook_spec.allow_abort:
dispatch_result.aborted = True
dispatch_result.stopped_by = target.entry.full_name
logger.info(f"HookHandler {target.entry.full_name} 中止了 Hook {dispatch_result.hook_name}")
else:
error_message = (
f"Hook {dispatch_result.hook_name} 不允许 abort"
f"已忽略 {target.entry.full_name} 的 abort 请求"
)
logger.warning(error_message)
dispatch_result.errors.append(error_message)
def _apply_error_policy(
self,
target: _HookInvocationTarget,
hook_spec: HookSpec,
dispatch_result: HookDispatchResult,
error_message: str,
) -> None:
"""根据错误策略处理阻塞处理器失败。
Args:
target: 触发错误的处理器目标。
hook_spec: 当前 Hook 规格。
dispatch_result: 当前聚合结果对象。
error_message: 需要记录的错误描述。
"""
if target.entry.error_policy != "abort":
return
if not hook_spec.allow_abort:
logger.warning(
"Hook %s 禁止 abort已将 %s 的错误策略按 skip 处理",
dispatch_result.hook_name,
target.entry.full_name,
)
return
dispatch_result.aborted = True
dispatch_result.stopped_by = target.entry.full_name
logger.warning(
"HookHandler %s 因错误策略 abort 中止了 Hook %s: %s",
target.entry.full_name,
dispatch_result.hook_name,
error_message,
)
def _schedule_observe_handler(
self,
hook_name: str,
hook_spec: HookSpec,
target: _HookInvocationTarget,
kwargs: Dict[str, Any],
) -> None:
"""后台调度观察型处理器。
Args:
hook_name: 当前 Hook 名称。
hook_spec: 当前 Hook 规格。
target: 当前观察型处理器目标。
kwargs: 调用参数快照。
"""
if not hook_spec.allow_observe:
logger.warning("Hook %s 不允许 observe 处理器,已跳过 %s", hook_name, target.entry.full_name)
return
task = asyncio.create_task(
self._run_observe_handler(
hook_name=hook_name,
hook_spec=hook_spec,
target=target,
kwargs=dict(kwargs),
)
)
self._background_tasks.add(task)
task.add_done_callback(self._handle_background_task_done)
async def _run_observe_handler(
self,
hook_name: str,
hook_spec: HookSpec,
target: _HookInvocationTarget,
kwargs: Dict[str, Any],
) -> None:
"""执行观察型处理器并吞掉控制流副作用。
Args:
hook_name: 当前 Hook 名称。
hook_spec: 当前 Hook 规格。
target: 当前观察型处理器目标。
kwargs: 调用参数快照。
"""
execution_result = await self._invoke_handler(
hook_name=hook_name,
hook_spec=hook_spec,
target=target,
kwargs=kwargs,
)
if not execution_result.success:
logger.warning(
"观察型 HookHandler %s 执行失败: %s",
target.entry.full_name,
execution_result.error_message or "未知错误",
)
return
if execution_result.modified_kwargs is not None:
logger.warning(
"观察型 HookHandler %s 返回了 modified_kwargs已忽略", target.entry.full_name
)
if execution_result.action == "abort":
logger.warning(
"观察型 HookHandler %s 请求 abort已忽略", target.entry.full_name
)
def _handle_background_task_done(self, task: asyncio.Task[Any]) -> None:
"""处理观察任务完成回调。
Args:
task: 已完成的后台任务。
"""
self._background_tasks.discard(task)
with contextlib.suppress(asyncio.CancelledError):
exception = task.exception()
if exception is not None:
logger.error(f"观察型 Hook 后台任务执行失败: {exception}")

View File

@@ -49,7 +49,7 @@ from .api_registry import APIRegistry
from .capability_service import CapabilityService from .capability_service import CapabilityService
from .component_registry import ComponentRegistry from .component_registry import ComponentRegistry
from .event_dispatcher import EventDispatcher from .event_dispatcher import EventDispatcher
from .hook_dispatcher import HookDispatcher from .hook_dispatcher import HookDispatchResult, HookDispatcher
from .logger_bridge import RunnerLogBridge from .logger_bridge import RunnerLogBridge
from .message_gateway import MessageGateway from .message_gateway import MessageGateway
from .rpc_server import RPCServer from .rpc_server import RPCServer
@@ -80,6 +80,7 @@ class PluginRunnerSupervisor:
def __init__( def __init__(
self, self,
plugin_dirs: Optional[List[Path]] = None, plugin_dirs: Optional[List[Path]] = None,
group_name: str = "third_party",
socket_path: Optional[str] = None, socket_path: Optional[str] = None,
health_check_interval_sec: Optional[float] = None, health_check_interval_sec: Optional[float] = None,
max_restart_attempts: Optional[int] = None, max_restart_attempts: Optional[int] = None,
@@ -89,12 +90,14 @@ class PluginRunnerSupervisor:
Args: Args:
plugin_dirs: 由当前 Runner 负责加载的插件目录列表。 plugin_dirs: 由当前 Runner 负责加载的插件目录列表。
group_name: 当前 Supervisor 所属运行时分组名称。
socket_path: 自定义 IPC 地址;留空时由传输层自动生成。 socket_path: 自定义 IPC 地址;留空时由传输层自动生成。
health_check_interval_sec: 健康检查间隔,单位秒。 health_check_interval_sec: 健康检查间隔,单位秒。
max_restart_attempts: 自动重启 Runner 的最大次数。 max_restart_attempts: 自动重启 Runner 的最大次数。
runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。 runner_spawn_timeout_sec: 等待 Runner 建连并就绪的超时时间,单位秒。
""" """
runtime_config = global_config.plugin_runtime runtime_config = global_config.plugin_runtime
self._group_name: str = str(group_name or "third_party").strip() or "third_party"
self._plugin_dirs: List[Path] = plugin_dirs or [] self._plugin_dirs: List[Path] = plugin_dirs or []
self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0 self._health_interval: float = health_check_interval_sec or runtime_config.health_check_interval_sec or 30.0
self._runner_spawn_timeout: float = ( self._runner_spawn_timeout: float = (
@@ -108,7 +111,7 @@ class PluginRunnerSupervisor:
self._api_registry = APIRegistry() self._api_registry = APIRegistry()
self._component_registry = ComponentRegistry() self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry) self._event_dispatcher = EventDispatcher(self._component_registry)
self._hook_dispatcher = HookDispatcher(self._component_registry) self._hook_dispatcher = HookDispatcher(lambda: [self])
self._message_gateway = MessageGateway(self._component_registry) self._message_gateway = MessageGateway(self._component_registry)
self._log_bridge = RunnerLogBridge() self._log_bridge = RunnerLogBridge()
@@ -133,6 +136,12 @@ class PluginRunnerSupervisor:
"""返回授权管理器。""" """返回授权管理器。"""
return self._authorization return self._authorization
@property
def group_name(self) -> str:
"""返回当前 Supervisor 的运行时分组名称。"""
return self._group_name
@property @property
def capability_service(self) -> CapabilityService: def capability_service(self) -> CapabilityService:
"""返回能力服务。""" """返回能力服务。"""
@@ -243,17 +252,18 @@ class PluginRunnerSupervisor:
""" """
return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args) return await self._event_dispatcher.dispatch_event(event_type, self, message, extra_args)
async def dispatch_hook(self, stage: str, **kwargs: Any) -> Dict[str, Any]: async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult:
"""分发 Hook 到已注册的 Hook 处理器 """在当前 Supervisor 内触发一次命名 Hook 调用
Args: Args:
stage: Hook 阶段名称。 hook_name: 本次触发的 Hook 名称。
**kwargs: 传递给 Hook 的关键字参数。 **kwargs: 传递给 Hook 处理器的关键字参数。
Returns: Returns:
Dict[str, Any]: 经 Hook 修改后的参数字典 HookDispatchResult: 聚合后的 Hook 调用结果
""" """
return await self._hook_dispatcher.hook_dispatch(stage, self, **kwargs)
return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs)
async def send_message_to_external( async def send_message_to_external(
self, self,

View File

@@ -3,8 +3,9 @@
提供 PluginRuntimeManager 单例,负责: 提供 PluginRuntimeManager 单例,负责:
1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程) 1. 管理双 PluginSupervisor 的生命周期(内置插件 / 第三方插件各一个子进程)
2. 将 EventType 桥接到运行时的 event dispatch 2. 将 EventType 桥接到运行时的 event dispatch
3. 在运行时的 ComponentRegistry 中查找命令 3. 触发跨 Supervisor 的命名 Hook 调用
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 4. 在运行时的 ComponentRegistry 中查找命令
5. 提供统一的能力实现注册接口,使插件可以调用主程序功能
""" """
from pathlib import Path from pathlib import Path
@@ -24,6 +25,7 @@ from src.plugin_runtime.capabilities import (
RuntimeDataCapabilityMixin, RuntimeDataCapabilityMixin,
) )
from src.plugin_runtime.capabilities.registry import register_capability_impls from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult, HookDispatcher, HookSpec
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
from src.plugin_runtime.runner.manifest_validator import ManifestValidator from src.plugin_runtime.runner.manifest_validator import ManifestValidator
@@ -72,6 +74,7 @@ class PluginRuntimeManager(
self._manifest_validator: ManifestValidator = ManifestValidator() self._manifest_validator: ManifestValidator = ManifestValidator()
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False self._config_reload_callback_registered: bool = False
self._hook_dispatcher: HookDispatcher = HookDispatcher(lambda: self.supervisors)
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None: async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。 """接收 Platform IO 审核后的入站消息并送入主消息链。
@@ -182,6 +185,7 @@ class PluginRuntimeManager(
if builtin_dirs: if builtin_dirs:
self._builtin_supervisor = PluginSupervisor( self._builtin_supervisor = PluginSupervisor(
plugin_dirs=builtin_dirs, plugin_dirs=builtin_dirs,
group_name="builtin",
socket_path=builtin_socket, socket_path=builtin_socket,
) )
self._register_capability_impls(self._builtin_supervisor) self._register_capability_impls(self._builtin_supervisor)
@@ -189,6 +193,7 @@ class PluginRuntimeManager(
if third_party_dirs: if third_party_dirs:
self._third_party_supervisor = PluginSupervisor( self._third_party_supervisor = PluginSupervisor(
plugin_dirs=third_party_dirs, plugin_dirs=third_party_dirs,
group_name="third_party",
socket_path=third_party_socket, socket_path=third_party_socket,
) )
self._register_capability_impls(self._third_party_supervisor) self._register_capability_impls(self._third_party_supervisor)
@@ -235,6 +240,7 @@ class PluginRuntimeManager(
await platform_io_manager.stop() await platform_io_manager.stop()
except Exception as platform_io_exc: except Exception as platform_io_exc:
logger.warning(f"Platform IO 停止失败: {platform_io_exc}") logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
await self._hook_dispatcher.stop()
self._started = False self._started = False
self._builtin_supervisor = None self._builtin_supervisor = None
self._third_party_supervisor = None self._third_party_supervisor = None
@@ -274,6 +280,7 @@ class PluginRuntimeManager(
else: else:
logger.info("插件运行时已停止") logger.info("插件运行时已停止")
finally: finally:
await self._hook_dispatcher.stop()
self._started = False self._started = False
self._builtin_supervisor = None self._builtin_supervisor = None
self._third_party_supervisor = None self._third_party_supervisor = None
@@ -284,11 +291,41 @@ class PluginRuntimeManager(
"""返回插件运行时是否处于启动状态。""" """返回插件运行时是否处于启动状态。"""
return self._started return self._started
@property
def hook_dispatcher(self) -> HookDispatcher:
"""返回跨 Supervisor 的命名 Hook 分发器。"""
return self._hook_dispatcher
@property
def invoke_dispatcher(self) -> HookDispatcher:
"""返回命名 Hook 分发器的兼容别名。"""
return self._hook_dispatcher
@property @property
def supervisors(self) -> List["PluginSupervisor"]: def supervisors(self) -> List["PluginSupervisor"]:
"""获取所有活跃的 Supervisor""" """获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None] return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
def register_hook_spec(self, spec: HookSpec) -> None:
"""注册单个命名 Hook 规格。
Args:
spec: 需要注册的 Hook 规格。
"""
self._hook_dispatcher.register_hook_spec(spec)
def register_hook_specs(self, specs: Sequence[HookSpec]) -> None:
"""批量注册命名 Hook 规格。
Args:
specs: 需要注册的 Hook 规格序列。
"""
self._hook_dispatcher.register_hook_specs(specs)
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]: def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。""" """根据当前已注册插件构建全局依赖图。"""
@@ -588,6 +625,19 @@ class PluginRuntimeManager(
return True, modified return True, modified
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> HookDispatchResult:
"""触发一次跨 Supervisor 的命名 Hook 调用。
Args:
hook_name: 本次触发的 Hook 名称。
**kwargs: 传递给 Hook 处理器的关键字参数。
Returns:
HookDispatchResult: 聚合后的 Hook 调用结果。
"""
return await self._hook_dispatcher.invoke_hook(hook_name, **kwargs)
# ─── 命令查找 ────────────────────────────────────────────── # ─── 命令查找 ──────────────────────────────────────────────
def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]: def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]:

View File

@@ -330,7 +330,6 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.invoke_workflow_step", self._handle_workflow_step)
self._rpc_client.register_method("plugin.health", self._handle_health) self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
@@ -1053,73 +1052,28 @@ class PluginRunner:
) )
except Exception as exc: except Exception as exc:
logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True) logger.error(f"插件 {plugin_id} hook_handler {component_name} 执行异常: {exc}", exc_info=True)
return envelope.make_response(payload={"success": False, "continue_processing": True}) return envelope.make_response(
payload={
"success": False,
"action": "continue",
"error_message": str(exc),
}
)
if raw is None: if raw is None:
result = {"success": True, "continue_processing": True} result = {"success": True, "action": "continue"}
elif isinstance(raw, dict): elif isinstance(raw, dict):
result = { result = {
"success": True, "success": True,
"continue_processing": raw.get("continue_processing", True), "action": str(raw.get("action", "continue") or "continue").strip().lower() or "continue",
"modified_kwargs": raw.get("modified_kwargs"), "modified_kwargs": raw.get("modified_kwargs"),
"custom_result": raw.get("custom_result"), "custom_result": raw.get("custom_result"),
} }
else: else:
result = {"success": True, "continue_processing": True, "custom_result": raw} result = {"success": True, "action": "continue", "custom_result": raw}
return envelope.make_response(payload=result) return envelope.make_response(payload=result)
async def _handle_workflow_step(self, envelope: Envelope) -> Envelope:
"""处理 WorkflowStep 调用请求
与通用 invoke 不同,会将返回值规范化为
{hook_result, modified_message, stage_output} 格式。
"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
component_name = invoke.component_name
handler_method = self._resolve_component_handler(meta, component_name)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
raw = (
await handler_method(**invoke.args)
if inspect.iscoroutinefunction(handler_method)
else handler_method(**invoke.args)
)
# 规范化返回值
if isinstance(raw, str):
result = {"hook_result": raw}
elif isinstance(raw, dict):
result = raw
result.setdefault("hook_result", "continue")
else:
result = {"hook_result": "continue"}
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} workflow_step {component_name} 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_health(self, envelope: Envelope) -> Envelope: async def _handle_health(self, envelope: Envelope) -> Envelope:
"""处理健康检查""" """处理健康检查"""
uptime_ms = int((time.monotonic() - self._start_time) * 1000) uptime_ms = int((time.monotonic() - self._start_time) * 1000)