refactor: 移除generation;添加新的ErrorCode;修改ErrorCode的一个名称

This commit is contained in:
UnCLAS-Prommer
2026-03-17 20:00:19 +08:00
committed by DrSmoothl
parent 49b620219d
commit 84a6524bd9
4 changed files with 138 additions and 352 deletions

View File

@@ -40,6 +40,7 @@ class AuthorizationManager:
self._permission_tokens.clear() self._permission_tokens.clear()
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]: def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
# sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
"""检查插件是否有权调用某项能力 """检查插件是否有权调用某项能力
Returns: Returns:

View File

@@ -7,11 +7,7 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
CapabilityRequestPayload,
CapabilityResponsePayload,
Envelope,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -59,31 +55,19 @@ class CapabilityService:
try: try:
req = CapabilityRequestPayload.model_validate(envelope.payload) req = CapabilityRequestPayload.model_validate(envelope.payload)
except Exception as e: except Exception as e:
return envelope.make_error_response( return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}")
ErrorCode.E_BAD_PAYLOAD.value,
f"能力调用 payload 格式错误: {e}",
)
capability = req.capability capability = req.capability
# 1. 权限校验 # 1. 权限校验
allowed, reason = self._authorization.check_capability(plugin_id, capability) allowed, reason = self._authorization.check_capability(plugin_id, capability)
if not allowed: if not allowed:
error_code = ( return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
)
return envelope.make_error_response(
error_code.value,
reason,
)
# 2. 查找实现 # 2. 查找实现
impl = self._implementations.get(capability) impl = self._implementations.get(capability)
if impl is None: if impl is None:
return envelope.make_error_response( return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的能力: {capability}",
)
# 3. 执行 # 3. 执行
try: try:
@@ -94,10 +78,7 @@ class CapabilityService:
return envelope.make_error_response(e.code.value, e.message, e.details) return envelope.make_error_response(e.code.value, e.message, e.details)
except Exception as e: except Exception as e:
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True) logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
return envelope.make_error_response( return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
ErrorCode.E_CAPABILITY_FAILED.value,
str(e),
)
def list_capabilities(self) -> List[str]: def list_capabilities(self) -> List[str]:
"""列出所有已注册的能力""" """列出所有已注册的能力"""

View File

