补充 uds 的类型注解

This commit is contained in:
DrSmoothl
2026-03-16 08:16:01 +08:00
parent 7420f84fd0
commit 7136deac93
2 changed files with 22 additions and 24 deletions

View File

@@ -3,6 +3,7 @@
根据运行平台自动选择最优传输实现。 根据运行平台自动选择最优传输实现。
""" """
from pathlib import Path
from typing import Optional from typing import Optional
import sys import sys
@@ -21,7 +22,7 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
if sys.platform != "win32": if sys.platform != "win32":
from .uds import UDSTransportServer from .uds import UDSTransportServer
return UDSTransportServer(socket_path=socket_path) return UDSTransportServer(socket_path=Path(socket_path) if socket_path is not None else None)
else: else:
from .named_pipe import NamedPipeTransportServer from .named_pipe import NamedPipeTransportServer
@@ -46,7 +47,7 @@ def create_transport_client(address: str) -> TransportClient:
if "/" in address or address.endswith(".sock"): if "/" in address or address.endswith(".sock"):
from .uds import UDSTransportClient from .uds import UDSTransportClient
return UDSTransportClient(socket_path=address) return UDSTransportClient(socket_path=Path(address))
elif ":" in address: elif ":" in address:
from .tcp import TCPTransportClient from .tcp import TCPTransportClient

View File

@@ -27,44 +27,41 @@ _UDS_PATH_MAX = 104
class UDSTransportServer(TransportServer): class UDSTransportServer(TransportServer):
"""UDS 传输服务端""" """UDS 传输服务端"""
def __init__(self, socket_path: Optional[str] = None) -> None: def __init__(self, socket_path: Optional[Path] = None) -> None:
if socket_path is None: if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid import uuid
socket_path = os.path.join( socket_path = Path(tempfile.gettempdir()) / f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
)
# 如果路径超出 UDS 限制,回退到更短的路径 # 如果路径超出 UDS 限制,回退到更短的路径
if len(socket_path.encode()) > _UDS_PATH_MAX: if len(str(socket_path).encode()) > _UDS_PATH_MAX:
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") socket_path = Path("/tmp") / f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
raise OSError(f"UDS socket 路径过长 ({len(str(socket_path).encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
if len(socket_path.encode()) > _UDS_PATH_MAX: self._socket_path: Path = socket_path
raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
self._socket_path = socket_path
self._server: Optional[asyncio.AbstractServer] = None self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None: async def start(self, handler: ConnectionHandler) -> None:
# 清理残留 socket 文件 # 清理残留 socket 文件
if os.path.exists(self._socket_path): if self._socket_path.exists():
os.unlink(self._socket_path) self._socket_path.unlink()
# 确保父目录存在 # 确保父目录存在
Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True) self._socket_path.parent.mkdir(parents=True, exist_ok=True)
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
conn = UDSConnection(reader, writer) conn = UDSConnection(reader, writer)
try: try:
await handler(conn) await handler(conn)
finally: finally:
await conn.close() await conn.close()
self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path) self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
# 设置文件权限为仅当前用户可访问 # 设置文件权限为仅当前用户可访问
os.chmod(self._socket_path, 0o600) self._socket_path.chmod(0o600)
async def stop(self) -> None: async def stop(self) -> None:
if self._server: if self._server:
@@ -72,19 +69,19 @@ class UDSTransportServer(TransportServer):
await self._server.wait_closed() await self._server.wait_closed()
self._server = None self._server = None
# 清理 socket 文件 # 清理 socket 文件
if os.path.exists(self._socket_path): if self._socket_path.exists():
os.unlink(self._socket_path) self._socket_path.unlink()
def get_address(self) -> str: def get_address(self) -> str:
return self._socket_path return str(self._socket_path)
class UDSTransportClient(TransportClient): class UDSTransportClient(TransportClient):
"""UDS 传输客户端""" """UDS 传输客户端"""
def __init__(self, socket_path: str) -> None: def __init__(self, socket_path: Path) -> None:
self._socket_path = socket_path self._socket_path: Path = socket_path
async def connect(self) -> Connection: async def connect(self) -> Connection:
reader, writer = await asyncio.open_unix_connection(self._socket_path) reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer) return UDSConnection(reader, writer)