- Updated Codec class to use abstract methods for encoding and decoding envelopes. - Changed Envelope class to use Dict and Optional for payload and error fields. - Refined error handling in RPCError class with Optional type hints for details. - Enhanced manifest validation logic with type hints for better type safety. - Improved plugin loading mechanism with consistent type annotations. - Updated RPCClient to utilize Optional for codec and connection attributes. - Refactored transport classes to use Optional for server attributes and socket paths.
254 lines
8.9 KiB
Python
254 lines
8.9 KiB
Python
"""Runner 端 RPC Client
|
||
|
||
负责:
|
||
1. 连接 Host RPC Server
|
||
2. 发送握手(runner.hello)
|
||
3. 发送组件注册请求
|
||
4. 接收并分发 Host 的调用请求
|
||
5. 发送能力调用请求到 Host
|
||
"""
|
||
|
||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||
|
||
import asyncio
|
||
import contextlib
|
||
import uuid
|
||
|
||
from src.common.logger import get_logger
|
||
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
|
||
from src.plugin_runtime.protocol.envelope import (
|
||
Envelope,
|
||
HelloPayload,
|
||
HelloResponsePayload,
|
||
MessageType,
|
||
RequestIdGenerator,
|
||
)
|
||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||
from src.plugin_runtime.transport.base import Connection
|
||
from src.plugin_runtime.transport.factory import create_transport_client
|
||
|
||
logger = get_logger("plugin_runtime.runner.rpc_client")
|
||
|
||
# RPC 方法处理器类型
|
||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||
|
||
SDK_VERSION = "1.0.0"
|
||
|
||
|
||
class RPCClient:
|
||
"""Runner 端 RPC 客户端
|
||
|
||
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
host_address: str,
|
||
session_token: str,
|
||
codec: Optional[Codec] = None,
|
||
):
|
||
self._host_address = host_address
|
||
self._session_token = session_token
|
||
self._codec = codec or MsgPackCodec()
|
||
|
||
self._id_gen = RequestIdGenerator()
|
||
self._connection: Optional[Connection] = None
|
||
self._runner_id = str(uuid.uuid4())
|
||
self._generation: int = 0
|
||
|
||
# 方法处理器注册表(Host 发来的调用)
|
||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||
|
||
# 等待响应的 pending 请求: request_id -> Future
|
||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||
|
||
# 运行状态
|
||
self._running = False
|
||
self._recv_task: Optional[asyncio.Task] = None
|
||
|
||
@property
|
||
def generation(self) -> int:
|
||
return self._generation
|
||
|
||
@property
|
||
def is_connected(self) -> bool:
|
||
return self._connection is not None and not self._connection.is_closed
|
||
|
||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||
"""注册方法处理器(处理 Host 发来的请求)"""
|
||
self._method_handlers[method] = handler
|
||
|
||
async def connect_and_handshake(self) -> bool:
|
||
"""连接 Host 并完成握手
|
||
|
||
Returns:
|
||
是否握手成功
|
||
"""
|
||
client = create_transport_client(self._host_address)
|
||
self._connection = await client.connect()
|
||
|
||
# 发送 runner.hello
|
||
hello = HelloPayload(
|
||
runner_id=self._runner_id,
|
||
sdk_version=SDK_VERSION,
|
||
session_token=self._session_token,
|
||
)
|
||
request_id = self._id_gen.next()
|
||
envelope = Envelope(
|
||
request_id=request_id,
|
||
message_type=MessageType.REQUEST,
|
||
method="runner.hello",
|
||
payload=hello.model_dump(),
|
||
)
|
||
|
||
data = self._codec.encode_envelope(envelope)
|
||
await self._connection.send_frame(data)
|
||
|
||
# 接收握手响应
|
||
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
|
||
resp = self._codec.decode_envelope(resp_data)
|
||
|
||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||
if not resp_payload.accepted:
|
||
logger.error(f"握手被拒绝: {resp_payload.reason}")
|
||
await self._connection.close()
|
||
self._connection = None
|
||
return False
|
||
|
||
self._generation = resp_payload.assigned_generation
|
||
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
|
||
|
||
# 启动消息接收循环
|
||
self._running = True
|
||
self._recv_task = asyncio.create_task(self._recv_loop())
|
||
|
||
return True
|
||
|
||
async def disconnect(self) -> None:
|
||
"""断开连接"""
|
||
self._running = False
|
||
if self._recv_task:
|
||
self._recv_task.cancel()
|
||
with contextlib.suppress(asyncio.CancelledError):
|
||
await self._recv_task
|
||
self._recv_task = None
|
||
|
||
# 取消所有 pending 请求
|
||
for future in self._pending_requests.values():
|
||
if not future.done():
|
||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
|
||
self._pending_requests.clear()
|
||
|
||
if self._connection:
|
||
await self._connection.close()
|
||
self._connection = None
|
||
|
||
async def send_request(
|
||
self,
|
||
method: str,
|
||
plugin_id: str = "",
|
||
payload: Optional[Dict[str, Any]] = None,
|
||
timeout_ms: int = 30000,
|
||
) -> Envelope:
|
||
"""向 Host 发送 RPC 请求并等待响应"""
|
||
if not self.is_connected:
|
||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||
|
||
request_id = self._id_gen.next()
|
||
envelope = Envelope(
|
||
request_id=request_id,
|
||
message_type=MessageType.REQUEST,
|
||
method=method,
|
||
plugin_id=plugin_id,
|
||
generation=self._generation,
|
||
timeout_ms=timeout_ms,
|
||
payload=payload or {},
|
||
)
|
||
|
||
loop = asyncio.get_event_loop()
|
||
future: asyncio.Future[Envelope] = loop.create_future()
|
||
self._pending_requests[request_id] = future
|
||
|
||
try:
|
||
data = self._codec.encode_envelope(envelope)
|
||
await self._connection.send_frame(data)
|
||
|
||
timeout_sec = timeout_ms / 1000.0
|
||
return await asyncio.wait_for(future, timeout=timeout_sec)
|
||
except asyncio.TimeoutError:
|
||
self._pending_requests.pop(request_id, None)
|
||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
|
||
except Exception as e:
|
||
self._pending_requests.pop(request_id, None)
|
||
if isinstance(e, RPCError):
|
||
raise
|
||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||
|
||
# ─── 内部方法 ──────────────────────────────────────────────
|
||
|
||
async def _recv_loop(self) -> None:
|
||
"""消息接收主循环"""
|
||
while self._running and self._connection and not self._connection.is_closed:
|
||
try:
|
||
data = await self._connection.recv_frame()
|
||
except (asyncio.IncompleteReadError, ConnectionError):
|
||
logger.info("Host 连接已断开")
|
||
break
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"接收帧失败: {e}")
|
||
break
|
||
|
||
try:
|
||
envelope = self._codec.decode_envelope(data)
|
||
except Exception as e:
|
||
logger.error(f"解码消息失败: {e}")
|
||
continue
|
||
|
||
if envelope.is_response():
|
||
self._handle_response(envelope)
|
||
elif envelope.is_request():
|
||
asyncio.create_task(self._handle_request(envelope))
|
||
elif envelope.is_event():
|
||
asyncio.create_task(self._handle_event(envelope))
|
||
|
||
def _handle_response(self, envelope: Envelope) -> None:
|
||
"""处理来自 Host 的响应"""
|
||
future = self._pending_requests.pop(envelope.request_id, None)
|
||
if future and not future.done():
|
||
if envelope.error:
|
||
future.set_exception(RPCError.from_dict(envelope.error))
|
||
else:
|
||
future.set_result(envelope)
|
||
|
||
async def _handle_request(self, envelope: Envelope) -> None:
|
||
"""处理来自 Host 的请求(调用插件组件)"""
|
||
handler = self._method_handlers.get(envelope.method)
|
||
if handler is None:
|
||
error_resp = envelope.make_error_response(
|
||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||
f"未注册的方法: {envelope.method}",
|
||
)
|
||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||
return
|
||
|
||
try:
|
||
response = await handler(envelope)
|
||
await self._connection.send_frame(self._codec.encode_envelope(response))
|
||
except RPCError as e:
|
||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||
except Exception as e:
|
||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||
|
||
async def _handle_event(self, envelope: Envelope) -> None:
|
||
"""处理来自 Host 的事件"""
|
||
if handler := self._method_handlers.get(envelope.method):
|
||
try:
|
||
await handler(envelope)
|
||
except Exception as e:
|
||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|