补充 uds 的类型注解

This commit is contained in:
DrSmoothl
2026-03-16 08:16:01 +08:00
parent 7420f84fd0
commit 7136deac93
2 changed files with 22 additions and 24 deletions

View File

@@ -3,6 +3,7 @@
根据运行平台自动选择最优传输实现。
"""
from pathlib import Path
from typing import Optional
import sys
@@ -21,7 +22,7 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
if sys.platform != "win32":
from .uds import UDSTransportServer
return UDSTransportServer(socket_path=socket_path)
return UDSTransportServer(socket_path=Path(socket_path) if socket_path is not None else None)
else:
from .named_pipe import NamedPipeTransportServer
@@ -46,7 +47,7 @@ def create_transport_client(address: str) -> TransportClient:
if "/" in address or address.endswith(".sock"):
from .uds import UDSTransportClient
return UDSTransportClient(socket_path=address)
return UDSTransportClient(socket_path=Path(address))
elif ":" in address:
from .tcp import TCPTransportClient

View File

@@ -27,44 +27,41 @@ _UDS_PATH_MAX = 104
class UDSTransportServer(TransportServer):
"""UDS 传输服务端"""
def __init__(self, socket_path: Optional[str] = None) -> None:
def __init__(self, socket_path: Optional[Path] = None) -> None:
if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid
socket_path = os.path.join(
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
)
socket_path = Path(tempfile.gettempdir()) / f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
# 如果路径超出 UDS 限制,回退到更短的路径
if len(socket_path.encode()) > _UDS_PATH_MAX:
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
socket_path = Path("/tmp") / f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
raise OSError(f"UDS socket 路径过长 ({len(str(socket_path).encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
if len(socket_path.encode()) > _UDS_PATH_MAX:
raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
self._socket_path = socket_path
self._socket_path: Path = socket_path
self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None:
# 清理残留 socket 文件
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
if self._socket_path.exists():
self._socket_path.unlink()
# 确保父目录存在
Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True)
self._socket_path.parent.mkdir(parents=True, exist_ok=True)
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
conn = UDSConnection(reader, writer)
try:
await handler(conn)
finally:
await conn.close()
self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path)
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
# 设置文件权限为仅当前用户可访问
os.chmod(self._socket_path, 0o600)
self._socket_path.chmod(0o600)
async def stop(self) -> None:
if self._server:
@@ -72,19 +69,19 @@ class UDSTransportServer(TransportServer):
await self._server.wait_closed()
self._server = None
# 清理 socket 文件
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
if self._socket_path.exists():
self._socket_path.unlink()
def get_address(self) -> str:
return self._socket_path
return str(self._socket_path)
class UDSTransportClient(TransportClient):
"""UDS 传输客户端"""
def __init__(self, socket_path: str) -> None:
self._socket_path = socket_path
def __init__(self, socket_path: Path) -> None:
self._socket_path: Path = socket_path
async def connect(self) -> Connection:
reader, writer = await asyncio.open_unix_connection(self._socket_path)
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer)