fix: (AI) 更robust的传输
This commit is contained in:
committed by
DrSmoothl
parent
3419075599
commit
e1b2ecb5b1
@@ -1,6 +1,9 @@
|
||||
"""Windows Named Pipe 传输实现。
|
||||
|
||||
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
|
||||
|
||||
注意:Named Pipe 是 Windows 特有的 IPC 机制,
|
||||
在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
|
||||
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
|
||||
|
||||
|
||||
class _NamedPipeServerHandle(Protocol):
|
||||
"""Named Pipe 服务端句柄的协议定义。"""
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
class _NamedPipeEventLoop(Protocol):
|
||||
"""ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
|
||||
async def start_serving_pipe(
|
||||
self,
|
||||
protocol_factory: Callable[[], asyncio.BaseProtocol],
|
||||
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
|
||||
|
||||
|
||||
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||
"""规范化 Named Pipe 地址。
|
||||
|
||||
Args:
|
||||
pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
|
||||
否则会自动添加前缀。如果为 None 则生成随机名称。
|
||||
|
||||
Returns:
|
||||
规范化的管道地址(格式:\\\\.\\pipe\\name)
|
||||
"""
|
||||
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
|
||||
return pipe_name
|
||||
|
||||
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||
|
||||
|
||||
class NamedPipeConnection(Connection):
|
||||
"""基于 Windows Named Pipe 的连接。"""
|
||||
"""基于 Windows Named Pipe 的连接。
|
||||
|
||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||
"""
|
||||
|
||||
pass
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
super().__init__(reader, writer)
|
||||
|
||||
|
||||
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||
"""Named Pipe 服务端协议实现。
|
||||
|
||||
处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
|
||||
"""
|
||||
|
||||
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._reader: asyncio.StreamReader = asyncio.StreamReader()
|
||||
super().__init__(self._reader)
|
||||
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||
self._handler_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""连接建立时的回调。"""
|
||||
super().connection_made(transport)
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
|
||||
connection = NamedPipeConnection(self._reader, writer)
|
||||
self._handler_task = self._loop.create_task(self._run_handler(connection))
|
||||
# 使用 asyncio.create_task 确保任务正确调度
|
||||
self._handler_task = asyncio.create_task(self._run_handler(connection))
|
||||
self._handler_task.add_done_callback(self._on_handler_done)
|
||||
|
||||
async def _run_handler(self, connection: NamedPipeConnection) -> None:
|
||||
"""运行连接处理器。"""
|
||||
try:
|
||||
await self._handler(connection)
|
||||
finally:
|
||||
await connection.close()
|
||||
|
||||
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
|
||||
"""连接处理器完成时的回调。"""
|
||||
if task.cancelled():
|
||||
return
|
||||
if exc := task.exception():
|
||||
self._loop.call_exception_handler(
|
||||
{
|
||||
"message": "Named pipe 连接处理失败",
|
||||
"exception": exc,
|
||||
"protocol": self,
|
||||
}
|
||||
)
|
||||
try:
|
||||
self._loop.call_exception_handler(
|
||||
{
|
||||
"message": "Named pipe 连接处理失败",
|
||||
"exception": exc,
|
||||
"protocol": self,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# 如果 loop 已经关闭,忽略异常
|
||||
pass
|
||||
|
||||
|
||||
class NamedPipeTransportServer(TransportServer):
|
||||
"""Windows Named Pipe 传输服务端。"""
|
||||
"""Windows Named Pipe 传输服务端。
|
||||
|
||||
使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
|
||||
"""
|
||||
|
||||
def __init__(self, pipe_name: Optional[str] = None) -> None:
|
||||
self._address: str = _normalize_pipe_address(pipe_name)
|
||||
self._servers: List[_NamedPipeServerHandle] = []
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
"""启动 Named Pipe 服务端。
|
||||
|
||||
Args:
|
||||
handler: 新连接到来时的回调函数
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当在非 Windows 平台或事件循环不支持时
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||
|
||||
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Named Pipe 服务端并清理资源。"""
|
||||
for server in self._servers:
|
||||
server.close()
|
||||
# 等待所有服务器句柄完全关闭
|
||||
await asyncio.gather(
|
||||
*[asyncio.sleep(0.1) for _ in self._servers],
|
||||
return_exceptions=True
|
||||
)
|
||||
self._servers.clear()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def get_address(self) -> str:
|
||||
return self._address
|
||||
|
||||
|
||||
class NamedPipeTransportClient(TransportClient):
|
||||
"""Windows Named Pipe 传输客户端。"""
|
||||
"""Windows Named Pipe 传输客户端。
|
||||
|
||||
用于主动连接到 Named Pipe 服务端。
|
||||
"""
|
||||
|
||||
def __init__(self, address: str) -> None:
|
||||
self._address: str = _normalize_pipe_address(address)
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
"""建立到 Named Pipe 服务端的连接。
|
||||
|
||||
Returns:
|
||||
NamedPipeConnection: 连接对象
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 当在非 Windows 平台或事件循环不支持时
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||
raise NotImplementedError("Named pipe 仅支持 Windows")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop, "create_pipe_connection"):
|
||||
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
||||
raise NotImplementedError("当前事件循环不支持 Windows named pipe")
|
||||
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
|
||||
# 使用返回的 protocol 创建 StreamWriter
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
|
||||
return NamedPipeConnection(reader, writer)
|
||||
@@ -1,6 +1,9 @@
|
||||
"""Unix Domain Socket 传输实现
|
||||
|
||||
适用于 Linux / macOS 平台。
|
||||
|
||||
注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
|
||||
在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -8,20 +11,30 @@ from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||
|
||||
|
||||
class UDSConnection(Connection):
|
||||
"""基于 UDS 的连接"""
|
||||
"""基于 UDS 的连接
|
||||
|
||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||
"""
|
||||
|
||||
pass # 直接复用 Connection 基类的分帧读写
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
super().__init__(reader, writer)
|
||||
|
||||
|
||||
# Unix domain socket 路径的系统限制(sun_path 字段长度)
|
||||
# Linux: 108 字节, macOS: 104 字节
|
||||
_UDS_PATH_MAX = 104
|
||||
# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节
|
||||
if sys.platform == "linux":
|
||||
_UDS_PATH_MAX = 108
|
||||
elif sys.platform == "darwin": # macOS
|
||||
_UDS_PATH_MAX = 104
|
||||
else:
|
||||
_UDS_PATH_MAX = 104 # 保守默认值
|
||||
|
||||
|
||||
class UDSTransportServer(TransportServer):
|
||||
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
"""启动 UDS 服务端
|
||||
|
||||
Args:
|
||||
handler: 新连接到来时的回调函数
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
|
||||
"""
|
||||
# 平台检查:UDS 仅在 Unix-like 系统上可用
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
|
||||
|
||||
# 清理残留 socket 文件
|
||||
if self._socket_path.exists():
|
||||
self._socket_path.unlink()
|
||||
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||
try:
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||
|
||||
# 设置文件权限为仅当前用户可访问
|
||||
self._socket_path.chmod(0o600)
|
||||
# 设置文件权限为仅当前用户可访问
|
||||
self._socket_path.chmod(0o600)
|
||||
except Exception:
|
||||
# 启动失败时清理可能创建的目录和 socket 文件
|
||||
if self._socket_path.exists():
|
||||
self._socket_path.unlink()
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
|
||||
|
||||
|
||||
class UDSTransportClient(TransportClient):
|
||||
"""UDS 传输客户端"""
|
||||
"""UDS 传输客户端
|
||||
|
||||
用于主动连接到 UDS 服务端。
|
||||
"""
|
||||
|
||||
def __init__(self, socket_path: Path) -> None:
|
||||
self._socket_path: Path = socket_path
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
"""建立到 UDS 服务端的连接
|
||||
|
||||
Returns:
|
||||
UDSConnection: 连接对象
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
|
||||
"""
|
||||
# 平台检查:UDS 仅在 Unix-like 系统上可用
|
||||
if sys.platform == "win32":
|
||||
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
|
||||
|
||||
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
|
||||
return UDSConnection(reader, writer)
|
||||
|
||||
Reference in New Issue
Block a user