diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index ac1f5b9e..9c71bcea 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -1372,6 +1372,53 @@ class TestRPCServer: finally: loop.close() + @pytest.mark.asyncio + async def test_send_queue_backpressure_is_enforced(self): + from src.plugin_runtime.host.rpc_server import RPCServer + from src.plugin_runtime.protocol.errors import ErrorCode, RPCError + + class DummyTransport: + async def start(self, handler): + return None + + async def stop(self): + return None + + def get_address(self): + return "dummy" + + class BlockingConnection: + def __init__(self): + self.is_closed = False + self.release = asyncio.Event() + + async def send_frame(self, data): + await self.release.wait() + + async def close(self): + self.is_closed = True + + server = RPCServer(transport=DummyTransport(), send_queue_size=1) + await server.start() + + conn = BlockingConnection() + server._connection = conn + server._runner_generation = 1 + + first_send = asyncio.create_task(server.send_event("runner.log_batch")) + await asyncio.sleep(0) + second_send = asyncio.create_task(server.send_event("runner.log_batch")) + await asyncio.sleep(0) + + with pytest.raises(RPCError) as exc_info: + await server.send_event("runner.log_batch") + + assert exc_info.value.code == ErrorCode.E_BACKPRESSURE + + conn.release.set() + await asyncio.gather(first_send, second_send) + await server.stop() + class TestRPCClient: """Runner RPCClient 后台任务生命周期测试""" diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index f71fa4b0..ec27ca88 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -10,6 +10,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple import asyncio +import contextlib import secrets from src.common.logger import get_logger @@ -67,7 +68,8 @@ class RPCServer: self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} # 发送队列(背压控制) - self._send_queue: Optional[asyncio.Queue[bytes]] = None + self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None + self._send_worker_task: Optional[asyncio.Task] = None # 运行状态 self._running: bool = False @@ -164,6 +166,7 @@ class RPCServer: """启动 RPC 服务器""" self._running = True self._send_queue = asyncio.Queue(maxsize=self._send_queue_size) + self._send_worker_task = asyncio.create_task(self._send_loop()) await self._transport.start(self._handle_connection) logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}") @@ -177,6 +180,14 @@ class RPCServer: future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) self._pending_requests.clear() + self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭") + + if self._send_worker_task: + self._send_worker_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._send_worker_task + self._send_worker_task = None + # 取消后台任务 for task in self._tasks: task.cancel() @@ -232,10 +243,6 @@ class RPCServer: payload=payload or {}, ) - # 背压检查 - if self._send_queue and self._send_queue.full(): - raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") - # 注册 pending future loop = asyncio.get_running_loop() future: asyncio.Future[Envelope] = loop.create_future() @@ -244,7 +251,7 @@ class RPCServer: try: # 发送请求 data = self._codec.encode_envelope(envelope) - await conn.send_frame(data) + await self._enqueue_send(conn, data) # 等待响应 timeout_sec = timeout_ms / 1000.0 @@ -273,7 +280,7 @@ class RPCServer: payload=payload or {}, ) data = self._codec.encode_envelope(envelope) - await self._connection.send_frame(data) + await self._enqueue_send(self._connection, data) # ─── 内部方法 ────────────────────────────────────────────── @@ -464,6 +471,59 @@ class RPCServer: else: future.set_result(envelope) + async def _enqueue_send(self, conn: Connection, data: bytes) -> None: + """通过发送队列串行发送消息,提供真实背压。""" + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + + if self._send_queue is None: + await conn.send_frame(data) + return + + loop = asyncio.get_running_loop() + send_future: asyncio.Future[None] = loop.create_future() + + try: + self._send_queue.put_nowait((conn, data, send_future)) + except asyncio.QueueFull: + raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None + + await send_future + + async def _send_loop(self) -> None: + """后台发送循环:串行消费发送队列,统一执行连接写入。""" + if self._send_queue is None: + return + + while True: + try: + conn, data, send_future = await self._send_queue.get() + except asyncio.CancelledError: + break + + try: + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + await conn.send_frame(data) + if not send_future.done(): + send_future.set_result(None) + except asyncio.CancelledError: + if not send_future.done(): + send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) + raise + except Exception as e: + send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e) + if not send_future.done(): + send_future.set_exception(send_error) + finally: + self._send_queue.task_done() + + @staticmethod + def _normalize_send_exception(error: Exception) -> RPCError: + if isinstance(error, ConnectionError): + return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error)) + return RPCError(ErrorCode.E_UNKNOWN, str(error)) + async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: """处理来自 Runner 的请求(通常是能力调用 cap.*)""" handler = self._method_handlers.get(envelope.method) @@ -530,3 +590,21 @@ class RPCServer: stale_count += 1 self._pending_requests.pop(request_id, None) return stale_count + + def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int: + if self._send_queue is None: + return 0 + + failed_count = 0 + while True: + try: + _conn, _data, send_future = self._send_queue.get_nowait() + except asyncio.QueueEmpty: + break + + if not send_future.done(): + send_future.set_exception(RPCError(error_code, message)) + failed_count += 1 + self._send_queue.task_done() + + return failed_count