补充 uds 的类型注解
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
根据运行平台自动选择最优传输实现。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import sys
|
||||
@@ -21,7 +22,7 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
|
||||
if sys.platform != "win32":
|
||||
from .uds import UDSTransportServer
|
||||
|
||||
return UDSTransportServer(socket_path=socket_path)
|
||||
return UDSTransportServer(socket_path=Path(socket_path) if socket_path is not None else None)
|
||||
else:
|
||||
from .named_pipe import NamedPipeTransportServer
|
||||
|
||||
@@ -46,7 +47,7 @@ def create_transport_client(address: str) -> TransportClient:
|
||||
if "/" in address or address.endswith(".sock"):
|
||||
from .uds import UDSTransportClient
|
||||
|
||||
return UDSTransportClient(socket_path=address)
|
||||
return UDSTransportClient(socket_path=Path(address))
|
||||
elif ":" in address:
|
||||
from .tcp import TCPTransportClient
|
||||
|
||||
|
||||
@@ -27,44 +27,41 @@ _UDS_PATH_MAX = 104
|
||||
class UDSTransportServer(TransportServer):
|
||||
"""UDS 传输服务端"""
|
||||
|
||||
def __init__(self, socket_path: Optional[str] = None) -> None:
|
||||
def __init__(self, socket_path: Optional[Path] = None) -> None:
|
||||
if socket_path is None:
|
||||
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||
import uuid
|
||||
|
||||
socket_path = os.path.join(
|
||||
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
||||
)
|
||||
socket_path = Path(tempfile.gettempdir()) / f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
||||
|
||||
# 如果路径超出 UDS 限制,回退到更短的路径
|
||||
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
||||
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
||||
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
|
||||
socket_path = Path("/tmp") / f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
||||
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
|
||||
raise OSError(f"UDS socket 路径过长 ({len(str(socket_path).encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
|
||||
|
||||
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
||||
raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
|
||||
|
||||
self._socket_path = socket_path
|
||||
self._socket_path: Path = socket_path
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
# 清理残留 socket 文件
|
||||
if os.path.exists(self._socket_path):
|
||||
os.unlink(self._socket_path)
|
||||
if self._socket_path.exists():
|
||||
self._socket_path.unlink()
|
||||
|
||||
# 确保父目录存在
|
||||
Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._socket_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
conn = UDSConnection(reader, writer)
|
||||
try:
|
||||
await handler(conn)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path)
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||
|
||||
# 设置文件权限为仅当前用户可访问
|
||||
os.chmod(self._socket_path, 0o600)
|
||||
self._socket_path.chmod(0o600)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
@@ -72,19 +69,19 @@ class UDSTransportServer(TransportServer):
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
# 清理 socket 文件
|
||||
if os.path.exists(self._socket_path):
|
||||
os.unlink(self._socket_path)
|
||||
if self._socket_path.exists():
|
||||
self._socket_path.unlink()
|
||||
|
||||
def get_address(self) -> str:
|
||||
return self._socket_path
|
||||
return str(self._socket_path)
|
||||
|
||||
|
||||
class UDSTransportClient(TransportClient):
|
||||
"""UDS 传输客户端"""
|
||||
|
||||
def __init__(self, socket_path: str) -> None:
|
||||
self._socket_path = socket_path
|
||||
def __init__(self, socket_path: Path) -> None:
|
||||
self._socket_path: Path = socket_path
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
reader, writer = await asyncio.open_unix_connection(self._socket_path)
|
||||
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
|
||||
return UDSConnection(reader, writer)
|
||||
|
||||
Reference in New Issue
Block a user