补充 uds 的类型注解
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
根据运行平台自动选择最优传输实现。
|
根据运行平台自动选择最优传输实现。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
@@ -21,7 +22,7 @@ def create_transport_server(socket_path: Optional[str] = None) -> TransportServe
|
|||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
from .uds import UDSTransportServer
|
from .uds import UDSTransportServer
|
||||||
|
|
||||||
return UDSTransportServer(socket_path=socket_path)
|
return UDSTransportServer(socket_path=Path(socket_path) if socket_path is not None else None)
|
||||||
else:
|
else:
|
||||||
from .named_pipe import NamedPipeTransportServer
|
from .named_pipe import NamedPipeTransportServer
|
||||||
|
|
||||||
@@ -46,7 +47,7 @@ def create_transport_client(address: str) -> TransportClient:
|
|||||||
if "/" in address or address.endswith(".sock"):
|
if "/" in address or address.endswith(".sock"):
|
||||||
from .uds import UDSTransportClient
|
from .uds import UDSTransportClient
|
||||||
|
|
||||||
return UDSTransportClient(socket_path=address)
|
return UDSTransportClient(socket_path=Path(address))
|
||||||
elif ":" in address:
|
elif ":" in address:
|
||||||
from .tcp import TCPTransportClient
|
from .tcp import TCPTransportClient
|
||||||
|
|
||||||
|
|||||||
@@ -27,44 +27,41 @@ _UDS_PATH_MAX = 104
|
|||||||
class UDSTransportServer(TransportServer):
|
class UDSTransportServer(TransportServer):
|
||||||
"""UDS 传输服务端"""
|
"""UDS 传输服务端"""
|
||||||
|
|
||||||
def __init__(self, socket_path: Optional[str] = None) -> None:
|
def __init__(self, socket_path: Optional[Path] = None) -> None:
|
||||||
if socket_path is None:
|
if socket_path is None:
|
||||||
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
socket_path = os.path.join(
|
socket_path = Path(tempfile.gettempdir()) / f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
||||||
tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果路径超出 UDS 限制,回退到更短的路径
|
# 如果路径超出 UDS 限制,回退到更短的路径
|
||||||
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
|
||||||
socket_path = os.path.join("/tmp", f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
socket_path = Path("/tmp") / f"mb-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock"
|
||||||
|
if len(str(socket_path).encode()) > _UDS_PATH_MAX:
|
||||||
|
raise OSError(f"UDS socket 路径过长 ({len(str(socket_path).encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
|
||||||
|
|
||||||
if len(socket_path.encode()) > _UDS_PATH_MAX:
|
self._socket_path: Path = socket_path
|
||||||
raise OSError(f"UDS socket 路径过长 ({len(socket_path.encode())} > {_UDS_PATH_MAX} 字节): {socket_path}")
|
|
||||||
|
|
||||||
self._socket_path = socket_path
|
|
||||||
self._server: Optional[asyncio.AbstractServer] = None
|
self._server: Optional[asyncio.AbstractServer] = None
|
||||||
|
|
||||||
async def start(self, handler: ConnectionHandler) -> None:
|
async def start(self, handler: ConnectionHandler) -> None:
|
||||||
# 清理残留 socket 文件
|
# 清理残留 socket 文件
|
||||||
if os.path.exists(self._socket_path):
|
if self._socket_path.exists():
|
||||||
os.unlink(self._socket_path)
|
self._socket_path.unlink()
|
||||||
|
|
||||||
# 确保父目录存在
|
# 确保父目录存在
|
||||||
Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True)
|
self._socket_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||||
conn = UDSConnection(reader, writer)
|
conn = UDSConnection(reader, writer)
|
||||||
try:
|
try:
|
||||||
await handler(conn)
|
await handler(conn)
|
||||||
finally:
|
finally:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path)
|
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
|
||||||
|
|
||||||
# 设置文件权限为仅当前用户可访问
|
# 设置文件权限为仅当前用户可访问
|
||||||
os.chmod(self._socket_path, 0o600)
|
self._socket_path.chmod(0o600)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if self._server:
|
if self._server:
|
||||||
@@ -72,19 +69,19 @@ class UDSTransportServer(TransportServer):
|
|||||||
await self._server.wait_closed()
|
await self._server.wait_closed()
|
||||||
self._server = None
|
self._server = None
|
||||||
# 清理 socket 文件
|
# 清理 socket 文件
|
||||||
if os.path.exists(self._socket_path):
|
if self._socket_path.exists():
|
||||||
os.unlink(self._socket_path)
|
self._socket_path.unlink()
|
||||||
|
|
||||||
def get_address(self) -> str:
|
def get_address(self) -> str:
|
||||||
return self._socket_path
|
return str(self._socket_path)
|
||||||
|
|
||||||
|
|
||||||
class UDSTransportClient(TransportClient):
|
class UDSTransportClient(TransportClient):
|
||||||
"""UDS 传输客户端"""
|
"""UDS 传输客户端"""
|
||||||
|
|
||||||
def __init__(self, socket_path: str) -> None:
|
def __init__(self, socket_path: Path) -> None:
|
||||||
self._socket_path = socket_path
|
self._socket_path: Path = socket_path
|
||||||
|
|
||||||
async def connect(self) -> Connection:
|
async def connect(self) -> Connection:
|
||||||
reader, writer = await asyncio.open_unix_connection(self._socket_path)
|
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
|
||||||
return UDSConnection(reader, writer)
|
return UDSConnection(reader, writer)
|
||||||
|
|||||||
Reference in New Issue
Block a user