feat: Enhance plugin runtime with new component registry and workflow executor

- Introduced `ComponentRegistry` for managing plugin components with support for registration, enabling/disabling, and querying by type and plugin.
- Added `EventDispatcher` to handle event distribution to registered event handlers, supporting both blocking and non-blocking execution.
- Implemented `WorkflowExecutor` to manage a linear workflow execution across multiple stages, including command routing and error handling.
- Created `ManifestValidator` for validating plugin manifests against required fields and version compatibility.
- Updated `RPCClient` to use `MsgPackCodec` for message encoding.
- Enhanced `PluginRunner` to support lifecycle hooks for plugins, including `on_load` and `on_unload`.
- Added sys.path isolation to restrict plugin access to only necessary directories.
This commit is contained in:
DrSmoothl
2026-03-06 11:55:59 +08:00
parent 61dc15a513
commit 2f21cd00bc
19 changed files with 1970 additions and 318 deletions

View File

@@ -1 +1 @@
# Host 端 - Supervisor、RPC Server、策略引擎、路由

View File

@@ -73,15 +73,7 @@ class CapabilityService:
reason,
)
# 2. 限流校验
allowed, reason = self._policy.check_rate_limit(plugin_id)
if not allowed:
return envelope.make_error_response(
ErrorCode.E_BACKPRESSURE.value,
reason,
)
# 3. 查找实现
# 2. 查找实现
impl = self._implementations.get(capability)
if impl is None:
return envelope.make_error_response(
@@ -89,7 +81,7 @@ class CapabilityService:
f"未注册的能力: {capability}",
)
# 4. 执行
# 3. 执行
try:
result = await impl(plugin_id, capability, req.args)
resp_payload = CapabilityResponsePayload(success=True, result=result)

View File

