Files
mai-bot/src/plugin_runtime/host/workflow_executor.py

423 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Host-side WorkflowExecutor
6 阶段线性流转INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS
每个阶段执行顺序:
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
2. 按 priority 降序排列
3. 串行执行 blocking hook可修改 message返回 HookResult
4. 并发执行 non-blocking hook只读
5. 检查是否有 SKIP_STAGE 或 ABORT
6. PLAN 阶段内置 Command 匹配路由
支持:
- HookResult: CONTINUE / SKIP_STAGE / ABORT
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
- stage_outputs: 阶段间带命名空间的数据传递
- modification_log: 消息修改审计
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import asyncio
import time
import uuid
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = get_logger("plugin_runtime.host.workflow_executor")
# 阶段顺序
STAGE_SEQUENCE: List[str] = [
"ingress",
"pre_process",
"plan",
"tool_execute",
"post_process",
"egress",
]
# HookResult 常量(与 SDK HookResult enum 值对应)
HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
# 从配置文件读取,允许用户调整
def _get_blocking_timeout() -> float:
return global_config.plugin_runtime.workflow_blocking_timeout_sec
class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
self.stage = stage
self.hook_name = hook_name
self.timestamp = time.perf_counter()
self.fields_changed = fields_changed
class WorkflowContext:
"""Workflow 执行上下文"""
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id
self.timings: Dict[str, float] = {}
self.errors: List[str] = []
# 阶段间数据传递(按 stage 命名空间隔离)
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
# 消息修改审计日志
self.modification_log: List[ModificationRecord] = []
# PLAN 阶段命令匹配结果
self.matched_command: Optional[str] = None
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
self.stage_outputs.setdefault(stage, {})[key] = value
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
return self.stage_outputs.get(stage, {}).get(key, default)
class WorkflowResult:
"""Workflow 执行结果"""
def __init__(
self,
status: str = "completed", # completed / aborted / failed
return_message: str = "",
stopped_at: str = "",
diagnostics: Optional[Dict[str, Any]] = None,
) -> None:
self.status = status
self.return_message = return_message
self.stopped_at = stopped_at
self.diagnostics = diagnostics or {}
# invoke_fn 签名
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class WorkflowExecutor:
"""Host-side Workflow 执行器
实现 stage-based pipeline + per-stage hook chain with priority + early return。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry = registry
self._background_tasks: Set[asyncio.Task] = set()
async def execute(
self,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
stream_id: Optional[str] = None,
context: Optional[WorkflowContext] = None,
command_invoke_fn: Optional[InvokeFn] = None,
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 workflow pipeline。
Args:
invoke_fn: 用于 workflow_step 的回调
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command
未传则复用 invoke_fn
Returns:
(result, final_message, context)
"""
ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None
for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter()
try:
# PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message:
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result)
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 获取该阶段所有 hook已按 priority 降序排列)
all_steps = self._registry.get_workflow_steps(stage)
if not all_steps:
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 1. Pre-filter
filtered_steps = self._pre_filter(all_steps, current_message)
# 2. 分离 blocking 和 non-blocking
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
# 3. 串行执行 blocking hook
skip_stage = False
for step in blocking_steps:
hook_result, modified, step_error = await self._invoke_step(
invoke_fn, step, stage, ctx, current_message
)
if step_error:
error_policy = step.metadata.get("error_policy", "abort")
ctx.errors.append(f"{step.full_name}: {step_error}")
if error_policy == "abort":
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="failed",
return_message=step_error,
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
elif error_policy == "skip":
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
continue
else: # log
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
continue
# 更新消息(仅 blocking hook 有权修改)
if modified:
changed_fields = (
_diff_keys(current_message, modified) if current_message else list(modified.keys())
)
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
current_message = modified
if hook_result == HOOK_ABORT:
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="aborted",
return_message=f"aborted by {step.full_name}",
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
if hook_result == HOOK_SKIP_STAGE:
skip_stage = True
break
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage:
for step in nonblocking_steps:
self._track_background_task(
asyncio.create_task(
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
)
)
ctx.timings[stage] = time.perf_counter() - stage_start
except Exception as e:
ctx.timings[stage] = time.perf_counter() - stage_start
ctx.errors.append(f"{stage}: {e}")
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
return (
WorkflowResult(
status="failed",
return_message=str(e),
stopped_at=stage,
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
return (
WorkflowResult(
status="completed",
return_message="workflow completed",
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# ─── 内部方法 ──────────────────────────────────────────────
def _pre_filter(
self,
steps: List[RegisteredComponent],
message: Optional[Dict[str, Any]],
) -> List[RegisteredComponent]:
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
if not message:
return steps
result = []
for step in steps:
filter_cond = step.metadata.get("filter", {})
if not filter_cond:
result.append(step)
continue
if self._match_filter(filter_cond, message):
result.append(step)
return result
@staticmethod
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
"""简单 key-value 匹配过滤。
filter 中的每个 key 必须在 message 中存在且值相等,
全部匹配才通过。
"""
for key, expected in filter_cond.items():
actual = message.get(key)
if (isinstance(expected, list) and actual not in expected) or (
not isinstance(expected, list) and actual != expected
):
return False
return True
async def _invoke_step(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: Optional[Dict[str, Any]],
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
"""调用单个 blocking hook。
Returns:
(hook_result, modified_message, error_string_or_None)
"""
timeout_ms = step.metadata.get("timeout_ms", 0)
# 使用 hook 声明的超时,但不超过全局安全阀
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
step_key = f"{stage}:{step.full_name}"
step_start = time.perf_counter()
try:
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
ctx.timings[step_key] = time.perf_counter() - step_start
hook_result = resp.get("hook_result", HOOK_CONTINUE)
modified_message = resp.get("modified_message")
# 存 stage output如果 hook 提供了)
stage_out = resp.get("stage_output")
if isinstance(stage_out, dict):
for k, v in stage_out.items():
ctx.set_stage_output(stage, k, v)
return hook_result, modified_message, None
except asyncio.TimeoutError:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
except Exception as e:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, str(e)
async def _invoke_step_fire_and_forget(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: Optional[Dict[str, Any]],
) -> None:
"""Non-blocking hook 调用,只读,忽略结果。"""
timeout_ms = step.metadata.get("timeout_ms", 0)
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
try:
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
await asyncio.wait_for(coro, timeout=timeout_sec)
except asyncio.TimeoutError:
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
except Exception as e:
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
async def _route_command(
self,
invoke_fn: InvokeFn,
message: Dict[str, Any],
ctx: WorkflowContext,
) -> Optional[Dict[str, Any]]:
"""PLAN 阶段内置 Command 路由。
在 registry 中查找匹配的 command 组件,
匹配到则直接路由到对应 command handler返回执行结果。
不匹配则返回 None让 PLAN 阶段的 hook 继续执行。
"""
plain_text = message.get("plain_text", "")
if not plain_text:
return None
match_result = self._registry.find_command_by_text(plain_text)
if match_result is None:
return None
matched, matched_groups = match_result
ctx.matched_command = matched.full_name
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try:
return await invoke_fn(
matched.plugin_id,
matched.name,
{
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
"matched_groups": matched_groups,
},
)
except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}")
return None
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
"""返回 new 中与 old 不同的 key 列表。"""
return [k for k, v in new.items() if k not in old or old[k] != v]