diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py index d746c4d2..3fb48c6a 100644 --- a/src/plugin_runtime/host/authorization.py +++ b/src/plugin_runtime/host/authorization.py @@ -40,6 +40,7 @@ class AuthorizationManager: self._permission_tokens.clear() def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]: + # sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression """检查插件是否有权调用某项能力 Returns: diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py index e0c56c2b..98366a07 100644 --- a/src/plugin_runtime/host/capability_service.py +++ b/src/plugin_runtime/host/capability_service.py @@ -7,11 +7,7 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。 from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING from src.common.logger import get_logger -from src.plugin_runtime.protocol.envelope import ( - CapabilityRequestPayload, - CapabilityResponsePayload, - Envelope, -) +from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope from src.plugin_runtime.protocol.errors import ErrorCode, RPCError if TYPE_CHECKING: @@ -59,31 +55,19 @@ class CapabilityService: try: req = CapabilityRequestPayload.model_validate(envelope.payload) except Exception as e: - return envelope.make_error_response( - ErrorCode.E_BAD_PAYLOAD.value, - f"能力调用 payload 格式错误: {e}", - ) + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 格式错误: {e}") capability = req.capability # 1. 权限校验 allowed, reason = self._authorization.check_capability(plugin_id, capability) if not allowed: - error_code = ( - ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED - ) - return envelope.make_error_response( - error_code.value, - reason, - ) + return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason) # 2. 查找实现 impl = self._implementations.get(capability) if impl is None: - return envelope.make_error_response( - ErrorCode.E_METHOD_NOT_ALLOWED.value, - f"未注册的能力: {capability}", - ) + return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}") # 3. 执行 try: @@ -94,10 +78,7 @@ class CapabilityService: return envelope.make_error_response(e.code.value, e.message, e.details) except Exception as e: logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True) - return envelope.make_error_response( - ErrorCode.E_CAPABILITY_FAILED.value, - str(e), - ) + return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e)) def list_capabilities(self) -> List[str]: """列出所有已注册的能力""" diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py index 79fe0d9a..75ef9b2a 100644 --- a/src/plugin_runtime/host/rpc_server.py +++ b/src/plugin_runtime/host/rpc_server.py @@ -7,7 +7,7 @@ 4. 请求-响应关联与超时管理 """ -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine import asyncio import contextlib @@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer logger = get_logger("plugin_runtime.host.rpc_server") # RPC 方法处理器类型 -MethodHandler = Callable[[Envelope], Awaitable[Envelope]] +MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]] class RPCServer: @@ -55,109 +55,29 @@ class RPCServer: self._id_gen = RequestIdGenerator() self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接 - self._runner_id: Optional[str] = None - self._runner_generation: int = 0 - self._staged_connection: Optional[Connection] = None - self._staged_runner_id: Optional[str] = None - self._staged_runner_generation: int = 0 - self._staging_takeover: bool = False # 方法处理器注册表 self._method_handlers: Dict[str, MethodHandler] = {} - # 等待响应的 pending 请求: request_id -> (Future, target_generation) - self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {} + # 等待响应的 pending 请求: request_id -> Future + self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {} # 发送队列(背压控制) self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None - self._send_worker_task: Optional[asyncio.Task] = None + self._send_worker_task: Optional[asyncio.Task[None]] = None # 运行状态 self._running: bool = False - self._tasks: List[asyncio.Task] = [] + self._tasks: List[asyncio.Task[None]] = [] @property def session_token(self) -> str: return self._session_token - def reset_session_token(self) -> str: - """重新生成会话令牌(热重载时调用,防止旧 Runner 重连)""" - self._session_token = secrets.token_hex(32) - return self._session_token - - def restore_session_token(self, token: str) -> None: - """恢复指定的会话令牌(热重载回滚时调用)""" - self._session_token = token - - @property - def runner_generation(self) -> int: - return self._runner_generation - - @property - def staged_generation(self) -> int: - return self._staged_runner_generation - @property def is_connected(self) -> bool: return self._connection is not None and not self._connection.is_closed - def has_generation(self, generation: int) -> bool: - return generation == self._runner_generation or ( - self._staged_connection is not None - and not self._staged_connection.is_closed - and generation == self._staged_runner_generation - ) - - def begin_staged_takeover(self) -> None: - """允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接。""" - self._staging_takeover = True - - async def commit_staged_takeover(self) -> None: - """提交 staged Runner,原活跃连接在提交后被关闭。""" - if self._staged_connection is None or self._staged_connection.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接") - - old_connection = self._connection - old_generation = self._runner_generation - - self._connection = self._staged_connection - self._runner_id = self._staged_runner_id - self._runner_generation = self._staged_runner_generation - - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._staging_takeover = False - - if stale_count := self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已被新 generation 接管", - generation=old_generation, - ): - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - - if old_connection and old_connection is not self._connection and not old_connection.is_closed: - await old_connection.close() - - async def rollback_staged_takeover(self) -> None: - """放弃 staged Runner,保留当前活跃连接。""" - staged_connection = self._staged_connection - staged_generation = self._staged_runner_generation - - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._staging_takeover = False - - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "新 Runner 预热失败,已回滚", - generation=staged_generation, - ) - - if staged_connection and not staged_connection.is_closed: - await staged_connection.close() - def register_method(self, method: str, handler: MethodHandler) -> None: """注册 RPC 方法处理器""" self._method_handlers[method] = handler @@ -173,14 +93,8 @@ class RPCServer: async def stop(self) -> None: """停止 RPC 服务器""" self._running = False - - # 取消所有 pending 请求 - for future, _generation in self._pending_requests.values(): - if not future.done(): - future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) - self._pending_requests.clear() - - self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭") + self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") + self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭") if self._send_worker_task: self._send_worker_task.cancel() @@ -198,10 +112,6 @@ class RPCServer: await self._connection.close() self._connection = None - if self._staged_connection: - await self._staged_connection.close() - self._staged_connection = None - await self._transport.stop() logger.info("RPC Server 已停止") @@ -211,7 +121,6 @@ class RPCServer: plugin_id: str = "", payload: Optional[Dict[str, Any]] = None, timeout_ms: int = 30000, - target_generation: Optional[int] = None, ) -> Envelope: """向 Runner 发送 RPC 请求并等待响应 @@ -227,18 +136,14 @@ class RPCServer: Raises: RPCError: 调用失败 """ - generation = target_generation or self._runner_generation - conn = self._get_connection_for_generation(generation) - if conn is None or conn.is_closed: + if not self._connection or self._connection.is_closed: raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - - request_id = self._id_gen.next() + request_id = await self._id_gen.next() envelope = Envelope( request_id=request_id, message_type=MessageType.REQUEST, method=method, plugin_id=plugin_id, - generation=generation, timeout_ms=timeout_ms, payload=payload or {}, ) @@ -246,12 +151,12 @@ class RPCServer: # 注册 pending future loop = asyncio.get_running_loop() future: asyncio.Future[Envelope] = loop.create_future() - self._pending_requests[request_id] = (future, generation) + self._pending_requests[request_id] = future try: # 发送请求 data = self._codec.encode_envelope(envelope) - await self._enqueue_send(conn, data) + await self._enqueue_send(self._connection, data) # 等待响应 timeout_sec = timeout_ms / 1000.0 @@ -265,93 +170,66 @@ class RPCServer: raise raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e - async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None: - """向 Runner 发送单向事件(不等待响应)""" - conn = self._connection - if conn is None or conn.is_closed: - return + # ============ 内部方法 ============ + # ========= 发送循环 ========= + async def _send_loop(self) -> None: + """后台发送循环:串行消费发送队列,统一执行连接写入。""" + if self._send_queue is None: + raise RuntimeError("没有消息队列") - request_id = self._id_gen.next() - envelope = Envelope( - request_id=request_id, - message_type=MessageType.EVENT, - method=method, - plugin_id=plugin_id, - generation=self._runner_generation, - payload=payload or {}, - ) - data = self._codec.encode_envelope(envelope) - await self._enqueue_send(conn, data) + while True: + try: + conn, data, send_future = await self._send_queue.get() + except asyncio.CancelledError: + break - # ─── 内部方法 ────────────────────────────────────────────── + try: + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + await conn.send_frame(data) + if not send_future.done(): + send_future.set_result(None) + except asyncio.CancelledError: + if not send_future.done(): + send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) + raise + except Exception as e: + send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED}) + if not send_future.done(): + send_future.set_exception(send_error) + finally: + self._send_queue.task_done() + # ====== 发送循环方法 ====== async def _handle_connection(self, conn: Connection) -> None: """处理新的 Runner 连接""" logger.info("收到 Runner 连接") - previous_connection = self._connection - previous_generation = self._runner_generation - # 第一条消息必须是 runner.hello 握手 try: - role = await self._handle_handshake(conn) - if role is None: + success = await self._handle_handshake(conn) + if not success: await conn.close() return except Exception as e: logger.error(f"握手失败: {e}") await conn.close() return - - if role == "staged": - expected_generation = self._staged_runner_generation - logger.info( - f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}" - ) - else: - self._connection = conn - expected_generation = self._runner_generation - logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") - - if previous_connection and previous_connection is not conn and not previous_connection.is_closed: - logger.info("检测到新 Runner 已接管连接,关闭旧连接") - if stale_count := self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已被新 generation 接管", - generation=previous_generation, - ): - logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求") - await previous_connection.close() - + logger.info("Runner staged 握手成功") + self._connection = conn # 启动消息接收循环 try: - await self._recv_loop(conn, expected_generation=expected_generation) + await self._recv_loop(conn) except Exception as e: logger.error(f"连接异常断开: {e}") finally: - if self._connection is conn: - self._connection = None - self._runner_id = None - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Runner 连接已断开", - generation=expected_generation, - ) - elif self._staged_connection is conn: - self._staged_connection = None - self._staged_runner_id = None - self._staged_runner_generation = 0 - self._fail_pending_requests( - ErrorCode.E_PLUGIN_CRASHED, - "Staged Runner 连接已断开", - generation=expected_generation, - ) + self._connection = None + self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开") - async def _handle_handshake(self, conn: Connection) -> Optional[str]: + async def _handle_handshake(self, conn: Connection) -> bool: """处理 runner.hello 握手""" # 接收握手请求 data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0) envelope = self._codec.decode_envelope(data) - if envelope.method != "runner.hello": logger.error(f"期望 runner.hello,收到 {envelope.method}") error_resp = envelope.make_error_response( @@ -359,21 +237,17 @@ class RPCServer: "首条消息必须为 runner.hello", ) await conn.send_frame(self._codec.encode_envelope(error_resp)) - return None + return False # 解析握手 payload hello = HelloPayload.model_validate(envelope.payload) - # 校验会话令牌 if hello.session_token != self._session_token: logger.error("会话令牌不匹配") - resp_payload = HelloResponsePayload( - accepted=False, - reason="会话令牌无效", - ) + resp_payload = HelloResponsePayload(accepted=False, reason="会话令牌无效") resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return None + return False # 校验 SDK 版本 if not self._check_sdk_version(hello.sdk_version): @@ -384,31 +258,26 @@ class RPCServer: ) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) - return None + return False - # 握手成功 - role = "active" - assigned_generation = self._runner_generation + 1 - if self._staging_takeover and self.is_connected: - role = "staged" - self._staged_connection = conn - self._staged_runner_id = hello.runner_id - self._staged_runner_generation = assigned_generation - else: - self._runner_id = hello.runner_id - self._runner_generation = assigned_generation - - resp_payload = HelloResponsePayload( - accepted=True, - host_version=PROTOCOL_VERSION, - assigned_generation=assigned_generation, - ) + # 发送响应 + resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION) resp = envelope.make_response(payload=resp_payload.model_dump()) await conn.send_frame(self._codec.encode_envelope(resp)) + return True - return role + def _check_sdk_version(self, sdk_version: str) -> bool: + """检查 SDK 版本是否在支持范围内""" + try: + sdk_parts = _parse_version_tuple(sdk_version) + min_parts = _parse_version_tuple(MIN_SDK_VERSION) + max_parts = _parse_version_tuple(MAX_SDK_VERSION) + return min_parts <= sdk_parts <= max_parts + except (ValueError, AttributeError): + return False - async def _recv_loop(self, conn: Connection, expected_generation: int) -> None: + # ========= 接收循环 ========= + async def _recv_loop(self, conn: Connection) -> None: """消息接收主循环""" while self._running and not conn.is_closed: try: @@ -430,109 +299,40 @@ class RPCServer: if envelope.is_response(): self._handle_response(envelope) elif envelope.is_request(): - if envelope.generation != expected_generation: - error_resp = envelope.make_error_response( - ErrorCode.E_GENERATION_MISMATCH.value, - f"过期 generation: {envelope.generation} != {expected_generation}", - ) - await conn.send_frame(self._codec.encode_envelope(error_resp)) - continue # 异步处理请求(Runner 发来的能力调用) task = asyncio.create_task(self._handle_request(envelope, conn)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) - elif envelope.is_event(): - if envelope.generation != expected_generation: - logger.warning( - f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}" - ) - continue - task = asyncio.create_task(self._handle_event(envelope)) + elif envelope.is_broadcast(): + task = asyncio.create_task(self._handle_broadcast(envelope)) self._tasks.append(task) task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) + else: + logger.warning(f"未知的消息类型: {envelope.message_type}") + continue + # ====== 接收循环内部方法 ====== def _handle_response(self, envelope: Envelope) -> None: """处理来自 Runner 的响应""" - pending = self._pending_requests.get(envelope.request_id) - if pending is None: + pending_future = self._pending_requests.pop(envelope.request_id, None) + if pending_future is None: return - - future, expected_generation = pending - if envelope.generation != expected_generation: - logger.warning( - f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}" - ) - return - - self._pending_requests.pop(envelope.request_id, None) - if not future.done(): + if not pending_future.done(): if envelope.error: - future.set_exception(RPCError.from_dict(envelope.error)) + pending_future.set_exception(RPCError.from_dict(envelope.error)) else: - future.set_result(envelope) - - async def _enqueue_send(self, conn: Connection, data: bytes) -> None: - """通过发送队列串行发送消息,提供真实背压。""" - if conn.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - - if self._send_queue is None: - await conn.send_frame(data) - return - - loop = asyncio.get_running_loop() - send_future: asyncio.Future[None] = loop.create_future() - - try: - self._send_queue.put_nowait((conn, data, send_future)) - except asyncio.QueueFull: - raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None - - await send_future - - async def _send_loop(self) -> None: - """后台发送循环:串行消费发送队列,统一执行连接写入。""" - if self._send_queue is None: - return - - while True: - try: - conn, data, send_future = await self._send_queue.get() - except asyncio.CancelledError: - break - - try: - if conn.is_closed: - raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") - await conn.send_frame(data) - if not send_future.done(): - send_future.set_result(None) - except asyncio.CancelledError: - if not send_future.done(): - send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) - raise - except Exception as e: - send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e) - if not send_future.done(): - send_future.set_exception(send_error) - finally: - self._send_queue.task_done() - - @staticmethod - def _normalize_send_exception(error: Exception) -> RPCError: - if isinstance(error, ConnectionError): - return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error)) - return RPCError(ErrorCode.E_UNKNOWN, str(error)) + pending_future.set_result(envelope) async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: """处理来自 Runner 的请求(通常是能力调用 cap.*)""" - handler = self._method_handlers.get(envelope.method) - if handler is None: - error_resp = envelope.make_error_response( + target_method = envelope.method + handler = self._method_handlers.get(target_method) + if not handler: + error_response = envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的方法: {envelope.method}", ) - await conn.send_frame(self._codec.encode_envelope(error_resp)) + await conn.send_frame(self._codec.encode_envelope(error_response)) return try: @@ -546,59 +346,25 @@ class RPCServer: error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) await conn.send_frame(self._codec.encode_envelope(error_resp)) - async def _handle_event(self, envelope: Envelope) -> None: - """处理来自 Runner 的事件""" + async def _handle_broadcast(self, envelope: Envelope) -> None: if handler := self._method_handlers.get(envelope.method): try: result = await handler(envelope) # 检查 handler 返回的信封是否包含错误信息 - if result is not None and isinstance(result, Envelope) and result.error: + if result.error: logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}") except Exception as e: logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) - @staticmethod - def _check_sdk_version(sdk_version: str) -> bool: - """检查 SDK 版本是否在支持范围内""" - try: - sdk_parts = RPCServer._parse_version_tuple(sdk_version) - min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION) - max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION) - return min_parts <= sdk_parts <= max_parts - except (ValueError, AttributeError): - return False - - @staticmethod - def _parse_version_tuple(version: str) -> Tuple[int, int, int]: - base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0] - base_version = base_version.split("+", 1)[0] - parts = [part for part in base_version.split(".") if part != ""] - while len(parts) < 3: - parts.append("0") - return (int(parts[0]), int(parts[1]), int(parts[2])) - - def _get_connection_for_generation(self, generation: int) -> Optional[Connection]: - if generation == self._runner_generation: - return self._connection - if generation == self._staged_runner_generation: - return self._staged_connection - return None - - def _fail_pending_requests( - self, - error_code: ErrorCode, - message: str, - generation: Optional[int] = None, - ) -> int: - stale_count = 0 - for request_id, (future, request_generation) in list(self._pending_requests.items()): - if generation is not None and request_generation != generation: - continue + def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int: + """失败所有等待中的请求(如连接断开时)""" + aborted_request_count = 0 + for future in self._pending_requests.values(): if not future.done(): future.set_exception(RPCError(error_code, message)) - stale_count += 1 - self._pending_requests.pop(request_id, None) - return stale_count + aborted_request_count += 1 + self._pending_requests.clear() + return aborted_request_count def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int: if self._send_queue is None: @@ -617,3 +383,31 @@ class RPCServer: self._send_queue.task_done() return failed_count + + async def _enqueue_send(self, conn: Connection, data: bytes) -> None: + """通过发送队列串行发送消息,提供真实背压。""" + if conn.is_closed: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + + if self._send_queue is None: + await conn.send_frame(data) + return + + loop = asyncio.get_running_loop() + send_future: asyncio.Future[None] = loop.create_future() + + try: + self._send_queue.put_nowait((conn, data, send_future)) + except asyncio.QueueFull: + raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None + + await send_future + + +def _parse_version_tuple(version: str) -> Tuple[int, int, int]: + base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0] + base_version = base_version.split("+", 1)[0] + parts = [part for part in base_version.split(".") if part != ""] + while len(parts) < 3: + parts.append("0") + return (int(parts[0]), int(parts[1]), int(parts[2])) diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py index dcae6b8f..ed19760d 100644 --- a/src/plugin_runtime/protocol/errors.py +++ b/src/plugin_runtime/protocol/errors.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional -class ErrorCode(str, Enum): +class ErrorCode(Enum): """RPC 错误码枚举""" # 通用 @@ -18,17 +18,17 @@ class ErrorCode(str, Enum): E_TIMEOUT = "E_TIMEOUT" E_BAD_PAYLOAD = "E_BAD_PAYLOAD" E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH" + E_SHUTTING_DOWN = "E_SHUTTING_DOWN" # 权限与策略 E_UNAUTHORIZED = "E_UNAUTHORIZED" E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED" - E_BACKPRESSURE = "E_BACKPRESSURE" + E_BACK_PRESSURE = "E_BACK_PRESSURE" E_HOST_OVERLOADED = "E_HOST_OVERLOADED" # 插件生命周期 E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED" E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND" - E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH" E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS" # 能力调用 @@ -65,3 +65,13 @@ class RPCError(Exception): message=data.get("message", ""), details=data.get("details", {}), ) + + @classmethod + def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None): + if isinstance(exception, cls): + return exception + if code_mapping: + for exception_type, code in code_mapping.items(): + if isinstance(exception, exception_type): + return cls(code=code, message=str(exception)) + return cls(ErrorCode.E_UNKNOWN, str(exception))