feat: 添加 Windows Named Pipe 传输实现,支持异步连接和数据传输,修复 Windows 平台插件系统导入隔离误把 DLLs 加进去的 bug
This commit is contained in:
@@ -181,6 +181,36 @@ class TestTransport:
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(sys.platform != "win32", reason="Windows only")
|
||||
async def test_named_pipe_connection_framing(self):
|
||||
"""Windows Named Pipe 分帧协议测试"""
|
||||
from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer
|
||||
|
||||
server = NamedPipeTransportServer()
|
||||
received = asyncio.Event()
|
||||
received_data = []
|
||||
|
||||
async def handler(conn):
|
||||
data = await conn.recv_frame()
|
||||
received_data.append(data)
|
||||
await conn.send_frame(b"pipe_pong")
|
||||
received.set()
|
||||
|
||||
await server.start(handler)
|
||||
client = NamedPipeTransportClient(server.get_address())
|
||||
conn = await client.connect()
|
||||
await conn.send_frame(b"pipe_ping")
|
||||
|
||||
await asyncio.wait_for(received.wait(), timeout=5.0)
|
||||
assert received_data[0] == b"pipe_ping"
|
||||
|
||||
resp = await conn.recv_frame()
|
||||
assert resp == b"pipe_pong"
|
||||
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_factory(self):
|
||||
"""传输工厂测试"""
|
||||
@@ -193,6 +223,10 @@ class TestTransport:
|
||||
client = create_transport_client("/tmp/test.sock")
|
||||
assert client is not None
|
||||
|
||||
# Windows Named Pipe 地址
|
||||
client = create_transport_client(r"\\.\pipe\maibot-test")
|
||||
assert client is not None
|
||||
|
||||
# TCP 地址
|
||||
client = create_transport_client("127.0.0.1:9999")
|
||||
assert client is not None
|
||||
@@ -585,13 +619,10 @@ class TestE2E:
|
||||
"""Host-Runner 握手流程测试"""
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||||
from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server
|
||||
|
||||
import secrets
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
|
||||
session_token = secrets.token_hex(16)
|
||||
codec = MsgPackCodec()
|
||||
handshake_done = asyncio.Event()
|
||||
@@ -621,11 +652,11 @@ class TestE2E:
|
||||
# 保持连接一会儿
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
server = UDSTransportServer(socket_path=socket_path)
|
||||
server = create_transport_server()
|
||||
await server.start(server_handler)
|
||||
|
||||
# 客户端握手
|
||||
client = UDSTransportClient(socket_path)
|
||||
client = create_transport_client(server.get_address())
|
||||
conn = await client.connect()
|
||||
|
||||
hello = HelloPayload(
|
||||
|
||||
@@ -586,11 +586,18 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
if path := sysconfig.get_path(key):
|
||||
stdlib_paths.add(os.path.normpath(path))
|
||||
|
||||
runtime_paths = set(stdlib_paths)
|
||||
if os.name == "nt":
|
||||
# Windows 的部分平台扩展模块和依赖会通过 <prefix>/DLLs 暴露在 sys.path 中。
|
||||
for prefix in {sys.prefix, sys.exec_prefix, sys.base_prefix, sys.base_exec_prefix}:
|
||||
if prefix:
|
||||
runtime_paths.add(os.path.normpath(os.path.join(prefix, "DLLs")))
|
||||
|
||||
allowed = set()
|
||||
for p in sys.path:
|
||||
norm = os.path.normpath(p)
|
||||
# 保留标准库和 site-packages
|
||||
if any(norm.startswith(sp) for sp in stdlib_paths):
|
||||
if any(norm.startswith(runtime_path) for runtime_path in runtime_paths):
|
||||
allowed.add(p)
|
||||
# 保留 site-packages(第三方库 + SDK)
|
||||
if "site-packages" in norm or "dist-packages" in norm:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""传输层抽象基类
|
||||
|
||||
定义 TransportServer 和 TransportClient 的统一接口。
|
||||
所有传输后端(UDS、Named Pipe、TCP 回退)必须实现此接口。
|
||||
所有传输后端(UDS、Named Pipe、显式 TCP)必须实现此接口。
|
||||
业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。
|
||||
|
||||
分帧协议:4-byte big-endian length prefix + payload
|
||||
|
||||
@@ -13,32 +13,36 @@ from .base import TransportClient, TransportServer
|
||||
def create_transport_server(socket_path: Optional[str] = None) -> TransportServer:
|
||||
"""创建传输服务端
|
||||
|
||||
Linux/macOS 使用 UDS,Windows 使用 TCP 回退。
|
||||
Linux/macOS 使用 UDS,Windows 使用 Named Pipe。
|
||||
|
||||
Args:
|
||||
socket_path: UDS socket 路径(仅 Linux/macOS 有效)
|
||||
socket_path: UDS socket 路径或 Windows pipe 名称
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
from .uds import UDSTransportServer
|
||||
|
||||
return UDSTransportServer(socket_path=socket_path)
|
||||
else:
|
||||
# Windows 回退到 TCP(后续可改为 Named Pipe)
|
||||
from .tcp import TCPTransportServer
|
||||
from .named_pipe import NamedPipeTransportServer
|
||||
|
||||
return TCPTransportServer()
|
||||
return NamedPipeTransportServer(pipe_name=socket_path)
|
||||
|
||||
|
||||
def create_transport_client(address: str) -> TransportClient:
|
||||
"""创建传输客户端
|
||||
|
||||
根据地址格式自动判断传输类型:
|
||||
- 以 '\\\\.\\pipe\\' 开头 -> Windows Named Pipe
|
||||
- 包含 '/' 或 '.sock' -> UDS
|
||||
- 包含 ':' -> TCP
|
||||
|
||||
Args:
|
||||
address: Host 端监听地址
|
||||
"""
|
||||
if address.startswith("\\\\.\\pipe\\"):
|
||||
from .named_pipe import NamedPipeTransportClient
|
||||
|
||||
return NamedPipeTransportClient(address)
|
||||
if "/" in address or address.endswith(".sock"):
|
||||
from .uds import UDSTransportClient
|
||||
|
||||
|
||||
135
src/plugin_runtime/transport/named_pipe.py
Normal file
135
src/plugin_runtime/transport/named_pipe.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Windows Named Pipe 传输实现。
|
||||
|
||||
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Protocol, cast
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||
|
||||
_PIPE_PREFIX = "\\\\.\\pipe\\"
|
||||
_DEFAULT_PIPE_PREFIX = "maibot-plugin"
|
||||
|
||||
|
||||
class _NamedPipeEventLoop(Protocol):
|
||||
async def start_serving_pipe(self, protocol_factory: Any, address: str) -> list[Any]: ...
|
||||
|
||||
async def create_pipe_connection(self, protocol_factory: Any, address: str) -> tuple[Any, Any]: ...
|
||||
|
||||
def call_exception_handler(self, context: dict[str, Any]) -> None: ...
|
||||
|
||||
def create_task(self, coro: Any) -> asyncio.Task[None]: ...
|
||||
|
||||
|
||||
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
|
||||
return pipe_name
|
||||
|
||||
if pipe_name:
|
||||
sanitized_name = re.sub(r"[^0-9A-Za-z._-]+", "-", pipe_name).strip("-.")
|
||||
else:
|
||||
sanitized_name = f"{_DEFAULT_PIPE_PREFIX}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
if not sanitized_name:
|
||||
sanitized_name = f"{_DEFAULT_PIPE_PREFIX}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
return f"{_PIPE_PREFIX}{sanitized_name}"
|
||||
|
||||
|
||||
class NamedPipeConnection(Connection):
|
||||
"""基于 Windows Named Pipe 的连接。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._reader = asyncio.StreamReader()
|
||||
super().__init__(self._reader)
|
||||
self._handler = handler
|
||||
self._loop = loop
|
||||
self._handler_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
super().connection_made(transport)
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
|
||||
connection = NamedPipeConnection(self._reader, writer)
|
||||
self._handler_task = self._loop.create_task(self._run_handler(connection))
|
||||
self._handler_task.add_done_callback(self._on_handler_done)
|
||||
|
||||
async def _run_handler(self, connection: NamedPipeConnection) -> None:
|
||||
try:
|
||||
await self._handler(connection)
|
||||
finally:
|
||||
await connection.close()
|
||||
|
||||
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
if exc := task.exception():
|
||||
self._loop.call_exception_handler(
|
||||
{
|
||||
"message": "Named pipe 连接处理失败",
|
||||
"exception": exc,
|
||||
"protocol": self,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class NamedPipeTransportServer(TransportServer):
|
||||
"""Windows Named Pipe 传输服务端。"""
|
||||
|
||||
def __init__(self, pipe_name: Optional[str] = None) -> None:
|
||||
self._address = _normalize_pipe_address(pipe_name)
|
||||
self._servers: list[Any] = []
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
if sys.platform != "win32":
|
||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop, "start_serving_pipe"):
|
||||
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
||||
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||
|
||||
self._servers = await pipe_loop.start_serving_pipe(
|
||||
lambda: _NamedPipeServerProtocol(handler, loop),
|
||||
self._address,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
for server in self._servers:
|
||||
server.close()
|
||||
self._servers.clear()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def get_address(self) -> str:
|
||||
return self._address
|
||||
|
||||
|
||||
class NamedPipeTransportClient(TransportClient):
|
||||
"""Windows Named Pipe 传输客户端。"""
|
||||
|
||||
def __init__(self, address: str) -> None:
|
||||
self._address = _normalize_pipe_address(address)
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
if sys.platform != "win32":
|
||||
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop, "create_pipe_connection"):
|
||||
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
||||
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
|
||||
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
|
||||
return NamedPipeConnection(reader, writer)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""TCP 传输实现(回退方案)
|
||||
"""TCP 传输实现。
|
||||
|
||||
仅当 UDS / Named Pipe 不可用时启用。
|
||||
用于显式 TCP 地址场景或调试场景。
|
||||
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
|
||||
"""
|
||||
|
||||
@@ -18,7 +18,7 @@ class TCPConnection(Connection):
|
||||
|
||||
|
||||
class TCPTransportServer(TransportServer):
|
||||
"""TCP 传输服务端(回退方案)"""
|
||||
"""TCP 传输服务端。"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
|
||||
self._host = host
|
||||
|
||||
Reference in New Issue
Block a user