feat: 实现 RPC 服务器的发送队列和背压控制机制

This commit is contained in:
DrSmoothl
2026-03-13 16:05:51 +08:00
parent d92aa800a3
commit 7e2b509bf0
2 changed files with 132 additions and 7 deletions

View File

@@ -1372,6 +1372,53 @@ class TestRPCServer:
finally: finally:
loop.close() loop.close()
@pytest.mark.asyncio
async def test_send_queue_backpressure_is_enforced(self):
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
class DummyTransport:
async def start(self, handler):
return None
async def stop(self):
return None
def get_address(self):
return "dummy"
class BlockingConnection:
def __init__(self):
self.is_closed = False
self.release = asyncio.Event()
async def send_frame(self, data):
await self.release.wait()
async def close(self):
self.is_closed = True
server = RPCServer(transport=DummyTransport(), send_queue_size=1)
await server.start()
conn = BlockingConnection()
server._connection = conn
server._runner_generation = 1
first_send = asyncio.create_task(server.send_event("runner.log_batch"))
await asyncio.sleep(0)
second_send = asyncio.create_task(server.send_event("runner.log_batch"))
await asyncio.sleep(0)
with pytest.raises(RPCError) as exc_info:
await server.send_event("runner.log_batch")
assert exc_info.value.code == ErrorCode.E_BACKPRESSURE
conn.release.set()
await asyncio.gather(first_send, second_send)
await server.stop()
class TestRPCClient: class TestRPCClient:
"""Runner RPCClient 后台任务生命周期测试""" """Runner RPCClient 后台任务生命周期测试"""

View File

@@ -10,6 +10,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import asyncio import asyncio
import contextlib
import secrets import secrets
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -67,7 +68,8 @@ class RPCServer:
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
# 发送队列(背压控制) # 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[bytes]] = None self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
self._send_worker_task: Optional[asyncio.Task] = None
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
@@ -164,6 +166,7 @@ class RPCServer:
"""启动 RPC 服务器""" """启动 RPC 服务器"""
self._running = True self._running = True
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size) self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
self._send_worker_task = asyncio.create_task(self._send_loop())
await self._transport.start(self._handle_connection) await self._transport.start(self._handle_connection)
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}") logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
@@ -177,6 +180,14 @@ class RPCServer:
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear() self._pending_requests.clear()
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
if self._send_worker_task:
self._send_worker_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._send_worker_task
self._send_worker_task = None
# 取消后台任务 # 取消后台任务
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()
@@ -232,10 +243,6 @@ class RPCServer:
payload=payload or {}, payload=payload or {},
) )
# 背压检查
if self._send_queue and self._send_queue.full():
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
# 注册 pending future # 注册 pending future
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future() future: asyncio.Future[Envelope] = loop.create_future()
@@ -244,7 +251,7 @@ class RPCServer:
try: try:
# 发送请求 # 发送请求
data = self._codec.encode_envelope(envelope) data = self._codec.encode_envelope(envelope)
await conn.send_frame(data) await self._enqueue_send(conn, data)
# 等待响应 # 等待响应
timeout_sec = timeout_ms / 1000.0 timeout_sec = timeout_ms / 1000.0
@@ -273,7 +280,7 @@ class RPCServer:
payload=payload or {}, payload=payload or {},
) )
data = self._codec.encode_envelope(envelope) data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data) await self._enqueue_send(self._connection, data)
# ─── 内部方法 ────────────────────────────────────────────── # ─── 内部方法 ──────────────────────────────────────────────
@@ -464,6 +471,59 @@ class RPCServer:
else: else:
future.set_result(envelope) future.set_result(envelope)
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
await send_future
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
return
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
@staticmethod
def _normalize_send_exception(error: Exception) -> RPCError:
if isinstance(error, ConnectionError):
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
return RPCError(ErrorCode.E_UNKNOWN, str(error))
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*""" """处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method) handler = self._method_handlers.get(envelope.method)
@@ -530,3 +590,21 @@ class RPCServer:
stale_count += 1 stale_count += 1
self._pending_requests.pop(request_id, None) self._pending_requests.pop(request_id, None)
return stale_count return stale_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None:
return 0
failed_count = 0
while True:
try:
_conn, _data, send_future = self._send_queue.get_nowait()
except asyncio.QueueEmpty:
break
if not send_future.done():
send_future.set_exception(RPCError(error_code, message))
failed_count += 1
self._send_queue.task_done()
return failed_count