refactor: hook_dispatcher相关的修改
This commit is contained in:
committed by
DrSmoothl
parent
593400c0aa
commit
310d7798ba
@@ -58,7 +58,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
|||||||
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
|
"plugin_runtime.host.component_registry": ("#ffaf00", None, False),
|
||||||
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
|
"plugin_runtime.host.capability_service": ("#ffd700", None, False),
|
||||||
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
|
"plugin_runtime.host.event_dispatcher": ("#87d700", None, False),
|
||||||
"plugin_runtime.host.workflow_executor": ("#5fd7af", None, False),
|
"plugin_runtime.host.hook_dispatcher": ("#5fd7af", None, False),
|
||||||
"plugin_runtime.runner.main": ("#d787ff", None, False),
|
"plugin_runtime.runner.main": ("#d787ff", None, False),
|
||||||
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
|
"plugin_runtime.runner.rpc_client": ("#8787ff", None, False),
|
||||||
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
|
"plugin_runtime.runner.manifest_validator": ("#5fafff", None, False),
|
||||||
|
|||||||
@@ -1642,14 +1642,14 @@ class PluginRuntimeConfig(ConfigBase):
|
|||||||
)
|
)
|
||||||
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
|
"""等待 Runner 子进程启动并注册的超时时间(秒)"""
|
||||||
|
|
||||||
workflow_blocking_timeout_sec: float = Field(
|
hook_blocking_timeout_sec: float = Field(
|
||||||
default=120.0,
|
default=30,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"x-widget": "number",
|
"x-widget": "number",
|
||||||
"x-icon": "timer",
|
"x-icon": "timer",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
"""Workflow 阻塞步骤的全局超时上限(秒)"""
|
"""Hook 阻塞步骤的全局超时上限(秒)"""
|
||||||
|
|
||||||
ipc_socket_path: str = Field(
|
ipc_socket_path: str = Field(
|
||||||
default="",
|
default="",
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class ComponentTypes(str, Enum):
|
|||||||
COMMAND = "COMMAND"
|
COMMAND = "COMMAND"
|
||||||
TOOL = "TOOL"
|
TOOL = "TOOL"
|
||||||
EVENT_HANDLER = "EVENT_HANDLER"
|
EVENT_HANDLER = "EVENT_HANDLER"
|
||||||
WORKFLOW_HANDLER = "WORKFLOW_HANDLER"
|
HOOK_HANDLER = "HOOK_HANDLER"
|
||||||
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
||||||
|
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ class StatusDict(TypedDict):
|
|||||||
COMMAND: int
|
COMMAND: int
|
||||||
TOOL: int
|
TOOL: int
|
||||||
EVENT_HANDLER: int
|
EVENT_HANDLER: int
|
||||||
WORKFLOW_HANDLER: int
|
HOOK_HANDLER: int
|
||||||
MESSAGE_GATEWAY: int
|
MESSAGE_GATEWAY: int
|
||||||
plugins: int
|
plugins: int
|
||||||
|
|
||||||
@@ -105,12 +105,13 @@ class EventHandlerEntry(ComponentEntry):
|
|||||||
super().__init__(name, component_type, plugin_id, metadata)
|
super().__init__(name, component_type, plugin_id, metadata)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowHandlerEntry(ComponentEntry):
|
class HookHandlerEntry(ComponentEntry):
|
||||||
"""WorkflowHandler 组件条目"""
|
"""WorkflowHandler 组件条目"""
|
||||||
|
|
||||||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||||
self.stage: str = metadata.get("stage", "")
|
self.stage: str = metadata.get("stage", "")
|
||||||
self.priority: int = metadata.get("priority", 0)
|
self.priority: int = metadata.get("priority", 0)
|
||||||
|
self.blocking: bool = metadata.get("blocking", False)
|
||||||
super().__init__(name, component_type, plugin_id, metadata)
|
super().__init__(name, component_type, plugin_id, metadata)
|
||||||
|
|
||||||
|
|
||||||
@@ -172,8 +173,8 @@ class ComponentRegistry:
|
|||||||
comp = ToolEntry(name, component_type, plugin_id, metadata)
|
comp = ToolEntry(name, component_type, plugin_id, metadata)
|
||||||
elif component_type == ComponentTypes.EVENT_HANDLER:
|
elif component_type == ComponentTypes.EVENT_HANDLER:
|
||||||
comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
|
comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
|
||||||
elif component_type == ComponentTypes.WORKFLOW_HANDLER:
|
elif component_type == ComponentTypes.HOOK_HANDLER:
|
||||||
comp = WorkflowHandlerEntry(name, component_type, plugin_id, metadata)
|
comp = HookHandlerEntry(name, component_type, plugin_id, metadata)
|
||||||
elif component_type == ComponentTypes.MESSAGE_GATEWAY:
|
elif component_type == ComponentTypes.MESSAGE_GATEWAY:
|
||||||
comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
|
comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
|
||||||
else:
|
else:
|
||||||
@@ -380,23 +381,23 @@ class ComponentRegistry:
|
|||||||
handlers.sort(key=lambda c: c.weight, reverse=True)
|
handlers.sort(key=lambda c: c.weight, reverse=True)
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def get_workflow_handlers(
|
def get_hook_handlers(
|
||||||
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||||||
) -> List[WorkflowHandlerEntry]:
|
) -> List[HookHandlerEntry]:
|
||||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。
|
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stage: workflow 阶段名称
|
stage: hook 名称
|
||||||
enabled_only: 是否仅返回启用的组件
|
enabled_only: 是否仅返回启用的组件
|
||||||
session_id: 可选的会话ID,若提供则考虑会话禁用状态
|
session_id: 可选的会话ID,若提供则考虑会话禁用状态
|
||||||
Returns:
|
Returns:
|
||||||
handlers (List[WorkflowHandlerEntry]): 符合条件的 WorkflowHandler 组件列表,按 priority 降序排序
|
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
|
||||||
"""
|
"""
|
||||||
handlers: List[WorkflowHandlerEntry] = []
|
handlers: List[HookHandlerEntry] = []
|
||||||
for comp in self._by_type.get(ComponentTypes.WORKFLOW_HANDLER, {}).values():
|
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
|
||||||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||||||
continue
|
continue
|
||||||
if not isinstance(comp, WorkflowHandlerEntry):
|
if not isinstance(comp, HookHandlerEntry):
|
||||||
continue
|
continue
|
||||||
if comp.stage == stage:
|
if comp.stage == stage:
|
||||||
handlers.append(comp)
|
handlers.append(comp)
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ class EventDispatcher:
|
|||||||
self._component_registry: "ComponentRegistry" = component_registry
|
self._component_registry: "ComponentRegistry" = component_registry
|
||||||
self._result_history: Dict[str, List[EventResult]] = {}
|
self._result_history: Dict[str, List[EventResult]] = {}
|
||||||
self._history_enabled: Set[str] = set()
|
self._history_enabled: Set[str] = set()
|
||||||
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
|
|
||||||
self._background_tasks: Set[asyncio.Task] = set()
|
self._background_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
def enable_history(self, event_type: str) -> None:
|
def enable_history(self, event_type: str) -> None:
|
||||||
@@ -68,6 +67,13 @@ class EventDispatcher:
|
|||||||
if event_type in self._result_history:
|
if event_type in self._result_history:
|
||||||
self._result_history[event_type] = []
|
self._result_history[event_type] = []
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止 EventDispatcher,取消所有未完成的后台任务"""
|
||||||
|
for task in self._background_tasks:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
|
self._background_tasks.clear()
|
||||||
|
|
||||||
async def dispatch_event(
|
async def dispatch_event(
|
||||||
self,
|
self,
|
||||||
event_type: str,
|
event_type: str,
|
||||||
|
|||||||
166
src/plugin_runtime/host/hook_dispatcher.py
Normal file
166
src/plugin_runtime/host/hook_dispatcher.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Hook Dispatch 系统
|
||||||
|
|
||||||
|
插件可以注册自己的Hook,当特定函数被调用时,Hook Dispatch系统会将调用转发给插件的Hook处理函数。
|
||||||
|
每个Hook的参数随Hook点位确定,因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
|
||||||
|
在参数/返回值匹配的情况下允许修改参数/返回值。
|
||||||
|
|
||||||
|
HookDispatcher 负责:
|
||||||
|
1. 按 stage 查询已注册的 hook_handler(通过 ComponentRegistry)
|
||||||
|
2. 按 priority 排序,区分 blocking 和非 blocking 模式
|
||||||
|
3. blocking 模式:依次同步调用,支持修改参数/提前终止
|
||||||
|
4. 非 blocking 模式:异步调用,不阻塞主流程
|
||||||
|
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .supervisor import PluginRunnerSupervisor
|
||||||
|
from .component_registry import ComponentRegistry, HookHandlerEntry
|
||||||
|
|
||||||
|
logger = get_logger("plugin_runtime.host.hook_dispatcher")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HookResult:
|
||||||
|
"""单个 HookHandler 的执行结果"""
|
||||||
|
|
||||||
|
handler_name: str
|
||||||
|
success: bool = field(default=True)
|
||||||
|
continue_processing: bool = field(default=True)
|
||||||
|
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
|
||||||
|
custom_result: Any = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class HookDispatcher:
|
||||||
|
"""Host-side Hook 分发器
|
||||||
|
|
||||||
|
由业务层调用 hook_dispatch(),
|
||||||
|
内部通过 ComponentRegistry 查询 handler,
|
||||||
|
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, component_registry: "ComponentRegistry") -> None:
|
||||||
|
"""初始化 HookDispatcher
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
|
||||||
|
"""
|
||||||
|
self._component_registry: "ComponentRegistry" = component_registry
|
||||||
|
self._background_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""停止 HookDispatcher,取消所有未完成的后台任务"""
|
||||||
|
for task in self._background_tasks:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
|
self._background_tasks.clear()
|
||||||
|
|
||||||
|
async def hook_dispatch(
|
||||||
|
self,
|
||||||
|
stage: str,
|
||||||
|
supervisor: "PluginRunnerSupervisor",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""分发 hook 到所有对应 handler 的便捷方法。
|
||||||
|
|
||||||
|
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
|
||||||
|
无需调用方手动构造 invoke_fn 闭包。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stage: hook 名称
|
||||||
|
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
|
||||||
|
**kwargs: 关键字参数,会展开传递给 handler
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
|
||||||
|
"""
|
||||||
|
handler_entries = self._component_registry.get_hook_handlers(stage)
|
||||||
|
if not handler_entries:
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
current_kwargs = kwargs.copy()
|
||||||
|
blocking_handlers: List["HookHandlerEntry"] = []
|
||||||
|
non_blocking_handlers: List["HookHandlerEntry"] = []
|
||||||
|
|
||||||
|
# 分离 blocking 和非 blocking handler
|
||||||
|
for entry in handler_entries:
|
||||||
|
if entry.blocking:
|
||||||
|
blocking_handlers.append(entry)
|
||||||
|
else:
|
||||||
|
non_blocking_handlers.append(entry)
|
||||||
|
|
||||||
|
# 处理 blocking handlers(同步调用,支持修改参数/提前终止)
|
||||||
|
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
|
||||||
|
for entry in blocking_handlers:
|
||||||
|
hook_args = {"stage": stage, **current_kwargs}
|
||||||
|
try:
|
||||||
|
# 应用超时控制
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._invoke_handler(supervisor, entry, hook_args),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
|
||||||
|
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
if result.modified_kwargs is not None:
|
||||||
|
current_kwargs = result.modified_kwargs
|
||||||
|
if not result.continue_processing:
|
||||||
|
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 处理 non-blocking handlers(异步调用,不阻塞主流程)
|
||||||
|
for entry in non_blocking_handlers:
|
||||||
|
async_kwargs = current_kwargs.copy()
|
||||||
|
hook_args = {"stage": stage, **async_kwargs}
|
||||||
|
task = asyncio.create_task(
|
||||||
|
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
|
||||||
|
)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
|
return current_kwargs
|
||||||
|
|
||||||
|
async def _invoke_handler(
|
||||||
|
self,
|
||||||
|
supervisor: "PluginRunnerSupervisor",
|
||||||
|
handler_entry: "HookHandlerEntry",
|
||||||
|
args: Dict[str, Any],
|
||||||
|
) -> Optional[HookResult]:
|
||||||
|
"""调用单个 handler 并收集结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
supervisor: PluginRunnerSupervisor 实例
|
||||||
|
handler_entry: HookHandlerEntry 实例
|
||||||
|
args: 传递给 handler 的参数字典
|
||||||
|
stage: hook 名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[HookResult]: 执行结果,如果执行失败则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resp_envelope = await supervisor.invoke_plugin(
|
||||||
|
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
|
||||||
|
)
|
||||||
|
resp = resp_envelope.payload
|
||||||
|
result = HookResult(
|
||||||
|
handler_name=handler_entry.full_name,
|
||||||
|
success=resp.get("success", True),
|
||||||
|
continue_processing=resp.get("continue_processing", True),
|
||||||
|
modified_kwargs=resp.get("modified_kwargs"),
|
||||||
|
custom_result=resp.get("custom_result"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
|
||||||
|
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -1,422 +0,0 @@
|
|||||||
"""Host-side WorkflowExecutor
|
|
||||||
|
|
||||||
6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
|
|
||||||
|
|
||||||
每个阶段执行顺序:
|
|
||||||
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
|
|
||||||
2. 按 priority 降序排列
|
|
||||||
3. 串行执行 blocking hook(可修改 message,返回 HookResult)
|
|
||||||
4. 并发执行 non-blocking hook(只读)
|
|
||||||
5. 检查是否有 SKIP_STAGE 或 ABORT
|
|
||||||
6. PLAN 阶段内置 Command 匹配路由
|
|
||||||
|
|
||||||
支持:
|
|
||||||
- HookResult: CONTINUE / SKIP_STAGE / ABORT
|
|
||||||
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
|
|
||||||
- stage_outputs: 阶段间带命名空间的数据传递
|
|
||||||
- modification_log: 消息修改审计
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
|
||||||
|
|
||||||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
|
||||||
|
|
||||||
# 阶段顺序
|
|
||||||
STAGE_SEQUENCE: List[str] = [
|
|
||||||
"ingress",
|
|
||||||
"pre_process",
|
|
||||||
"plan",
|
|
||||||
"tool_execute",
|
|
||||||
"post_process",
|
|
||||||
"egress",
|
|
||||||
]
|
|
||||||
|
|
||||||
# HookResult 常量(与 SDK HookResult enum 值对应)
|
|
||||||
HOOK_CONTINUE = "continue"
|
|
||||||
HOOK_SKIP_STAGE = "skip_stage"
|
|
||||||
HOOK_ABORT = "abort"
|
|
||||||
|
|
||||||
|
|
||||||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
|
||||||
# 从配置文件读取,允许用户调整
|
|
||||||
def _get_blocking_timeout() -> float:
|
|
||||||
return global_config.plugin_runtime.workflow_blocking_timeout_sec
|
|
||||||
|
|
||||||
|
|
||||||
class ModificationRecord:
|
|
||||||
"""消息修改记录"""
|
|
||||||
|
|
||||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
|
||||||
|
|
||||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
|
||||||
self.stage = stage
|
|
||||||
self.hook_name = hook_name
|
|
||||||
self.timestamp = time.perf_counter()
|
|
||||||
self.fields_changed = fields_changed
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowContext:
|
|
||||||
"""Workflow 执行上下文"""
|
|
||||||
|
|
||||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
|
|
||||||
self.trace_id = trace_id or uuid.uuid4().hex
|
|
||||||
self.stream_id = stream_id
|
|
||||||
self.timings: Dict[str, float] = {}
|
|
||||||
self.errors: List[str] = []
|
|
||||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
|
||||||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
|
||||||
# 消息修改审计日志
|
|
||||||
self.modification_log: List[ModificationRecord] = []
|
|
||||||
# PLAN 阶段命令匹配结果
|
|
||||||
self.matched_command: Optional[str] = None
|
|
||||||
|
|
||||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
|
||||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
|
||||||
|
|
||||||
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
|
|
||||||
return self.stage_outputs.get(stage, {}).get(key, default)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowResult:
|
|
||||||
"""Workflow 执行结果"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
status: str = "completed", # completed / aborted / failed
|
|
||||||
return_message: str = "",
|
|
||||||
stopped_at: str = "",
|
|
||||||
diagnostics: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.status = status
|
|
||||||
self.return_message = return_message
|
|
||||||
self.stopped_at = stopped_at
|
|
||||||
self.diagnostics = diagnostics or {}
|
|
||||||
|
|
||||||
|
|
||||||
# invoke_fn 签名
|
|
||||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowExecutor:
|
|
||||||
"""Host-side Workflow 执行器
|
|
||||||
|
|
||||||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, registry: ComponentRegistry) -> None:
|
|
||||||
self._registry = registry
|
|
||||||
self._background_tasks: Set[asyncio.Task] = set()
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
invoke_fn: InvokeFn,
|
|
||||||
message: Optional[Dict[str, Any]] = None,
|
|
||||||
stream_id: Optional[str] = None,
|
|
||||||
context: Optional[WorkflowContext] = None,
|
|
||||||
command_invoke_fn: Optional[InvokeFn] = None,
|
|
||||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
|
||||||
"""执行 workflow pipeline。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
invoke_fn: 用于 workflow_step 的回调
|
|
||||||
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command),
|
|
||||||
未传则复用 invoke_fn
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(result, final_message, context)
|
|
||||||
"""
|
|
||||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
|
||||||
current_message = dict(message) if message else None
|
|
||||||
|
|
||||||
for stage in STAGE_SEQUENCE:
|
|
||||||
stage_start = time.perf_counter()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# PLAN 阶段: 先做 Command 路由
|
|
||||||
if stage == "plan" and current_message:
|
|
||||||
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
|
|
||||||
if cmd_result is not None:
|
|
||||||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
|
||||||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取该阶段所有 hook(已按 priority 降序排列)
|
|
||||||
all_steps = self._registry.get_workflow_steps(stage)
|
|
||||||
if not all_steps:
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 1. Pre-filter
|
|
||||||
filtered_steps = self._pre_filter(all_steps, current_message)
|
|
||||||
|
|
||||||
# 2. 分离 blocking 和 non-blocking
|
|
||||||
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
|
|
||||||
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
|
|
||||||
|
|
||||||
# 3. 串行执行 blocking hook
|
|
||||||
skip_stage = False
|
|
||||||
for step in blocking_steps:
|
|
||||||
hook_result, modified, step_error = await self._invoke_step(
|
|
||||||
invoke_fn, step, stage, ctx, current_message
|
|
||||||
)
|
|
||||||
|
|
||||||
if step_error:
|
|
||||||
error_policy = step.metadata.get("error_policy", "abort")
|
|
||||||
ctx.errors.append(f"{step.full_name}: {step_error}")
|
|
||||||
|
|
||||||
if error_policy == "abort":
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
|
||||||
return (
|
|
||||||
WorkflowResult(
|
|
||||||
status="failed",
|
|
||||||
return_message=step_error,
|
|
||||||
stopped_at=stage,
|
|
||||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
ctx,
|
|
||||||
)
|
|
||||||
elif error_policy == "skip":
|
|
||||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
|
|
||||||
continue
|
|
||||||
else: # log
|
|
||||||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 更新消息(仅 blocking hook 有权修改)
|
|
||||||
if modified:
|
|
||||||
changed_fields = (
|
|
||||||
_diff_keys(current_message, modified) if current_message else list(modified.keys())
|
|
||||||
)
|
|
||||||
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
|
|
||||||
current_message = modified
|
|
||||||
|
|
||||||
if hook_result == HOOK_ABORT:
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
|
||||||
return (
|
|
||||||
WorkflowResult(
|
|
||||||
status="aborted",
|
|
||||||
return_message=f"aborted by {step.full_name}",
|
|
||||||
stopped_at=stage,
|
|
||||||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
ctx,
|
|
||||||
)
|
|
||||||
|
|
||||||
if hook_result == HOOK_SKIP_STAGE:
|
|
||||||
skip_stage = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
|
||||||
if nonblocking_steps and not skip_stage:
|
|
||||||
for step in nonblocking_steps:
|
|
||||||
self._track_background_task(
|
|
||||||
asyncio.create_task(
|
|
||||||
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
|
||||||
ctx.errors.append(f"{stage}: {e}")
|
|
||||||
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
|
|
||||||
return (
|
|
||||||
WorkflowResult(
|
|
||||||
status="failed",
|
|
||||||
return_message=str(e),
|
|
||||||
stopped_at=stage,
|
|
||||||
diagnostics={"trace_id": ctx.trace_id},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
ctx,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
WorkflowResult(
|
|
||||||
status="completed",
|
|
||||||
return_message="workflow completed",
|
|
||||||
diagnostics={"trace_id": ctx.trace_id},
|
|
||||||
),
|
|
||||||
current_message,
|
|
||||||
ctx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _track_background_task(self, task: asyncio.Task) -> None:
|
|
||||||
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
|
|
||||||
self._background_tasks.add(task)
|
|
||||||
task.add_done_callback(self._background_tasks.discard)
|
|
||||||
|
|
||||||
# ─── 内部方法 ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _pre_filter(
|
|
||||||
self,
|
|
||||||
steps: List[RegisteredComponent],
|
|
||||||
message: Optional[Dict[str, Any]],
|
|
||||||
) -> List[RegisteredComponent]:
|
|
||||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
|
||||||
if not message:
|
|
||||||
return steps
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for step in steps:
|
|
||||||
filter_cond = step.metadata.get("filter", {})
|
|
||||||
if not filter_cond:
|
|
||||||
result.append(step)
|
|
||||||
continue
|
|
||||||
if self._match_filter(filter_cond, message):
|
|
||||||
result.append(step)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
|
||||||
"""简单 key-value 匹配过滤。
|
|
||||||
|
|
||||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
|
||||||
全部匹配才通过。
|
|
||||||
"""
|
|
||||||
for key, expected in filter_cond.items():
|
|
||||||
actual = message.get(key)
|
|
||||||
if (isinstance(expected, list) and actual not in expected) or (
|
|
||||||
not isinstance(expected, list) and actual != expected
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _invoke_step(
|
|
||||||
self,
|
|
||||||
invoke_fn: InvokeFn,
|
|
||||||
step: RegisteredComponent,
|
|
||||||
stage: str,
|
|
||||||
ctx: WorkflowContext,
|
|
||||||
message: Optional[Dict[str, Any]],
|
|
||||||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
|
||||||
"""调用单个 blocking hook。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(hook_result, modified_message, error_string_or_None)
|
|
||||||
"""
|
|
||||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
|
||||||
# 使用 hook 声明的超时,但不超过全局安全阀
|
|
||||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
|
||||||
step_key = f"{stage}:{step.full_name}"
|
|
||||||
step_start = time.perf_counter()
|
|
||||||
|
|
||||||
try:
|
|
||||||
coro = invoke_fn(
|
|
||||||
step.plugin_id,
|
|
||||||
step.name,
|
|
||||||
{
|
|
||||||
"stage": stage,
|
|
||||||
"trace_id": ctx.trace_id,
|
|
||||||
"message": message,
|
|
||||||
"stage_outputs": ctx.stage_outputs,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
|
||||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
|
||||||
|
|
||||||
hook_result = resp.get("hook_result", HOOK_CONTINUE)
|
|
||||||
modified_message = resp.get("modified_message")
|
|
||||||
# 存 stage output(如果 hook 提供了)
|
|
||||||
stage_out = resp.get("stage_output")
|
|
||||||
if isinstance(stage_out, dict):
|
|
||||||
for k, v in stage_out.items():
|
|
||||||
ctx.set_stage_output(stage, k, v)
|
|
||||||
|
|
||||||
return hook_result, modified_message, None
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
|
||||||
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
|
||||||
return HOOK_CONTINUE, None, str(e)
|
|
||||||
|
|
||||||
async def _invoke_step_fire_and_forget(
|
|
||||||
self,
|
|
||||||
invoke_fn: InvokeFn,
|
|
||||||
step: RegisteredComponent,
|
|
||||||
stage: str,
|
|
||||||
ctx: WorkflowContext,
|
|
||||||
message: Optional[Dict[str, Any]],
|
|
||||||
) -> None:
|
|
||||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
|
||||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
|
||||||
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
|
|
||||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
|
||||||
|
|
||||||
try:
|
|
||||||
coro = invoke_fn(
|
|
||||||
step.plugin_id,
|
|
||||||
step.name,
|
|
||||||
{
|
|
||||||
"stage": stage,
|
|
||||||
"trace_id": ctx.trace_id,
|
|
||||||
"message": message,
|
|
||||||
"stage_outputs": ctx.stage_outputs,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
|
|
||||||
|
|
||||||
async def _route_command(
|
|
||||||
self,
|
|
||||||
invoke_fn: InvokeFn,
|
|
||||||
message: Dict[str, Any],
|
|
||||||
ctx: WorkflowContext,
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""PLAN 阶段内置 Command 路由。
|
|
||||||
|
|
||||||
在 registry 中查找匹配的 command 组件,
|
|
||||||
匹配到则直接路由到对应 command handler,返回执行结果。
|
|
||||||
不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
|
|
||||||
"""
|
|
||||||
plain_text = message.get("plain_text", "")
|
|
||||||
if not plain_text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
match_result = self._registry.find_command_by_text(plain_text)
|
|
||||||
if match_result is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
matched, matched_groups = match_result
|
|
||||||
|
|
||||||
ctx.matched_command = matched.full_name
|
|
||||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await invoke_fn(
|
|
||||||
matched.plugin_id,
|
|
||||||
matched.name,
|
|
||||||
{
|
|
||||||
"text": plain_text,
|
|
||||||
"message": message,
|
|
||||||
"trace_id": ctx.trace_id,
|
|
||||||
"matched_groups": matched_groups,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
|
||||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
|
||||||
"""返回 new 中与 old 不同的 key 列表。"""
|
|
||||||
return [k for k, v in new.items() if k not in old or old[k] != v]
|
|
||||||
@@ -134,9 +134,9 @@ class ComponentDeclaration(BaseModel):
|
|||||||
name: str = Field(description="组件名称")
|
name: str = Field(description="组件名称")
|
||||||
"""组件名称"""
|
"""组件名称"""
|
||||||
component_type: str = Field(
|
component_type: str = Field(
|
||||||
description="组件类型:action/command/tool/event_handler/workflow_handler/message_gateway"
|
description="组件类型:action/command/tool/event_handler/hook_handler/message_gateway"
|
||||||
)
|
)
|
||||||
"""组件类型:`action`/`command`/`tool`/`event_handler`/`workflow_handler`/`message_gateway`"""
|
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
|
||||||
plugin_id: str = Field(description="所属插件 ID")
|
plugin_id: str = Field(description="所属插件 ID")
|
||||||
"""所属插件 ID"""
|
"""所属插件 ID"""
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
||||||
|
|||||||
Reference in New Issue
Block a user