@@ -1,105 +0,0 @@
"""熔断器
为每个插件提供熔断保护,连续失败超过阈值后临时禁用。
支持指数退避恢复。
"""
from enum import Enum
import time
class CircuitState(str, Enum):
CLOSED = "closed" # 正常工作
OPEN = "open" # 熔断(拒绝所有调用)
HALF_OPEN = "half_open" # 探测恢复
class CircuitBreaker:
"""单个插件的熔断器"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout_sec: float = 30.0,
max_recovery_timeout_sec: float = 300.0,
):
self.failure_threshold = failure_threshold
self.base_recovery_timeout = recovery_timeout_sec
self.max_recovery_timeout = max_recovery_timeout_sec
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = 0.0
self._consecutive_opens = 0 # 用于指数退避
@property
def state(self) -> CircuitState:
if self._state == CircuitState.OPEN:
# 检查是否可以进入半开状态
elapsed = time.monotonic() - self._last_failure_time
recovery_timeout = min(
self.base_recovery_timeout * (2 ** self._consecutive_opens),
self.max_recovery_timeout,
)
if elapsed >= recovery_timeout:
self._state = CircuitState.HALF_OPEN
return self._state
def allow_request(self) -> bool:
"""是否允许通过请求"""
state = self.state
if state == CircuitState.CLOSED:
return True
if state == CircuitState.HALF_OPEN:
return True # 允许一次试探
return False # OPEN 状态拒绝
def record_success(self) -> None:
"""记录一次成功调用"""
if self._state == CircuitState.HALF_OPEN:
# 半开状态成功 -> 关闭熔断
self._state = CircuitState.CLOSED
self._failure_count = 0
self._consecutive_opens = 0
elif self._state == CircuitState.CLOSED:
self._failure_count = 0
def record_failure(self) -> None:
"""记录一次失败调用"""
self._failure_count += 1
self._last_failure_time = time.monotonic()
if self._state == CircuitState.HALF_OPEN:
# 半开状态失败 -> 重新开启熔断
self._state = CircuitState.OPEN
self._consecutive_opens += 1
elif self._failure_count >= self.failure_threshold:
self._state = CircuitState.OPEN
self._consecutive_opens += 1
def reset(self) -> None:
"""重置熔断器"""
self._state = CircuitState.CLOSED
self._failure_count = 0
self._consecutive_opens = 0
class CircuitBreakerRegistry:
"""熔断器注册表,为每个插件维护独立的熔断器"""
def __init__(self, **default_kwargs):
self._breakers: dict[str, CircuitBreaker] = {}
self._default_kwargs = default_kwargs
def get(self, plugin_id: str) -> CircuitBreaker:
if plugin_id not in self._breakers:
self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs)
return self._breakers[plugin_id]
def remove(self, plugin_id: str) -> None:
self._breakers.pop(plugin_id, None)
def reset_all(self) -> None:
for breaker in self._breakers.values():
breaker.reset()

View File

@@ -0,0 +1,235 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_step
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
- 多维度查询(按名称、类型、插件)
- 注册统计
"""
from typing import Any
import logging
import re
logger = logging.getLogger("plugin_runtime.host.component_registry")
class RegisteredComponent:
"""已注册的组件条目"""
__slots__ = (
"name", "full_name", "component_type", "plugin_id",
"metadata", "enabled", "_compiled_pattern",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: dict[str, Any],
):
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
self.plugin_id = plugin_id
self.metadata = metadata
self.enabled = metadata.get("enabled", True)
# 预编译命令正则(仅 command 类型)
self._compiled_pattern: re.Pattern | None = None
if component_type == "command":
pattern = metadata.get("command_pattern", "")
if pattern:
try:
self._compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ComponentRegistry:
"""Host-side 组件注册表
由 Supervisor 在收到 plugin.register_components 时调用。
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self):
# 全量索引
self._components: dict[str, RegisteredComponent] = {} # full_name -> comp
# 按类型索引
self._by_type: dict[str, dict[str, RegisteredComponent]] = {
"action": {},
"command": {},
"tool": {},
"event_handler": {},
"workflow_step": {},
}
# 按插件索引
self._by_plugin: dict[str, list[RegisteredComponent]] = {}
# ──── 注册 / 注销 ─────────────────────────────────────────
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: dict[str, Any],
) -> bool:
"""注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
self._components[comp.full_name] = comp
if component_type not in self._by_type:
self._by_type[component_type] = {}
self._by_type[component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
def register_plugin_components(
self,
plugin_id: str,
components: list[dict[str, Any]],
) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。"""
count = 0
for comp_data in components:
ok = self.register_component(
name=comp_data.get("name", ""),
component_type=comp_data.get("component_type", ""),
plugin_id=plugin_id,
metadata=comp_data.get("metadata", {}),
)
if ok:
count += 1
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
type_dict = self._by_type.get(comp.component_type)
if type_dict:
type_dict.pop(comp.full_name, None)
return len(comps)
# ──── 启用 / 禁用 ─────────────────────────────────────────
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
"""启用或禁用指定组件。"""
comp = self._components.get(full_name)
if comp is None:
return False
comp.enabled = enabled
return True
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
"""批量启用或禁用某插件的所有组件。"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
comp.enabled = enabled
return len(comps)
# ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> RegisteredComponent | None:
"""按全名查询。"""
return self._components.get(full_name)
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return list(type_dict.values())
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
if enabled_only:
return [c for c in comps if c.enabled]
return list(comps)
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
for comp in self._by_type.get("command", {}).values():
if not comp.enabled:
continue
if comp._compiled_pattern and comp._compiled_pattern.search(text):
return comp
# 别名匹配
aliases = comp.metadata.get("aliases", [])
for alias in aliases:
if text.startswith(alias):
return comp
return None
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
if enabled_only and not comp.enabled:
continue
if comp.metadata.get("event_type") == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
return handlers
def get_workflow_steps(
self, stage: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():
if enabled_only and not comp.enabled:
continue
if comp.metadata.get("stage") == stage:
steps.append(comp)
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps
def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: dict[str, Any] = {
"name": comp.full_name,
"description": comp.metadata.get("description", ""),
}
# 从结构化参数或原始参数构建 parameters
params = comp.metadata.get("parameters", [])
params_raw = comp.metadata.get("parameters_raw", {})
if params:
tool_def["parameters"] = params
elif params_raw:
tool_def["parameters"] = params_raw
result.append(tool_def)
return result
# ──── 统计 ─────────────────────────────────────────────────
def get_stats(self) -> dict[str, int]:
"""获取注册统计。"""
stats: dict[str, int] = {"total": len(self._components)}
for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats

View File

@@ -0,0 +1,146 @@
"""Host-side EventDispatcher
负责:
1. 按事件类型查询已注册的 event_handler通过 ComponentRegistry
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞intercept_message和非阻塞分发
4. 事件结果历史记录
"""
from typing import Any, Optional
import asyncio
import logging
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = logging.getLogger("plugin_runtime.host.event_dispatcher")
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
self,
handler_name: str,
success: bool = True,
continue_processing: bool = True,
modified_message: dict[str, Any] | None = None,
custom_result: Any = None,
):
self.handler_name = handler_name
self.success = success
self.continue_processing = continue_processing
self.modified_message = modified_message
self.custom_result = custom_result
class EventDispatcher:
"""Host-side 事件分发器
由业务层调用 dispatch_event()
内部通过 ComponentRegistry 查询 handler
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, registry: ComponentRegistry):
self._registry = registry
self._result_history: dict[str, list[EventResult]] = {}
self._history_enabled: set[str] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
def get_history(self, event_type: str) -> list[EventResult]:
return self._result_history.get(event_type, [])
def clear_history(self, event_type: str) -> None:
if event_type in self._result_history:
self._result_history[event_type] = []
async def dispatch_event(
self,
event_type: str,
invoke_fn, # async (plugin_id, component_name, args) -> dict — Supervisor.invoke_plugin wrapper
message: dict[str, Any] | None = None,
extra_args: dict[str, Any] | None = None,
) -> tuple[bool, Optional[dict[str, Any]]]:
"""分发事件到所有对应 handler。
Args:
event_type: 事件类型字符串
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
message: MaiMessages 序列化后的 dict可选
extra_args: 额外参数
Returns:
(should_continue, modified_message_dict)
"""
handlers = self._registry.get_event_handlers(event_type)
if not handlers:
return True, None
should_continue = True
modified_message: dict[str, Any] | None = None
fire_and_forget_tasks: list[asyncio.Task] = []
for handler in handlers:
intercept = handler.metadata.get("intercept_message", False)
args = {
"event_type": event_type,
"message": modified_message or message,
**(extra_args or {}),
}
if intercept:
# 阻塞执行
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
if result and not result.continue_processing:
should_continue = False
if result and result.modified_message:
modified_message = result.modified_message
else:
# 非阻塞
task = asyncio.create_task(
self._invoke_handler(invoke_fn, handler, args, event_type)
)
fire_and_forget_tasks.append(task)
# 不等待 fire-and-forget 任务(但不丢弃引用以防 GC
if fire_and_forget_tasks:
for t in fire_and_forget_tasks:
t.add_done_callback(lambda _t: None)
return should_continue, modified_message
async def _invoke_handler(
self,
invoke_fn,
handler: RegisteredComponent,
args: dict[str, Any],
event_type: str,
) -> EventResult | None:
"""调用单个 handler 并收集结果。"""
try:
resp = await invoke_fn(handler.plugin_id, handler.name, args)
result = EventResult(
handler_name=handler.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(
handler_name=handler.full_name,
success=False,
continue_processing=True,
)
if event_type in self._history_enabled:
self._result_history.setdefault(event_type, []).append(result)
return result

View File

@@ -1,42 +1,27 @@
"""策略引擎
负责能力授权校验、限流、配额管理
负责能力授权校验。
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
"""
from dataclasses import dataclass, field
import time
@dataclass
class CapabilityToken:
"""能力令牌
描述某个插件在当前会话中被授予的能力和资源限制。
"""
"""能力令牌"""
plugin_id: str
generation: int
capabilities: set[str] = field(default_factory=set)
qps_limit: int = 20
burst_limit: int = 50
daily_token_limit: int = 200000
max_payload_kb: int = 256
# 运行时统计
_call_count: int = field(default=0, init=False, repr=False)
_window_start: float = field(default_factory=time.monotonic, init=False, repr=False)
_window_calls: int = field(default=0, init=False, repr=False)
class PolicyEngine:
"""策略引擎
管理所有插件的能力令牌,提供授权校验与限流决策
管理所有插件的能力令牌,提供授权校验。
"""
def __init__(self):
# plugin_id -> CapabilityToken
self._tokens: dict[str, CapabilityToken] = {}
def register_plugin(
@@ -44,18 +29,12 @@ class PolicyEngine:
plugin_id: str,
generation: int,
capabilities: list[str],
limits: dict | None = None,
) -> CapabilityToken:
"""为插件签发能力令牌"""
limits = limits or {}
token = CapabilityToken(
plugin_id=plugin_id,
generation=generation,
capabilities=set(capabilities),
qps_limit=limits.get("qps", 20),
burst_limit=limits.get("burst", 50),
daily_token_limit=limits.get("daily_tokens", 200000),
max_payload_kb=limits.get("max_payload_kb", 256),
)
self._tokens[plugin_id] = token
return token
@@ -79,43 +58,6 @@ class PolicyEngine:
return True, ""
def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]:
"""检查插件是否超过调用频率限制(滑动窗口)
Returns:
(allowed, reason)
"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册"
now = time.monotonic()
elapsed = now - token._window_start
# 每秒重置窗口
if elapsed >= 1.0:
token._window_start = now
token._window_calls = 0
token._window_calls += 1
if token._window_calls > token.burst_limit:
return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)"
return True, ""
def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]:
"""检查 payload 大小是否在限制内"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册"
max_bytes = token.max_payload_kb * 1024
if payload_size_bytes > max_bytes:
return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}"
return True, ""
def get_token(self, plugin_id: str) -> CapabilityToken | None:
"""获取插件的能力令牌"""
return self._tokens.get(plugin_id)

View File

@@ -13,7 +13,7 @@ import asyncio
import logging
import secrets
from src.plugin_runtime.protocol.codec import Codec, create_codec
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
from src.plugin_runtime.protocol.envelope import (
PROTOCOL_VERSION,
MIN_SDK_VERSION,
@@ -48,7 +48,7 @@ class RPCServer:
):
self._transport = transport
self._session_token = session_token or secrets.token_hex(32)
self._codec = codec or create_codec()
self._codec = codec or MsgPackCodec()
self._send_queue_size = send_queue_size
self._id_gen = RequestIdGenerator()

View File

@@ -2,10 +2,9 @@
负责:
1. 拉起 Runner 子进程
2. 健康检查
3. 熔断与恢复
4. 代码热重载generation 切换)
5. 优雅关停
2. 健康检查 + 崩溃自动重启
3. 代码热重载generation 切换)
4. 优雅关停
"""
from typing import Any
@@ -16,9 +15,11 @@ import os
import sys
from src.plugin_runtime.host.capability_service import CapabilityService
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
from src.plugin_runtime.host.component_registry import ComponentRegistry
from src.plugin_runtime.host.event_dispatcher import EventDispatcher
from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.host.workflow_executor import WorkflowExecutor, WorkflowContext, WorkflowResult
from src.plugin_runtime.protocol.envelope import (
Envelope,
HealthPayload,
@@ -42,7 +43,6 @@ class PluginSupervisor:
plugin_dirs: list[str] | None = None,
socket_path: str | None = None,
health_check_interval_sec: float = 30.0,
use_json_codec: bool = False,
):
self._plugin_dirs = plugin_dirs or []
self._health_interval = health_check_interval_sec
@@ -50,12 +50,14 @@ class PluginSupervisor:
# 基础设施
self._transport = create_transport_server(socket_path=socket_path)
self._policy = PolicyEngine()
self._breakers = CircuitBreakerRegistry()
self._capability_service = CapabilityService(self._policy)
self._component_registry = ComponentRegistry()
self._event_dispatcher = EventDispatcher(self._component_registry)
self._workflow_executor = WorkflowExecutor(self._component_registry)
# 编解码
from src.plugin_runtime.protocol.codec import create_codec
codec = create_codec(use_json=use_json_codec)
from src.plugin_runtime.protocol.codec import MsgPackCodec
codec = MsgPackCodec()
self._rpc_server = RPCServer(
transport=self._transport,
@@ -65,6 +67,8 @@ class PluginSupervisor:
# Runner 子进程
self._runner_process: asyncio.subprocess.Process | None = None
self._runner_generation: int = 0
self._max_restart_attempts: int = 3
self._restart_count: int = 0
# 已注册的插件组件信息
self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
@@ -84,10 +88,72 @@ class PluginSupervisor:
def capability_service(self) -> CapabilityService:
return self._capability_service
@property
def component_registry(self) -> ComponentRegistry:
return self._component_registry
@property
def event_dispatcher(self) -> EventDispatcher:
return self._event_dispatcher
@property
def workflow_executor(self) -> WorkflowExecutor:
return self._workflow_executor
@property
def rpc_server(self) -> RPCServer:
return self._rpc_server
async def dispatch_event(
self,
event_type: str,
message: dict[str, Any] | None = None,
extra_args: dict[str, Any] | None = None,
) -> tuple[bool, dict[str, Any] | None]:
"""分发事件到所有对应 handler 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
resp = await self.invoke_plugin(
method="plugin.emit_event",
plugin_id=plugin_id,
component_name=component_name,
args=args,
)
return resp.payload
return await self._event_dispatcher.dispatch_event(
event_type=event_type,
invoke_fn=_invoke,
message=message,
extra_args=extra_args,
)
async def execute_workflow(
self,
message: dict[str, Any] | None = None,
stream_id: str | None = None,
context: WorkflowContext | None = None,
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
"""执行 Workflow Pipeline 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
resp = await self.invoke_plugin(
method="plugin.invoke_workflow_step",
plugin_id=plugin_id,
component_name=component_name,
args=args,
)
payload = resp.payload
if payload.get("success"):
result = payload.get("result")
return result if isinstance(result, dict) else {}
raise RuntimeError(payload.get("result", "workflow step invoke failed"))
return await self._workflow_executor.execute(
invoke_fn=_invoke,
message=message,
stream_id=stream_id,
context=context,
)
async def start(self) -> None:
"""启动 Supervisor
@@ -137,11 +203,6 @@ class PluginSupervisor:
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
"""
# 熔断检查
breaker = self._breakers.get(plugin_id)
if not breaker.allow_request():
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断")
try:
response = await self._rpc_server.send_request(
method=method,
@@ -152,10 +213,8 @@ class PluginSupervisor:
},
timeout_ms=timeout_ms,
)
breaker.record_success()
return response
except RPCError:
breaker.record_failure()
raise
async def reload_plugins(self, reason: str = "manual") -> None:
@@ -232,12 +291,20 @@ class PluginSupervisor:
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=envelope.generation,
capabilities=reg.capabilities_required,
capabilities=reg.capabilities_required or [],
)
# 在 ComponentRegistry 中注册组件
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
)
stats = self._component_registry.get_stats()
logger.info(
f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
f"注册表总计: {stats}"
)
return envelope.make_response(payload={"accepted": True})
@@ -294,10 +361,32 @@ class PluginSupervisor:
await self._runner_process.wait()
async def _health_check_loop(self) -> None:
"""周期性健康检查"""
"""周期性健康检查 + 崩溃自动重启"""
while self._running:
await asyncio.sleep(self._health_interval)
# 检查 Runner 进程是否意外退出
if self._runner_process and self._runner_process.returncode is not None:
exit_code = self._runner_process.returncode
logger.warning(f"Runner 进程已退出 (exit_code={exit_code})")
if self._restart_count < self._max_restart_attempts:
self._restart_count += 1
logger.info(f"尝试重启 Runner ({self._restart_count}/{self._max_restart_attempts})")
# 清理旧的组件注册
for plugin_id in list(self._registered_plugins.keys()):
self._component_registry.remove_components_by_plugin(plugin_id)
self._policy.revoke_plugin(plugin_id)
self._registered_plugins.clear()
try:
await self._spawn_runner()
except Exception as e:
logger.error(f"Runner 重启失败: {e}", exc_info=True)
else:
logger.error(f"Runner 连续崩溃 {self._max_restart_attempts} 次,停止重启")
continue
if not self._rpc_server.is_connected:
logger.warning("Runner 未连接,跳过健康检查")
continue
@@ -307,6 +396,9 @@ class PluginSupervisor:
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
logger.warning(f"Runner 健康检查异常: {health}")
else:
# 健康检查成功,重置重启计数
self._restart_count = 0
except RPCError as e:
logger.error(f"健康检查失败: {e}")
except asyncio.CancelledError:

