feat: 实现 RPC 服务器的发送队列和背压控制机制

This commit is contained in:
DrSmoothl
2026-03-13 16:05:51 +08:00
parent d92aa800a3
commit 7e2b509bf0
2 changed files with 132 additions and 7 deletions

View File

@@ -10,6 +10,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import asyncio
import contextlib
import secrets
from src.common.logger import get_logger
@@ -67,7 +68,8 @@ class RPCServer:
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[bytes]] = None
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
self._send_worker_task: Optional[asyncio.Task] = None
# 运行状态
self._running: bool = False
@@ -164,6 +166,7 @@ class RPCServer:
"""启动 RPC 服务器"""
self._running = True
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
self._send_worker_task = asyncio.create_task(self._send_loop())
await self._transport.start(self._handle_connection)
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
@@ -177,6 +180,14 @@ class RPCServer:
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
if self._send_worker_task:
self._send_worker_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._send_worker_task
self._send_worker_task = None
# 取消后台任务
for task in self._tasks:
task.cancel()
@@ -232,10 +243,6 @@ class RPCServer:
payload=payload or {},
)
# 背压检查
if self._send_queue and self._send_queue.full():
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
@@ -244,7 +251,7 @@ class RPCServer:
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
await conn.send_frame(data)
await self._enqueue_send(conn, data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -273,7 +280,7 @@ class RPCServer:
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
await self._enqueue_send(self._connection, data)
# ─── 内部方法 ──────────────────────────────────────────────
@@ -464,6 +471,59 @@ class RPCServer:
else:
future.set_result(envelope)
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
await send_future
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
return
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
@staticmethod
def _normalize_send_exception(error: Exception) -> RPCError:
if isinstance(error, ConnectionError):
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
return RPCError(ErrorCode.E_UNKNOWN, str(error))
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)
@@ -530,3 +590,21 @@ class RPCServer:
stale_count += 1
self._pending_requests.pop(request_id, None)
return stale_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None:
return 0
failed_count = 0
while True:
try:
_conn, _data, send_future = self._send_queue.get_nowait()
except asyncio.QueueEmpty:
break
if not send_future.done():
send_future.set_exception(RPCError(error_code, message))
failed_count += 1
self._send_queue.task_done()
return failed_count