feat: 实现插件注册的分阶段接入与切换机制,优化 RPC 连接管理

This commit is contained in:
DrSmoothl
2026-03-13 15:21:40 +08:00
parent 5d30b3a908
commit 44a9e9ecd7
5 changed files with 393 additions and 87 deletions

View File

@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import asyncio
import secrets
@@ -55,12 +55,16 @@ class RPCServer:
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
self._runner_id: Optional[str] = None
self._runner_generation: int = 0
self._staged_connection: Optional[Connection] = None
self._staged_runner_id: Optional[str] = None
self._staged_runner_generation: int = 0
self._staging_takeover: bool = False
# 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, asyncio.Future] = {}
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[bytes]] = None
@@ -86,10 +90,72 @@ class RPCServer:
def runner_generation(self) -> int:
return self._runner_generation
@property
def staged_generation(self) -> int:
return self._staged_runner_generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def has_generation(self, generation: int) -> bool:
return generation == self._runner_generation or (
self._staged_connection is not None
and not self._staged_connection.is_closed
and generation == self._staged_runner_generation
)
def begin_staged_takeover(self) -> None:
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
self._staging_takeover = True
async def commit_staged_takeover(self) -> None:
"""提交 staged Runner原活跃连接在提交后被关闭。"""
if self._staged_connection is None or self._staged_connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
old_connection = self._connection
old_generation = self._runner_generation
self._connection = self._staged_connection
self._runner_id = self._staged_runner_id
self._runner_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
stale_count = self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=old_generation,
)
if stale_count:
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
await old_connection.close()
async def rollback_staged_takeover(self) -> None:
"""放弃 staged Runner保留当前活跃连接。"""
staged_connection = self._staged_connection
staged_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"新 Runner 预热失败,已回滚",
generation=staged_generation,
)
if staged_connection and not staged_connection.is_closed:
await staged_connection.close()
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
self._method_handlers[method] = handler
@@ -106,7 +172,7 @@ class RPCServer:
self._running = False
# 取消所有 pending 请求
for future in self._pending_requests.values():
for future, _generation in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
@@ -121,6 +187,10 @@ class RPCServer:
await self._connection.close()
self._connection = None
if self._staged_connection:
await self._staged_connection.close()
self._staged_connection = None
await self._transport.stop()
logger.info("RPC Server 已停止")
@@ -130,6 +200,7 @@ class RPCServer:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
target_generation: Optional[int] = None,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
@@ -145,7 +216,9 @@ class RPCServer:
Raises:
RPCError: 调用失败
"""
if not self.is_connected:
generation = target_generation or self._runner_generation
conn = self._get_connection_for_generation(generation)
if conn is None or conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
request_id = self._id_gen.next()
@@ -154,7 +227,7 @@ class RPCServer:
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=self._runner_generation,
generation=generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -166,12 +239,12 @@ class RPCServer:
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = future
self._pending_requests[request_id] = (future, generation)
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
await conn.send_frame(data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -207,11 +280,13 @@ class RPCServer:
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
previous_connection = self._connection
previous_generation = self._runner_generation
# 第一条消息必须是 runner.hello 握手
try:
handshake_ok = await self._handle_handshake(conn)
if not handshake_ok:
role = await self._handle_handshake(conn)
if role is None:
await conn.close()
return
except Exception as e:
@@ -219,41 +294,52 @@ class RPCServer:
await conn.close()
return
old_connection = self._connection
self._connection = conn
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if role == "staged":
expected_generation = self._staged_runner_generation
logger.info(
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
)
else:
self._connection = conn
expected_generation = self._runner_generation
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if old_connection and old_connection is not conn and not old_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
# 新连接接管后,旧 Runner 的 in-flight 请求不会再收到响应
# (过期 generation 响应会被 _handle_response 丢弃),
# 在此处立即 fail-fast 所有 pending 请求,避免挂到超时
stale_count = 0
for _req_id, future in list(self._pending_requests.items()):
if not future.done():
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管"))
stale_count += 1
self._pending_requests.clear()
if stale_count:
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await old_connection.close()
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
stale_count = self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=previous_generation,
)
if stale_count:
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await previous_connection.close()
# 启动消息接收循环
try:
await self._recv_loop(conn)
await self._recv_loop(conn, expected_generation=expected_generation)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
if self._connection is conn:
self._connection = None
self._runner_id = None
# 连接断开时,立即让所有等待中的请求失败,避免挂起至超时
for _req_id, future in list(self._pending_requests.items()):
if not future.done():
future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开"))
self._pending_requests.clear()
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已断开",
generation=expected_generation,
)
elif self._staged_connection is conn:
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Staged Runner 连接已断开",
generation=expected_generation,
)
async def _handle_handshake(self, conn: Connection) -> bool:
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
@@ -266,7 +352,7 @@ class RPCServer:
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
return False
return None
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
@@ -280,7 +366,7 @@ class RPCServer:
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
return None
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
@@ -291,23 +377,31 @@ class RPCServer:
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
return None
# 握手成功
self._runner_id = hello.runner_id
self._runner_generation += 1
role = "active"
assigned_generation = self._runner_generation + 1
if self._staging_takeover and self.is_connected:
role = "staged"
self._staged_connection = conn
self._staged_runner_id = hello.runner_id
self._staged_runner_generation = assigned_generation
else:
self._runner_id = hello.runner_id
self._runner_generation = assigned_generation
resp_payload = HelloResponsePayload(
accepted=True,
host_version=PROTOCOL_VERSION,
assigned_generation=self._runner_generation,
assigned_generation=assigned_generation,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return True
return role
async def _recv_loop(self, conn: Connection) -> None:
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
@@ -329,10 +423,10 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
if not self._is_current_generation(envelope):
if envelope.generation != expected_generation:
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {self._runner_generation}",
f"过期 generation: {envelope.generation} != {expected_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
@@ -341,9 +435,9 @@ class RPCServer:
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
if not self._is_current_generation(envelope):
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}"
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
@@ -352,22 +446,24 @@ class RPCServer:
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
if not self._is_current_generation(envelope):
pending = self._pending_requests.get(envelope.request_id)
if pending is None:
return
future, expected_generation = pending
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}"
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
)
return
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
self._pending_requests.pop(envelope.request_id, None)
if not future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
def _is_current_generation(self, envelope: Envelope) -> bool:
return envelope.generation == self._runner_generation
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)
@@ -411,3 +507,26 @@ class RPCServer:
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
if generation == self._runner_generation:
return self._connection
if generation == self._staged_runner_generation:
return self._staged_connection
return None
def _fail_pending_requests(
self,
error_code: ErrorCode,
message: str,
generation: Optional[int] = None,
) -> int:
stale_count = 0
for request_id, (future, request_generation) in list(self._pending_requests.items()):
if generation is not None and request_generation != generation:
continue
if not future.done():
future.set_exception(RPCError(error_code, message))
stale_count += 1
self._pending_requests.pop(request_id, None)
return stale_count

View File

@@ -135,6 +135,7 @@ class PluginSupervisor:
# 已注册的插件组件信息
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {}
# 后台任务
self._health_task: Optional[asyncio.Task] = None
@@ -319,6 +320,10 @@ class PluginSupervisor:
old_session_token = self._rpc_server.session_token
expected_generation = self._rpc_server.runner_generation + 1
# 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接
self._rpc_server.begin_staged_takeover()
self._staged_registered_plugins.clear()
# 重新生成 session token防止被终止的旧 Runner 重连
self._rpc_server.reset_session_token()
@@ -330,27 +335,35 @@ class PluginSupervisor:
# 拉起新 Runner
try:
await self._spawn_runner()
await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout)
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
await self._wait_for_runner_generation(
expected_generation,
timeout_sec=self._runner_spawn_timeout,
allow_staged=True,
)
resp = await self._rpc_server.send_request(
"plugin.health",
timeout_ms=5000,
target_generation=expected_generation,
)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
await self._rpc_server.commit_staged_takeover()
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
await self._terminate_process(self._runner_process, old_process)
await self._rpc_server.rollback_staged_takeover()
self._runner_process = old_process
# 恢复旧 session token使旧 Runner 的连接仍可正常工作
self._rpc_server.restore_session_token(old_session_token)
self._staged_registered_plugins.clear()
self._registered_plugins = dict(old_registered_plugins)
self._rebuild_runtime_state()
return
# 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件
# 只移除不再存在于新注册表中的旧插件组件
for old_pid in list(old_registered_plugins.keys()):
if old_pid not in self._registered_plugins:
self._component_registry.remove_components_by_plugin(old_pid)
self._policy.revoke_plugin(old_pid)
self._runner_generation = self._rpc_server.runner_generation
self._registered_plugins = dict(self._staged_registered_plugins)
self._staged_registered_plugins.clear()
self._rebuild_runtime_state()
# 关停旧 Runner
if old_process and old_process.returncode is None:
@@ -380,13 +393,22 @@ class PluginSupervisor:
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
if envelope.generation != self._rpc_server.runner_generation:
active_generation = self._rpc_server.runner_generation
staged_generation = self._rpc_server.staged_generation
if envelope.generation not in {active_generation, staged_generation}:
return envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}",
f"组件注册 generation 过期: {envelope.generation} 不在已知代际中",
)
# 记录注册信息
if envelope.generation == staged_generation and staged_generation != 0:
self._staged_registered_plugins[reg.plugin_id] = reg
logger.info(
f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功,"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
)
return envelope.make_response(payload={"accepted": True, "staged": True})
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
@@ -396,7 +418,8 @@ class PluginSupervisor:
capabilities=reg.capabilities_required or [],
)
# 在 ComponentRegistry 中注册组件
# 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件
self._component_registry.remove_components_by_plugin(reg.plugin_id)
self._component_registry.register_plugin_components(
plugin_id=reg.plugin_id,
components=[c.model_dump() for c in reg.components],
@@ -518,10 +541,17 @@ class PluginSupervisor:
except Exception as e:
logger.error(f"健康检查异常: {e}")
async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None:
async def _wait_for_runner_generation(
self,
expected_generation: int,
timeout_sec: float,
allow_staged: bool = False,
) -> None:
"""等待指定代际的 Runner 完成连接。"""
deadline = asyncio.get_running_loop().time() + timeout_sec
while asyncio.get_running_loop().time() < deadline:
if allow_staged and self._rpc_server.has_generation(expected_generation):
return
if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation:
self._runner_generation = self._rpc_server.runner_generation
return
@@ -533,6 +563,7 @@ class PluginSupervisor:
self._component_registry.clear()
self._policy.clear()
self._registered_plugins.clear()
self._staged_registered_plugins.clear()
def _rebuild_runtime_state(self) -> None:
"""根据已记录的插件注册信息重建运行时状态。"""

View File

@@ -430,8 +430,10 @@ class PluginRuntimeManager:
"""
from src.services import send_service as send_api
message_type: str = args.get("message_type", "")
content = args.get("content", "")
message_type: str = args.get("message_type", "") or args.get("custom_type", "")
content = args.get("content")
if content is None:
content = args.get("data", "")
stream_id: str = args.get("stream_id", "")
if not message_type or not stream_id:
return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"}

View File

@@ -76,6 +76,7 @@ class RPCClient:
# 运行状态
self._running = False
self._recv_task: Optional[asyncio.Task] = None
self._background_tasks: set[asyncio.Task] = set()
@property
def generation(self) -> int:
@@ -144,6 +145,13 @@ class RPCClient:
await self._recv_task
self._recv_task = None
for task in list(self._background_tasks):
task.cancel()
if self._background_tasks:
with contextlib.suppress(Exception):
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
# 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
@@ -248,9 +256,9 @@ class RPCClient:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
asyncio.create_task(self._handle_request(envelope))
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
elif envelope.is_event():
asyncio.create_task(self._handle_event(envelope))
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Host 的响应"""
@@ -290,3 +298,8 @@ class RPCClient:
await handler(envelope)
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持后台任务强引用,直到其完成或被取消。"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)