feat: Enhance Hook System with HookHandler and Dispatcher

- Introduced HookHandlerEntry to manage hook processing with attributes like hook name, mode, order, timeout, and error policy.
- Implemented normalization methods for hook attributes to ensure valid configurations.
- Updated ComponentRegistry to support retrieval of hook handlers based on hook names, with sorting by mode and order.
- Refactored HookDispatcher to handle invocation of hooks, separating blocking and non-blocking handlers, and managing execution results.
- Added support for registering hook specifications and invoking hooks across supervisors in PluginRuntimeManager.
- Removed deprecated workflow step handling from PluginRunner, streamlining hook invocation responses.
This commit is contained in:
DrSmoothl
2026-03-24 19:04:05 +08:00
parent 865e4916e3
commit 0b0f47a444
6 changed files with 1247 additions and 523 deletions

View File

@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_handler / message_gateway
- 按类型注册组件action / command / tool / event_handler / hook_handler / message_gateway
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -106,14 +106,129 @@ class EventHandlerEntry(ComponentEntry):
class HookHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目"""
"""HookHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "")
self.priority: int = metadata.get("priority", 0)
self.blocking: bool = metadata.get("blocking", False)
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
self.order: str = self._normalize_order(metadata.get("order", "normal"))
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
super().__init__(name, component_type, plugin_id, metadata)
@staticmethod
def _normalize_error_policy(raw_value: Any) -> str:
"""规范化 Hook 异常处理策略。
Args:
raw_value: 原始异常处理策略值。
Returns:
str: 规范化后的异常处理策略。
Raises:
ValueError: 当异常处理策略不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "skip"
if normalized_value not in {"abort", "skip", "log"}:
raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_hook_name(raw_value: Any) -> str:
"""规范化命名 Hook 名称。
Args:
raw_value: 原始 Hook 名称。
Returns:
str: 去空白后的 Hook 名称。
Raises:
ValueError: 当 Hook 名称为空时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
if not (normalized_value := str(normalized_source or "").strip()):
raise ValueError("HookHandler 的 hook 名称不能为空")
return normalized_value
@staticmethod
def _normalize_mode(raw_value: Any) -> str:
"""规范化 Hook 处理模式。
Args:
raw_value: 原始模式值。
Returns:
str: 规范化后的模式。
Raises:
ValueError: 当模式不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "blocking"
if normalized_value not in {"blocking", "observe"}:
raise ValueError(f"HookHandler 模式不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_order(raw_value: Any) -> str:
"""规范化 Hook 顺序槽位。
Args:
raw_value: 原始顺序值。
Returns:
str: 规范化后的顺序槽位。
Raises:
ValueError: 当顺序值不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "normal"
if normalized_value not in {"early", "normal", "late"}:
raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_timeout_ms(raw_value: Any) -> int:
"""规范化 Hook 超时配置。
Args:
raw_value: 原始超时值。
Returns:
int: 规范化后的超时毫秒数。
Raises:
ValueError: 当超时值为负数或无法转换为整数时抛出。
"""
try:
timeout_ms = int(raw_value or 0)
except (TypeError, ValueError) as exc:
raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc
if timeout_ms < 0:
raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}")
return timeout_ms
@property
def is_blocking(self) -> bool:
"""返回当前 Hook 是否为阻塞模式。"""
return self.mode == "blocking"
@property
def is_observe(self) -> bool:
"""返回当前 Hook 是否为观察模式。"""
return self.mode == "observe"
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
@@ -454,16 +569,17 @@ class ComponentRegistry:
return handlers
def get_hook_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[HookHandlerEntry]:
"""获取特定 hook 阶段的所有步骤,按 priority 降序
"""获取订阅指定命名 Hook 的全部处理器
Args:
stage: hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID若提供则考虑会话禁用状态
hook_name: 目标 Hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话 ID若提供则考虑会话禁用状态
Returns:
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表
"""
handlers: List[HookHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
@@ -471,11 +587,37 @@ class ComponentRegistry:
continue
if not isinstance(comp, HookHandlerEntry):
continue
if comp.stage == stage:
if comp.hook == hook_name:
handlers.append(comp)
handlers.sort(key=lambda c: c.priority, reverse=True)
handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name))
return handlers
@staticmethod
def _get_hook_mode_rank(mode: str) -> int:
"""返回 Hook 模式的排序权重。
Args:
mode: Hook 模式字符串。
Returns:
int: 越小表示越靠前。
"""
return {"blocking": 0, "observe": 1}.get(mode, 99)
@staticmethod
def _get_hook_order_rank(order: str) -> int:
"""返回 Hook 顺序槽位的排序权重。
Args:
order: Hook 顺序槽位字符串。
Returns:
int: 越小表示越靠前。
"""
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
def get_message_gateway(
self,
plugin_id: str,
@@ -566,8 +708,13 @@ class ComponentRegistry:
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
stats[comp_type.value.lower()] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats
return StatusDict(
total=len(self._components),
action=len(self._by_type[ComponentTypes.ACTION]),
command=len(self._by_type[ComponentTypes.COMMAND]),
tool=len(self._by_type[ComponentTypes.TOOL]),
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
plugins=len(self._by_plugin),
)