Ruff Format
This commit is contained in:
@@ -67,7 +67,9 @@ class CapabilityService:
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
|
||||
if not allowed:
|
||||
error_code = ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||
error_code = (
|
||||
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||
)
|
||||
return envelope.make_error_response(
|
||||
error_code.value,
|
||||
reason,
|
||||
|
||||
@@ -22,8 +22,13 @@ class RegisteredComponent:
|
||||
"""已注册的组件条目"""
|
||||
|
||||
__slots__ = (
|
||||
"name", "full_name", "component_type", "plugin_id",
|
||||
"metadata", "enabled", "_compiled_pattern",
|
||||
"name",
|
||||
"full_name",
|
||||
"component_type",
|
||||
"plugin_id",
|
||||
"metadata",
|
||||
"enabled",
|
||||
"_compiled_pattern",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -165,18 +170,14 @@ class ComponentRegistry:
|
||||
"""按全名查询。"""
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True
|
||||
) -> List[RegisteredComponent]:
|
||||
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""按类型查询。"""
|
||||
type_dict = self._by_type.get(component_type, {})
|
||||
if enabled_only:
|
||||
return [c for c in type_dict.values() if c.enabled]
|
||||
return list(type_dict.values())
|
||||
|
||||
def get_components_by_plugin(
|
||||
self, plugin_id: str, *, enabled_only: bool = True
|
||||
) -> List[RegisteredComponent]:
|
||||
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""按插件查询。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||
@@ -200,9 +201,7 @@ class ComponentRegistry:
|
||||
return comp, {}
|
||||
return None
|
||||
|
||||
def get_event_handlers(
|
||||
self, event_type: str, *, enabled_only: bool = True
|
||||
) -> List[RegisteredComponent]:
|
||||
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||
handlers = []
|
||||
for comp in self._by_type.get("event_handler", {}).values():
|
||||
@@ -213,9 +212,7 @@ class ComponentRegistry:
|
||||
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
|
||||
return handlers
|
||||
|
||||
def get_workflow_steps(
|
||||
self, stage: str, *, enabled_only: bool = True
|
||||
) -> List[RegisteredComponent]:
|
||||
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
|
||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||
steps = []
|
||||
for comp in self._by_type.get("workflow_step", {}).values():
|
||||
|
||||
@@ -22,6 +22,7 @@ InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
class EventResult:
|
||||
"""单个 EventHandler 的执行结果"""
|
||||
|
||||
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
|
||||
|
||||
def __init__(
|
||||
@@ -107,9 +108,7 @@ class EventDispatcher:
|
||||
modified_message = result.modified_message
|
||||
else:
|
||||
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
|
||||
task = asyncio.create_task(
|
||||
self._invoke_handler(invoke_fn, handler, args, event_type)
|
||||
)
|
||||
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Dict, List, Optional, Set, Tuple
|
||||
@dataclass
|
||||
class CapabilityToken:
|
||||
"""能力令牌"""
|
||||
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
@@ -231,9 +231,7 @@ class RPCServer:
|
||||
stale_count = 0
|
||||
for _req_id, future in list(self._pending_requests.items()):
|
||||
if not future.done():
|
||||
future.set_exception(
|
||||
RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管")
|
||||
)
|
||||
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
|
||||
stale_count += 1
|
||||
self._pending_requests.clear()
|
||||
if stale_count:
|
||||
@@ -399,9 +397,7 @@ class RPCServer:
|
||||
result = await handler(envelope)
|
||||
# 检查 handler 返回的信封是否包含错误信息
|
||||
if result is not None and isinstance(result, Envelope) and result.error:
|
||||
logger.warning(
|
||||
f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}"
|
||||
)
|
||||
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ logger = get_logger("plugin_runtime.host.supervisor")
|
||||
|
||||
# ─── 日志桥 ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RunnerLogBridge:
|
||||
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
|
||||
|
||||
@@ -80,9 +81,7 @@ class RunnerLogBridge:
|
||||
|
||||
stdlib_logging.getLogger(entry.logger_name).handle(record)
|
||||
|
||||
return envelope.make_response(
|
||||
payload={"accepted": True, "count": len(batch.entries)}
|
||||
)
|
||||
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})
|
||||
|
||||
|
||||
class PluginSupervisor:
|
||||
@@ -101,8 +100,12 @@ class PluginSupervisor:
|
||||
):
|
||||
_cfg = global_config.plugin_runtime
|
||||
self._plugin_dirs = plugin_dirs or []
|
||||
self._health_interval = health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
|
||||
self._runner_spawn_timeout = runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
|
||||
self._health_interval = (
|
||||
health_check_interval_sec if health_check_interval_sec is not None else _cfg.health_check_interval_sec
|
||||
)
|
||||
self._runner_spawn_timeout = (
|
||||
runner_spawn_timeout_sec if runner_spawn_timeout_sec is not None else _cfg.runner_spawn_timeout_sec
|
||||
)
|
||||
|
||||
# 基础设施
|
||||
self._transport = create_transport_server(socket_path=socket_path)
|
||||
@@ -114,6 +117,7 @@ class PluginSupervisor:
|
||||
|
||||
# 编解码
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
|
||||
codec = MsgPackCodec()
|
||||
|
||||
self._rpc_server = RPCServer(
|
||||
@@ -124,7 +128,9 @@ class PluginSupervisor:
|
||||
# Runner 子进程
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._runner_generation: int = 0
|
||||
self._max_restart_attempts: int = max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
|
||||
self._max_restart_attempts: int = (
|
||||
max_restart_attempts if max_restart_attempts is not None else _cfg.max_restart_attempts
|
||||
)
|
||||
self._restart_count: int = 0
|
||||
|
||||
# 已注册的插件组件信息
|
||||
@@ -173,6 +179,7 @@ class PluginSupervisor:
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler 的快捷方法。"""
|
||||
|
||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.emit_event",
|
||||
@@ -196,6 +203,7 @@ class PluginSupervisor:
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 Workflow Pipeline 的快捷方法。"""
|
||||
|
||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.invoke_workflow_step",
|
||||
@@ -415,7 +423,9 @@ class PluginSupervisor:
|
||||
env[ENV_PLUGIN_DIRS] = os.pathsep.join(self._plugin_dirs)
|
||||
|
||||
self._runner_process = await asyncio.create_subprocess_exec(
|
||||
sys.executable, "-m", runner_module,
|
||||
sys.executable,
|
||||
"-m",
|
||||
runner_module,
|
||||
env=env,
|
||||
# stdout 不捕获:Runner 的日志均通过 IPC 传㛹(RunnerIPCLogHandler)
|
||||
stdout=None,
|
||||
@@ -557,9 +567,7 @@ class PluginSupervisor:
|
||||
)
|
||||
self._stderr_drain_task = task
|
||||
task.add_done_callback(
|
||||
lambda done_task: None
|
||||
if self._stderr_drain_task is not done_task
|
||||
else self._clear_stderr_drain_task()
|
||||
lambda done_task: None if self._stderr_drain_task is not done_task else self._clear_stderr_drain_task()
|
||||
)
|
||||
|
||||
def _clear_stderr_drain_task(self) -> None:
|
||||
|
||||
@@ -44,6 +44,7 @@ HOOK_CONTINUE = "continue"
|
||||
HOOK_SKIP_STAGE = "skip_stage"
|
||||
HOOK_ABORT = "abort"
|
||||
|
||||
|
||||
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
|
||||
# 从配置文件读取,允许用户调整
|
||||
def _get_blocking_timeout() -> float:
|
||||
@@ -52,6 +53,7 @@ def _get_blocking_timeout() -> float:
|
||||
|
||||
class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
|
||||
@@ -141,9 +143,7 @@ class WorkflowExecutor:
|
||||
try:
|
||||
# PLAN 阶段: 先做 Command 路由
|
||||
if stage == "plan" and current_message:
|
||||
cmd_result = await self._route_command(
|
||||
command_invoke_fn or invoke_fn, current_message, ctx
|
||||
)
|
||||
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
|
||||
if cmd_result is not None:
|
||||
# 命令匹配成功,跳过 PLAN 阶段的 hook,直接存结果进 stage_outputs
|
||||
ctx.set_stage_output("plan", "command_result", cmd_result)
|
||||
@@ -195,10 +195,10 @@ class WorkflowExecutor:
|
||||
|
||||
# 更新消息(仅 blocking hook 有权修改)
|
||||
if modified:
|
||||
changed_fields = _diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||||
ctx.modification_log.append(
|
||||
ModificationRecord(stage, step.full_name, changed_fields)
|
||||
changed_fields = (
|
||||
_diff_keys(current_message, modified) if current_message else list(modified.keys())
|
||||
)
|
||||
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
|
||||
current_message = modified
|
||||
|
||||
if hook_result == HOOK_ABORT:
|
||||
@@ -222,9 +222,7 @@ class WorkflowExecutor:
|
||||
if nonblocking_steps and not skip_stage:
|
||||
nb_tasks = [
|
||||
asyncio.create_task(
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
|
||||
)
|
||||
for step in nonblocking_steps
|
||||
]
|
||||
@@ -314,12 +312,16 @@ class WorkflowExecutor:
|
||||
step_start = time.perf_counter()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(step.plugin_id, step.name, {
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
})
|
||||
coro = invoke_fn(
|
||||
step.plugin_id,
|
||||
step.name,
|
||||
{
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
},
|
||||
)
|
||||
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
ctx.timings[step_key] = time.perf_counter() - step_start
|
||||
|
||||
@@ -355,12 +357,16 @@ class WorkflowExecutor:
|
||||
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
|
||||
|
||||
try:
|
||||
coro = invoke_fn(step.plugin_id, step.name, {
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
})
|
||||
coro = invoke_fn(
|
||||
step.plugin_id,
|
||||
step.name,
|
||||
{
|
||||
"stage": stage,
|
||||
"trace_id": ctx.trace_id,
|
||||
"message": message,
|
||||
"stage_outputs": ctx.stage_outputs,
|
||||
},
|
||||
)
|
||||
await asyncio.wait_for(coro, timeout=timeout_sec)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
|
||||
@@ -393,12 +399,16 @@ class WorkflowExecutor:
|
||||
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
|
||||
|
||||
try:
|
||||
return await invoke_fn(matched.plugin_id, matched.name, {
|
||||
"text": plain_text,
|
||||
"message": message,
|
||||
"trace_id": ctx.trace_id,
|
||||
"matched_groups": matched_groups,
|
||||
})
|
||||
return await invoke_fn(
|
||||
matched.plugin_id,
|
||||
matched.name,
|
||||
{
|
||||
"text": plain_text,
|
||||
"message": message,
|
||||
"trace_id": ctx.trace_id,
|
||||
"matched_groups": matched_groups,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
|
||||
ctx.errors.append(f"command:{matched.full_name}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user