Refactor protocol and transport modules to use type hints for improved clarity and consistency
- Updated Codec class to use abstract methods for encoding and decoding envelopes. - Changed Envelope class to use Dict and Optional for payload and error fields. - Refined error handling in RPCError class with Optional type hints for details. - Enhanced manifest validation logic with type hints for better type safety. - Improved plugin loading mechanism with consistent type annotations. - Updated RPCClient to utilize Optional for codec and connection attributes. - Refactored transport classes to use Optional for server attributes and socket paths.
This commit is contained in:
@@ -4,22 +4,21 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
|
||||
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
CapabilityRequestPayload,
|
||||
CapabilityResponsePayload,
|
||||
Envelope,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
|
||||
logger = get_logger("plugin_runtime.host.capability_service")
|
||||
|
||||
# 能力实现函数类型: (plugin_id, capability, args) -> result
|
||||
CapabilityImpl = Callable[[str, str, dict[str, Any]], Awaitable[Any]]
|
||||
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
class CapabilityService:
|
||||
@@ -35,7 +34,7 @@ class CapabilityService:
|
||||
def __init__(self, policy_engine: PolicyEngine):
|
||||
self._policy = policy_engine
|
||||
# capability_name -> implementation
|
||||
self._implementations: dict[str, CapabilityImpl] = {}
|
||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||
|
||||
def register_capability(self, name: str, impl: CapabilityImpl) -> None:
|
||||
"""注册一个能力实现
|
||||
@@ -95,6 +94,6 @@ class CapabilityService:
|
||||
str(e),
|
||||
)
|
||||
|
||||
def list_capabilities(self) -> list[str]:
|
||||
def list_capabilities(self) -> List[str]:
|
||||
"""列出所有已注册的能力"""
|
||||
return list(self._implementations.keys())
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
- 注册统计
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
import re
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
@@ -30,7 +31,7 @@ class RegisteredComponent:
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
):
|
||||
self.name = name
|
||||
self.full_name = f"{plugin_id}.{name}"
|
||||
@@ -40,7 +41,7 @@ class RegisteredComponent:
|
||||
self.enabled = metadata.get("enabled", True)
|
||||
|
||||
# 预编译命令正则(仅 command 类型)
|
||||
self._compiled_pattern: re.Pattern | None = None
|
||||
self._compiled_pattern: Optional[re.Pattern] = None
|
||||
if component_type == "command":
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
@@ -58,10 +59,10 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 全量索引
|
||||
self._components: dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
|
||||
# 按类型索引
|
||||
self._by_type: dict[str, dict[str, RegisteredComponent]] = {
|
||||
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
|
||||
"action": {},
|
||||
"command": {},
|
||||
"tool": {},
|
||||
@@ -70,7 +71,7 @@ class ComponentRegistry:
|
||||
}
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: dict[str, list[RegisteredComponent]] = {}
|
||||
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
|
||||
|
||||
# ──── 注册 / 注销 ─────────────────────────────────────────
|
||||
|
||||
@@ -79,7 +80,7 @@ class ComponentRegistry:
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""注册单个组件。"""
|
||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||
@@ -99,7 +100,7 @@ class ComponentRegistry:
|
||||
def register_plugin_components(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: list[dict[str, Any]],
|
||||
components: List[Dict[str, Any]],
|
||||
) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。"""
|
||||
count = 0
|
||||
@@ -142,13 +143,13 @@ class ComponentRegistry:
|
||||
|
||||
# ──── 查询方法 ─────────────────────────────────────────────
|
||||
|
||||
def get_component(self, full_name: str) -> RegisteredComponent | None:
|
||||
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
|
||||
"""按全名查询。"""
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""按类型查询。"""
|
||||
type_dict = self._by_type.get(component_type, {})
|
||||
if enabled_only:
|
||||
@@ -157,12 +158,12 @@ class ComponentRegistry:
|
||||
|
||||
def get_components_by_plugin(
|
||||
self, plugin_id: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""按插件查询。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||
|
||||
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
|
||||
def find_command_by_text(self, text: str) -> Optional[RegisteredComponent]:
|
||||
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
|
||||
for comp in self._by_type.get("command", {}).values():
|
||||
if not comp.enabled:
|
||||
@@ -178,7 +179,7 @@ class ComponentRegistry:
|
||||
|
||||
def get_event_handlers(
|
||||
self, event_type: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||
handlers = []
|
||||
for comp in self._by_type.get("event_handler", {}).values():
|
||||
@@ -191,7 +192,7 @@ class ComponentRegistry:
|
||||
|
||||
def get_workflow_steps(
|
||||
self, stage: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||
steps = []
|
||||
for comp in self._by_type.get("workflow_step", {}).values():
|
||||
@@ -202,11 +203,11 @@ class ComponentRegistry:
|
||||
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
|
||||
return steps
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]:
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
|
||||
"""获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
|
||||
result = []
|
||||
result: List[Dict[str, Any]] = []
|
||||
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
|
||||
tool_def: dict[str, Any] = {
|
||||
tool_def: Dict[str, Any] = {
|
||||
"name": comp.full_name,
|
||||
"description": comp.metadata.get("description", ""),
|
||||
}
|
||||
@@ -222,9 +223,9 @@ class ComponentRegistry:
|
||||
|
||||
# ──── 统计 ─────────────────────────────────────────────────
|
||||
|
||||
def get_stats(self) -> dict[str, int]:
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取注册统计。"""
|
||||
stats: dict[str, int] = {"total": len(self._components)}
|
||||
stats: Dict[str, int] = {"total": len(self._components)}
|
||||
for comp_type, type_dict in self._by_type.items():
|
||||
stats[comp_type] = len(type_dict)
|
||||
stats["plugins"] = len(self._by_plugin)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 事件结果历史记录
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -17,7 +17,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
|
||||
logger = get_logger("plugin_runtime.host.event_dispatcher")
|
||||
|
||||
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
|
||||
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class EventResult:
|
||||
@@ -29,7 +29,7 @@ class EventResult:
|
||||
handler_name: str,
|
||||
success: bool = True,
|
||||
continue_processing: bool = True,
|
||||
modified_message: dict[str, Any] | None = None,
|
||||
modified_message: Optional[Dict[str, Any]] = None,
|
||||
custom_result: Any = None,
|
||||
):
|
||||
self.handler_name = handler_name
|
||||
@@ -49,14 +49,14 @@ class EventDispatcher:
|
||||
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry: ComponentRegistry = registry
|
||||
self._result_history: dict[str, list[EventResult]] = {}
|
||||
self._history_enabled: set[str] = set()
|
||||
self._result_history: Dict[str, List[EventResult]] = {}
|
||||
self._history_enabled: Set[str] = set()
|
||||
|
||||
def enable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.add(event_type)
|
||||
self._result_history.setdefault(event_type, [])
|
||||
|
||||
def get_history(self, event_type: str) -> list[EventResult]:
|
||||
def get_history(self, event_type: str) -> List[EventResult]:
|
||||
return self._result_history.get(event_type, [])
|
||||
|
||||
def clear_history(self, event_type: str) -> None:
|
||||
@@ -67,9 +67,9 @@ class EventDispatcher:
|
||||
self,
|
||||
event_type: str,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler。
|
||||
|
||||
Args:
|
||||
@@ -86,8 +86,8 @@ class EventDispatcher:
|
||||
return True, None
|
||||
|
||||
should_continue = True
|
||||
modified_message: dict[str, Any] | None = None
|
||||
fire_and_forget_tasks: list[asyncio.Task] = []
|
||||
modified_message: Optional[Dict[str, Any]] = None
|
||||
fire_and_forget_tasks: List[asyncio.Task] = []
|
||||
|
||||
for handler in handlers:
|
||||
intercept = handler.metadata.get("intercept_message", False)
|
||||
@@ -122,9 +122,9 @@ class EventDispatcher:
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
handler: RegisteredComponent,
|
||||
args: dict[str, Any],
|
||||
args: Dict[str, Any],
|
||||
event_type: str,
|
||||
) -> EventResult | None:
|
||||
) -> Optional[EventResult]:
|
||||
"""调用单个 handler 并收集结果。"""
|
||||
try:
|
||||
resp = await invoke_fn(handler.plugin_id, handler.name, args)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,7 +13,7 @@ class CapabilityToken:
|
||||
"""能力令牌"""
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: set[str] = field(default_factory=set)
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
@@ -22,13 +23,13 @@ class PolicyEngine:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tokens: dict[str, CapabilityToken] = {}
|
||||
self._tokens: Dict[str, CapabilityToken] = {}
|
||||
|
||||
def register_plugin(
|
||||
self,
|
||||
plugin_id: str,
|
||||
generation: int,
|
||||
capabilities: list[str],
|
||||
capabilities: List[str],
|
||||
) -> CapabilityToken:
|
||||
"""为插件签发能力令牌"""
|
||||
token = CapabilityToken(
|
||||
@@ -43,7 +44,7 @@ class PolicyEngine:
|
||||
"""撤销插件的能力令牌"""
|
||||
self._tokens.pop(plugin_id, None)
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> tuple[bool, str]:
|
||||
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
@@ -58,10 +59,10 @@ class PolicyEngine:
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> CapabilityToken | None:
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
|
||||
"""获取插件的能力令牌"""
|
||||
return self._tokens.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[str]:
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._tokens.keys())
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
@@ -42,8 +42,8 @@ class RPCServer:
|
||||
def __init__(
|
||||
self,
|
||||
transport: TransportServer,
|
||||
session_token: str | None = None,
|
||||
codec: Codec | None = None,
|
||||
session_token: Optional[str] = None,
|
||||
codec: Optional[Codec] = None,
|
||||
send_queue_size: int = 128,
|
||||
):
|
||||
self._transport = transport
|
||||
@@ -52,22 +52,22 @@ class RPCServer:
|
||||
self._send_queue_size = send_queue_size
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Connection | None = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: str | None = None
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: dict[str, MethodHandler] = {}
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: asyncio.Queue[bytes] | None = None
|
||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._tasks: list[asyncio.Task] = []
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
@@ -115,7 +115,7 @@ class RPCServer:
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str = "",
|
||||
payload: dict[str, Any] | None = None,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
@@ -172,7 +172,7 @@ class RPCServer:
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""向 Runner 发送单向事件(不等待响应)"""
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 优雅关停
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
@@ -40,8 +40,8 @@ class PluginSupervisor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugin_dirs: list[str] | None = None,
|
||||
socket_path: str | None = None,
|
||||
plugin_dirs: Optional[List[str]] = None,
|
||||
socket_path: Optional[str] = None,
|
||||
health_check_interval_sec: float = 30.0,
|
||||
):
|
||||
self._plugin_dirs = plugin_dirs or []
|
||||
@@ -65,16 +65,16 @@ class PluginSupervisor:
|
||||
)
|
||||
|
||||
# Runner 子进程
|
||||
self._runner_process: asyncio.subprocess.Process | None = None
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._runner_generation: int = 0
|
||||
self._max_restart_attempts: int = 3
|
||||
self._restart_count: int = 0
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
|
||||
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||
|
||||
# 后台任务
|
||||
self._health_task: asyncio.Task | None = None
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
# 注册内部 RPC 方法
|
||||
@@ -107,11 +107,11 @@ class PluginSupervisor:
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
message: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler 的快捷方法。"""
|
||||
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
|
||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.emit_event",
|
||||
plugin_id=plugin_id,
|
||||
@@ -129,12 +129,12 @@ class PluginSupervisor:
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 Workflow Pipeline 的快捷方法。"""
|
||||
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
|
||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.invoke_workflow_step",
|
||||
plugin_id=plugin_id,
|
||||
@@ -196,7 +196,7 @@ class PluginSupervisor:
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""调用插件组件
|
||||
@@ -225,6 +225,12 @@ class PluginSupervisor:
|
||||
# 保存旧进程引用
|
||||
old_process = self._runner_process
|
||||
|
||||
# 清理旧的组件注册,防止幽灵组件残留
|
||||
for plugin_id in list(self._registered_plugins.keys()):
|
||||
self._component_registry.remove_components_by_plugin(plugin_id)
|
||||
self._policy.revoke_plugin(plugin_id)
|
||||
self._registered_plugins.clear()
|
||||
|
||||
# 拉起新 Runner
|
||||
await self._spawn_runner()
|
||||
|
||||
@@ -286,7 +292,7 @@ class PluginSupervisor:
|
||||
# 在策略引擎中注册插件
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=envelope.generation,
|
||||
generation=self._runner_generation,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
- modification_log: 消息修改审计
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
@@ -29,7 +29,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
|
||||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
||||
|
||||
# 阶段顺序
|
||||
STAGE_SEQUENCE: list[str] = [
|
||||
STAGE_SEQUENCE: List[str] = [
|
||||
"ingress",
|
||||
"pre_process",
|
||||
"plan",
|
||||
@@ -48,7 +48,7 @@ class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]):
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
@@ -58,17 +58,17 @@ class ModificationRecord:
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: str | None = None, stream_id: str | None = None):
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: dict[str, float] = {}
|
||||
self.errors: list[str] = []
|
||||
self.timings: Dict[str, float] = {}
|
||||
self.errors: List[str] = []
|
||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||||
self.stage_outputs: dict[str, dict[str, Any]] = {}
|
||||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
||||
# 消息修改审计日志
|
||||
self.modification_log: list[ModificationRecord] = []
|
||||
self.modification_log: List[ModificationRecord] = []
|
||||
# PLAN 阶段命令匹配结果
|
||||
self.matched_command: str | None = None
|
||||
self.matched_command: Optional[str] = None
|
||||
|
||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||||
@@ -85,7 +85,7 @@ class WorkflowResult:
|
||||
status: str = "completed", # completed / aborted / failed
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: dict[str, Any] | None = None,
|
||||
diagnostics: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
@@ -94,7 +94,7 @@ class WorkflowResult:
|
||||
|
||||
|
||||
# invoke_fn 签名
|
||||
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
@@ -109,10 +109,10 @@ class WorkflowExecutor:
|
||||
async def execute(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 workflow pipeline。
|
||||
|
||||
Returns:
|
||||
@@ -120,6 +120,8 @@ class WorkflowExecutor:
|
||||
"""
|
||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||
current_message = dict(message) if message else None
|
||||
# 保持非阻塞任务引用,防止被 GC 回收
|
||||
background_tasks: List[asyncio.Task] = []
|
||||
|
||||
for stage in STAGE_SEQUENCE:
|
||||
stage_start = time.perf_counter()
|
||||
@@ -207,14 +209,15 @@ class WorkflowExecutor:
|
||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||
if nonblocking_steps and not skip_stage:
|
||||
nb_tasks = [
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
asyncio.create_task(
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
)
|
||||
for step in nonblocking_steps
|
||||
]
|
||||
# 并发执行但不阻塞 pipeline
|
||||
for task in [asyncio.create_task(t) for t in nb_tasks]:
|
||||
task.add_done_callback(lambda _: None)
|
||||
# 保持任务引用以防止被 GC 回收
|
||||
background_tasks.extend(nb_tasks)
|
||||
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
|
||||
@@ -247,9 +250,9 @@ class WorkflowExecutor:
|
||||
|
||||
def _pre_filter(
|
||||
self,
|
||||
steps: list[RegisteredComponent],
|
||||
message: dict[str, Any] | None,
|
||||
) -> list[RegisteredComponent]:
|
||||
steps: List[RegisteredComponent],
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> List[RegisteredComponent]:
|
||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||||
if not message:
|
||||
return steps
|
||||
@@ -265,7 +268,7 @@ class WorkflowExecutor:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool:
|
||||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
||||
"""简单 key-value 匹配过滤。
|
||||
|
||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||||
@@ -285,8 +288,8 @@ class WorkflowExecutor:
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, Any] | None, str | None]:
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""调用单个 blocking hook。
|
||||
|
||||
Returns:
|
||||
@@ -331,7 +334,7 @@ class WorkflowExecutor:
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
@@ -354,9 +357,9 @@ class WorkflowExecutor:
|
||||
async def _route_command(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any],
|
||||
message: Dict[str, Any],
|
||||
ctx: WorkflowContext,
|
||||
) -> dict[str, Any] | None:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""PLAN 阶段内置 Command 路由。
|
||||
|
||||
在 registry 中查找匹配的 command 组件,
|
||||
@@ -386,6 +389,6 @@ class WorkflowExecutor:
|
||||
return None
|
||||
|
||||
|
||||
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
|
||||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
||||
"""返回 new 中与 old 不同的 key 列表。"""
|
||||
return [k for k, v in new.items() if k not in old or old[k] != v]
|
||||
|
||||
Reference in New Issue
Block a user