pylance fix x2

This commit is contained in:
DrSmoothl
2026-03-14 02:08:50 +08:00
parent 10ff7a01c2
commit 2e080e437a
2 changed files with 36 additions and 16 deletions

View File

@@ -8,7 +8,7 @@
5. 发送能力调用请求到 Host 5. 发送能力调用请求到 Host
""" """
from typing import Any, Awaitable, Callable, Dict, Optional from typing import Any, Awaitable, Callable, Dict, Optional, cast
import asyncio import asyncio
import contextlib import contextlib
@@ -90,6 +90,13 @@ class RPCClient:
"""注册方法处理器(处理 Host 发来的请求)""" """注册方法处理器(处理 Host 发来的请求)"""
self._method_handlers[method] = handler self._method_handlers[method] = handler
def _require_connection(self) -> Connection:
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
connection = self._connection
if connection is None or connection.is_closed:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
return cast(Connection, connection)
async def connect_and_handshake(self) -> bool: async def connect_and_handshake(self) -> bool:
"""连接 Host 并完成握手 """连接 Host 并完成握手
@@ -98,6 +105,7 @@ class RPCClient:
""" """
client = create_transport_client(self._host_address) client = create_transport_client(self._host_address)
self._connection = await client.connect() self._connection = await client.connect()
connection = self._require_connection()
# 发送 runner.hello # 发送 runner.hello
hello = HelloPayload( hello = HelloPayload(
@@ -114,10 +122,10 @@ class RPCClient:
) )
data = self._codec.encode_envelope(envelope) data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data) await connection.send_frame(data)
# 接收握手响应 # 接收握手响应
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0) resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
resp = self._codec.decode_envelope(resp_data) resp = self._codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(resp.payload) resp_payload = HelloResponsePayload.model_validate(resp.payload)
@@ -170,8 +178,7 @@ class RPCClient:
timeout_ms: int = 30000, timeout_ms: int = 30000,
) -> Envelope: ) -> Envelope:
"""向 Host 发送 RPC 请求并等待响应""" """向 Host 发送 RPC 请求并等待响应"""
if not self.is_connected: connection = self._require_connection()
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
request_id = self._id_gen.next() request_id = self._id_gen.next()
envelope = Envelope( envelope = Envelope(
@@ -190,7 +197,7 @@ class RPCClient:
try: try:
data = self._codec.encode_envelope(envelope) data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data) await connection.send_frame(data)
timeout_sec = timeout_ms / 1000.0 timeout_sec = timeout_ms / 1000.0
return await asyncio.wait_for(future, timeout=timeout_sec) return await asyncio.wait_for(future, timeout=timeout_sec)
@@ -221,6 +228,8 @@ class RPCClient:
if not self.is_connected: if not self.is_connected:
return return
connection = self._require_connection()
request_id = self._id_gen.next() request_id = self._id_gen.next()
envelope = Envelope( envelope = Envelope(
request_id=request_id, request_id=request_id,
@@ -231,7 +240,7 @@ class RPCClient:
payload=payload or {}, payload=payload or {},
) )
data = self._codec.encode_envelope(envelope) data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data) await connection.send_frame(data)
async def _recv_loop(self) -> None: async def _recv_loop(self) -> None:
"""消息接收主循环""" """消息接收主循环"""
@@ -271,25 +280,31 @@ class RPCClient:
async def _handle_request(self, envelope: Envelope) -> None: async def _handle_request(self, envelope: Envelope) -> None:
"""处理来自 Host 的请求(调用插件组件)""" """处理来自 Host 的请求(调用插件组件)"""
connection = self._connection
if connection is None or connection.is_closed:
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
return
connection = cast(Connection, connection)
handler = self._method_handlers.get(envelope.method) handler = self._method_handlers.get(envelope.method)
if handler is None: if handler is None:
error_resp = envelope.make_error_response( error_resp = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value, ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}", f"未注册的方法: {envelope.method}",
) )
await self._connection.send_frame(self._codec.encode_envelope(error_resp)) await connection.send_frame(self._codec.encode_envelope(error_resp))
return return
try: try:
response = await handler(envelope) response = await handler(envelope)
await self._connection.send_frame(self._codec.encode_envelope(response)) await connection.send_frame(self._codec.encode_envelope(response))
except RPCError as e: except RPCError as e:
error_resp = envelope.make_error_response(e.code.value, e.message, e.details) error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
await self._connection.send_frame(self._codec.encode_envelope(error_resp)) await connection.send_frame(self._codec.encode_envelope(error_resp))
except Exception as e: except Exception as e:
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await self._connection.send_frame(self._codec.encode_envelope(error_resp)) await connection.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None: async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Host 的事件""" """处理来自 Host 的事件"""

View File

@@ -9,7 +9,7 @@
6. 转发插件的能力调用到 Host 6. 转发插件的能力调用到 Host
""" """
from typing import Any, List, Optional from typing import Any, List, Optional, Protocol, cast
from pathlib import Path from pathlib import Path
@@ -43,6 +43,10 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main") logger = get_logger("plugin_runtime.runner.main")
class _ContextAwarePlugin(Protocol):
def _set_context(self, context: Any) -> None: ...
def _disable_runner_console_logging() -> None: def _disable_runner_console_logging() -> None:
"""关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。""" """关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。"""
root_logger = stdlib_logging.getLogger() root_logger = stdlib_logging.getLogger()
@@ -204,7 +208,9 @@ class PluginRunner:
rpc_client = self._rpc_client rpc_client = self._rpc_client
bound_plugin_id = plugin_id bound_plugin_id = plugin_id
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any: async def _rpc_call(
method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
) -> Any:
"""桥接 PluginContext.call_capability → RPCClient.send_request。 """桥接 PluginContext.call_capability → RPCClient.send_request。
无论调用方传入何种 plugin_id实际发往 Host 的 plugin_id 无论调用方传入何种 plugin_id实际发往 Host 的 plugin_id
@@ -225,7 +231,7 @@ class PluginRunner:
return resp.payload.get("result") return resp.payload.get("result")
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call) ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
instance._set_context(ctx) cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext") logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None: def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
@@ -543,8 +549,7 @@ class PluginRunner:
async def _handle_config_updated(self, envelope: Envelope) -> Envelope: async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
"""处理配置更新事件""" """处理配置更新事件"""
plugin_id = envelope.plugin_id plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id) if meta := self._loader.get_plugin(plugin_id):
if meta:
try: try:
config_data = envelope.payload.get("config_data", {}) config_data = envelope.payload.get("config_data", {})
config_version = envelope.payload.get("config_version", "") config_version = envelope.payload.get("config_version", "")