fix: (AI) 更robust的传输

This commit is contained in:
UnCLAS-Prommer
2026-03-16 19:37:58 +08:00
committed by DrSmoothl
parent 3419075599
commit e1b2ecb5b1
2 changed files with 129 additions and 24 deletions

View File

@@ -1,6 +1,9 @@
"""Windows Named Pipe 传输实现。 """Windows Named Pipe 传输实现。
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。 适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
注意Named Pipe 是 Windows 特有的 IPC 机制,
在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
""" """
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
class _NamedPipeServerHandle(Protocol): class _NamedPipeServerHandle(Protocol):
"""Named Pipe 服务端句柄的协议定义。"""
def close(self) -> None: ... def close(self) -> None: ...
class _NamedPipeEventLoop(Protocol): class _NamedPipeEventLoop(Protocol):
"""ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
async def start_serving_pipe( async def start_serving_pipe(
self, self,
protocol_factory: Callable[[], asyncio.BaseProtocol], protocol_factory: Callable[[], asyncio.BaseProtocol],
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str: def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
"""规范化 Named Pipe 地址。
Args:
pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
否则会自动添加前缀。如果为 None 则生成随机名称。
Returns:
规范化的管道地址(格式:\\\\.\\pipe\\name
"""
if pipe_name and pipe_name.startswith(_PIPE_PREFIX): if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
return pipe_name return pipe_name
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
class NamedPipeConnection(Connection): class NamedPipeConnection(Connection):
"""基于 Windows Named Pipe 的连接。""" """基于 Windows Named Pipe 的连接。
pass 封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__(reader, writer)
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol): class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
"""Named Pipe 服务端协议实现。
处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
"""
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None: def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
self._reader: asyncio.StreamReader = asyncio.StreamReader() self._reader: asyncio.StreamReader = asyncio.StreamReader()
super().__init__(self._reader) super().__init__(self._reader)
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
self._handler_task: Optional[asyncio.Task[None]] = None self._handler_task: Optional[asyncio.Task[None]] = None
def connection_made(self, transport: asyncio.BaseTransport) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""连接建立时的回调。"""
super().connection_made(transport) super().connection_made(transport)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop) writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
connection = NamedPipeConnection(self._reader, writer) connection = NamedPipeConnection(self._reader, writer)
self._handler_task = self._loop.create_task(self._run_handler(connection)) # 使用 asyncio.create_task 确保任务正确调度
self._handler_task = asyncio.create_task(self._run_handler(connection))
self._handler_task.add_done_callback(self._on_handler_done) self._handler_task.add_done_callback(self._on_handler_done)
async def _run_handler(self, connection: NamedPipeConnection) -> None: async def _run_handler(self, connection: NamedPipeConnection) -> None:
"""运行连接处理器。"""
try: try:
await self._handler(connection) await self._handler(connection)
finally: finally:
await connection.close() await connection.close()
def _on_handler_done(self, task: asyncio.Task[None]) -> None: def _on_handler_done(self, task: asyncio.Task[None]) -> None:
"""连接处理器完成时的回调。"""
if task.cancelled(): if task.cancelled():
return return
if exc := task.exception(): if exc := task.exception():
self._loop.call_exception_handler( try:
{ self._loop.call_exception_handler(
"message": "Named pipe 连接处理失败", {
"exception": exc, "message": "Named pipe 连接处理失败",
"protocol": self, "exception": exc,
} "protocol": self,
) }
)
except Exception:
# 如果 loop 已经关闭,忽略异常
pass
class NamedPipeTransportServer(TransportServer): class NamedPipeTransportServer(TransportServer):
"""Windows Named Pipe 传输服务端。""" """Windows Named Pipe 传输服务端。
使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
"""
def __init__(self, pipe_name: Optional[str] = None) -> None: def __init__(self, pipe_name: Optional[str] = None) -> None:
self._address: str = _normalize_pipe_address(pipe_name) self._address: str = _normalize_pipe_address(pipe_name)
self._servers: List[_NamedPipeServerHandle] = [] self._servers: List[_NamedPipeServerHandle] = []
async def start(self, handler: ConnectionHandler) -> None: async def start(self, handler: ConnectionHandler) -> None:
"""启动 Named Pipe 服务端。
Args:
handler: 新连接到来时的回调函数
Raises:
RuntimeError: 当在非 Windows 平台或事件循环不支持时
"""
if sys.platform != "win32": if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows") raise RuntimeError("Named pipe 仅支持 Windows")
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
) )
async def stop(self) -> None: async def stop(self) -> None:
"""停止 Named Pipe 服务端并清理资源。"""
for server in self._servers: for server in self._servers:
server.close() server.close()
# 等待所有服务器句柄完全关闭
await asyncio.gather(
*[asyncio.sleep(0.1) for _ in self._servers],
return_exceptions=True
)
self._servers.clear() self._servers.clear()
await asyncio.sleep(0)
def get_address(self) -> str: def get_address(self) -> str:
return self._address return self._address
class NamedPipeTransportClient(TransportClient): class NamedPipeTransportClient(TransportClient):
"""Windows Named Pipe 传输客户端。""" """Windows Named Pipe 传输客户端。
用于主动连接到 Named Pipe 服务端。
"""
def __init__(self, address: str) -> None: def __init__(self, address: str) -> None:
self._address: str = _normalize_pipe_address(address) self._address: str = _normalize_pipe_address(address)
async def connect(self) -> Connection: async def connect(self) -> Connection:
"""建立到 Named Pipe 服务端的连接。
Returns:
NamedPipeConnection: 连接对象
Raises:
NotImplementedError: 当在非 Windows 平台或事件循环不支持时
"""
if sys.platform != "win32": if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows") raise NotImplementedError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if not hasattr(loop, "create_pipe_connection"): if not hasattr(loop, "create_pipe_connection"):
raise RuntimeError("当前事件循环不支持 Windows named pipe") raise NotImplementedError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop) pipe_loop = cast(_NamedPipeEventLoop, loop)
reader = asyncio.StreamReader() reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader) protocol = asyncio.StreamReaderProtocol(reader)
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address) transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop) # 使用返回的 protocol 创建 StreamWriter
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
return NamedPipeConnection(reader, writer) return NamedPipeConnection(reader, writer)

View File

@@ -1,6 +1,9 @@
"""Unix Domain Socket 传输实现 """Unix Domain Socket 传输实现
适用于 Linux / macOS 平台。 适用于 Linux / macOS 平台。
注意UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
""" """
from pathlib import Path from pathlib import Path
@@ -8,20 +11,30 @@ from typing import Optional
import asyncio import asyncio
import os import os
import sys
import tempfile import tempfile
from .base import Connection, ConnectionHandler, TransportClient, TransportServer from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class UDSConnection(Connection): class UDSConnection(Connection):
"""基于 UDS 的连接""" """基于 UDS 的连接
pass # 直接复用 Connection 基类的分帧读写 封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__(reader, writer)
# Unix domain socket 路径的系统限制sun_path 字段长度) # Unix domain socket 路径的系统限制sun_path 字段长度)
# Linux: 108 字节, macOS: 104 字节 # Linux: 108 字节macOS: 104 字节,其他 Unix: 通常 104 字节
_UDS_PATH_MAX = 104 if sys.platform == "linux":
_UDS_PATH_MAX = 108
elif sys.platform == "darwin": # macOS
_UDS_PATH_MAX = 104
else:
_UDS_PATH_MAX = 104 # 保守默认值
class UDSTransportServer(TransportServer): class UDSTransportServer(TransportServer):
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
self._server: Optional[asyncio.AbstractServer] = None self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None: async def start(self, handler: ConnectionHandler) -> None:
"""启动 UDS 服务端
Args:
handler: 新连接到来时的回调函数
Raises:
RuntimeError: 当在非 Unix 平台(如 Windows上调用时
"""
# 平台检查UDS 仅在 Unix-like 系统上可用
if sys.platform == "win32":
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
# 清理残留 socket 文件 # 清理残留 socket 文件
if self._socket_path.exists(): if self._socket_path.exists():
self._socket_path.unlink() self._socket_path.unlink()
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
finally: finally:
await conn.close() await conn.close()
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) try:
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
# 设置文件权限为仅当前用户可访问 # 设置文件权限为仅当前用户可访问
self._socket_path.chmod(0o600) self._socket_path.chmod(0o600)
except Exception:
# 启动失败时清理可能创建的目录和 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
raise
async def stop(self) -> None: async def stop(self) -> None:
if self._server: if self._server:
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient): class UDSTransportClient(TransportClient):
"""UDS 传输客户端""" """UDS 传输客户端
用于主动连接到 UDS 服务端。
"""
def __init__(self, socket_path: Path) -> None: def __init__(self, socket_path: Path) -> None:
self._socket_path: Path = socket_path self._socket_path: Path = socket_path
async def connect(self) -> Connection: async def connect(self) -> Connection:
"""建立到 UDS 服务端的连接
Returns:
UDSConnection: 连接对象
Raises:
RuntimeError: 当在非 Unix 平台(如 Windows上调用时
"""
# 平台检查UDS 仅在 Unix-like 系统上可用
if sys.platform == "win32":
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
reader, writer = await asyncio.open_unix_connection(str(self._socket_path)) reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer) return UDSConnection(reader, writer)