diff --git a/src/plugin_runtime/transport/named_pipe.py b/src/plugin_runtime/transport/named_pipe.py index a759507d..7fd39bc9 100644 --- a/src/plugin_runtime/transport/named_pipe.py +++ b/src/plugin_runtime/transport/named_pipe.py @@ -1,6 +1,9 @@ """Windows Named Pipe 传输实现。 适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。 + +注意:Named Pipe 是 Windows 特有的 IPC 机制, +在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。 """ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast @@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin" class _NamedPipeServerHandle(Protocol): + """Named Pipe 服务端句柄的协议定义。""" def close(self) -> None: ... class _NamedPipeEventLoop(Protocol): + """ProactorEventLoop 的协议定义,提供 named pipe 相关方法。""" async def start_serving_pipe( self, protocol_factory: Callable[[], asyncio.BaseProtocol], @@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol): def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: + """规范化 Named Pipe 地址。 + + Args: + pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用, + 否则会自动添加前缀。如果为 None 则生成随机名称。 + + Returns: + 规范化的管道地址(格式:\\\\.\\pipe\\name) + """ if pipe_name and pipe_name.startswith(_PIPE_PREFIX): return pipe_name @@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: class NamedPipeConnection(Connection): - """基于 Windows Named Pipe 的连接。""" + """基于 Windows Named Pipe 的连接。 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ - pass + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__(reader, writer) class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): + """Named Pipe 服务端协议实现。 + + 处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。 + """ + def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None: self._reader: asyncio.StreamReader = asyncio.StreamReader() super().__init__(self._reader) @@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): self._handler_task: Optional[asyncio.Task[None]] = None def connection_made(self, transport: asyncio.BaseTransport) -> None: + """连接建立时的回调。""" super().connection_made(transport) writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop) connection = NamedPipeConnection(self._reader, writer) - self._handler_task = self._loop.create_task(self._run_handler(connection)) + # 使用 asyncio.create_task 确保任务正确调度 + self._handler_task = asyncio.create_task(self._run_handler(connection)) self._handler_task.add_done_callback(self._on_handler_done) async def _run_handler(self, connection: NamedPipeConnection) -> None: + """运行连接处理器。""" try: await self._handler(connection) finally: await connection.close() def _on_handler_done(self, task: asyncio.Task[None]) -> None: + """连接处理器完成时的回调。""" if task.cancelled(): return if exc := task.exception(): - self._loop.call_exception_handler( - { - "message": "Named pipe 连接处理失败", - "exception": exc, - "protocol": self, - } - ) + try: + self._loop.call_exception_handler( + { + "message": "Named pipe 连接处理失败", + "exception": exc, + "protocol": self, + } + ) + except Exception: + # 如果 loop 已经关闭,忽略异常 + pass class NamedPipeTransportServer(TransportServer): - """Windows Named Pipe 传输服务端。""" + """Windows Named Pipe 传输服务端。 + + 使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。 + """ def __init__(self, pipe_name: Optional[str] = None) -> None: self._address: str = _normalize_pipe_address(pipe_name) self._servers: List[_NamedPipeServerHandle] = [] async def start(self, handler: ConnectionHandler) -> None: + """启动 Named Pipe 服务端。 + + Args: + handler: 新连接到来时的回调函数 + + Raises: + RuntimeError: 当在非 Windows 平台或事件循环不支持时 + """ if sys.platform != "win32": raise RuntimeError("Named pipe 仅支持 Windows") @@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer): ) async def stop(self) -> None: + """停止 Named Pipe 服务端并清理资源。""" for server in self._servers: server.close() + # 等待所有服务器句柄完全关闭 + await asyncio.gather( + *[asyncio.sleep(0.1) for _ in self._servers], + return_exceptions=True + ) self._servers.clear() - await asyncio.sleep(0) def get_address(self) -> str: return self._address class NamedPipeTransportClient(TransportClient): - """Windows Named Pipe 传输客户端。""" + """Windows Named Pipe 传输客户端。 + + 用于主动连接到 Named Pipe 服务端。 + """ def __init__(self, address: str) -> None: self._address: str = _normalize_pipe_address(address) async def connect(self) -> Connection: + """建立到 Named Pipe 服务端的连接。 + + Returns: + NamedPipeConnection: 连接对象 + + Raises: + NotImplementedError: 当在非 Windows 平台或事件循环不支持时 + """ if sys.platform != "win32": - raise RuntimeError("Named pipe 仅支持 Windows") + raise NotImplementedError("Named pipe 仅支持 Windows") loop = asyncio.get_running_loop() if not hasattr(loop, "create_pipe_connection"): - raise RuntimeError("当前事件循环不支持 Windows named pipe") + raise NotImplementedError("当前事件循环不支持 Windows named pipe") pipe_loop = cast(_NamedPipeEventLoop, loop) reader = asyncio.StreamReader() protocol = asyncio.StreamReaderProtocol(reader) transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address) - writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop) + # 使用返回的 protocol 创建 StreamWriter + writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop) return NamedPipeConnection(reader, writer) \ No newline at end of file diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index 47bf033b..af71ea5d 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -1,6 +1,9 @@ """Unix Domain Socket 传输实现 适用于 Linux / macOS 平台。 + +注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制, +在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。 """ from pathlib import Path @@ -8,20 +11,30 @@ from typing import Optional import asyncio import os +import sys import tempfile from .base import Connection, ConnectionHandler, TransportClient, TransportServer class UDSConnection(Connection): - """基于 UDS 的连接""" + """基于 UDS 的连接 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ - pass # 直接复用 Connection 基类的分帧读写 + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + super().__init__(reader, writer) # Unix domain socket 路径的系统限制(sun_path 字段长度) -# Linux: 108 字节, macOS: 104 字节 -_UDS_PATH_MAX = 104 +# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节 +if sys.platform == "linux": + _UDS_PATH_MAX = 108 +elif sys.platform == "darwin": # macOS + _UDS_PATH_MAX = 104 +else: + _UDS_PATH_MAX = 104 # 保守默认值 class UDSTransportServer(TransportServer): @@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer): self._server: Optional[asyncio.AbstractServer] = None async def start(self, handler: ConnectionHandler) -> None: + """启动 UDS 服务端 + + Args: + handler: 新连接到来时的回调函数 + + Raises: + RuntimeError: 当在非 Unix 平台(如 Windows)上调用时 + """ + # 平台检查:UDS 仅在 Unix-like 系统上可用 + if sys.platform == "win32": + raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe") + # 清理残留 socket 文件 if self._socket_path.exists(): self._socket_path.unlink() @@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer): finally: await conn.close() - self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) + try: + self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) - # 设置文件权限为仅当前用户可访问 - self._socket_path.chmod(0o600) + # 设置文件权限为仅当前用户可访问 + self._socket_path.chmod(0o600) + except Exception: + # 启动失败时清理可能创建的目录和 socket 文件 + if self._socket_path.exists(): + self._socket_path.unlink() + raise async def stop(self) -> None: if self._server: @@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer): class UDSTransportClient(TransportClient): - """UDS 传输客户端""" + """UDS 传输客户端 + + 用于主动连接到 UDS 服务端。 + """ def __init__(self, socket_path: Path) -> None: self._socket_path: Path = socket_path async def connect(self) -> Connection: + """建立到 UDS 服务端的连接 + + Returns: + UDSConnection: 连接对象 + + Raises: + RuntimeError: 当在非 Unix 平台(如 Windows)上调用时 + """ + # 平台检查:UDS 仅在 Unix-like 系统上可用 + if sys.platform == "win32": + raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe") + reader, writer = await asyncio.open_unix_connection(str(self._socket_path)) return UDSConnection(reader, writer)