423 lines
16 KiB
Python
423 lines
16 KiB
Python
"""Host-side WorkflowExecutor
|
||
|
||
6 阶段线性流转(INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS)
|
||
|
||
每个阶段执行顺序:
|
||
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
|
||
2. 按 priority 降序排列
|
||
3. 串行执行 blocking hook(可修改 message,返回 HookResult)
|
||
4. 并发执行 non-blocking hook(只读)
|
||
5. 检查是否有 SKIP_STAGE 或 ABORT
|
||
6. PLAN 阶段内置 Command 匹配路由
|
||
|
||
支持:
|
||
- HookResult: CONTINUE / SKIP_STAGE / ABORT
|
||
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
|
||
- stage_outputs: 阶段间带命名空间的数据传递
|
||
- modification_log: 消息修改审计
|
||
"""
|
||
|
||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||
|
||
import asyncio
|
||
import time
|
||
import uuid
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config
|
||
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
|
||
|
||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
||
|
||
# 阶段顺序
|
||
STAGE_SEQUENCE: List[str] = [
|
||
"ingress",
|
||
"pre_process",
|
||
"plan",
|
||
"tool_execute",
|
||
"post_process",
|
||
"egress",
|
||
]
|
||
|
||
# HookResult 常量(与 SDK HookResult enum 值对应)
|
||
HOOK_CONTINUE = "continue"
|
||
HOOK_SKIP_STAGE = "skip_stage"
|
||
HOOK_ABORT = "abort"
|
||
|
||
|
||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
||
# 从配置文件读取,允许用户调整
|
||
def _get_blocking_timeout() -> float:
|
||
return global_config.plugin_runtime.workflow_blocking_timeout_sec
|
||
|
||
|
||
class ModificationRecord:
|
||
"""消息修改记录"""
|
||
|
||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||
|
||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||
self.stage = stage
|
||
self.hook_name = hook_name
|
||
self.timestamp = time.perf_counter()
|
||
self.fields_changed = fields_changed
|
||
|
||
|
||
class WorkflowContext:
|
||
"""Workflow 执行上下文"""
|
||
|
||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
|
||
self.trace_id = trace_id or uuid.uuid4().hex
|
||
self.stream_id = stream_id
|
||
self.timings: Dict[str, float] = {}
|
||
self.errors: List[str] = []
|
||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
||
# 消息修改审计日志
|
||
self.modification_log: List[ModificationRecord] = []
|
||
# PLAN 阶段命令匹配结果
|
||
self.matched_command: Optional[str] = None
|
||
|
||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||
|
||
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
|
||
return self.stage_outputs.get(stage, {}).get(key, default)
|
||
|
||
|
||
class WorkflowResult:
|
||
"""Workflow 执行结果"""
|
||
|
||
def __init__(
|
||
self,
|
||
status: str = "completed", # completed / aborted / failed
|
||
return_message: str = "",
|
||
stopped_at: str = "",
|
||
diagnostics: Optional[Dict[str, Any]] = None,
|
||
) -> None:
|
||
self.status = status
|
||
self.return_message = return_message
|
||
self.stopped_at = stopped_at
|
||
self.diagnostics = diagnostics or {}
|
||
|
||
|
||
# invoke_fn 签名
|
||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||
|
||
|
||
class WorkflowExecutor:
|
||
"""Host-side Workflow 执行器
|
||
|
||
实现 stage-based pipeline + per-stage hook chain with priority + early return。
|
||
"""
|
||
|
||
def __init__(self, registry: ComponentRegistry) -> None:
|
||
self._registry = registry
|
||
self._background_tasks: Set[asyncio.Task] = set()
|
||
|
||
async def execute(
|
||
self,
|
||
invoke_fn: InvokeFn,
|
||
message: Optional[Dict[str, Any]] = None,
|
||
stream_id: Optional[str] = None,
|
||
context: Optional[WorkflowContext] = None,
|
||
command_invoke_fn: Optional[InvokeFn] = None,
|
||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||
"""执行 workflow pipeline。
|
||
|
||
Args:
|
||
invoke_fn: 用于 workflow_step 的回调
|
||
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command),
|
||
未传则复用 invoke_fn
|
||
|
||
Returns:
|
||
(result, final_message, context)
|
||
"""
|
||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||
current_message = dict(message) if message else None
|
||
|
||
for stage in STAGE_SEQUENCE:
|
||
stage_start = time.perf_counter()
|
||
|
||
try:
|
||
# PLAN 阶段: 先做 Command 路由
|
||
if stage == "plan" and current_message:
|
||
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
|
||
if cmd_result is not None:
|
||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||
continue
|
||
|
||
# 获取该阶段所有 hook(已按 priority 降序排列)
|
||
all_steps = self._registry.get_workflow_steps(stage)
|
||
if not all_steps:
|
||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||
continue
|
||
|
||
# 1. Pre-filter
|
||
filtered_steps = self._pre_filter(all_steps, current_message)
|
||
|
||
# 2. 分离 blocking 和 non-blocking
|
||
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
|
||
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
|
||
|
||
# 3. 串行执行 blocking hook
|
||
skip_stage = False
|
||
for step in blocking_steps:
|
||
hook_result, modified, step_error = await self._invoke_step(
|
||
invoke_fn, step, stage, ctx, current_message
|
||
)
|
||
|
||
if step_error:
|
||
error_policy = step.metadata.get("error_policy", "abort")
|
||
ctx.errors.append(f"{step.full_name}: {step_error}")
|
||
|
||
if error_policy == "abort":
|
||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||
return (
|
||
WorkflowResult(
|
||
status="failed",
|
||
return_message=step_error,
|
||
stopped_at=stage,
|
||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||
),
|
||
current_message,
|
||
ctx,
|
||
)
|
||
elif error_policy == "skip":
|
||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
|
||
continue
|
||
else: # log
|
||
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
|
||
continue
|
||
|
||
# 更新消息(仅 blocking hook 有权修改)
|
||
if modified:
|
||
changed_fields = (
|
||
_diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||
)
|
||
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
|
||
current_message = modified
|
||
|
||
if hook_result == HOOK_ABORT:
|
||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||
return (
|
||
WorkflowResult(
|
||
status="aborted",
|
||
return_message=f"aborted by {step.full_name}",
|
||
stopped_at=stage,
|
||
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
|
||
),
|
||
current_message,
|
||
ctx,
|
||
)
|
||
|
||
if hook_result == HOOK_SKIP_STAGE:
|
||
skip_stage = True
|
||
break
|
||
|
||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||
if nonblocking_steps and not skip_stage:
|
||
for step in nonblocking_steps:
|
||
self._track_background_task(
|
||
asyncio.create_task(
|
||
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
||
)
|
||
)
|
||
|
||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||
|
||
except Exception as e:
|
||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||
ctx.errors.append(f"{stage}: {e}")
|
||
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
|
||
return (
|
||
WorkflowResult(
|
||
status="failed",
|
||
return_message=str(e),
|
||
stopped_at=stage,
|
||
diagnostics={"trace_id": ctx.trace_id},
|
||
),
|
||
current_message,
|
||
ctx,
|
||
)
|
||
|
||
return (
|
||
WorkflowResult(
|
||
status="completed",
|
||
return_message="workflow completed",
|
||
diagnostics={"trace_id": ctx.trace_id},
|
||
),
|
||
current_message,
|
||
ctx,
|
||
)
|
||
|
||
def _track_background_task(self, task: asyncio.Task) -> None:
|
||
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
|
||
self._background_tasks.add(task)
|
||
task.add_done_callback(self._background_tasks.discard)
|
||
|
||
# ─── 内部方法 ──────────────────────────────────────────────
|
||
|
||
def _pre_filter(
|
||
self,
|
||
steps: List[RegisteredComponent],
|
||
message: Optional[Dict[str, Any]],
|
||
) -> List[RegisteredComponent]:
|
||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||
if not message:
|
||
return steps
|
||
|
||
result = []
|
||
for step in steps:
|
||
filter_cond = step.metadata.get("filter", {})
|
||
if not filter_cond:
|
||
result.append(step)
|
||
continue
|
||
if self._match_filter(filter_cond, message):
|
||
result.append(step)
|
||
return result
|
||
|
||
@staticmethod
|
||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
||
"""简单 key-value 匹配过滤。
|
||
|
||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||
全部匹配才通过。
|
||
"""
|
||
for key, expected in filter_cond.items():
|
||
actual = message.get(key)
|
||
if (isinstance(expected, list) and actual not in expected) or (
|
||
not isinstance(expected, list) and actual != expected
|
||
):
|
||
return False
|
||
return True
|
||
|
||
async def _invoke_step(
|
||
self,
|
||
invoke_fn: InvokeFn,
|
||
step: RegisteredComponent,
|
||
stage: str,
|
||
ctx: WorkflowContext,
|
||
message: Optional[Dict[str, Any]],
|
||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
||
"""调用单个 blocking hook。
|
||
|
||
Returns:
|
||
(hook_result, modified_message, error_string_or_None)
|
||
"""
|
||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||
# 使用 hook 声明的超时,但不超过全局安全阀
|
||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||
step_key = f"{stage}:{step.full_name}"
|
||
step_start = time.perf_counter()
|
||
|
||
try:
|
||
coro = invoke_fn(
|
||
step.plugin_id,
|
||
step.name,
|
||
{
|
||
"stage": stage,
|
||
"trace_id": ctx.trace_id,
|
||
"message": message,
|
||
"stage_outputs": ctx.stage_outputs,
|
||
},
|
||
)
|
||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||
|
||
hook_result = resp.get("hook_result", HOOK_CONTINUE)
|
||
modified_message = resp.get("modified_message")
|
||
# 存 stage output(如果 hook 提供了)
|
||
stage_out = resp.get("stage_output")
|
||
if isinstance(stage_out, dict):
|
||
for k, v in stage_out.items():
|
||
ctx.set_stage_output(stage, k, v)
|
||
|
||
return hook_result, modified_message, None
|
||
|
||
except asyncio.TimeoutError:
|
||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
|
||
|
||
except Exception as e:
|
||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||
return HOOK_CONTINUE, None, str(e)
|
||
|
||
async def _invoke_step_fire_and_forget(
|
||
self,
|
||
invoke_fn: InvokeFn,
|
||
step: RegisteredComponent,
|
||
stage: str,
|
||
ctx: WorkflowContext,
|
||
message: Optional[Dict[str, Any]],
|
||
) -> None:
|
||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
|
||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||
|
||
try:
|
||
coro = invoke_fn(
|
||
step.plugin_id,
|
||
step.name,
|
||
{
|
||
"stage": stage,
|
||
"trace_id": ctx.trace_id,
|
||
"message": message,
|
||
"stage_outputs": ctx.stage_outputs,
|
||
},
|
||
)
|
||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
||
except Exception as e:
|
||
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
|
||
|
||
async def _route_command(
|
||
self,
|
||
invoke_fn: InvokeFn,
|
||
message: Dict[str, Any],
|
||
ctx: WorkflowContext,
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""PLAN 阶段内置 Command 路由。
|
||
|
||
在 registry 中查找匹配的 command 组件,
|
||
匹配到则直接路由到对应 command handler,返回执行结果。
|
||
不匹配则返回 None,让 PLAN 阶段的 hook 继续执行。
|
||
"""
|
||
plain_text = message.get("plain_text", "")
|
||
if not plain_text:
|
||
return None
|
||
|
||
match_result = self._registry.find_command_by_text(plain_text)
|
||
if match_result is None:
|
||
return None
|
||
|
||
matched, matched_groups = match_result
|
||
|
||
ctx.matched_command = matched.full_name
|
||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||
|
||
try:
|
||
return await invoke_fn(
|
||
matched.plugin_id,
|
||
matched.name,
|
||
{
|
||
"text": plain_text,
|
||
"message": message,
|
||
"trace_id": ctx.trace_id,
|
||
"matched_groups": matched_groups,
|
||
},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||
return None
|
||
|
||
|
||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
||
"""返回 new 中与 old 不同的 key 列表。"""
|
||
return [k for k, v in new.items() if k not in old or old[k] != v]
|