feat: 实现 RPC 服务器的发送队列和背压控制机制
This commit is contained in:
@@ -1372,6 +1372,53 @@ class TestRPCServer:
|
|||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_queue_backpressure_is_enforced(self):
|
||||||
|
from src.plugin_runtime.host.rpc_server import RPCServer
|
||||||
|
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||||
|
|
||||||
|
class DummyTransport:
|
||||||
|
async def start(self, handler):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_address(self):
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
class BlockingConnection:
|
||||||
|
def __init__(self):
|
||||||
|
self.is_closed = False
|
||||||
|
self.release = asyncio.Event()
|
||||||
|
|
||||||
|
async def send_frame(self, data):
|
||||||
|
await self.release.wait()
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self.is_closed = True
|
||||||
|
|
||||||
|
server = RPCServer(transport=DummyTransport(), send_queue_size=1)
|
||||||
|
await server.start()
|
||||||
|
|
||||||
|
conn = BlockingConnection()
|
||||||
|
server._connection = conn
|
||||||
|
server._runner_generation = 1
|
||||||
|
|
||||||
|
first_send = asyncio.create_task(server.send_event("runner.log_batch"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
second_send = asyncio.create_task(server.send_event("runner.log_batch"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
with pytest.raises(RPCError) as exc_info:
|
||||||
|
await server.send_event("runner.log_batch")
|
||||||
|
|
||||||
|
assert exc_info.value.code == ErrorCode.E_BACKPRESSURE
|
||||||
|
|
||||||
|
conn.release.set()
|
||||||
|
await asyncio.gather(first_send, second_send)
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
|
|
||||||
class TestRPCClient:
|
class TestRPCClient:
|
||||||
"""Runner RPCClient 后台任务生命周期测试"""
|
"""Runner RPCClient 后台任务生命周期测试"""
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -67,7 +68,8 @@ class RPCServer:
|
|||||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||||
|
|
||||||
# 发送队列(背压控制)
|
# 发送队列(背压控制)
|
||||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
|
||||||
|
self._send_worker_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
# 运行状态
|
# 运行状态
|
||||||
self._running: bool = False
|
self._running: bool = False
|
||||||
@@ -164,6 +166,7 @@ class RPCServer:
|
|||||||
"""启动 RPC 服务器"""
|
"""启动 RPC 服务器"""
|
||||||
self._running = True
|
self._running = True
|
||||||
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
||||||
|
self._send_worker_task = asyncio.create_task(self._send_loop())
|
||||||
await self._transport.start(self._handle_connection)
|
await self._transport.start(self._handle_connection)
|
||||||
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
|
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
|
||||||
|
|
||||||
@@ -177,6 +180,14 @@ class RPCServer:
|
|||||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||||
self._pending_requests.clear()
|
self._pending_requests.clear()
|
||||||
|
|
||||||
|
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
|
||||||
|
|
||||||
|
if self._send_worker_task:
|
||||||
|
self._send_worker_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._send_worker_task
|
||||||
|
self._send_worker_task = None
|
||||||
|
|
||||||
# 取消后台任务
|
# 取消后台任务
|
||||||
for task in self._tasks:
|
for task in self._tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -232,10 +243,6 @@ class RPCServer:
|
|||||||
payload=payload or {},
|
payload=payload or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 背压检查
|
|
||||||
if self._send_queue and self._send_queue.full():
|
|
||||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
|
|
||||||
|
|
||||||
# 注册 pending future
|
# 注册 pending future
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future: asyncio.Future[Envelope] = loop.create_future()
|
future: asyncio.Future[Envelope] = loop.create_future()
|
||||||
@@ -244,7 +251,7 @@ class RPCServer:
|
|||||||
try:
|
try:
|
||||||
# 发送请求
|
# 发送请求
|
||||||
data = self._codec.encode_envelope(envelope)
|
data = self._codec.encode_envelope(envelope)
|
||||||
await conn.send_frame(data)
|
await self._enqueue_send(conn, data)
|
||||||
|
|
||||||
# 等待响应
|
# 等待响应
|
||||||
timeout_sec = timeout_ms / 1000.0
|
timeout_sec = timeout_ms / 1000.0
|
||||||
@@ -273,7 +280,7 @@ class RPCServer:
|
|||||||
payload=payload or {},
|
payload=payload or {},
|
||||||
)
|
)
|
||||||
data = self._codec.encode_envelope(envelope)
|
data = self._codec.encode_envelope(envelope)
|
||||||
await self._connection.send_frame(data)
|
await self._enqueue_send(self._connection, data)
|
||||||
|
|
||||||
# ─── 内部方法 ──────────────────────────────────────────────
|
# ─── 内部方法 ──────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -464,6 +471,59 @@ class RPCServer:
|
|||||||
else:
|
else:
|
||||||
future.set_result(envelope)
|
future.set_result(envelope)
|
||||||
|
|
||||||
|
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||||
|
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||||
|
if conn.is_closed:
|
||||||
|
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||||
|
|
||||||
|
if self._send_queue is None:
|
||||||
|
await conn.send_frame(data)
|
||||||
|
return
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
send_future: asyncio.Future[None] = loop.create_future()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._send_queue.put_nowait((conn, data, send_future))
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
|
||||||
|
|
||||||
|
await send_future
|
||||||
|
|
||||||
|
async def _send_loop(self) -> None:
|
||||||
|
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||||
|
if self._send_queue is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
conn, data, send_future = await self._send_queue.get()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
if conn.is_closed:
|
||||||
|
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||||
|
await conn.send_frame(data)
|
||||||
|
if not send_future.done():
|
||||||
|
send_future.set_result(None)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if not send_future.done():
|
||||||
|
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
|
||||||
|
if not send_future.done():
|
||||||
|
send_future.set_exception(send_error)
|
||||||
|
finally:
|
||||||
|
self._send_queue.task_done()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_send_exception(error: Exception) -> RPCError:
|
||||||
|
if isinstance(error, ConnectionError):
|
||||||
|
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
|
||||||
|
return RPCError(ErrorCode.E_UNKNOWN, str(error))
|
||||||
|
|
||||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||||
handler = self._method_handlers.get(envelope.method)
|
handler = self._method_handlers.get(envelope.method)
|
||||||
@@ -530,3 +590,21 @@ class RPCServer:
|
|||||||
stale_count += 1
|
stale_count += 1
|
||||||
self._pending_requests.pop(request_id, None)
|
self._pending_requests.pop(request_id, None)
|
||||||
return stale_count
|
return stale_count
|
||||||
|
|
||||||
|
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
|
||||||
|
if self._send_queue is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
failed_count = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_conn, _data, send_future = self._send_queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not send_future.done():
|
||||||
|
send_future.set_exception(RPCError(error_code, message))
|
||||||
|
failed_count += 1
|
||||||
|
self._send_queue.task_done()
|
||||||
|
|
||||||
|
return failed_count
|
||||||
|
|||||||
Reference in New Issue
Block a user