fix: (AI) 更robust的传输
This commit is contained in:
committed by
DrSmoothl
parent
3419075599
commit
e1b2ecb5b1
@@ -1,6 +1,9 @@
|
|||||||
"""Windows Named Pipe 传输实现。
|
"""Windows Named Pipe 传输实现。
|
||||||
|
|
||||||
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
|
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
|
||||||
|
|
||||||
|
注意:Named Pipe 是 Windows 特有的 IPC 机制,
|
||||||
|
在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
|
||||||
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
|
|||||||
|
|
||||||
|
|
||||||
class _NamedPipeServerHandle(Protocol):
|
class _NamedPipeServerHandle(Protocol):
|
||||||
|
"""Named Pipe 服务端句柄的协议定义。"""
|
||||||
def close(self) -> None: ...
|
def close(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class _NamedPipeEventLoop(Protocol):
|
class _NamedPipeEventLoop(Protocol):
|
||||||
|
"""ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
|
||||||
async def start_serving_pipe(
|
async def start_serving_pipe(
|
||||||
self,
|
self,
|
||||||
protocol_factory: Callable[[], asyncio.BaseProtocol],
|
protocol_factory: Callable[[], asyncio.BaseProtocol],
|
||||||
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||||
|
"""规范化 Named Pipe 地址。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
|
||||||
|
否则会自动添加前缀。如果为 None 则生成随机名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
规范化的管道地址(格式:\\\\.\\pipe\\name)
|
||||||
|
"""
|
||||||
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
|
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
|
||||||
return pipe_name
|
return pipe_name
|
||||||
|
|
||||||
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
|||||||
|
|
||||||
|
|
||||||
class NamedPipeConnection(Connection):
|
class NamedPipeConnection(Connection):
|
||||||
"""基于 Windows Named Pipe 的连接。"""
|
"""基于 Windows Named Pipe 的连接。
|
||||||
|
|
||||||
|
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||||
|
"""
|
||||||
|
|
||||||
pass
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||||
|
super().__init__(reader, writer)
|
||||||
|
|
||||||
|
|
||||||
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||||
|
"""Named Pipe 服务端协议实现。
|
||||||
|
|
||||||
|
处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
|
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
|
||||||
self._reader: asyncio.StreamReader = asyncio.StreamReader()
|
self._reader: asyncio.StreamReader = asyncio.StreamReader()
|
||||||
super().__init__(self._reader)
|
super().__init__(self._reader)
|
||||||
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
|||||||
self._handler_task: Optional[asyncio.Task[None]] = None
|
self._handler_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
|
"""连接建立时的回调。"""
|
||||||
super().connection_made(transport)
|
super().connection_made(transport)
|
||||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
|
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
|
||||||
connection = NamedPipeConnection(self._reader, writer)
|
connection = NamedPipeConnection(self._reader, writer)
|
||||||
self._handler_task = self._loop.create_task(self._run_handler(connection))
|
# 使用 asyncio.create_task 确保任务正确调度
|
||||||
|
self._handler_task = asyncio.create_task(self._run_handler(connection))
|
||||||
self._handler_task.add_done_callback(self._on_handler_done)
|
self._handler_task.add_done_callback(self._on_handler_done)
|
||||||
|
|
||||||
async def _run_handler(self, connection: NamedPipeConnection) -> None:
|
async def _run_handler(self, connection: NamedPipeConnection) -> None:
|
||||||
|
"""运行连接处理器。"""
|
||||||
try:
|
try:
|
||||||
await self._handler(connection)
|
await self._handler(connection)
|
||||||
finally:
|
finally:
|
||||||
await connection.close()
|
await connection.close()
|
||||||
|
|
||||||
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
|
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
|
||||||
|
"""连接处理器完成时的回调。"""
|
||||||
if task.cancelled():
|
if task.cancelled():
|
||||||
return
|
return
|
||||||
if exc := task.exception():
|
if exc := task.exception():
|
||||||
self._loop.call_exception_handler(
|
try:
|
||||||
{
|
self._loop.call_exception_handler(
|
||||||
"message": "Named pipe 连接处理失败",
|
{
|
||||||
"exception": exc,
|
"message": "Named pipe 连接处理失败",
|
||||||
"protocol": self,
|
"exception": exc,
|
||||||
}
|
"protocol": self,
|
||||||
)
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# 如果 loop 已经关闭,忽略异常
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NamedPipeTransportServer(TransportServer):
|
class NamedPipeTransportServer(TransportServer):
|
||||||
"""Windows Named Pipe 传输服务端。"""
|
"""Windows Named Pipe 传输服务端。
|
||||||
|
|
||||||
|
使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, pipe_name: Optional[str] = None) -> None:
|
def __init__(self, pipe_name: Optional[str] = None) -> None:
|
||||||
self._address: str = _normalize_pipe_address(pipe_name)
|
self._address: str = _normalize_pipe_address(pipe_name)
|
||||||
self._servers: List[_NamedPipeServerHandle] = []
|
self._servers: List[_NamedPipeServerHandle] = []
|
||||||
|
|
||||||
async def start(self, handler: ConnectionHandler) -> None:
|
async def start(self, handler: ConnectionHandler) -> None:
|
||||||
|
"""启动 Named Pipe 服务端。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: 新连接到来时的回调函数
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当在非 Windows 平台或事件循环不支持时
|
||||||
|
"""
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||||
|
|
||||||
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
|
"""停止 Named Pipe 服务端并清理资源。"""
|
||||||
for server in self._servers:
|
for server in self._servers:
|
||||||
server.close()
|
server.close()
|
||||||
|
# 等待所有服务器句柄完全关闭
|
||||||
|
await asyncio.gather(
|
||||||
|
*[asyncio.sleep(0.1) for _ in self._servers],
|
||||||
|
return_exceptions=True
|
||||||
|
)
|
||||||
self._servers.clear()
|
self._servers.clear()
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
def get_address(self) -> str:
|
def get_address(self) -> str:
|
||||||
return self._address
|
return self._address
|
||||||
|
|
||||||
|
|
||||||
class NamedPipeTransportClient(TransportClient):
|
class NamedPipeTransportClient(TransportClient):
|
||||||
"""Windows Named Pipe 传输客户端。"""
|
"""Windows Named Pipe 传输客户端。
|
||||||
|
|
||||||
|
用于主动连接到 Named Pipe 服务端。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, address: str) -> None:
|
def __init__(self, address: str) -> None:
|
||||||
self._address: str = _normalize_pipe_address(address)
|
self._address: str = _normalize_pipe_address(address)
|
||||||
|
|
||||||
async def connect(self) -> Connection:
|
async def connect(self) -> Connection:
|
||||||
|
"""建立到 Named Pipe 服务端的连接。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NamedPipeConnection: 连接对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: 当在非 Windows 平台或事件循环不支持时
|
||||||
|
"""
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
raise NotImplementedError("Named pipe 仅支持 Windows")
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if not hasattr(loop, "create_pipe_connection"):
|
if not hasattr(loop, "create_pipe_connection"):
|
||||||
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
raise NotImplementedError("当前事件循环不支持 Windows named pipe")
|
||||||
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||||
|
|
||||||
reader = asyncio.StreamReader()
|
reader = asyncio.StreamReader()
|
||||||
protocol = asyncio.StreamReaderProtocol(reader)
|
protocol = asyncio.StreamReaderProtocol(reader)
|
||||||
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
|
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
|
||||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
|
# 使用返回的 protocol 创建 StreamWriter
|
||||||
|
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
|
||||||
return NamedPipeConnection(reader, writer)
|
return NamedPipeConnection(reader, writer)
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
"""Unix Domain Socket 传输实现
|
"""Unix Domain Socket 传输实现
|
||||||
|
|
||||||
适用于 Linux / macOS 平台。
|
适用于 Linux / macOS 平台。
|
||||||
|
|
||||||
|
注意:UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
|
||||||
|
在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -8,20 +11,30 @@ from typing import Optional
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||||
|
|
||||||
|
|
||||||
class UDSConnection(Connection):
|
class UDSConnection(Connection):
|
||||||
"""基于 UDS 的连接"""
|
"""基于 UDS 的连接
|
||||||
|
|
||||||
|
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||||
|
"""
|
||||||
|
|
||||||
pass # 直接复用 Connection 基类的分帧读写
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||||
|
super().__init__(reader, writer)
|
||||||
|
|
||||||
|
|
||||||
# Unix domain socket 路径的系统限制(sun_path 字段长度)
|
# Unix domain socket 路径的系统限制(sun_path 字段长度)
|
||||||
# Linux: 108 字节, macOS: 104 字节
|
# Linux: 108 字节,macOS: 104 字节,其他 Unix: 通常 104 字节
|
||||||
_UDS_PATH_MAX = 104
|
if sys.platform == "linux":
|
||||||
|
_UDS_PATH_MAX = 108
|
||||||
|
elif sys.platform == "darwin": # macOS
|
||||||
|
_UDS_PATH_MAX = 104
|
||||||
|
else:
|
||||||
|
_UDS_PATH_MAX = 104 # 保守默认值
|
||||||
|
|
||||||
|
|
||||||
class UDSTransportServer(TransportServer):
|
class UDSTransportServer(TransportServer):
|
||||||
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
|
|||||||
self._server: Optional[asyncio.AbstractServer] = None
|
self._server: Optional[asyncio.AbstractServer] = None
|
||||||
|
|
||||||
async def start(self, handler: ConnectionHandler) -> None:
|
async def start(self, handler: ConnectionHandler) -> None:
|
||||||
|
"""启动 UDS 服务端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: 新连接到来时的回调函数
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
|
||||||
|
"""
|
||||||
|
# 平台检查:UDS 仅在 Unix-like 系统上可用
|
||||||
|
if sys.platform == "win32":
|
||||||
|
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
|
||||||
|
|
||||||
# 清理残留 socket 文件
|
# 清理残留 socket 文件
|
||||||
if self._socket_path.exists():
|
if self._socket_path.exists():
|
||||||
self._socket_path.unlink()
|
self._socket_path.unlink()
|
||||||
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
|
|||||||
finally:
|
finally:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
try:
|
||||||
|
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||||
|
|
||||||
# 设置文件权限为仅当前用户可访问
|
# 设置文件权限为仅当前用户可访问
|
||||||
self._socket_path.chmod(0o600)
|
self._socket_path.chmod(0o600)
|
||||||
|
except Exception:
|
||||||
|
# 启动失败时清理可能创建的目录和 socket 文件
|
||||||
|
if self._socket_path.exists():
|
||||||
|
self._socket_path.unlink()
|
||||||
|
raise
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if self._server:
|
if self._server:
|
||||||
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
|
|||||||
|
|
||||||
|
|
||||||
class UDSTransportClient(TransportClient):
|
class UDSTransportClient(TransportClient):
|
||||||
"""UDS 传输客户端"""
|
"""UDS 传输客户端
|
||||||
|
|
||||||
|
用于主动连接到 UDS 服务端。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, socket_path: Path) -> None:
|
def __init__(self, socket_path: Path) -> None:
|
||||||
self._socket_path: Path = socket_path
|
self._socket_path: Path = socket_path
|
||||||
|
|
||||||
async def connect(self) -> Connection:
|
async def connect(self) -> Connection:
|
||||||
|
"""建立到 UDS 服务端的连接
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UDSConnection: 连接对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当在非 Unix 平台(如 Windows)上调用时
|
||||||
|
"""
|
||||||
|
# 平台检查:UDS 仅在 Unix-like 系统上可用
|
||||||
|
if sys.platform == "win32":
|
||||||
|
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
|
||||||
|
|
||||||
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
|
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
|
||||||
return UDSConnection(reader, writer)
|
return UDSConnection(reader, writer)
|
||||||
|
|||||||
Reference in New Issue
Block a user