fix: (AI) 更robust的传输

This commit is contained in:
UnCLAS-Prommer
2026-03-16 19:37:58 +08:00
committed by DrSmoothl
parent 3419075599
commit e1b2ecb5b1
2 changed files with 129 additions and 24 deletions

View File

@@ -1,6 +1,9 @@
"""Windows Named Pipe 传输实现。
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
注意Named Pipe 是 Windows 特有的 IPC 机制,
在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
"""
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
class _NamedPipeServerHandle(Protocol):
"""Named Pipe 服务端句柄的协议定义。"""
def close(self) -> None: ...
class _NamedPipeEventLoop(Protocol):
"""ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
async def start_serving_pipe(
self,
protocol_factory: Callable[[], asyncio.BaseProtocol],
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
"""规范化 Named Pipe 地址。
Args:
pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
否则会自动添加前缀。如果为 None 则生成随机名称。
Returns:
规范化的管道地址(格式:\\\\.\\pipe\\name
"""
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
return pipe_name
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
class NamedPipeConnection(Connection):
"""基于 Windows Named Pipe 的连接。"""
"""基于 Windows Named Pipe 的连接。
封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
pass
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__(reader, writer)
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
"""Named Pipe 服务端协议实现。
处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
"""
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
self._reader: asyncio.StreamReader = asyncio.StreamReader()
super().__init__(self._reader)
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
self._handler_task: Optional[asyncio.Task[None]] = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""连接建立时的回调。"""
super().connection_made(transport)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
connection = NamedPipeConnection(self._reader, writer)
self._handler_task = self._loop.create_task(self._run_handler(connection))
# 使用 asyncio.create_task 确保任务正确调度
self._handler_task = asyncio.create_task(self._run_handler(connection))
self._handler_task.add_done_callback(self._on_handler_done)
async def _run_handler(self, connection: NamedPipeConnection) -> None:
"""运行连接处理器。"""
try:
await self._handler(connection)
finally:
await connection.close()
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
"""连接处理器完成时的回调。"""
if task.cancelled():
return
if exc := task.exception():
self._loop.call_exception_handler(
{
"message": "Named pipe 连接处理失败",
"exception": exc,
"protocol": self,
}
)
try:
self._loop.call_exception_handler(
{
"message": "Named pipe 连接处理失败",
"exception": exc,
"protocol": self,
}
)
except Exception:
# 如果 loop 已经关闭,忽略异常
pass
class NamedPipeTransportServer(TransportServer):
"""Windows Named Pipe 传输服务端。"""
"""Windows Named Pipe 传输服务端。
使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
"""
def __init__(self, pipe_name: Optional[str] = None) -> None:
self._address: str = _normalize_pipe_address(pipe_name)
self._servers: List[_NamedPipeServerHandle] = []
async def start(self, handler: ConnectionHandler) -> None:
"""启动 Named Pipe 服务端。
Args:
handler: 新连接到来时的回调函数
Raises:
RuntimeError: 当在非 Windows 平台或事件循环不支持时
"""
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
)
async def stop(self) -> None:
"""停止 Named Pipe 服务端并清理资源。"""
for server in self._servers:
server.close()
# 等待所有服务器句柄完全关闭
await asyncio.gather(
*[asyncio.sleep(0.1) for _ in self._servers],
return_exceptions=True
)
self._servers.clear()
await asyncio.sleep(0)
def get_address(self) -> str:
return self._address
class NamedPipeTransportClient(TransportClient):
"""Windows Named Pipe 传输客户端。"""
"""Windows Named Pipe 传输客户端。
用于主动连接到 Named Pipe 服务端。
"""
def __init__(self, address: str) -> None:
self._address: str = _normalize_pipe_address(address)
async def connect(self) -> Connection:
"""建立到 Named Pipe 服务端的连接。
Returns:
NamedPipeConnection: 连接对象
Raises:
NotImplementedError: 当在非 Windows 平台或事件循环不支持时
"""
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
raise NotImplementedError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop()
if not hasattr(loop, "create_pipe_connection"):
raise RuntimeError("当前事件循环不支持 Windows named pipe")
raise NotImplementedError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
# 使用返回的 protocol 创建 StreamWriter
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
return NamedPipeConnection(reader, writer)

View File

@@ -1,6 +1,9 @@
"""Unix Domain Socket 传输实现
适用于 Linux / macOS 平台。
注意UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
"""
from pathlib import Path
@@ -8,20 +11,30 @@ from typing import Optional
import asyncio
import os
import sys
import tempfile
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class UDSConnection(Connection):
"""基于 UDS 的连接"""
"""基于 UDS 的连接
封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
pass # 直接复用 Connection 基类的分帧读写
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__(reader, writer)
# Unix domain socket 路径的系统限制sun_path 字段长度)
# Linux: 108 字节, macOS: 104 字节
_UDS_PATH_MAX = 104
# Linux: 108 字节macOS: 104 字节,其他 Unix: 通常 104 字节
if sys.platform == "linux":
_UDS_PATH_MAX = 108
elif sys.platform == "darwin": # macOS
_UDS_PATH_MAX = 104
else:
_UDS_PATH_MAX = 104 # 保守默认值
class UDSTransportServer(TransportServer):
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None:
"""启动 UDS 服务端
Args:
handler: 新连接到来时的回调函数
Raises:
RuntimeError: 当在非 Unix 平台(如 Windows上调用时
"""
# 平台检查UDS 仅在 Unix-like 系统上可用
if sys.platform == "win32":
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
# 清理残留 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
finally:
await conn.close()
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
try:
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
# 设置文件权限为仅当前用户可访问
self._socket_path.chmod(0o600)
# 设置文件权限为仅当前用户可访问
self._socket_path.chmod(0o600)
except Exception:
# 启动失败时清理可能创建的目录和 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
raise
async def stop(self) -> None:
if self._server:
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient):
"""UDS 传输客户端"""
"""UDS 传输客户端
用于主动连接到 UDS 服务端。
"""
def __init__(self, socket_path: Path) -> None:
self._socket_path: Path = socket_path
async def connect(self) -> Connection:
"""建立到 UDS 服务端的连接
Returns:
UDSConnection: 连接对象
Raises:
RuntimeError: 当在非 Unix 平台(如 Windows上调用时
"""
# 平台检查UDS 仅在 Unix-like 系统上可用
if sys.platform == "win32":
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer)