From 7136deac93679888c23b4715b9b2199eb578a131 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 16 Mar 2026 08:16:01 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85=20uds=20=E7=9A=84=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugin_runtime/transport/factory.py | 5 +-- src/plugin_runtime/transport/uds.py | 41 ++++++++++++------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/plugin_runtime/transport/factory.py b/src/plugin_runtime/transport/factory.py index a2733b0a..64eaa4f3 100644 --- a/src/plugin_runtime/transport/factory.py +++ b/src/plugin_runtime/transport/factory.py @@ -3,6 +3,7 @@ 根据运行平台自动选择最优传输实现。 """ +from pathlib import Path from typing import Optional import sys @@ -21,7 +22,7 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe if sys.platform != "win32": from .uds import UDSTransportServer - return UDSTransportServer(socket_path=socket_path) + return UDSTransportServer(socket_path=Path(socket_path) if socket_path is not None else None) else: from .named_pipe import NamedPipeTransportServer @@ -46,7 +47,7 @@ def create_transport_client(address: str) -> TransportClient: if "/" in address or address.endswith(".sock"): from .uds import UDSTransportClient - return UDSTransportClient(socket_path=address) + return UDSTransportClient(socket_path=Path(address)) elif ":" in address: from .tcp import TCPTransportClient diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py index da10f0d4..47bf033b 100644 --- a/src/plugin_runtime/transport/uds.py +++ b/src/plugin_runtime/transport/uds.py @@ -27,44 +27,41 @@ _UDS_PATH_MAX = 104 class UDSTransportServer(TransportServer): """UDS 传输服务端""" - def __init__(self, socket_path: Optional[str] = None) -> None: + def __init__(self, socket_path: Optional[Path] = None) -> None: if socket_path is None: # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 import uuid - socket_path = os.path.join( - tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock" - ) + socket_path = Path(tempfile.gettempdir()) / f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock" # 如果路径超出 UDS 限制,回退到更短的路径 - if len(socket_path.encode()) > _UDS_PATH_MAX: - socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") + if len(str(socket_path).encode()) > _UDS_PATH_MAX: + socket_path = Path("/tmp") / f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock" + if len(str(socket_path).encode()) > _UDS_PATH_MAX: + raise OSError(f"UDS socket 路径过长 ({len(str(socket_path).encode())} > {_UDS_PATH_MAX} 字节): {socket_path}") - if len(socket_path.encode()) > _UDS_PATH_MAX: - raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}") - - self._socket_path = socket_path + self._socket_path: Path = socket_path self._server: Optional[asyncio.AbstractServer] = None async def start(self, handler: ConnectionHandler) -> None: # 清理残留 socket 文件 - if os.path.exists(self._socket_path): - os.unlink(self._socket_path) + if self._socket_path.exists(): + self._socket_path.unlink() # 确保父目录存在 - Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True) + self._socket_path.parent.mkdir(parents=True, exist_ok=True) - async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: conn = UDSConnection(reader, writer) try: await handler(conn) finally: await conn.close() - self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path) + self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path)) # 设置文件权限为仅当前用户可访问 - os.chmod(self._socket_path, 0o600) + self._socket_path.chmod(0o600) async def stop(self) -> None: if self._server: @@ -72,19 +69,19 @@ class UDSTransportServer(TransportServer): await self._server.wait_closed() self._server = None # 清理 socket 文件 - if os.path.exists(self._socket_path): - os.unlink(self._socket_path) + if self._socket_path.exists(): + self._socket_path.unlink() def get_address(self) -> str: - return self._socket_path + return str(self._socket_path) class UDSTransportClient(TransportClient): """UDS 传输客户端""" - def __init__(self, socket_path: str) -> None: - self._socket_path = socket_path + def __init__(self, socket_path: Path) -> None: + self._socket_path: Path = socket_path async def connect(self) -> Connection: - reader, writer = await asyncio.open_unix_connection(self._socket_path) + reader, writer = await asyncio.open_unix_connection(str(self._socket_path)) return UDSConnection(reader, writer)