feat: 实现 RPC 服务器的发送队列和背压控制机制
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import secrets
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -67,7 +68,8 @@ class RPCServer:
|
||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
|
||||
self._send_worker_task: Optional[asyncio.Task] = None
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
@@ -164,6 +166,7 @@ class RPCServer:
|
||||
"""启动 RPC 服务器"""
|
||||
self._running = True
|
||||
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
||||
self._send_worker_task = asyncio.create_task(self._send_loop())
|
||||
await self._transport.start(self._handle_connection)
|
||||
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
|
||||
|
||||
@@ -177,6 +180,14 @@ class RPCServer:
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
|
||||
|
||||
if self._send_worker_task:
|
||||
self._send_worker_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._send_worker_task
|
||||
self._send_worker_task = None
|
||||
|
||||
# 取消后台任务
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
@@ -232,10 +243,6 @@ class RPCServer:
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
# 背压检查
|
||||
if self._send_queue and self._send_queue.full():
|
||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
|
||||
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
@@ -244,7 +251,7 @@ class RPCServer:
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await conn.send_frame(data)
|
||||
await self._enqueue_send(conn, data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
@@ -273,7 +280,7 @@ class RPCServer:
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
await self._enqueue_send(self._connection, data)
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
@@ -464,6 +471,59 @@ class RPCServer:
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_send_exception(error: Exception) -> RPCError:
|
||||
if isinstance(error, ConnectionError):
|
||||
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
|
||||
return RPCError(ErrorCode.E_UNKNOWN, str(error))
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
@@ -530,3 +590,21 @@ class RPCServer:
|
||||
stale_count += 1
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return stale_count
|
||||
|
||||
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
|
||||
if self._send_queue is None:
|
||||
return 0
|
||||
|
||||
failed_count = 0
|
||||
while True:
|
||||
try:
|
||||
_conn, _data, send_future = self._send_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(error_code, message))
|
||||
failed_count += 1
|
||||
self._send_queue.task_done()
|
||||
|
||||
return failed_count
|
||||
|
||||
Reference in New Issue
Block a user