pylance fix x2
This commit is contained in:
@@ -8,7 +8,7 @@
|
|||||||
5. 发送能力调用请求到 Host
|
5. 发送能力调用请求到 Host
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
from typing import Any, Awaitable, Callable, Dict, Optional, cast
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
@@ -90,6 +90,13 @@ class RPCClient:
|
|||||||
"""注册方法处理器(处理 Host 发来的请求)"""
|
"""注册方法处理器(处理 Host 发来的请求)"""
|
||||||
self._method_handlers[method] = handler
|
self._method_handlers[method] = handler
|
||||||
|
|
||||||
|
def _require_connection(self) -> Connection:
|
||||||
|
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
|
||||||
|
connection = self._connection
|
||||||
|
if connection is None or connection.is_closed:
|
||||||
|
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||||||
|
return cast(Connection, connection)
|
||||||
|
|
||||||
async def connect_and_handshake(self) -> bool:
|
async def connect_and_handshake(self) -> bool:
|
||||||
"""连接 Host 并完成握手
|
"""连接 Host 并完成握手
|
||||||
|
|
||||||
@@ -98,6 +105,7 @@ class RPCClient:
|
|||||||
"""
|
"""
|
||||||
client = create_transport_client(self._host_address)
|
client = create_transport_client(self._host_address)
|
||||||
self._connection = await client.connect()
|
self._connection = await client.connect()
|
||||||
|
connection = self._require_connection()
|
||||||
|
|
||||||
# 发送 runner.hello
|
# 发送 runner.hello
|
||||||
hello = HelloPayload(
|
hello = HelloPayload(
|
||||||
@@ -114,10 +122,10 @@ class RPCClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
data = self._codec.encode_envelope(envelope)
|
data = self._codec.encode_envelope(envelope)
|
||||||
await self._connection.send_frame(data)
|
await connection.send_frame(data)
|
||||||
|
|
||||||
# 接收握手响应
|
# 接收握手响应
|
||||||
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
|
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
|
||||||
resp = self._codec.decode_envelope(resp_data)
|
resp = self._codec.decode_envelope(resp_data)
|
||||||
|
|
||||||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||||||
@@ -170,8 +178,7 @@ class RPCClient:
|
|||||||
timeout_ms: int = 30000,
|
timeout_ms: int = 30000,
|
||||||
) -> Envelope:
|
) -> Envelope:
|
||||||
"""向 Host 发送 RPC 请求并等待响应"""
|
"""向 Host 发送 RPC 请求并等待响应"""
|
||||||
if not self.is_connected:
|
connection = self._require_connection()
|
||||||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
|
||||||
|
|
||||||
request_id = self._id_gen.next()
|
request_id = self._id_gen.next()
|
||||||
envelope = Envelope(
|
envelope = Envelope(
|
||||||
@@ -190,7 +197,7 @@ class RPCClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
data = self._codec.encode_envelope(envelope)
|
data = self._codec.encode_envelope(envelope)
|
||||||
await self._connection.send_frame(data)
|
await connection.send_frame(data)
|
||||||
|
|
||||||
timeout_sec = timeout_ms / 1000.0
|
timeout_sec = timeout_ms / 1000.0
|
||||||
return await asyncio.wait_for(future, timeout=timeout_sec)
|
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||||||
@@ -221,6 +228,8 @@ class RPCClient:
|
|||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
connection = self._require_connection()
|
||||||
|
|
||||||
request_id = self._id_gen.next()
|
request_id = self._id_gen.next()
|
||||||
envelope = Envelope(
|
envelope = Envelope(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
@@ -231,7 +240,7 @@ class RPCClient:
|
|||||||
payload=payload or {},
|
payload=payload or {},
|
||||||
)
|
)
|
||||||
data = self._codec.encode_envelope(envelope)
|
data = self._codec.encode_envelope(envelope)
|
||||||
await self._connection.send_frame(data)
|
await connection.send_frame(data)
|
||||||
|
|
||||||
async def _recv_loop(self) -> None:
|
async def _recv_loop(self) -> None:
|
||||||
"""消息接收主循环"""
|
"""消息接收主循环"""
|
||||||
@@ -271,25 +280,31 @@ class RPCClient:
|
|||||||
|
|
||||||
async def _handle_request(self, envelope: Envelope) -> None:
|
async def _handle_request(self, envelope: Envelope) -> None:
|
||||||
"""处理来自 Host 的请求(调用插件组件)"""
|
"""处理来自 Host 的请求(调用插件组件)"""
|
||||||
|
connection = self._connection
|
||||||
|
if connection is None or connection.is_closed:
|
||||||
|
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
|
||||||
|
return
|
||||||
|
connection = cast(Connection, connection)
|
||||||
|
|
||||||
handler = self._method_handlers.get(envelope.method)
|
handler = self._method_handlers.get(envelope.method)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
error_resp = envelope.make_error_response(
|
error_resp = envelope.make_error_response(
|
||||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||||
f"未注册的方法: {envelope.method}",
|
f"未注册的方法: {envelope.method}",
|
||||||
)
|
)
|
||||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await handler(envelope)
|
response = await handler(envelope)
|
||||||
await self._connection.send_frame(self._codec.encode_envelope(response))
|
await connection.send_frame(self._codec.encode_envelope(response))
|
||||||
except RPCError as e:
|
except RPCError as e:
|
||||||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
await connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||||
|
|
||||||
async def _handle_event(self, envelope: Envelope) -> None:
|
async def _handle_event(self, envelope: Envelope) -> None:
|
||||||
"""处理来自 Host 的事件"""
|
"""处理来自 Host 的事件"""
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
6. 转发插件的能力调用到 Host
|
6. 转发插件的能力调用到 Host
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Protocol, cast
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -43,6 +43,10 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
|
|||||||
logger = get_logger("plugin_runtime.runner.main")
|
logger = get_logger("plugin_runtime.runner.main")
|
||||||
|
|
||||||
|
|
||||||
|
class _ContextAwarePlugin(Protocol):
|
||||||
|
def _set_context(self, context: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
def _disable_runner_console_logging() -> None:
|
def _disable_runner_console_logging() -> None:
|
||||||
"""关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。"""
|
"""关闭 Runner 的控制台日志输出,避免被 Host 从 stderr 二次包装。"""
|
||||||
root_logger = stdlib_logging.getLogger()
|
root_logger = stdlib_logging.getLogger()
|
||||||
@@ -204,7 +208,9 @@ class PluginRunner:
|
|||||||
rpc_client = self._rpc_client
|
rpc_client = self._rpc_client
|
||||||
bound_plugin_id = plugin_id
|
bound_plugin_id = plugin_id
|
||||||
|
|
||||||
async def _rpc_call(method: str, plugin_id: str = "", payload: dict = None) -> Any:
|
async def _rpc_call(
|
||||||
|
method: str, plugin_id: str = "", payload: Optional[dict[str, Any]] = None
|
||||||
|
) -> Any:
|
||||||
"""桥接 PluginContext.call_capability → RPCClient.send_request。
|
"""桥接 PluginContext.call_capability → RPCClient.send_request。
|
||||||
|
|
||||||
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
|
无论调用方传入何种 plugin_id,实际发往 Host 的 plugin_id
|
||||||
@@ -225,7 +231,7 @@ class PluginRunner:
|
|||||||
return resp.payload.get("result")
|
return resp.payload.get("result")
|
||||||
|
|
||||||
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
|
ctx = PluginContext(plugin_id=plugin_id, rpc_call=_rpc_call)
|
||||||
instance._set_context(ctx)
|
cast(_ContextAwarePlugin, instance)._set_context(ctx)
|
||||||
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
logger.debug(f"已为插件 {plugin_id} 注入 PluginContext")
|
||||||
|
|
||||||
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
|
def _apply_plugin_config(self, meta: PluginMeta, config_data: Optional[dict[str, Any]] = None) -> None:
|
||||||
@@ -543,8 +549,7 @@ class PluginRunner:
|
|||||||
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
|
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
|
||||||
"""处理配置更新事件"""
|
"""处理配置更新事件"""
|
||||||
plugin_id = envelope.plugin_id
|
plugin_id = envelope.plugin_id
|
||||||
meta = self._loader.get_plugin(plugin_id)
|
if meta := self._loader.get_plugin(plugin_id):
|
||||||
if meta:
|
|
||||||
try:
|
try:
|
||||||
config_data = envelope.payload.get("config_data", {})
|
config_data = envelope.payload.get("config_data", {})
|
||||||
config_version = envelope.payload.get("config_version", "")
|
config_version = envelope.payload.get("config_version", "")
|
||||||
|
|||||||
Reference in New Issue
Block a user