@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理 4. 请求-响应关联与超时管理
""" """
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
import asyncio import asyncio
import contextlib import contextlib
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
logger = get_logger("plugin_runtime.host.rpc_server") logger = get_logger("plugin_runtime.host.rpc_server")
# RPC 方法处理器类型 # RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]] MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
class RPCServer: class RPCServer:
@@ -55,109 +55,29 @@ class RPCServer:
self._id_gen = RequestIdGenerator() self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接 self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
self._runner_id: Optional[str] = None
self._runner_generation: int = 0
self._staged_connection: Optional[Connection] = None
self._staged_runner_id: Optional[str] = None
self._staged_runner_generation: int = 0
self._staging_takeover: bool = False
# 方法处理器注册表 # 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {} self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> (Future, target_generation) # 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
# 发送队列(背压控制) # 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
self._send_worker_task: Optional[asyncio.Task] = None self._send_worker_task: Optional[asyncio.Task[None]] = None
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
self._tasks: List[asyncio.Task] = [] self._tasks: List[asyncio.Task[None]] = []
@property @property
def session_token(self) -> str: def session_token(self) -> str:
return self._session_token return self._session_token
def reset_session_token(self) -> str:
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
self._session_token = secrets.token_hex(32)
return self._session_token
def restore_session_token(self, token: str) -> None:
"""恢复指定的会话令牌(热重载回滚时调用)"""
self._session_token = token
@property
def runner_generation(self) -> int:
return self._runner_generation
@property
def staged_generation(self) -> int:
return self._staged_runner_generation
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed return self._connection is not None and not self._connection.is_closed
def has_generation(self, generation: int) -> bool:
return generation == self._runner_generation or (
self._staged_connection is not None
and not self._staged_connection.is_closed
and generation == self._staged_runner_generation
)
def begin_staged_takeover(self) -> None:
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。"""
self._staging_takeover = True
async def commit_staged_takeover(self) -> None:
"""提交 staged Runner原活跃连接在提交后被关闭。"""
if self._staged_connection is None or self._staged_connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
old_connection = self._connection
old_generation = self._runner_generation
self._connection = self._staged_connection
self._runner_id = self._staged_runner_id
self._runner_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
if stale_count := self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=old_generation,
):
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
await old_connection.close()
async def rollback_staged_takeover(self) -> None:
"""放弃 staged Runner保留当前活跃连接。"""
staged_connection = self._staged_connection
staged_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"新 Runner 预热失败,已回滚",
generation=staged_generation,
)
if staged_connection and not staged_connection.is_closed:
await staged_connection.close()
def register_method(self, method: str, handler: MethodHandler) -> None: def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器""" """注册 RPC 方法处理器"""
self._method_handlers[method] = handler self._method_handlers[method] = handler
@@ -173,14 +93,8 @@ class RPCServer:
async def stop(self) -> None: async def stop(self) -> None:
"""停止 RPC 服务器""" """停止 RPC 服务器"""
self._running = False self._running = False
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
# 取消所有 pending 请求 self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
for future, _generation in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
if self._send_worker_task: if self._send_worker_task:
self._send_worker_task.cancel() self._send_worker_task.cancel()
@@ -198,10 +112,6 @@ class RPCServer:
await self._connection.close() await self._connection.close()
self._connection = None self._connection = None
if self._staged_connection:
await self._staged_connection.close()
self._staged_connection = None
await self._transport.stop() await self._transport.stop()
logger.info("RPC Server 已停止") logger.info("RPC Server 已停止")
@@ -211,7 +121,6 @@ class RPCServer:
plugin_id: str = "", plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None, payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000, timeout_ms: int = 30000,
target_generation: Optional[int] = None,
) -> Envelope: ) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应 """向 Runner 发送 RPC 请求并等待响应
@@ -227,18 +136,14 @@ class RPCServer:
Raises: Raises:
RPCError: 调用失败 RPCError: 调用失败
""" """
generation = target_generation or self._runner_generation if not self._connection or self._connection.is_closed:
conn = self._get_connection_for_generation(generation)
if conn is None or conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
request_id = await self._id_gen.next()
request_id = self._id_gen.next()
envelope = Envelope( envelope = Envelope(
request_id=request_id, request_id=request_id,
message_type=MessageType.REQUEST, message_type=MessageType.REQUEST,
method=method, method=method,
plugin_id=plugin_id, plugin_id=plugin_id,
generation=generation,
timeout_ms=timeout_ms, timeout_ms=timeout_ms,
payload=payload or {}, payload=payload or {},
) )
@@ -246,12 +151,12 @@ class RPCServer:
# 注册 pending future # 注册 pending future
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future() future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = (future, generation) self._pending_requests[request_id] = future
try: try:
# 发送请求 # 发送请求
data = self._codec.encode_envelope(envelope) data = self._codec.encode_envelope(envelope)
await self._enqueue_send(conn, data) await self._enqueue_send(self._connection, data)
# 等待响应 # 等待响应
timeout_sec = timeout_ms / 1000.0 timeout_sec = timeout_ms / 1000.0
@@ -265,93 +170,66 @@ class RPCServer:
raise raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None: # ============ 内部方法 ============
"""向 Runner 发送单向事件(不等待响应)""" # ========= 发送循环 =========
conn = self._connection async def _send_loop(self) -> None:
if conn is None or conn.is_closed: """后台发送循环:串行消费发送队列,统一执行连接写入。"""
return if self._send_queue is None:
raise RuntimeError("没有消息队列")
request_id = self._id_gen.next() while True:
envelope = Envelope( try:
request_id=request_id, conn, data, send_future = await self._send_queue.get()
message_type=MessageType.EVENT, except asyncio.CancelledError:
method=method, break
plugin_id=plugin_id,
generation=self._runner_generation,
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await self._enqueue_send(conn, data)
# ─── 内部方法 ────────────────────────────────────────────── try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
# ====== 发送循环方法 ======
async def _handle_connection(self, conn: Connection) -> None: async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接""" """处理新的 Runner 连接"""
logger.info("收到 Runner 连接") logger.info("收到 Runner 连接")
previous_connection = self._connection
previous_generation = self._runner_generation
# 第一条消息必须是 runner.hello 握手 # 第一条消息必须是 runner.hello 握手
try: try:
role = await self._handle_handshake(conn) success = await self._handle_handshake(conn)
if role is None: if not success:
await conn.close() await conn.close()
return return
except Exception as e: except Exception as e:
logger.error(f"握手失败: {e}") logger.error(f"握手失败: {e}")
await conn.close() await conn.close()
return return
logger.info("Runner staged 握手成功")
if role == "staged": self._connection = conn
expected_generation = self._staged_runner_generation
logger.info(
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
)
else:
self._connection = conn
expected_generation = self._runner_generation
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
if stale_count := self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=previous_generation,
):
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await previous_connection.close()
# 启动消息接收循环 # 启动消息接收循环
try: try:
await self._recv_loop(conn, expected_generation=expected_generation) await self._recv_loop(conn)
except Exception as e: except Exception as e:
logger.error(f"连接异常断开: {e}") logger.error(f"连接异常断开: {e}")
finally: finally:
if self._connection is conn: self._connection = None
self._connection = None self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
self._runner_id = None
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已断开",
generation=expected_generation,
)
elif self._staged_connection is conn:
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Staged Runner 连接已断开",
generation=expected_generation,
)
async def _handle_handshake(self, conn: Connection) -> Optional[str]: async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手""" """处理 runner.hello 握手"""
# 接收握手请求 # 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0) data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
envelope = self._codec.decode_envelope(data) envelope = self._codec.decode_envelope(data)
if envelope.method != "runner.hello": if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello收到 {envelope.method}") logger.error(f"期望 runner.hello收到 {envelope.method}")
error_resp = envelope.make_error_response( error_resp = envelope.make_error_response(
@@ -359,21 +237,17 @@ class RPCServer:
"首条消息必须为 runner.hello", "首条消息必须为 runner.hello",
) )
await conn.send_frame(self._codec.encode_envelope(error_resp)) await conn.send_frame(self._codec.encode_envelope(error_resp))
return None return False
# 解析握手 payload # 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload) hello = HelloPayload.model_validate(envelope.payload)
# 校验会话令牌 # 校验会话令牌
if hello.session_token != self._session_token: if hello.session_token != self._session_token:
logger.error("会话令牌不匹配") logger.error("会话令牌不匹配")
resp_payload = HelloResponsePayload( resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效")
accepted=False,
reason="会话令牌无效",
)
resp = envelope.make_response(payload=resp_payload.model_dump()) resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp)) await conn.send_frame(self._codec.encode_envelope(resp))
return None return False
# 校验 SDK 版本 # 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version): if not self._check_sdk_version(hello.sdk_version):
@@ -384,31 +258,26 @@ class RPCServer:
) )
resp = envelope.make_response(payload=resp_payload.model_dump()) resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp)) await conn.send_frame(self._codec.encode_envelope(resp))
return None return False
# 握手成功 # 发送响应
role = "active" resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
assigned_generation = self._runner_generation + 1
if self._staging_takeover and self.is_connected:
role = "staged"
self._staged_connection = conn
self._staged_runner_id = hello.runner_id
self._staged_runner_generation = assigned_generation
else:
self._runner_id = hello.runner_id
self._runner_generation = assigned_generation
resp_payload = HelloResponsePayload(
accepted=True,
host_version=PROTOCOL_VERSION,
assigned_generation=assigned_generation,
)
resp = envelope.make_response(payload=resp_payload.model_dump()) resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp)) await conn.send_frame(self._codec.encode_envelope(resp))
return True
return role def _check_sdk_version(self, sdk_version: str) -> bool:
"""检查 SDK 版本是否在支持范围内"""
try:
sdk_parts = _parse_version_tuple(sdk_version)
min_parts = _parse_version_tuple(MIN_SDK_VERSION)
max_parts = _parse_version_tuple(MAX_SDK_VERSION)
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None: # ========= 接收循环 =========
async def _recv_loop(self, conn: Connection) -> None:
"""消息接收主循环""" """消息接收主循环"""
while self._running and not conn.is_closed: while self._running and not conn.is_closed:
try: try:
@@ -430,109 +299,40 @@ class RPCServer:
if envelope.is_response(): if envelope.is_response():
self._handle_response(envelope) self._handle_response(envelope)
elif envelope.is_request(): elif envelope.is_request():
if envelope.generation != expected_generation:
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {expected_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
# 异步处理请求Runner 发来的能力调用) # 异步处理请求Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn)) task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task) self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event(): elif envelope.is_broadcast():
if envelope.generation != expected_generation: task = asyncio.create_task(self._handle_broadcast(envelope))
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
self._tasks.append(task) self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
else:
logger.warning(f"未知的消息类型: {envelope.message_type}")
continue
# ====== 接收循环内部方法 ======
def _handle_response(self, envelope: Envelope) -> None: def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应""" """处理来自 Runner 的响应"""
pending = self._pending_requests.get(envelope.request_id) pending_future = self._pending_requests.pop(envelope.request_id, None)
if pending is None: if pending_future is None:
return return
if not pending_future.done():
future, expected_generation = pending
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
)
return
self._pending_requests.pop(envelope.request_id, None)
if not future.done():
if envelope.error: if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error)) pending_future.set_exception(RPCError.from_dict(envelope.error))
else: else:
future.set_result(envelope) pending_future.set_result(envelope)
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
await send_future
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
return
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
@staticmethod
def _normalize_send_exception(error: Exception) -> RPCError:
if isinstance(error, ConnectionError):
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
return RPCError(ErrorCode.E_UNKNOWN, str(error))
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*""" """处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method) target_method = envelope.method
if handler is None: handler = self._method_handlers.get(target_method)
error_resp = envelope.make_error_response( if not handler:
error_response = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value, ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}", f"未注册的方法: {envelope.method}",
) )
await conn.send_frame(self._codec.encode_envelope(error_resp)) await conn.send_frame(self._codec.encode_envelope(error_response))
return return
try: try:
@@ -546,59 +346,25 @@ class RPCServer:
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await conn.send_frame(self._codec.encode_envelope(error_resp)) await conn.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None: async def _handle_broadcast(self, envelope: Envelope) -> None:
"""处理来自 Runner 的事件"""
if handler := self._method_handlers.get(envelope.method): if handler := self._method_handlers.get(envelope.method):
try: try:
result = await handler(envelope) result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息 # 检查 handler 返回的信封是否包含错误信息
if result is not None and isinstance(result, Envelope) and result.error: if result.error:
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}") logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
except Exception as e: except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
@staticmethod def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
def _check_sdk_version(sdk_version: str) -> bool: """失败所有等待中的请求(如连接断开时)"""
"""检查 SDK 版本是否在支持范围内""" aborted_request_count = 0
try: for future in self._pending_requests.values():
sdk_parts = RPCServer._parse_version_tuple(sdk_version)
min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
@staticmethod
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
base_version = base_version.split("+", 1)[0]
parts = [part for part in base_version.split(".") if part != ""]
while len(parts) < 3:
parts.append("0")
return (int(parts[0]), int(parts[1]), int(parts[2]))
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
if generation == self._runner_generation:
return self._connection
if generation == self._staged_runner_generation:
return self._staged_connection
return None
def _fail_pending_requests(
self,
error_code: ErrorCode,
message: str,
generation: Optional[int] = None,
) -> int:
stale_count = 0
for request_id, (future, request_generation) in list(self._pending_requests.items()):
if generation is not None and request_generation != generation:
continue
if not future.done(): if not future.done():
future.set_exception(RPCError(error_code, message)) future.set_exception(RPCError(error_code, message))
stale_count += 1 aborted_request_count += 1
self._pending_requests.pop(request_id, None) self._pending_requests.clear()
return stale_count return aborted_request_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int: def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None: if self._send_queue is None:
@@ -617,3 +383,31 @@ class RPCServer:
self._send_queue.task_done() self._send_queue.task_done()
return failed_count return failed_count
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
await send_future
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
base_version = base_version.split("+", 1)[0]
parts = [part for part in base_version.split(".") if part != ""]
while len(parts) < 3:
parts.append("0")
return (int(parts[0]), int(parts[1]), int(parts[2]))

