feat: 添加 Windows Named Pipe 传输实现,支持异步连接和数据传输,修复 Windows 平台插件系统导入隔离误把 DLLs 加进去的 bug

This commit is contained in:
DrSmoothl
2026-03-15 15:44:14 +08:00
parent a9969ad361
commit 49b9401709
6 changed files with 193 additions and 16 deletions

View File

@@ -181,6 +181,36 @@ class TestTransport:
await conn.close()
await server.stop()
@pytest.mark.asyncio
@pytest.mark.skipif(sys.platform != "win32", reason="Windows only")
async def test_named_pipe_connection_framing(self):
"""Windows Named Pipe 分帧协议测试"""
from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer
server = NamedPipeTransportServer()
received = asyncio.Event()
received_data = []
async def handler(conn):
data = await conn.recv_frame()
received_data.append(data)
await conn.send_frame(b"pipe_pong")
received.set()
await server.start(handler)
client = NamedPipeTransportClient(server.get_address())
conn = await client.connect()
await conn.send_frame(b"pipe_ping")
await asyncio.wait_for(received.wait(), timeout=5.0)
assert received_data[0] == b"pipe_ping"
resp = await conn.recv_frame()
assert resp == b"pipe_pong"
await conn.close()
await server.stop()
@pytest.mark.asyncio
async def test_transport_factory(self):
"""传输工厂测试"""
@@ -193,6 +223,10 @@ class TestTransport:
client = create_transport_client("/tmp/test.sock")
assert client is not None
# Windows Named Pipe 地址
client = create_transport_client(r"\\.\pipe\maibot-test")
assert client is not None
# TCP 地址
client = create_transport_client("127.0.0.1:9999")
assert client is not None
@@ -585,13 +619,10 @@ class TestE2E:
"""Host-Runner 握手流程测试"""
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server
import secrets
import tempfile
import os
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
session_token = secrets.token_hex(16)
codec = MsgPackCodec()
handshake_done = asyncio.Event()
@@ -621,11 +652,11 @@ class TestE2E:
# 保持连接一会儿
await asyncio.sleep(1.0)
server = UDSTransportServer(socket_path=socket_path)
server = create_transport_server()
await server.start(server_handler)
# 客户端握手
client = UDSTransportClient(socket_path)
client = create_transport_client(server.get_address())
conn = await client.connect()
hello = HelloPayload(

View File

@@ -586,11 +586,18 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
if path := sysconfig.get_path(key):
stdlib_paths.add(os.path.normpath(path))
runtime_paths = set(stdlib_paths)
if os.name == "nt":
# Windows 的部分平台扩展模块和依赖会通过 <prefix>/DLLs 暴露在 sys.path 中。
for prefix in {sys.prefix, sys.exec_prefix, sys.base_prefix, sys.base_exec_prefix}:
if prefix:
runtime_paths.add(os.path.normpath(os.path.join(prefix, "DLLs")))
allowed = set()
for p in sys.path:
norm = os.path.normpath(p)
# 保留标准库和 site-packages
if any(norm.startswith(sp) for sp in stdlib_paths):
if any(norm.startswith(runtime_path) for runtime_path in runtime_paths):
allowed.add(p)
# 保留 site-packages第三方库 + SDK
if "site-packages" in norm or "dist-packages" in norm:

View File

@@ -1,7 +1,7 @@
"""传输层抽象基类
定义 TransportServer 和 TransportClient 的统一接口。
所有传输后端UDS、Named Pipe、TCP 回退)必须实现此接口。
所有传输后端UDS、Named Pipe、显式 TCP必须实现此接口。
业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。
分帧协议4-byte big-endian length prefix + payload

View File

@@ -13,32 +13,36 @@ from .base import TransportClient, TransportServer
def create_transport_server(socket_path: Optional[str] = None) -> TransportServer:
"""创建传输服务端
Linux/macOS 使用 UDSWindows 使用 TCP 回退
Linux/macOS 使用 UDSWindows 使用 Named Pipe
Args:
socket_path: UDS socket 路径(仅 Linux/macOS 有效)
socket_path: UDS socket 路径或 Windows pipe 名称
"""
if sys.platform != "win32":
from .uds import UDSTransportServer
return UDSTransportServer(socket_path=socket_path)
else:
# Windows 回退到 TCP后续可改为 Named Pipe
from .tcp import TCPTransportServer
from .named_pipe import NamedPipeTransportServer
return TCPTransportServer()
return NamedPipeTransportServer(pipe_name=socket_path)
def create_transport_client(address: str) -> TransportClient:
"""创建传输客户端
根据地址格式自动判断传输类型:
- 以 '\\\\.\\pipe\\' 开头 -> Windows Named Pipe
- 包含 '/''.sock' -> UDS
- 包含 ':' -> TCP
Args:
address: Host 端监听地址
"""
if address.startswith("\\\\.\\pipe\\"):
from .named_pipe import NamedPipeTransportClient
return NamedPipeTransportClient(address)
if "/" in address or address.endswith(".sock"):
from .uds import UDSTransportClient

View File

@@ -0,0 +1,135 @@
"""Windows Named Pipe 传输实现。
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
"""
from typing import Any, Optional, Protocol, cast
import asyncio
import os
import re
import sys
import uuid
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
_PIPE_PREFIX = "\\\\.\\pipe\\"
_DEFAULT_PIPE_PREFIX = "maibot-plugin"
class _NamedPipeEventLoop(Protocol):
async def start_serving_pipe(self, protocol_factory: Any, address: str) -> list[Any]: ...
async def create_pipe_connection(self, protocol_factory: Any, address: str) -> tuple[Any, Any]: ...
def call_exception_handler(self, context: dict[str, Any]) -> None: ...
def create_task(self, coro: Any) -> asyncio.Task[None]: ...
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
return pipe_name
if pipe_name:
sanitized_name = re.sub(r"[^0-9A-Za-z._-]+", "-", pipe_name).strip("-.")
else:
sanitized_name = f"{_DEFAULT_PIPE_PREFIX}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
if not sanitized_name:
sanitized_name = f"{_DEFAULT_PIPE_PREFIX}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
return f"{_PIPE_PREFIX}{sanitized_name}"
class NamedPipeConnection(Connection):
"""基于 Windows Named Pipe 的连接。"""
pass
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
self._reader = asyncio.StreamReader()
super().__init__(self._reader)
self._handler = handler
self._loop = loop
self._handler_task: Optional[asyncio.Task[None]] = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
connection = NamedPipeConnection(self._reader, writer)
self._handler_task = self._loop.create_task(self._run_handler(connection))
self._handler_task.add_done_callback(self._on_handler_done)
async def _run_handler(self, connection: NamedPipeConnection) -> None:
try:
await self._handler(connection)
finally:
await connection.close()
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
if task.cancelled():
return
if exc := task.exception():
self._loop.call_exception_handler(
{
"message": "Named pipe 连接处理失败",
"exception": exc,
"protocol": self,
}
)
class NamedPipeTransportServer(TransportServer):
"""Windows Named Pipe 传输服务端。"""
def __init__(self, pipe_name: Optional[str] = None) -> None:
self._address = _normalize_pipe_address(pipe_name)
self._servers: list[Any] = []
async def start(self, handler: ConnectionHandler) -> None:
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop()
if not hasattr(loop, "start_serving_pipe"):
raise RuntimeError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop)
self._servers = await pipe_loop.start_serving_pipe(
lambda: _NamedPipeServerProtocol(handler, loop),
self._address,
)
async def stop(self) -> None:
for server in self._servers:
server.close()
self._servers.clear()
await asyncio.sleep(0)
def get_address(self) -> str:
return self._address
class NamedPipeTransportClient(TransportClient):
"""Windows Named Pipe 传输客户端。"""
def __init__(self, address: str) -> None:
self._address = _normalize_pipe_address(address)
async def connect(self) -> Connection:
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop()
if not hasattr(loop, "create_pipe_connection"):
raise RuntimeError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
return NamedPipeConnection(reader, writer)

View File

@@ -1,6 +1,6 @@
"""TCP 传输实现(回退方案)
"""TCP 传输实现
仅当 UDS / Named Pipe 不可用时启用
用于显式 TCP 地址场景或调试场景
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
"""
@@ -18,7 +18,7 @@ class TCPConnection(Connection):
class TCPTransportServer(TransportServer):
"""TCP 传输服务端(回退方案)"""
"""TCP 传输服务端"""
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
self._host = host