feat: 增强插件能力检查,支持 generation 校验并添加清理功能
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -75,6 +76,7 @@ class PluginSupervisor:
|
||||
|
||||
# 后台任务
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._runner_output_tasks: List[asyncio.Task] = []
|
||||
self._running = False
|
||||
|
||||
# 注册内部 RPC 方法
|
||||
@@ -224,40 +226,26 @@ class PluginSupervisor:
|
||||
|
||||
# 保存旧进程引用
|
||||
old_process = self._runner_process
|
||||
old_registered_plugins = dict(self._registered_plugins)
|
||||
expected_generation = self._rpc_server.runner_generation + 1
|
||||
|
||||
# 清理旧的组件注册,防止幽灵组件残留
|
||||
for plugin_id in list(self._registered_plugins.keys()):
|
||||
self._component_registry.remove_components_by_plugin(plugin_id)
|
||||
self._policy.revoke_plugin(plugin_id)
|
||||
self._registered_plugins.clear()
|
||||
self._clear_runtime_state()
|
||||
|
||||
# 拉起新 Runner
|
||||
await self._spawn_runner()
|
||||
|
||||
# 等待新 Runner 连接并完成握手
|
||||
for _ in range(30): # 最多等待 30 秒
|
||||
if self._rpc_server.is_connected:
|
||||
break
|
||||
await asyncio.sleep(1.0)
|
||||
else:
|
||||
logger.error("新 Runner 连接超时,回滚")
|
||||
# 回滚:终止新进程
|
||||
if self._runner_process and self._runner_process != old_process:
|
||||
self._runner_process.terminate()
|
||||
self._runner_process = old_process
|
||||
return
|
||||
|
||||
# 健康检查
|
||||
try:
|
||||
await self._spawn_runner()
|
||||
await self._wait_for_runner_generation(expected_generation, timeout_sec=30.0)
|
||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
||||
except Exception as e:
|
||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||
if self._runner_process and self._runner_process != old_process:
|
||||
self._runner_process.terminate()
|
||||
await self._terminate_process(self._runner_process, old_process)
|
||||
self._runner_process = old_process
|
||||
self._registered_plugins = dict(old_registered_plugins)
|
||||
self._rebuild_runtime_state()
|
||||
return
|
||||
|
||||
# 关停旧 Runner
|
||||
@@ -286,13 +274,19 @@ class PluginSupervisor:
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
if envelope.generation != self._rpc_server.runner_generation:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
|
||||
)
|
||||
|
||||
# 记录注册信息
|
||||
self._registered_plugins[reg.plugin_id] = reg
|
||||
|
||||
# 在策略引擎中注册插件
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=self._runner_generation,
|
||||
generation=envelope.generation,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
@@ -329,7 +323,8 @@ class PluginSupervisor:
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._runner_generation += 1
|
||||
self._attach_runner_output_tasks(self._runner_process)
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
|
||||
|
||||
async def _shutdown_runner(self) -> None:
|
||||
@@ -362,6 +357,8 @@ class PluginSupervisor:
|
||||
self._runner_process.kill()
|
||||
await self._runner_process.wait()
|
||||
|
||||
await self._cleanup_runner_output_tasks()
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""周期性健康检查 + 崩溃自动重启"""
|
||||
while self._running:
|
||||
@@ -382,6 +379,7 @@ class PluginSupervisor:
|
||||
self._registered_plugins.clear()
|
||||
|
||||
try:
|
||||
self._clear_runtime_state()
|
||||
await self._spawn_runner()
|
||||
except Exception as e:
|
||||
logger.error(f"Runner 重启失败: {e}", exc_info=True)
|
||||
@@ -407,3 +405,98 @@ class PluginSupervisor:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查异常: {e}")
|
||||
|
||||
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
|
||||
"""等待指定代际的 Runner 完成连接。"""
|
||||
deadline = asyncio.get_running_loop().time() + timeout_sec
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
|
||||
self._runner_generation = self._rpc_server.runner_generation
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
raise TimeoutError(f"等待 Runner generation {expected_generation} 超时")
|
||||
|
||||
def _clear_runtime_state(self) -> None:
|
||||
"""清空当前插件注册态。"""
|
||||
self._component_registry.clear()
|
||||
self._policy.clear()
|
||||
self._registered_plugins.clear()
|
||||
|
||||
def _rebuild_runtime_state(self) -> None:
|
||||
"""根据已记录的插件注册信息重建运行时状态。"""
|
||||
self._component_registry.clear()
|
||||
self._policy.clear()
|
||||
for reg in self._registered_plugins.values():
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=self._rpc_server.runner_generation,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
self._component_registry.register_plugin_components(
|
||||
plugin_id=reg.plugin_id,
|
||||
components=[c.model_dump() for c in reg.components],
|
||||
)
|
||||
|
||||
def _attach_runner_output_tasks(self, process: asyncio.subprocess.Process) -> None:
|
||||
"""为 Runner 输出流创建排空任务,避免 PIPE 填满阻塞子进程。"""
|
||||
streams = (
|
||||
(process.stdout, "stdout"),
|
||||
(process.stderr, "stderr"),
|
||||
)
|
||||
for stream, stream_name in streams:
|
||||
if stream is None:
|
||||
continue
|
||||
task = asyncio.create_task(self._drain_runner_stream(stream, stream_name, process.pid))
|
||||
self._runner_output_tasks.append(task)
|
||||
task.add_done_callback(
|
||||
lambda done_task: self._runner_output_tasks.remove(done_task)
|
||||
if done_task in self._runner_output_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
async def _drain_runner_stream(
|
||||
self,
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
pid: int,
|
||||
) -> None:
|
||||
"""持续消费 Runner 输出,避免 PIPE 回压导致子进程阻塞。"""
|
||||
try:
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
message = line.decode(errors="replace").rstrip()
|
||||
if message:
|
||||
logger.debug(f"[runner:{pid}:{stream_name}] {message}")
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"读取 Runner {stream_name} 失败: {e}")
|
||||
|
||||
async def _cleanup_runner_output_tasks(self) -> None:
|
||||
"""等待并清理 Runner 输出任务。"""
|
||||
tasks = list(self._runner_output_tasks)
|
||||
self._runner_output_tasks.clear()
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if tasks:
|
||||
with contextlib.suppress(Exception):
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@staticmethod
|
||||
async def _terminate_process(
|
||||
process: Optional[asyncio.subprocess.Process],
|
||||
keep_process: Optional[asyncio.subprocess.Process] = None,
|
||||
) -> None:
|
||||
"""终止指定进程,但跳过需要保留的旧进程引用。"""
|
||||
if process is None or process is keep_process or process.returncode is not None:
|
||||
return
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
Reference in New Issue
Block a user