From 2e080e437a1261f5069ed22d4a0cf002d7917d8a Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 14 Mar 2026 02:08:50 +0800 Subject: [PATCH] pylance fix x2 --- src/plugin_runtime/runner/rpc_client.py | 37 +++++++++++++++++------- src/plugin_runtime/runner/runner_main.py | 15 ++++++---- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py index 88cd4fee..6a1d59d5 100644 --- a/src/plugin_runtime/runner/rpc_client.py +++ b/src/plugin_runtime/runner/rpc_client.py @@ -8,7 +8,7 @@ 5. 发送能力调用请求到 Host """ -from typing import Any, Awaitable, Callable, Dict, Optional +from typing import Any, Awaitable, Callable, Dict, Optional, cast import asyncio import contextlib @@ -90,6 +90,13 @@ class RPCClient: """注册方法处理器(处理 Host 发来的请求)""" self._method_handlers[method] = handler + def _require_connection(self) -> Connection: + """返回当前可用连接;若连接不可用则抛出 RPCError。""" + connection = self._connection + if connection is None or connection.is_closed: + raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host") + return cast(Connection, connection) + async def connect_and_handshake(self) -> bool: """连接 Host 并完成握手 @@ -98,6 +105,7 @@ class RPCClient: """ client = create_transport_client(self._host_address) self._connection = await client.connect() + connection = self._require_connection() # 发送 runner.hello hello = HelloPayload( @@ -114,10 +122,10 @@ class RPCClient: ) data = self._codec.encode_envelope(envelope) - await self._connection.send_frame(data) + await connection.send_frame(data) # 接收握手响应 - resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0) + resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0) resp = self._codec.decode_envelope(resp_data) resp_payload = HelloResponsePayload.model_validate(resp.payload) @@ -170,8 +178,7 @@ class RPCClient: timeout_ms: int = 30000, ) -> Envelope: """向 Host 发送 RPC 请求并等待响应""" - if not self.is_connected: - raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host") + connection = self._require_connection() request_id = self._id_gen.next() envelope = Envelope( @@ -190,7 +197,7 @@ class RPCClient: try: data = self._codec.encode_envelope(envelope) - await self._connection.send_frame(data) + await connection.send_frame(data) timeout_sec = timeout_ms / 1000.0 return await asyncio.wait_for(future, timeout=timeout_sec) @@ -221,6 +228,8 @@ class RPCClient: if not self.is_connected: return + connection = self._require_connection() + request_id = self._id_gen.next() envelope = Envelope( request_id=request_id, @@ -231,7 +240,7 @@ class RPCClient: payload=payload or {}, ) data = self._codec.encode_envelope(envelope) - await self._connection.send_frame(data) + await connection.send_frame(data) async def _recv_loop(self) -> None: """消息接收主循环""" @@ -271,25 +280,31 @@ class RPCClient: async def _handle_request(self, envelope: Envelope) -> None: """处理来自 Host 的请求(调用插件组件)""" + connection = self._connection + if connection is None or connection.is_closed: + logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应") + return + connection = cast(Connection, connection) + handler = self._method_handlers.get(envelope.method) if handler is None: error_resp = envelope.make_error_response( ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的方法: {envelope.method}", ) - await self._connection.send_frame(self._codec.encode_envelope(error_resp)) + await connection.send_frame(self._codec.encode_envelope(error_resp)) return try: response = await handler(envelope) - await self._connection.send_frame(self._codec.encode_envelope(response)) + await connection.send_frame(self._codec.encode_envelope(response)) except RPCError as e: error_resp = envelope.make_error_response(e.code.value, e.message, e.details) - await self._connection.send_frame(self._codec.encode_envelope(error_resp)) + await connection.send_frame(self._codec.encode_envelope(error_resp)) except Exception as e: logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) - await self._connection.send_frame(self._codec.encode_envelope(error_resp)) + await connection.send_frame(self._codec.encode_envelope(error_resp)) async def _handle_event(self, envelope: Envelope) -> None: """处理来自 Host 的事件""" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index e06f76ef..8db330e4 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -9,7 +9,7 @@ 6. 转发插件的能力调用到 Host """ -from typing import Any, List, Optional +from typing import Any, List, Optional, Protocol, cast from pathlib import Path @@ -43,6 +43,10 @@ from src.plugin_runtime.runner.rpc_client import RPCClient logger = get_logger("plugin_runtime.runner.main") +class _ContextAwarePlugin(Protocol): + def _set_context(self, context: Any) -> None: ... + + def _disable_runner_console_logging() -> None: """关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。""" root_logger = stdlib_logging.getLogger() @@ -204,7 +208,9 @@ class PluginRunner: rpc_client = self._rpc_client bound_plugin_id = plugin_id - async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any: + async def _rpc_call( + method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None + ) -> Any: """桥接 PluginContext.call_capability → RPCClient.send_request。 无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id @@ -225,7 +231,7 @@ class PluginRunner: return resp.payload.get("result") ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call) - instance._set_context(ctx) + cast(_ContextAwarePlugin, instance)._set_context(ctx) logger.debug(f"已为插件 {plugin_id} 注入 PluginContext") def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None: @@ -543,8 +549,7 @@ class PluginRunner: async def _handle_config_updated(self, envelope: Envelope) -> Envelope: """处理配置更新事件""" plugin_id = envelope.plugin_id - meta = self._loader.get_plugin(plugin_id) - if meta: + if meta := self._loader.get_plugin(plugin_id): try: config_data = envelope.payload.get("config_data", {}) config_version = envelope.payload.get("config_version", "")