From 44a9e9ecd752041ecac4710109b79f1bc1e2e05a Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 13 Mar 2026 15:21:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=E7=9A=84=E5=88=86=E9=98=B6=E6=AE=B5=E6=8E=A5?= =?UTF-8?q?=E5=85=A5=E4=B8=8E=E5=88=87=E6=8D=A2=E6=9C=BA=E5=88=B6=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20RPC=20=E8=BF=9E=E6=8E=A5=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_plugin_runtime.py | 173 ++++++++++++++++-- src/plugin_runtime/host/rpc_server.py | 225 ++++++++++++++++++------ src/plugin_runtime/host/supervisor.py | 59 +++++-- src/plugin_runtime/integration.py | 6 +- src/plugin_runtime/runner/rpc_client.py | 17 +- 5 files changed, 393 insertions(+), 87 deletions(-) diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 7553b54c..db210bc6 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -677,6 +677,28 @@ class TestComponentRegistry: assert removed == 2 assert reg.get_stats()["total"] == 1 + def test_reregister_same_plugin_replaces_component_set(self): + from src.plugin_runtime.host.component_registry import ComponentRegistry + + reg = ComponentRegistry() + reg.register_plugin_components( + "p1", + [ + {"name": "a1", "component_type": "action", "metadata": {}}, + {"name": "a2", "component_type": "action", "metadata": {}}, + ], + ) + reg.remove_components_by_plugin("p1") + reg.register_plugin_components( + "p1", + [ + {"name": "a1", "component_type": "action", "metadata": {}}, + ], + ) + + assert reg.get_component("p1.a1") is not None + assert reg.get_component("p1.a2") is None + def test_event_handlers_sorted_by_weight(self): from src.plugin_runtime.host.component_registry import ComponentRegistry @@ -1257,7 +1279,7 @@ class TestRPCServer: loop = asyncio.new_event_loop() try: future = loop.create_future() - server._pending_requests[1] = future + server._pending_requests[1] = (future, 2) stale_response = Envelope( request_id=1, @@ -1274,23 +1296,53 @@ class TestRPCServer: loop.close() +class TestRPCClient: + """Runner RPCClient 后台任务生命周期测试""" + + @pytest.mark.asyncio + async def test_background_tasks_retained_and_cancelled_on_disconnect(self): + from src.plugin_runtime.runner.rpc_client import RPCClient + + client = RPCClient(host_address="dummy", session_token="token") + release = asyncio.Event() + + async def pending_task(): + await release.wait() + + task = asyncio.create_task(pending_task()) + client._track_background_task(task) + + assert task in client._background_tasks + + await asyncio.sleep(0) + assert task in client._background_tasks + + await client.disconnect() + + assert task.cancelled() is True + assert not client._background_tasks + + class TestSupervisor: """Supervisor 生命周期边界测试""" @staticmethod - def _build_register_payload(plugin_id: str = "plugin_a"): + def _build_register_payload(plugin_id: str = "plugin_a", component_names=None): from src.plugin_runtime.protocol.envelope import ComponentDeclaration, RegisterComponentsPayload + component_names = component_names or ["handler"] + return RegisterComponentsPayload( plugin_id=plugin_id, plugin_version="1.0.0", components=[ ComponentDeclaration( - name="handler", + name=name, component_type="event_handler", plugin_id=plugin_id, metadata={"event_type": "on_message"}, ) + for name in component_names ], capabilities_required=["send.text"], ) @@ -1331,8 +1383,11 @@ class TestSupervisor: class FakeRPCServer: def __init__(self): self.runner_generation = 1 + self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" + self.committed = False + self.staging_started = False def reset_session_token(self): self.session_token = "new-fake-token" @@ -1341,8 +1396,23 @@ class TestSupervisor: def restore_session_token(self, token): self.session_token = token - async def send_request(self, method, timeout_ms=5000, **kwargs): - assert self.runner_generation == 2 + def begin_staged_takeover(self): + self.staging_started = True + self.staged_generation = 2 + + async def commit_staged_takeover(self): + self.runner_generation = self.staged_generation + self.staged_generation = 0 + self.committed = True + + async def rollback_staged_takeover(self): + self.staged_generation = 0 + + def has_generation(self, generation): + return generation in {self.runner_generation, self.staged_generation} + + async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): + assert target_generation == 2 return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) supervisor._rpc_server = FakeRPCServer() @@ -1350,18 +1420,14 @@ class TestSupervisor: async def fake_spawn_runner(): supervisor._runner_process = new_process - - async def advance_generation(): - await asyncio.sleep(0.01) - supervisor._rpc_server.runner_generation = 2 - - asyncio.create_task(advance_generation()) + supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) await supervisor.reload_plugins("test") assert supervisor._runner_process is new_process + assert supervisor._rpc_server.committed is True assert old_process.terminated is True @pytest.mark.asyncio @@ -1380,6 +1446,69 @@ class TestSupervisor: class FakeRPCServer: def __init__(self): self.runner_generation = 1 + self.staged_generation = 0 + self.is_connected = True + self.session_token = "fake-token" + self.rolled_back = False + + def reset_session_token(self): + self.session_token = "new-fake-token" + return self.session_token + + def restore_session_token(self, token): + self.session_token = token + + def begin_staged_takeover(self): + self.staged_generation = 2 + + async def commit_staged_takeover(self): + self.runner_generation = self.staged_generation + self.staged_generation = 0 + + async def rollback_staged_takeover(self): + self.rolled_back = True + self.staged_generation = 0 + + def has_generation(self, generation): + return generation in {self.runner_generation, self.staged_generation} + + async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): + raise RuntimeError("new runner unhealthy") + + supervisor._rpc_server = FakeRPCServer() + + async def fake_spawn_runner(): + supervisor._runner_process = new_process + supervisor._staged_registered_plugins["plugin_a"] = self._build_register_payload("plugin_a") + + monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) + + await supervisor.reload_plugins("test") + + assert supervisor._runner_process is old_process + assert supervisor._rpc_server.rolled_back is True + assert old_reg.plugin_id in supervisor._registered_plugins + assert supervisor.component_registry.get_component("plugin_a.handler") is not None + + @pytest.mark.asyncio + async def test_reload_rebuilds_exact_component_set(self, monkeypatch): + from src.plugin_runtime.host.supervisor import PluginSupervisor + from src.plugin_runtime.protocol.envelope import HealthPayload + + supervisor = PluginSupervisor(plugin_dirs=[]) + old_process = self._make_process(1) + new_process = self._make_process(2) + old_reg = self._build_register_payload("plugin_a", component_names=["handler", "obsolete"]) + new_reg = self._build_register_payload("plugin_a", component_names=["handler"]) + + supervisor._runner_process = old_process + supervisor._registered_plugins[old_reg.plugin_id] = old_reg + supervisor._rebuild_runtime_state() + + class FakeRPCServer: + def __init__(self): + self.runner_generation = 1 + self.staged_generation = 0 self.is_connected = True self.session_token = "fake-token" @@ -1390,22 +1519,34 @@ class TestSupervisor: def restore_session_token(self, token): self.session_token = token - async def send_request(self, method, timeout_ms=5000, **kwargs): - raise RuntimeError("new runner unhealthy") + def begin_staged_takeover(self): + self.staged_generation = 2 + + async def commit_staged_takeover(self): + self.runner_generation = self.staged_generation + self.staged_generation = 0 + + async def rollback_staged_takeover(self): + self.staged_generation = 0 + + def has_generation(self, generation): + return generation in {self.runner_generation, self.staged_generation} + + async def send_request(self, method, timeout_ms=5000, target_generation=None, **kwargs): + return SimpleNamespace(payload=HealthPayload(healthy=True).model_dump()) supervisor._rpc_server = FakeRPCServer() async def fake_spawn_runner(): supervisor._runner_process = new_process - supervisor._rpc_server.runner_generation = 2 + supervisor._staged_registered_plugins[new_reg.plugin_id] = new_reg monkeypatch.setattr(supervisor, "_spawn_runner", fake_spawn_runner) await supervisor.reload_plugins("test") - assert supervisor._runner_process is old_process - assert old_reg.plugin_id in supervisor._registered_plugins assert supervisor.component_registry.get_component("plugin_a.handler") is not None + assert supervisor.component_registry.get_component("plugin_a.obsolete") is None @pytest.mark.asyncio async def test_attach_stderr_drain_drains_stream(self): diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 6b2933d5..f71fa4b0 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -7,7 +7,7 @@ 4. 请求-响应关联与超时管理 """ -from typing import Any, Awaitable, Callable, Dict, List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple import asyncio import secrets @@ -55,12 +55,16 @@ class RPCServer: self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接 self._runner_id: Optional[str] = None self._runner_generation: int = 0 + self._staged_connection: Optional[Connection] = None + self._staged_runner_id: Optional[str] = None + self._staged_runner_generation: int = 0 + self._staging_takeover: bool = False # 方法处理器注册表 self._method_handlers: Dict[str, MethodHandler] = {} - # 等待响应的 pending 请求: request_id -> Future - self._pending_requests: Dict[int, asyncio.Future] = {} + # 等待响应的 pending 请求: request_id -> (Future, target_generation) + self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} # 发送队列(背压控制) self._send_queue: Optional[asyncio.Queue[bytes]] = None @@ -86,10 +90,72 @@ class RPCServer: def runner_generation(self) -> int: return self._runner_generation + @property + def staged_generation(self) -> int: + return self._staged_runner_generation + @property def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed + def has_generation(self, generation: int) -> bool: + return generation == self._runner_generation or ( + self._staged_connection is not None + and not self._staged_connection.is_closed + and generation == self._staged_runner_generation + ) + + def begin_staged_takeover(self) -> None: + """允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。""" + self._staging_takeover = True + + async def commit_staged_takeover(self) -> None: + """提交 staged Runner,原活跃连接在提交后被关闭。""" + if self._staged_connection is None or self._staged_connection.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接") + + old_connection = self._connection + old_generation = self._runner_generation + + self._connection = self._staged_connection + self._runner_id = self._staged_runner_id + self._runner_generation = self._staged_runner_generation + + self._staged_connection = None + self._staged_runner_id = None + self._staged_runner_generation = 0 + self._staging_takeover = False + + stale_count = self._fail_pending_requests( + ErrorCode.E_PLUGIN_CRASHED, + "Runner 连接已被新 generation 接管", + generation=old_generation, + ) + if stale_count: + logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") + + if old_connection and old_connection is not self._connection and not old_connection.is_closed: + await old_connection.close() + + async def rollback_staged_takeover(self) -> None: + """放弃 staged Runner,保留当前活跃连接。""" + staged_connection = self._staged_connection + staged_generation = self._staged_runner_generation + + self._staged_connection = None + self._staged_runner_id = None + self._staged_runner_generation = 0 + self._staging_takeover = False + + self._fail_pending_requests( + ErrorCode.E_PLUGIN_CRASHED, + "新 Runner 预热失败,已回滚", + generation=staged_generation, + ) + + if staged_connection and not staged_connection.is_closed: + await staged_connection.close() + def register_method(self, method: str, handler: MethodHandler) -> None: """注册 RPC 方法处理器""" self._method_handlers[method] = handler @@ -106,7 +172,7 @@ class RPCServer: self._running = False # 取消所有 pending 请求 - for future in self._pending_requests.values(): + for future, _generation in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) self._pending_requests.clear() @@ -121,6 +187,10 @@ class RPCServer: await self._connection.close() self._connection = None + if self._staged_connection: + await self._staged_connection.close() + self._staged_connection = None + await self._transport.stop() logger.info("RPC Server 已停止") @@ -130,6 +200,7 @@ class RPCServer: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, + target_generation: Optional[int] = None, ) -> Envelope: """向 Runner 发送 RPC 请求并等待响应 @@ -145,7 +216,9 @@ class RPCServer: Raises: RPCError: 调用失败 """ - if not self.is_connected: + generation = target_generation or self._runner_generation + conn = self._get_connection_for_generation(generation) + if conn is None or conn.is_closed: raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") request_id = self._id_gen.next() @@ -154,7 +227,7 @@ class RPCServer: message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=self._runner_generation, + generation=generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -166,12 +239,12 @@ class RPCServer: # 注册 pending future loop = asyncio.get_running_loop() future: asyncio.Future[Envelope] = loop.create_future() - self._pending_requests[request_id] = future + self._pending_requests[request_id] = (future, generation) try: # 发送请求 data = self._codec.encode_envelope(envelope) - await self._connection.send_frame(data) + await conn.send_frame(data) # 等待响应 timeout_sec = timeout_ms / 1000.0 @@ -207,11 +280,13 @@ class RPCServer: async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") + previous_connection = self._connection + previous_generation = self._runner_generation # 第一条消息必须是 runner.hello 握手 try: - handshake_ok = await self._handle_handshake(conn) - if not handshake_ok: + role = await self._handle_handshake(conn) + if role is None: await conn.close() return except Exception as e: @@ -219,41 +294,52 @@ class RPCServer: await conn.close() return - old_connection = self._connection - self._connection = conn - logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") + if role == "staged": + expected_generation = self._staged_runner_generation + logger.info( + f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}" + ) + else: + self._connection = conn + expected_generation = self._runner_generation + logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") - if old_connection and old_connection is not conn and not old_connection.is_closed: - logger.info("检测到新 Runner 已接管连接,关闭旧连接") - # 新连接接管后,旧 Runner 的 in-flight 请求不会再收到响应 - # (过期 generation 响应会被 _handle_response 丢弃), - # 在此处立即 fail-fast 所有 pending 请求,避免挂到超时 - stale_count = 0 - for _req_id, future in list(self._pending_requests.items()): - if not future.done(): - future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已被新 generation 接管")) - stale_count += 1 - self._pending_requests.clear() - if stale_count: - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - await old_connection.close() + if previous_connection and previous_connection is not conn and not previous_connection.is_closed: + logger.info("检测到新 Runner 已接管连接,关闭旧连接") + stale_count = self._fail_pending_requests( + ErrorCode.E_PLUGIN_CRASHED, + "Runner 连接已被新 generation 接管", + generation=previous_generation, + ) + if stale_count: + logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") + await previous_connection.close() # 启动消息接收循环 try: - await self._recv_loop(conn) + await self._recv_loop(conn, expected_generation=expected_generation) except Exception as e: logger.error(f"连接异常断开: {e}") finally: if self._connection is conn: self._connection = None self._runner_id = None - # 连接断开时,立即让所有等待中的请求失败,避免挂起至超时 - for _req_id, future in list(self._pending_requests.items()): - if not future.done(): - future.set_exception(RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")) - self._pending_requests.clear() + self._fail_pending_requests( + ErrorCode.E_PLUGIN_CRASHED, + "Runner 连接已断开", + generation=expected_generation, + ) + elif self._staged_connection is conn: + self._staged_connection = None + self._staged_runner_id = None + self._staged_runner_generation = 0 + self._fail_pending_requests( + ErrorCode.E_PLUGIN_CRASHED, + "Staged Runner 连接已断开", + generation=expected_generation, + ) - async def _handle_handshake(self, conn: Connection) -> bool: + async def _handle_handshake(self, conn: Connection) -> Optional[str]: """处理 runner.hello 握手""" # 接收握手请求 data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0) @@ -266,7 +352,7 @@ class RPCServer: "首条消息必须为 runner.hello", ) await conn.send_frame(self._codec.encode_envelope(error_resp)) - return False + return None # 解析握手 payload hello = HelloPayload.model_validate(envelope.payload) @@ -280,7 +366,7 @@ class RPCServer: ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return False + return None # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): @@ -291,23 +377,31 @@ class RPCServer: ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return False + return None # 握手成功 - self._runner_id = hello.runner_id - self._runner_generation += 1 + role = "active" + assigned_generation = self._runner_generation + 1 + if self._staging_takeover and self.is_connected: + role = "staged" + self._staged_connection = conn + self._staged_runner_id = hello.runner_id + self._staged_runner_generation = assigned_generation + else: + self._runner_id = hello.runner_id + self._runner_generation = assigned_generation resp_payload = HelloResponsePayload( accepted=True, host_version=PROTOCOL_VERSION, - assigned_generation=self._runner_generation, + assigned_generation=assigned_generation, ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return True + return role - async def _recv_loop(self, conn: Connection) -> None: + async def _recv_loop(self, conn: Connection, expected_generation: int) -> None: """消息接收主循环""" while self._running and not conn.is_closed: try: @@ -329,10 +423,10 @@ class RPCServer: if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): - if not self._is_current_generation(envelope): + if envelope.generation != expected_generation: error_resp = envelope.make_error_response( ErrorCode.E_GENERATION_MISMATCH.value, - f"过期 generation: {envelope.generation} != {self._runner_generation}", + f"过期 generation: {envelope.generation} != {expected_generation}", ) await conn.send_frame(self._codec.encode_envelope(error_resp)) continue @@ -341,9 +435,9 @@ class RPCServer: self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) elif envelope.is_event(): - if not self._is_current_generation(envelope): + if envelope.generation != expected_generation: logger.warning( - f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {self._runner_generation}" + f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}" ) continue task = asyncio.create_task(self._handle_event(envelope)) @@ -352,22 +446,24 @@ class RPCServer: def _handle_response(self, envelope: Envelope) -> None: """处理来自 Runner 的响应""" - if not self._is_current_generation(envelope): + pending = self._pending_requests.get(envelope.request_id) + if pending is None: + return + + future, expected_generation = pending + if envelope.generation != expected_generation: logger.warning( - f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {self._runner_generation}" + f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}" ) return - future = self._pending_requests.pop(envelope.request_id, None) - if future and not future.done(): + self._pending_requests.pop(envelope.request_id, None) + if not future.done(): if envelope.error: future.set_exception(RPCError.from_dict(envelope.error)) else: future.set_result(envelope) - def _is_current_generation(self, envelope: Envelope) -> bool: - return envelope.generation == self._runner_generation - async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: """处理来自 Runner 的请求(通常是能力调用 cap.*)""" handler = self._method_handlers.get(envelope.method) @@ -411,3 +507,26 @@ class RPCServer: return min_parts <= sdk_parts <= max_parts except (ValueError, AttributeError): return False + + def _get_connection_for_generation(self, generation: int) -> Optional[Connection]: + if generation == self._runner_generation: + return self._connection + if generation == self._staged_runner_generation: + return self._staged_connection + return None + + def _fail_pending_requests( + self, + error_code: ErrorCode, + message: str, + generation: Optional[int] = None, + ) -> int: + stale_count = 0 + for request_id, (future, request_generation) in list(self._pending_requests.items()): + if generation is not None and request_generation != generation: + continue + if not future.done(): + future.set_exception(RPCError(error_code, message)) + stale_count += 1 + self._pending_requests.pop(request_id, None) + return stale_count diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 164d2641..1ed44e5a 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -135,6 +135,7 @@ class PluginSupervisor: # 已注册的插件组件信息 self._registered_plugins: Dict[str, RegisterComponentsPayload] = {} + self._staged_registered_plugins: Dict[str, RegisterComponentsPayload] = {} # 后台任务 self._health_task: Optional[asyncio.Task] = None @@ -319,6 +320,10 @@ class PluginSupervisor: old_session_token = self._rpc_server.session_token expected_generation = self._rpc_server.runner_generation + 1 + # 允许新 Runner 以 staged 方式接入,验证通过后再切换活跃连接 + self._rpc_server.begin_staged_takeover() + self._staged_registered_plugins.clear() + # 重新生成 session token,防止被终止的旧 Runner 重连 self._rpc_server.reset_session_token() @@ -330,27 +335,35 @@ class PluginSupervisor: # 拉起新 Runner try: await self._spawn_runner() - await self._wait_for_runner_generation(expected_generation, timeout_sec=self._runner_spawn_timeout) - resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000) + await self._wait_for_runner_generation( + expected_generation, + timeout_sec=self._runner_spawn_timeout, + allow_staged=True, + ) + resp = await self._rpc_server.send_request( + "plugin.health", + timeout_ms=5000, + target_generation=expected_generation, + ) health = HealthPayload.model_validate(resp.payload) if not health.healthy: raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败") + await self._rpc_server.commit_staged_takeover() except Exception as e: logger.error(f"新 Runner 健康检查失败: {e},回滚") await self._terminate_process(self._runner_process, old_process) + await self._rpc_server.rollback_staged_takeover() self._runner_process = old_process - # 恢复旧 session token,使旧 Runner 的连接仍可正常工作 self._rpc_server.restore_session_token(old_session_token) + self._staged_registered_plugins.clear() self._registered_plugins = dict(old_registered_plugins) self._rebuild_runtime_state() return - # 新 Runner 健康且已完成组件注册,现在清理旧的幽灵组件 - # 只移除不再存在于新注册表中的旧插件组件 - for old_pid in list(old_registered_plugins.keys()): - if old_pid not in self._registered_plugins: - self._component_registry.remove_components_by_plugin(old_pid) - self._policy.revoke_plugin(old_pid) + self._runner_generation = self._rpc_server.runner_generation + self._registered_plugins = dict(self._staged_registered_plugins) + self._staged_registered_plugins.clear() + self._rebuild_runtime_state() # 关停旧 Runner if old_process and old_process.returncode is None: @@ -380,13 +393,22 @@ class PluginSupervisor: except Exception as e: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) - if envelope.generation != self._rpc_server.runner_generation: + active_generation = self._rpc_server.runner_generation + staged_generation = self._rpc_server.staged_generation + if envelope.generation not in {active_generation, staged_generation}: return envelope.make_error_response( ErrorCode.E_GENERATION_MISMATCH.value, - f"组件注册 generation 过期: {envelope.generation} != {self._rpc_server.runner_generation}", + f"组件注册 generation 过期: {envelope.generation} 不在已知代际中", ) - # 记录注册信息 + if envelope.generation == staged_generation and staged_generation != 0: + self._staged_registered_plugins[reg.plugin_id] = reg + logger.info( + f"插件 {reg.plugin_id} v{reg.plugin_version} staged 注册成功," + f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}" + ) + return envelope.make_response(payload={"accepted": True, "staged": True}) + self._registered_plugins[reg.plugin_id] = reg # 在策略引擎中注册插件 @@ -396,7 +418,8 @@ class PluginSupervisor: capabilities=reg.capabilities_required or [], ) - # 在 ComponentRegistry 中注册组件 + # 同 generation 下重新注册时,以本次声明为准,避免残留幽灵组件 + self._component_registry.remove_components_by_plugin(reg.plugin_id) self._component_registry.register_plugin_components( plugin_id=reg.plugin_id, components=[c.model_dump() for c in reg.components], @@ -518,10 +541,17 @@ class PluginSupervisor: except Exception as e: logger.error(f"健康检查异常: {e}") - async def _wait_for_runner_generation(self, expected_generation: int, timeout_sec: float) -> None: + async def _wait_for_runner_generation( + self, + expected_generation: int, + timeout_sec: float, + allow_staged: bool = False, + ) -> None: """等待指定代际的 Runner 完成连接。""" deadline = asyncio.get_running_loop().time() + timeout_sec while asyncio.get_running_loop().time() < deadline: + if allow_staged and self._rpc_server.has_generation(expected_generation): + return if self._rpc_server.is_connected and self._rpc_server.runner_generation >= expected_generation: self._runner_generation = self._rpc_server.runner_generation return @@ -533,6 +563,7 @@ class PluginSupervisor: self._component_registry.clear() self._policy.clear() self._registered_plugins.clear() + self._staged_registered_plugins.clear() def _rebuild_runtime_state(self) -> None: """根据已记录的插件注册信息重建运行时状态。""" diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 4272e315..d4e02b6b 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -430,8 +430,10 @@ class PluginRuntimeManager: """ from src.services import send_service as send_api - message_type: str = args.get("message_type", "") - content = args.get("content", "") + message_type: str = args.get("message_type", "") or args.get("custom_type", "") + content = args.get("content") + if content is None: + content = args.get("data", "") stream_id: str = args.get("stream_id", "") if not message_type or not stream_id: return {"success": False, "error": "缺少必要参数 message_type 或 stream_id"} diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index 46a9e9af..88cd4fee 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -76,6 +76,7 @@ class RPCClient: # 运行状态 self._running = False self._recv_task: Optional[asyncio.Task] = None + self._background_tasks: set[asyncio.Task] = set() @property def generation(self) -> int: @@ -144,6 +145,13 @@ class RPCClient: await self._recv_task self._recv_task = None + for task in list(self._background_tasks): + task.cancel() + if self._background_tasks: + with contextlib.suppress(Exception): + await asyncio.gather(*self._background_tasks, return_exceptions=True) + self._background_tasks.clear() + # 取消所有 pending 请求 for future in self._pending_requests.values(): if not future.done(): @@ -248,9 +256,9 @@ class RPCClient: if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): - asyncio.create_task(self._handle_request(envelope)) + self._track_background_task(asyncio.create_task(self._handle_request(envelope))) elif envelope.is_event(): - asyncio.create_task(self._handle_event(envelope)) + self._track_background_task(asyncio.create_task(self._handle_event(envelope))) def _handle_response(self, envelope: Envelope) -> None: """处理来自 Host 的响应""" @@ -290,3 +298,8 @@ class RPCClient: await handler(envelope) except Exception as e: logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) + + def _track_background_task(self, task: asyncio.Task) -> None: + """保持后台任务强引用,直到其完成或被取消。""" + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard)