Files
mai-bot/src/plugin_runtime/runner/rpc_client.py
DrSmoothl 69219e36f7 Refactor protocol and transport modules to use type hints for improved clarity and consistency
- Updated Codec class to use abstract methods for encoding and decoding envelopes.
- Changed Envelope class to use Dict and Optional for payload and error fields.
- Refined error handling in RPCError class with Optional type hints for details.
- Enhanced manifest validation logic with type hints for better type safety.
- Improved plugin loading mechanism with consistent type annotations.
- Updated RPCClient to utilize Optional for codec and connection attributes.
- Refactored transport classes to use Optional for server attributes and socket paths.
2026-03-11 00:07:13 +08:00

254 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Runner 端 RPC Client
负责:
1. 连接 Host RPC Server
2. 发送握手runner.hello
3. 发送组件注册请求
4. 接收并分发 Host 的调用请求
5. 发送能力调用请求到 Host
"""
from typing import Any, Awaitable, Callable, Dict, Optional
import asyncio
import contextlib
import uuid
from src.common.logger import get_logger
from src.plugin_runtime.protocol.codec import Codec, MsgPackCodec
from src.plugin_runtime.protocol.envelope import (
Envelope,
HelloPayload,
HelloResponsePayload,
MessageType,
RequestIdGenerator,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.transport.base import Connection
from src.plugin_runtime.transport.factory import create_transport_client
logger = get_logger("plugin_runtime.runner.rpc_client")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
SDK_VERSION = "1.0.0"
class RPCClient:
"""Runner 端 RPC 客户端
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
"""
def __init__(
self,
host_address: str,
session_token: str,
codec: Optional[Codec] = None,
):
self._host_address = host_address
self._session_token = session_token
self._codec = codec or MsgPackCodec()
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None
self._runner_id = str(uuid.uuid4())
self._generation: int = 0
# 方法处理器注册表Host 发来的调用)
self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, asyncio.Future] = {}
# 运行状态
self._running = False
self._recv_task: Optional[asyncio.Task] = None
@property
def generation(self) -> int:
return self._generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册方法处理器(处理 Host 发来的请求)"""
self._method_handlers[method] = handler
async def connect_and_handshake(self) -> bool:
"""连接 Host 并完成握手
Returns:
是否握手成功
"""
client = create_transport_client(self._host_address)
self._connection = await client.connect()
# 发送 runner.hello
hello = HelloPayload(
runner_id=self._runner_id,
sdk_version=SDK_VERSION,
session_token=self._session_token,
)
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method="runner.hello",
payload=hello.model_dump(),
)
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
# 接收握手响应
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
resp = self._codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(resp.payload)
if not resp_payload.accepted:
logger.error(f"握手被拒绝: {resp_payload.reason}")
await self._connection.close()
self._connection = None
return False
self._generation = resp_payload.assigned_generation
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
# 启动消息接收循环
self._running = True
self._recv_task = asyncio.create_task(self._recv_loop())
return True
async def disconnect(self) -> None:
"""断开连接"""
self._running = False
if self._recv_task:
self._recv_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._recv_task
self._recv_task = None
# 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
self._pending_requests.clear()
if self._connection:
await self._connection.close()
self._connection = None
async def send_request(
self,
method: str,
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
"""向 Host 发送 RPC 请求并等待响应"""
if not self.is_connected:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=self._generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
loop = asyncio.get_event_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = future
try:
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
timeout_sec = timeout_ms / 1000.0
return await asyncio.wait_for(future, timeout=timeout_sec)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
except Exception as e:
self._pending_requests.pop(request_id, None)
if isinstance(e, RPCError):
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
# ─── 内部方法 ──────────────────────────────────────────────
async def _recv_loop(self) -> None:
"""消息接收主循环"""
while self._running and self._connection and not self._connection.is_closed:
try:
data = await self._connection.recv_frame()
except (asyncio.IncompleteReadError, ConnectionError):
logger.info("Host 连接已断开")
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"接收帧失败: {e}")
break
try:
envelope = self._codec.decode_envelope(data)
except Exception as e:
logger.error(f"解码消息失败: {e}")
continue
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
asyncio.create_task(self._handle_request(envelope))
elif envelope.is_event():
asyncio.create_task(self._handle_event(envelope))
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Host 的响应"""
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
async def _handle_request(self, envelope: Envelope) -> None:
"""处理来自 Host 的请求(调用插件组件)"""
handler = self._method_handlers.get(envelope.method)
if handler is None:
error_resp = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
return
try:
response = await handler(envelope)
await self._connection.send_frame(self._codec.encode_envelope(response))
except RPCError as e:
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
except Exception as e:
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Host 的事件"""
if handler := self._method_handlers.get(envelope.method):
try:
await handler(envelope)
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)