pylance fix x2

This commit is contained in:
DrSmoothl
2026-03-14 02:08:50 +08:00
parent 10ff7a01c2
commit 2e080e437a
2 changed files with 36 additions and 16 deletions

View File

@@ -8,7 +8,7 @@
5. 发送能力调用请求到 Host
"""
from typing import Any, Awaitable, Callable, Dict, Optional
from typing import Any, Awaitable, Callable, Dict, Optional, cast
import asyncio
import contextlib
@@ -90,6 +90,13 @@ class RPCClient:
"""注册方法处理器(处理 Host 发来的请求)"""
self._method_handlers[method] = handler
def _require_connection(self) -> Connection:
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
connection = self._connection
if connection is None or connection.is_closed:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
return cast(Connection, connection)
async def connect_and_handshake(self) -> bool:
"""连接 Host 并完成握手
@@ -98,6 +105,7 @@ class RPCClient:
"""
client = create_transport_client(self._host_address)
self._connection = await client.connect()
connection = self._require_connection()
# 发送 runner.hello
hello = HelloPayload(
@@ -114,10 +122,10 @@ class RPCClient:
)
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
await connection.send_frame(data)
# 接收握手响应
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
resp = self._codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(resp.payload)
@@ -170,8 +178,7 @@ class RPCClient:
timeout_ms: int = 30000,
) -> Envelope:
"""向 Host 发送 RPC 请求并等待响应"""
if not self.is_connected:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
connection = self._require_connection()
request_id = self._id_gen.next()
envelope = Envelope(
@@ -190,7 +197,7 @@ class RPCClient:
try:
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
await connection.send_frame(data)
timeout_sec = timeout_ms / 1000.0
return await asyncio.wait_for(future, timeout=timeout_sec)
@@ -221,6 +228,8 @@ class RPCClient:
if not self.is_connected:
return
connection = self._require_connection()
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
@@ -231,7 +240,7 @@ class RPCClient:
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
await connection.send_frame(data)
async def _recv_loop(self) -> None:
"""消息接收主循环"""
@@ -271,25 +280,31 @@ class RPCClient:
async def _handle_request(self, envelope: Envelope) -> None:
"""处理来自 Host 的请求(调用插件组件)"""
connection = self._connection
if connection is None or connection.is_closed:
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
return
connection = cast(Connection, connection)
handler = self._method_handlers.get(envelope.method)
if handler is None:
error_resp = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
await connection.send_frame(self._codec.encode_envelope(error_resp))
return
try:
response = await handler(envelope)
await self._connection.send_frame(self._codec.encode_envelope(response))
await connection.send_frame(self._codec.encode_envelope(response))
except RPCError as e:
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
await connection.send_frame(self._codec.encode_envelope(error_resp))
except Exception as e:
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
await connection.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Host 的事件"""

View File

@@ -9,7 +9,7 @@
6. 转发插件的能力调用到 Host
"""
from typing import Any, List, Optional
from typing import Any, List, Optional, Protocol, cast
from pathlib import Path
@@ -43,6 +43,10 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main")
class _ContextAwarePlugin(Protocol):
def _set_context(self, context: Any) -> None: ...
def _disable_runner_console_logging() -> None:
"""关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。"""
root_logger = stdlib_logging.getLogger()
@@ -204,7 +208,9 @@ class PluginRunner:
rpc_client = self._rpc_client
bound_plugin_id = plugin_id
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
async def _rpc_call(
method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
) -> Any:
"""桥接 PluginContext.call_capability → RPCClient.send_request。
无论调用方传入何种 plugin_id实际发往 Host 的 plugin_id
@@ -225,7 +231,7 @@ class PluginRunner:
return resp.payload.get("result")
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
instance._set_context(ctx)
cast(_ContextAwarePlugin, instance)._set_context(ctx)
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
@@ -543,8 +549,7 @@ class PluginRunner:
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
"""处理配置更新事件"""
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta:
if meta := self._loader.get_plugin(plugin_id):
try:
config_data = envelope.payload.get("config_data", {})
config_version = envelope.payload.get("config_version", "")