View File

@@ -0,0 +1,397 @@
"""Host-side WorkflowExecutor
6 阶段线性流转INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS
每个阶段执行顺序:
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
2. 按 priority 降序排列
3. 串行执行 blocking hook可修改 message返回 HookResult
4. 并发执行 non-blocking hook只读
5. 检查是否有 SKIP_STAGE 或 ABORT
6. PLAN 阶段内置 Command 匹配路由
支持:
- HookResult: CONTINUE / SKIP_STAGE / ABORT
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
- stage_outputs: 阶段间带命名空间的数据传递
- modification_log: 消息修改审计
"""
from typing import Any, Callable, Awaitable, Optional
import asyncio
import logging
import time
import uuid
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = logging.getLogger("plugin_runtime.host.workflow_executor")
# 阶段顺序
STAGE_SEQUENCE: list[str] = [
"ingress",
"pre_process",
"plan",
"tool_execute",
"post_process",
"egress",
]
# HookResult 常量(与 SDK HookResult enum 值对应)
HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]):
self.stage = stage
self.hook_name = hook_name
self.timestamp = time.perf_counter()
self.fields_changed = fields_changed
class WorkflowContext:
"""Workflow 执行上下文"""
def __init__(self, trace_id: str | None = None, stream_id: str | None = None):
self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id
self.timings: dict[str, float] = {}
self.errors: list[str] = []
# 阶段间数据传递(按 stage 命名空间隔离)
self.stage_outputs: dict[str, dict[str, Any]] = {}
# 消息修改审计日志
self.modification_log: list[ModificationRecord] = []
# PLAN 阶段命令匹配结果
self.matched_command: str | None = None
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
self.stage_outputs.setdefault(stage, {})[key] = value
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
return self.stage_outputs.get(stage, {}).get(key, default)
class WorkflowResult:
"""Workflow 执行结果"""
def __init__(
self,
status: str = "completed", # completed / aborted / failed
return_message: str = "",
stopped_at: str = "",
diagnostics: dict[str, Any] | None = None,
):
self.status = status
self.return_message = return_message
self.stopped_at = stopped_at
self.diagnostics = diagnostics or {}
# invoke_fn 签名
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
class WorkflowExecutor:
"""Host-side Workflow 执行器
实现 stage-based pipeline + per-stage hook chain with priority + early return。
"""
def __init__(self, registry: ComponentRegistry):
self._registry = registry
async def execute(
self,
invoke_fn: InvokeFn,
message: dict[str, Any] | None = None,
stream_id: str | None = None,
context: WorkflowContext | None = None,
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
"""执行 workflow pipeline。
Returns:
(result, final_message, context)
"""
ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None
for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter()
try:
# PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message:
cmd_result = await self._route_command(
invoke_fn, current_message, ctx
)
if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result)
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 获取该阶段所有 hook已按 priority 降序排列)
all_steps = self._registry.get_workflow_steps(stage)
if not all_steps:
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 1. Pre-filter
filtered_steps = self._pre_filter(all_steps, current_message)
# 2. 分离 blocking 和 non-blocking
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
# 3. 串行执行 blocking hook
skip_stage = False
for step in blocking_steps:
hook_result, modified, step_error = await self._invoke_step(
invoke_fn, step, stage, ctx, current_message
)
if step_error:
error_policy = step.metadata.get("error_policy", "abort")
ctx.errors.append(f"{step.full_name}: {step_error}")
if error_policy == "abort":
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="failed",
return_message=step_error,
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
elif error_policy == "skip":
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
continue
else: # log
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
continue
# 更新消息(仅 blocking hook 有权修改)
if modified:
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys())
ctx.modification_log.append(
ModificationRecord(stage, step.full_name, changed_fields)
)
current_message = modified
if hook_result == HOOK_ABORT:
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="aborted",
return_message=f"aborted by {step.full_name}",
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
if hook_result == HOOK_SKIP_STAGE:
skip_stage = True
break
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage:
nb_tasks = [
self._invoke_step_fire_and_forget(
invoke_fn, step, stage, ctx, current_message
)
for step in nonblocking_steps
]
# 并发执行但不阻塞 pipeline
for task in [asyncio.create_task(t) for t in nb_tasks]:
task.add_done_callback(lambda _: None)
ctx.timings[stage] = time.perf_counter() - stage_start
except Exception as e:
ctx.timings[stage] = time.perf_counter() - stage_start
ctx.errors.append(f"{stage}: {e}")
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
return (
WorkflowResult(
status="failed",
return_message=str(e),
stopped_at=stage,
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
return (
WorkflowResult(
status="completed",
return_message="workflow completed",
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
# ─── 内部方法 ──────────────────────────────────────────────
def _pre_filter(
self,
steps: list[RegisteredComponent],
message: dict[str, Any] | None,
) -> list[RegisteredComponent]:
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
if not message:
return steps
result = []
for step in steps:
filter_cond = step.metadata.get("filter", {})
if not filter_cond:
result.append(step)
continue
if self._match_filter(filter_cond, message):
result.append(step)
return result
@staticmethod
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool:
"""简单 key-value 匹配过滤。
filter 中的每个 key 必须在 message 中存在且值相等,
全部匹配才通过。
"""
for key, expected in filter_cond.items():
actual = message.get(key)
if isinstance(expected, list):
if actual not in expected:
return False
elif actual != expected:
return False
return True
async def _invoke_step(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: dict[str, Any] | None,
) -> tuple[str, dict[str, Any] | None, str | None]:
"""调用单个 blocking hook。
Returns:
(hook_result, modified_message, error_string_or_None)
"""
timeout_ms = step.metadata.get("timeout_ms", 0)
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
step_key = f"{stage}:{step.full_name}"
step_start = time.perf_counter()
try:
coro = invoke_fn(step.plugin_id, step.name, {
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
})
resp = await asyncio.wait_for(coro, timeout=timeout_sec) if timeout_sec else await coro
ctx.timings[step_key] = time.perf_counter() - step_start
hook_result = resp.get("hook_result", HOOK_CONTINUE)
modified_message = resp.get("modified_message")
# 存 stage output如果 hook 提供了)
stage_out = resp.get("stage_output")
if isinstance(stage_out, dict):
for k, v in stage_out.items():
ctx.set_stage_output(stage, k, v)
return hook_result, modified_message, None
except asyncio.TimeoutError:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
except Exception as e:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, str(e)
async def _invoke_step_fire_and_forget(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: dict[str, Any] | None,
) -> None:
"""Non-blocking hook 调用,只读,忽略结果。"""
timeout_ms = step.metadata.get("timeout_ms", 0)
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else None
try:
coro = invoke_fn(step.plugin_id, step.name, {
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
})
if timeout_sec:
await asyncio.wait_for(coro, timeout=timeout_sec)
else:
await coro
except Exception as e:
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
async def _route_command(
self,
invoke_fn: InvokeFn,
message: dict[str, Any],
ctx: WorkflowContext,
) -> dict[str, Any] | None:
"""PLAN 阶段内置 Command 路由。
在 registry 中查找匹配的 command 组件,
匹配到则直接路由到对应 command handler返回执行结果。
不匹配则返回 None让 PLAN 阶段的 hook 继续执行。
"""
plain_text = message.get("plain_text", "")
if not plain_text:
return None
matched = self._registry.find_command_by_text(plain_text)
if matched is None:
return None
ctx.matched_command = matched.full_name
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try:
resp = await invoke_fn(matched.plugin_id, matched.name, {
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
})
return resp
except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}")
return None
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
"""返回 new 中与 old 不同的 key 列表。"""
changed = []
for k in new:
if k not in old or old[k] != new[k]:
changed.append(k)
return changed