pylance fix x2
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
5. 发送能力调用请求到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -90,6 +90,13 @@ class RPCClient:
|
||||
"""注册方法处理器(处理 Host 发来的请求)"""
|
||||
self._method_handlers[method] = handler
|
||||
|
||||
def _require_connection(self) -> Connection:
|
||||
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
|
||||
connection = self._connection
|
||||
if connection is None or connection.is_closed:
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||||
return cast(Connection, connection)
|
||||
|
||||
async def connect_and_handshake(self) -> bool:
|
||||
"""连接 Host 并完成握手
|
||||
|
||||
@@ -98,6 +105,7 @@ class RPCClient:
|
||||
"""
|
||||
client = create_transport_client(self._host_address)
|
||||
self._connection = await client.connect()
|
||||
connection = self._require_connection()
|
||||
|
||||
# 发送 runner.hello
|
||||
hello = HelloPayload(
|
||||
@@ -114,10 +122,10 @@ class RPCClient:
|
||||
)
|
||||
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
await connection.send_frame(data)
|
||||
|
||||
# 接收握手响应
|
||||
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
|
||||
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
|
||||
resp = self._codec.decode_envelope(resp_data)
|
||||
|
||||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||||
@@ -170,8 +178,7 @@ class RPCClient:
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Host 发送 RPC 请求并等待响应"""
|
||||
if not self.is_connected:
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||||
connection = self._require_connection()
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
@@ -190,7 +197,7 @@ class RPCClient:
|
||||
|
||||
try:
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
await connection.send_frame(data)
|
||||
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||||
@@ -221,6 +228,8 @@ class RPCClient:
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
connection = self._require_connection()
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
@@ -231,7 +240,7 @@ class RPCClient:
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
await connection.send_frame(data)
|
||||
|
||||
async def _recv_loop(self) -> None:
|
||||
"""消息接收主循环"""
|
||||
@@ -271,25 +280,31 @@ class RPCClient:
|
||||
|
||||
async def _handle_request(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的请求(调用插件组件)"""
|
||||
connection = self._connection
|
||||
if connection is None or connection.is_closed:
|
||||
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
|
||||
return
|
||||
connection = cast(Connection, connection)
|
||||
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的方法: {envelope.method}",
|
||||
)
|
||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return
|
||||
|
||||
try:
|
||||
response = await handler(envelope)
|
||||
await self._connection.send_frame(self._codec.encode_envelope(response))
|
||||
await connection.send_frame(self._codec.encode_envelope(response))
|
||||
except RPCError as e:
|
||||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的事件"""
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Protocol, cast
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -43,6 +43,10 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
|
||||
logger = get_logger("plugin_runtime.runner.main")
|
||||
|
||||
|
||||
class _ContextAwarePlugin(Protocol):
|
||||
def _set_context(self, context: Any) -> None: ...
|
||||
|
||||
|
||||
def _disable_runner_console_logging() -> None:
|
||||
"""关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。"""
|
||||
root_logger = stdlib_logging.getLogger()
|
||||
@@ -204,7 +208,9 @@ class PluginRunner:
|
||||
rpc_client = self._rpc_client
|
||||
bound_plugin_id = plugin_id
|
||||
|
||||
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
|
||||
async def _rpc_call(
|
||||
method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""桥接 PluginContext.call_capability → RPCClient.send_request。
|
||||
|
||||
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
|
||||
@@ -225,7 +231,7 @@ class PluginRunner:
|
||||
return resp.payload.get("result")
|
||||
|
||||
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
|
||||
instance._set_context(ctx)
|
||||
cast(_ContextAwarePlugin, instance)._set_context(ctx)
|
||||
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
||||
|
||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
|
||||
@@ -543,8 +549,7 @@ class PluginRunner:
|
||||
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
|
||||
"""处理配置更新事件"""
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta:
|
||||
if meta := self._loader.get_plugin(plugin_id):
|
||||
try:
|
||||
config_data = envelope.payload.get("config_data", {})
|
||||
config_version = envelope.payload.get("config_version", "")
|
||||
|
||||
Reference in New Issue
Block a user