View File

@@ -7,7 +7,7 @@ from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
class ErrorCode(str, Enum): class ErrorCode(Enum):
"""RPC 错误码枚举""" """RPC 错误码枚举"""
# 通用 # 通用
@@ -18,17 +18,17 @@ class ErrorCode(str, Enum):
E_TIMEOUT = "E_TIMEOUT" E_TIMEOUT = "E_TIMEOUT"
E_BAD_PAYLOAD = "E_BAD_PAYLOAD" E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH" E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
E_SHUTTING_DOWN = "E_SHUTTING_DOWN"
# 权限与策略 # 权限与策略
E_UNAUTHORIZED = "E_UNAUTHORIZED" E_UNAUTHORIZED = "E_UNAUTHORIZED"
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED" E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
E_BACKPRESSURE = "E_BACKPRESSURE" E_BACK_PRESSURE = "E_BACK_PRESSURE"
E_HOST_OVERLOADED = "E_HOST_OVERLOADED" E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
# 插件生命周期 # 插件生命周期
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED" E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND" E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS" E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
# 能力调用 # 能力调用
@@ -65,3 +65,13 @@ class RPCError(Exception):
message=data.get("message", ""), message=data.get("message", ""),
details=data.get("details", {}), details=data.get("details", {}),
) )
@classmethod
def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None):
if isinstance(exception, cls):
return exception
if code_mapping:
for exception_type, code in code_mapping.items():
if isinstance(exception, exception_type):
return cls(code=code, message=str(exception))
return cls(ErrorCode.E_UNKNOWN, str(exception))