Refactor protocol and transport modules to use type hints for improved clarity and consistency
- Updated Codec class to use abstract methods for encoding and decoding envelopes. - Changed Envelope class to use Dict and Optional for payload and error fields. - Refined error handling in RPCError class with Optional type hints for details. - Enhanced manifest validation logic with type hints for better type safety. - Improved plugin loading mechanism with consistent type annotations. - Updated RPCClient to utilize Optional for codec and connection attributes. - Refactored transport classes to use Optional for server attributes and socket paths.
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
- modification_log: 消息修改审计
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
@@ -29,7 +29,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
|
||||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
||||
|
||||
# 阶段顺序
|
||||
STAGE_SEQUENCE: list[str] = [
|
||||
STAGE_SEQUENCE: List[str] = [
|
||||
"ingress",
|
||||
"pre_process",
|
||||
"plan",
|
||||
@@ -48,7 +48,7 @@ class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]):
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
@@ -58,17 +58,17 @@ class ModificationRecord:
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: str | None = None, stream_id: str | None = None):
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: dict[str, float] = {}
|
||||
self.errors: list[str] = []
|
||||
self.timings: Dict[str, float] = {}
|
||||
self.errors: List[str] = []
|
||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||||
self.stage_outputs: dict[str, dict[str, Any]] = {}
|
||||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
||||
# 消息修改审计日志
|
||||
self.modification_log: list[ModificationRecord] = []
|
||||
self.modification_log: List[ModificationRecord] = []
|
||||
# PLAN 阶段命令匹配结果
|
||||
self.matched_command: str | None = None
|
||||
self.matched_command: Optional[str] = None
|
||||
|
||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||||
@@ -85,7 +85,7 @@ class WorkflowResult:
|
||||
status: str = "completed", # completed / aborted / failed
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: dict[str, Any] | None = None,
|
||||
diagnostics: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
@@ -94,7 +94,7 @@ class WorkflowResult:
|
||||
|
||||
|
||||
# invoke_fn 签名
|
||||
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
@@ -109,10 +109,10 @@ class WorkflowExecutor:
|
||||
async def execute(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 workflow pipeline。
|
||||
|
||||
Returns:
|
||||
@@ -120,6 +120,8 @@ class WorkflowExecutor:
|
||||
"""
|
||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||
current_message = dict(message) if message else None
|
||||
# 保持非阻塞任务引用,防止被 GC 回收
|
||||
background_tasks: List[asyncio.Task] = []
|
||||
|
||||
for stage in STAGE_SEQUENCE:
|
||||
stage_start = time.perf_counter()
|
||||
@@ -207,14 +209,15 @@ class WorkflowExecutor:
|
||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||
if nonblocking_steps and not skip_stage:
|
||||
nb_tasks = [
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
asyncio.create_task(
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
)
|
||||
for step in nonblocking_steps
|
||||
]
|
||||
# 并发执行但不阻塞 pipeline
|
||||
for task in [asyncio.create_task(t) for t in nb_tasks]:
|
||||
task.add_done_callback(lambda _: None)
|
||||
# 保持任务引用以防止被 GC 回收
|
||||
background_tasks.extend(nb_tasks)
|
||||
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
|
||||
@@ -247,9 +250,9 @@ class WorkflowExecutor:
|
||||
|
||||
def _pre_filter(
|
||||
self,
|
||||
steps: list[RegisteredComponent],
|
||||
message: dict[str, Any] | None,
|
||||
) -> list[RegisteredComponent]:
|
||||
steps: List[RegisteredComponent],
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> List[RegisteredComponent]:
|
||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||||
if not message:
|
||||
return steps
|
||||
@@ -265,7 +268,7 @@ class WorkflowExecutor:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool:
|
||||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
||||
"""简单 key-value 匹配过滤。
|
||||
|
||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||||
@@ -285,8 +288,8 @@ class WorkflowExecutor:
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, Any] | None, str | None]:
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""调用单个 blocking hook。
|
||||
|
||||
Returns:
|
||||
@@ -331,7 +334,7 @@ class WorkflowExecutor:
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
@@ -354,9 +357,9 @@ class WorkflowExecutor:
|
||||
async def _route_command(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any],
|
||||
message: Dict[str, Any],
|
||||
ctx: WorkflowContext,
|
||||
) -> dict[str, Any] | None:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""PLAN 阶段内置 Command 路由。
|
||||
|
||||
在 registry 中查找匹配的 command 组件,
|
||||
@@ -386,6 +389,6 @@ class WorkflowExecutor:
|
||||
return None
|
||||
|
||||
|
||||
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
|
||||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
||||
"""返回 new 中与 old 不同的 key 列表。"""
|
||||
return [k for k, v in new.items() if k not in old or old[k] != v]
|
||||
|
||||
Reference in New Issue
Block a user