refactor: 移除generation;添加新的ErrorCode;修改ErrorCode的一个名称
This commit is contained in:
committed by
DrSmoothl
parent
49b620219d
commit
84a6524bd9
@@ -40,6 +40,7 @@ class AuthorizationManager:
|
||||
self._permission_tokens.clear()
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -7,11 +7,7 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
|
||||
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
CapabilityRequestPayload,
|
||||
CapabilityResponsePayload,
|
||||
Envelope,
|
||||
)
|
||||
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -59,31 +55,19 @@ class CapabilityService:
|
||||
try:
|
||||
req = CapabilityRequestPayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
f"能力调用 payload 格式错误: {e}",
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}")
|
||||
|
||||
capability = req.capability
|
||||
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._authorization.check_capability(plugin_id, capability)
|
||||
if not allowed:
|
||||
error_code = (
|
||||
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
|
||||
)
|
||||
return envelope.make_error_response(
|
||||
error_code.value,
|
||||
reason,
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
|
||||
|
||||
# 2. 查找实现
|
||||
impl = self._implementations.get(capability)
|
||||
if impl is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的能力: {capability}",
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
|
||||
|
||||
# 3. 执行
|
||||
try:
|
||||
@@ -94,10 +78,7 @@ class CapabilityService:
|
||||
return envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
except Exception as e:
|
||||
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_CAPABILITY_FAILED.value,
|
||||
str(e),
|
||||
)
|
||||
return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
|
||||
|
||||
def list_capabilities(self) -> List[str]:
|
||||
"""列出所有已注册的能力"""
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
|
||||
logger = get_logger("plugin_runtime.host.rpc_server")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
|
||||
|
||||
|
||||
class RPCServer:
|
||||
@@ -55,109 +55,29 @@ class RPCServer:
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
self._staged_connection: Optional[Connection] = None
|
||||
self._staged_runner_id: Optional[str] = None
|
||||
self._staged_runner_generation: int = 0
|
||||
self._staging_takeover: bool = False
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
|
||||
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
|
||||
self._send_worker_task: Optional[asyncio.Task] = None
|
||||
self._send_worker_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
self._tasks: List[asyncio.Task[None]] = []
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
return self._session_token
|
||||
|
||||
def reset_session_token(self) -> str:
|
||||
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
|
||||
self._session_token = secrets.token_hex(32)
|
||||
return self._session_token
|
||||
|
||||
def restore_session_token(self, token: str) -> None:
|
||||
"""恢复指定的会话令牌(热重载回滚时调用)"""
|
||||
self._session_token = token
|
||||
|
||||
@property
|
||||
def runner_generation(self) -> int:
|
||||
return self._runner_generation
|
||||
|
||||
@property
|
||||
def staged_generation(self) -> int:
|
||||
return self._staged_runner_generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def has_generation(self, generation: int) -> bool:
|
||||
return generation == self._runner_generation or (
|
||||
self._staged_connection is not None
|
||||
and not self._staged_connection.is_closed
|
||||
and generation == self._staged_runner_generation
|
||||
)
|
||||
|
||||
def begin_staged_takeover(self) -> None:
|
||||
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
|
||||
self._staging_takeover = True
|
||||
|
||||
async def commit_staged_takeover(self) -> None:
|
||||
"""提交 staged Runner,原活跃连接在提交后被关闭。"""
|
||||
if self._staged_connection is None or self._staged_connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
|
||||
|
||||
old_connection = self._connection
|
||||
old_generation = self._runner_generation
|
||||
|
||||
self._connection = self._staged_connection
|
||||
self._runner_id = self._staged_runner_id
|
||||
self._runner_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
if stale_count := self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=old_generation,
|
||||
):
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
|
||||
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
|
||||
await old_connection.close()
|
||||
|
||||
async def rollback_staged_takeover(self) -> None:
|
||||
"""放弃 staged Runner,保留当前活跃连接。"""
|
||||
staged_connection = self._staged_connection
|
||||
staged_generation = self._staged_runner_generation
|
||||
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._staging_takeover = False
|
||||
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"新 Runner 预热失败,已回滚",
|
||||
generation=staged_generation,
|
||||
)
|
||||
|
||||
if staged_connection and not staged_connection.is_closed:
|
||||
await staged_connection.close()
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
self._method_handlers[method] = handler
|
||||
@@ -173,14 +93,8 @@ class RPCServer:
|
||||
async def stop(self) -> None:
|
||||
"""停止 RPC 服务器"""
|
||||
self._running = False
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future, _generation in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
|
||||
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
|
||||
|
||||
if self._send_worker_task:
|
||||
self._send_worker_task.cancel()
|
||||
@@ -198,10 +112,6 @@ class RPCServer:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
if self._staged_connection:
|
||||
await self._staged_connection.close()
|
||||
self._staged_connection = None
|
||||
|
||||
await self._transport.stop()
|
||||
logger.info("RPC Server 已停止")
|
||||
|
||||
@@ -211,7 +121,6 @@ class RPCServer:
|
||||
plugin_id: str = "",
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
target_generation: Optional[int] = None,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
|
||||
@@ -227,18 +136,14 @@ class RPCServer:
|
||||
Raises:
|
||||
RPCError: 调用失败
|
||||
"""
|
||||
generation = target_generation or self._runner_generation
|
||||
conn = self._get_connection_for_generation(generation)
|
||||
if conn is None or conn.is_closed:
|
||||
if not self._connection or self._connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
request_id = await self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
@@ -246,12 +151,12 @@ class RPCServer:
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = (future, generation)
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._enqueue_send(conn, data)
|
||||
await self._enqueue_send(self._connection, data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
@@ -265,93 +170,66 @@ class RPCServer:
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""向 Runner 发送单向事件(不等待响应)"""
|
||||
conn = self._connection
|
||||
if conn is None or conn.is_closed:
|
||||
return
|
||||
# ============ 内部方法 ============
|
||||
# ========= 发送循环 =========
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
raise RuntimeError("没有消息队列")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.EVENT,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._enqueue_send(conn, data)
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
# ====== 发送循环方法 ======
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
previous_connection = self._connection
|
||||
previous_generation = self._runner_generation
|
||||
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
role = await self._handle_handshake(conn)
|
||||
if role is None:
|
||||
success = await self._handle_handshake(conn)
|
||||
if not success:
|
||||
await conn.close()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"握手失败: {e}")
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
if role == "staged":
|
||||
expected_generation = self._staged_runner_generation
|
||||
logger.info(
|
||||
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
|
||||
)
|
||||
else:
|
||||
self._connection = conn
|
||||
expected_generation = self._runner_generation
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
|
||||
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
|
||||
if stale_count := self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已被新 generation 接管",
|
||||
generation=previous_generation,
|
||||
):
|
||||
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
|
||||
await previous_connection.close()
|
||||
|
||||
logger.info("Runner staged 握手成功")
|
||||
self._connection = conn
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn, expected_generation=expected_generation)
|
||||
await self._recv_loop(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
if self._connection is conn:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
elif self._staged_connection is conn:
|
||||
self._staged_connection = None
|
||||
self._staged_runner_id = None
|
||||
self._staged_runner_generation = 0
|
||||
self._fail_pending_requests(
|
||||
ErrorCode.E_PLUGIN_CRASHED,
|
||||
"Staged Runner 连接已断开",
|
||||
generation=expected_generation,
|
||||
)
|
||||
self._connection = None
|
||||
self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
"""处理 runner.hello 握手"""
|
||||
# 接收握手请求
|
||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
|
||||
if envelope.method != "runner.hello":
|
||||
logger.error(f"期望 runner.hello,收到 {envelope.method}")
|
||||
error_resp = envelope.make_error_response(
|
||||
@@ -359,21 +237,17 @@ class RPCServer:
|
||||
"首条消息必须为 runner.hello",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 解析握手 payload
|
||||
hello = HelloPayload.model_validate(envelope.payload)
|
||||
|
||||
# 校验会话令牌
|
||||
if hello.session_token != self._session_token:
|
||||
logger.error("会话令牌不匹配")
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason="会话令牌无效",
|
||||
)
|
||||
resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效")
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
@@ -384,31 +258,26 @@ class RPCServer:
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return None
|
||||
return False
|
||||
|
||||
# 握手成功
|
||||
role = "active"
|
||||
assigned_generation = self._runner_generation + 1
|
||||
if self._staging_takeover and self.is_connected:
|
||||
role = "staged"
|
||||
self._staged_connection = conn
|
||||
self._staged_runner_id = hello.runner_id
|
||||
self._staged_runner_generation = assigned_generation
|
||||
else:
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation = assigned_generation
|
||||
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version=PROTOCOL_VERSION,
|
||||
assigned_generation=assigned_generation,
|
||||
)
|
||||
# 发送响应
|
||||
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return True
|
||||
|
||||
return role
|
||||
def _check_sdk_version(self, sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = _parse_version_tuple(sdk_version)
|
||||
min_parts = _parse_version_tuple(MIN_SDK_VERSION)
|
||||
max_parts = _parse_version_tuple(MAX_SDK_VERSION)
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
|
||||
# ========= 接收循环 =========
|
||||
async def _recv_loop(self, conn: Connection) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and not conn.is_closed:
|
||||
try:
|
||||
@@ -430,109 +299,40 @@ class RPCServer:
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
if envelope.generation != expected_generation:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_GENERATION_MISMATCH.value,
|
||||
f"过期 generation: {envelope.generation} != {expected_generation}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
continue
|
||||
# 异步处理请求(Runner 发来的能力调用)
|
||||
task = asyncio.create_task(self._handle_request(envelope, conn))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
continue
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
elif envelope.is_broadcast():
|
||||
task = asyncio.create_task(self._handle_broadcast(envelope))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
else:
|
||||
logger.warning(f"未知的消息类型: {envelope.message_type}")
|
||||
continue
|
||||
|
||||
# ====== 接收循环内部方法 ======
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
pending = self._pending_requests.get(envelope.request_id)
|
||||
if pending is None:
|
||||
pending_future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if pending_future is None:
|
||||
return
|
||||
|
||||
future, expected_generation = pending
|
||||
if envelope.generation != expected_generation:
|
||||
logger.warning(
|
||||
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
|
||||
)
|
||||
return
|
||||
|
||||
self._pending_requests.pop(envelope.request_id, None)
|
||||
if not future.done():
|
||||
if not pending_future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
pending_future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
async def _send_loop(self) -> None:
|
||||
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
|
||||
if self._send_queue is None:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
conn, data, send_future = await self._send_queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
await conn.send_frame(data)
|
||||
if not send_future.done():
|
||||
send_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
if not send_future.done():
|
||||
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
raise
|
||||
except Exception as e:
|
||||
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
|
||||
if not send_future.done():
|
||||
send_future.set_exception(send_error)
|
||||
finally:
|
||||
self._send_queue.task_done()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_send_exception(error: Exception) -> RPCError:
|
||||
if isinstance(error, ConnectionError):
|
||||
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
|
||||
return RPCError(ErrorCode.E_UNKNOWN, str(error))
|
||||
pending_future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
error_resp = envelope.make_error_response(
|
||||
target_method = envelope.method
|
||||
handler = self._method_handlers.get(target_method)
|
||||
if not handler:
|
||||
error_response = envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的方法: {envelope.method}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_response))
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -546,59 +346,25 @@ class RPCServer:
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的事件"""
|
||||
async def _handle_broadcast(self, envelope: Envelope) -> None:
|
||||
if handler := self._method_handlers.get(envelope.method):
|
||||
try:
|
||||
result = await handler(envelope)
|
||||
# 检查 handler 返回的信封是否包含错误信息
|
||||
if result is not None and isinstance(result, Envelope) and result.error:
|
||||
if result.error:
|
||||
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _check_sdk_version(sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = RPCServer._parse_version_tuple(sdk_version)
|
||||
min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
|
||||
max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
|
||||
base_version = base_version.split("+", 1)[0]
|
||||
parts = [part for part in base_version.split(".") if part != ""]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
|
||||
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
|
||||
if generation == self._runner_generation:
|
||||
return self._connection
|
||||
if generation == self._staged_runner_generation:
|
||||
return self._staged_connection
|
||||
return None
|
||||
|
||||
def _fail_pending_requests(
|
||||
self,
|
||||
error_code: ErrorCode,
|
||||
message: str,
|
||||
generation: Optional[int] = None,
|
||||
) -> int:
|
||||
stale_count = 0
|
||||
for request_id, (future, request_generation) in list(self._pending_requests.items()):
|
||||
if generation is not None and request_generation != generation:
|
||||
continue
|
||||
def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
|
||||
"""失败所有等待中的请求(如连接断开时)"""
|
||||
aborted_request_count = 0
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(error_code, message))
|
||||
stale_count += 1
|
||||
self._pending_requests.pop(request_id, None)
|
||||
return stale_count
|
||||
aborted_request_count += 1
|
||||
self._pending_requests.clear()
|
||||
return aborted_request_count
|
||||
|
||||
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
|
||||
if self._send_queue is None:
|
||||
@@ -617,3 +383,31 @@ class RPCServer:
|
||||
self._send_queue.task_done()
|
||||
|
||||
return failed_count
|
||||
|
||||
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
|
||||
"""通过发送队列串行发送消息,提供真实背压。"""
|
||||
if conn.is_closed:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
if self._send_queue is None:
|
||||
await conn.send_frame(data)
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
send_future: asyncio.Future[None] = loop.create_future()
|
||||
|
||||
try:
|
||||
self._send_queue.put_nowait((conn, data, send_future))
|
||||
except asyncio.QueueFull:
|
||||
raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
|
||||
|
||||
await send_future
|
||||
|
||||
|
||||
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
|
||||
base_version = base_version.split("+", 1)[0]
|
||||
parts = [part for part in base_version.split(".") if part != ""]
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
|
||||
Reference in New Issue
Block a user