feat: 实现插件注册的分阶段接入与切换机制,优化 RPC 连接管理

This commit is contained in:
DrSmoothl
2026-03-13 15:21:40 +08:00
parent 5d30b3a908
commit 44a9e9ecd7
5 changed files with 393 additions and 87 deletions

View File

@@ -135,6 +135,7 @@ class PluginSupervisor:
# 已注册的插件组件信息
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
# 后台任务
self._health_task: Optional[asyncio.Task] = None
@@ -319,6 +320,10 @@ class PluginSupervisor:
old_session_token = self._rpc_server.session_token
expected_generation = self._rpc_server.runner_generation + 1
# 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
self._rpc_server.begin_staged_takeover()
self._staged_registered_plugins.clear()
# 重新生成 session token防止被终止的旧 Runner 重连
self._rpc_server.reset_session_token()
@@ -330,27 +335,35 @@ class PluginSupervisor:
# 拉起新 Runner
try:
await self._spawn_runner()
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
await self._wait_for_runner_generation(
expected_generation,
timeout_sec=self._runner_spawn_timeout,
allow_staged=True,
)
resp = await self._rpc_server.send_request(
"plugin.health",
timeout_ms=5000,
target_generation=expected_generation,
)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
await self._rpc_server.commit_staged_takeover()
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
await self._terminate_process(self._runner_process, old_process)
await self._rpc_server.rollback_staged_takeover()
self._runner_process = old_process
# 恢复旧 session token使旧 Runner 的连接仍可正常工作
self._rpc_server.restore_session_token(old_session_token)
self._staged_registered_plugins.clear()
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
# 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件
# 只移除不再存在于新注册表中的旧插件组件
for old_pid in list(old_registered_plugins.keys()):
if old_pid not in self._registered_plugins:
self._component_registry.remove_components_by_plugin(old_pid)
self._policy.revoke_plugin(old_pid)
self._runner_generation = self._rpc_server.runner_generation
self._registered_plugins = dict(self._staged_registered_plugins)
self._staged_registered_plugins.clear()
self._rebuild_runtime_state()
# 关停旧 Runner
if old_process and old_process.returncode is None:
@@ -380,13 +393,22 @@ class PluginSupervisor:
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
if envelope.generation != self._rpc_server.runner_generation:
active_generation = self._rpc_server.runner_generation
staged_generation = self._rpc_server.staged_generation
if envelope.generation not in {active_generation, staged_generation}:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
)
# 记录注册信息
if envelope.generation == staged_generation and staged_generation != 0:
self._staged_registered_plugins[reg.plugin_id] = reg
logger.info(
f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
)
return envelope.make_response(payload={"accepted": True, "staged": True})
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
@@ -396,7 +418,8 @@ class PluginSupervisor:
capabilities=reg.capabilities_required or [],
)
# 在 ComponentRegistry 中注册组件
# 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
self._component_registry.remove_components_by_plugin(reg.plugin_id)
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
@@ -518,10 +541,17 @@ class PluginSupervisor:
except Exception as e:
logger.error(f"健康检查异常: {e}")
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
async def _wait_for_runner_generation(
self,
expected_generation: int,
timeout_sec: float,
allow_staged: bool = False,
) -> None:
"""等待指定代际的 Runner 完成连接。"""
deadline = asyncio.get_running_loop().time() + timeout_sec
while asyncio.get_running_loop().time() < deadline:
if allow_staged and self._rpc_server.has_generation(expected_generation):
return
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
self._runner_generation = self._rpc_server.runner_generation
return
@@ -533,6 +563,7 @@ class PluginSupervisor:
self._component_registry.clear()
self._policy.clear()
self._registered_plugins.clear()
self._staged_registered_plugins.clear()
def _rebuild_runtime_state(self) -> None:
"""根据已记录的插件注册信息重建运行时状态。"""