feat: 添加 Windows Named Pipe 传输实现,支持异步连接和数据传输,修复 Windows 平台插件系统导入隔离误把 DLLs 加进去的 bug
This commit is contained in:
@@ -181,6 +181,36 @@ class TestTransport:
|
|||||||
await conn.close()
|
await conn.close()
|
||||||
await server.stop()
|
await server.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(sys.platform != "win32", reason="Windows only")
|
||||||
|
async def test_named_pipe_connection_framing(self):
|
||||||
|
"""Windows Named Pipe 分帧协议测试"""
|
||||||
|
from src.plugin_runtime.transport.named_pipe import NamedPipeTransportClient, NamedPipeTransportServer
|
||||||
|
|
||||||
|
server = NamedPipeTransportServer()
|
||||||
|
received = asyncio.Event()
|
||||||
|
received_data = []
|
||||||
|
|
||||||
|
async def handler(conn):
|
||||||
|
data = await conn.recv_frame()
|
||||||
|
received_data.append(data)
|
||||||
|
await conn.send_frame(b"pipe_pong")
|
||||||
|
received.set()
|
||||||
|
|
||||||
|
await server.start(handler)
|
||||||
|
client = NamedPipeTransportClient(server.get_address())
|
||||||
|
conn = await client.connect()
|
||||||
|
await conn.send_frame(b"pipe_ping")
|
||||||
|
|
||||||
|
await asyncio.wait_for(received.wait(), timeout=5.0)
|
||||||
|
assert received_data[0] == b"pipe_ping"
|
||||||
|
|
||||||
|
resp = await conn.recv_frame()
|
||||||
|
assert resp == b"pipe_pong"
|
||||||
|
|
||||||
|
await conn.close()
|
||||||
|
await server.stop()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transport_factory(self):
|
async def test_transport_factory(self):
|
||||||
"""传输工厂测试"""
|
"""传输工厂测试"""
|
||||||
@@ -193,6 +223,10 @@ class TestTransport:
|
|||||||
client = create_transport_client("/tmp/test.sock")
|
client = create_transport_client("/tmp/test.sock")
|
||||||
assert client is not None
|
assert client is not None
|
||||||
|
|
||||||
|
# Windows Named Pipe 地址
|
||||||
|
client = create_transport_client(r"\\.\pipe\maibot-test")
|
||||||
|
assert client is not None
|
||||||
|
|
||||||
# TCP 地址
|
# TCP 地址
|
||||||
client = create_transport_client("127.0.0.1:9999")
|
client = create_transport_client("127.0.0.1:9999")
|
||||||
assert client is not None
|
assert client is not None
|
||||||
@@ -585,13 +619,10 @@ class TestE2E:
|
|||||||
"""Host-Runner 握手流程测试"""
|
"""Host-Runner 握手流程测试"""
|
||||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||||||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
from src.plugin_runtime.transport.factory import create_transport_client, create_transport_server
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
|
|
||||||
session_token = secrets.token_hex(16)
|
session_token = secrets.token_hex(16)
|
||||||
codec = MsgPackCodec()
|
codec = MsgPackCodec()
|
||||||
handshake_done = asyncio.Event()
|
handshake_done = asyncio.Event()
|
||||||
@@ -621,11 +652,11 @@ class TestE2E:
|
|||||||
# 保持连接一会儿
|
# 保持连接一会儿
|
||||||
await asyncio.sleep(1.0)
|
await asyncio.sleep(1.0)
|
||||||
|
|
||||||
server = UDSTransportServer(socket_path=socket_path)
|
server = create_transport_server()
|
||||||
await server.start(server_handler)
|
await server.start(server_handler)
|
||||||
|
|
||||||
# 客户端握手
|
# 客户端握手
|
||||||
client = UDSTransportClient(socket_path)
|
client = create_transport_client(server.get_address())
|
||||||
conn = await client.connect()
|
conn = await client.connect()
|
||||||
|
|
||||||
hello = HelloPayload(
|
hello = HelloPayload(
|
||||||
|
|||||||
@@ -586,11 +586,18 @@ def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
|||||||
if path := sysconfig.get_path(key):
|
if path := sysconfig.get_path(key):
|
||||||
stdlib_paths.add(os.path.normpath(path))
|
stdlib_paths.add(os.path.normpath(path))
|
||||||
|
|
||||||
|
runtime_paths = set(stdlib_paths)
|
||||||
|
if os.name == "nt":
|
||||||
|
# Windows 的部分平台扩展模块和依赖会通过 <prefix>/DLLs 暴露在 sys.path 中。
|
||||||
|
for prefix in {sys.prefix, sys.exec_prefix, sys.base_prefix, sys.base_exec_prefix}:
|
||||||
|
if prefix:
|
||||||
|
runtime_paths.add(os.path.normpath(os.path.join(prefix, "DLLs")))
|
||||||
|
|
||||||
allowed = set()
|
allowed = set()
|
||||||
for p in sys.path:
|
for p in sys.path:
|
||||||
norm = os.path.normpath(p)
|
norm = os.path.normpath(p)
|
||||||
# 保留标准库和 site-packages
|
# 保留标准库和 site-packages
|
||||||
if any(norm.startswith(sp) for sp in stdlib_paths):
|
if any(norm.startswith(runtime_path) for runtime_path in runtime_paths):
|
||||||
allowed.add(p)
|
allowed.add(p)
|
||||||
# 保留 site-packages(第三方库 + SDK)
|
# 保留 site-packages(第三方库 + SDK)
|
||||||
if "site-packages" in norm or "dist-packages" in norm:
|
if "site-packages" in norm or "dist-packages" in norm:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""传输层抽象基类
|
"""传输层抽象基类
|
||||||
|
|
||||||
定义 TransportServer 和 TransportClient 的统一接口。
|
定义 TransportServer 和 TransportClient 的统一接口。
|
||||||
所有传输后端(UDS、Named Pipe、TCP 回退)必须实现此接口。
|
所有传输后端(UDS、Named Pipe、显式 TCP)必须实现此接口。
|
||||||
业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。
|
业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。
|
||||||
|
|
||||||
分帧协议:4-byte big-endian length prefix + payload
|
分帧协议:4-byte big-endian length prefix + payload
|
||||||
|
|||||||
@@ -13,32 +13,36 @@ from .base import TransportClient, TransportServer
|
|||||||
def create_transport_server(socket_path: Optional[str] = None) -> TransportServer:
|
def create_transport_server(socket_path: Optional[str] = None) -> TransportServer:
|
||||||
"""创建传输服务端
|
"""创建传输服务端
|
||||||
|
|
||||||
Linux/macOS 使用 UDS,Windows 使用 TCP 回退。
|
Linux/macOS 使用 UDS,Windows 使用 Named Pipe。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
socket_path: UDS socket 路径(仅 Linux/macOS 有效)
|
socket_path: UDS socket 路径或 Windows pipe 名称
|
||||||
"""
|
"""
|
||||||
if sys.platform != "win32":
|
if sys.platform != "win32":
|
||||||
from .uds import UDSTransportServer
|
from .uds import UDSTransportServer
|
||||||
|
|
||||||
return UDSTransportServer(socket_path=socket_path)
|
return UDSTransportServer(socket_path=socket_path)
|
||||||
else:
|
else:
|
||||||
# Windows 回退到 TCP(后续可改为 Named Pipe)
|
from .named_pipe import NamedPipeTransportServer
|
||||||
from .tcp import TCPTransportServer
|
|
||||||
|
|
||||||
return TCPTransportServer()
|
return NamedPipeTransportServer(pipe_name=socket_path)
|
||||||
|
|
||||||
|
|
||||||
def create_transport_client(address: str) -> TransportClient:
|
def create_transport_client(address: str) -> TransportClient:
|
||||||
"""创建传输客户端
|
"""创建传输客户端
|
||||||
|
|
||||||
根据地址格式自动判断传输类型:
|
根据地址格式自动判断传输类型:
|
||||||
|
- 以 '\\\\.\\pipe\\' 开头 -> Windows Named Pipe
|
||||||
- 包含 '/' 或 '.sock' -> UDS
|
- 包含 '/' 或 '.sock' -> UDS
|
||||||
- 包含 ':' -> TCP
|
- 包含 ':' -> TCP
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
address: Host 端监听地址
|
address: Host 端监听地址
|
||||||
"""
|
"""
|
||||||
|
if address.startswith("\\\\.\\pipe\\"):
|
||||||
|
from .named_pipe import NamedPipeTransportClient
|
||||||
|
|
||||||
|
return NamedPipeTransportClient(address)
|
||||||
if "/" in address or address.endswith(".sock"):
|
if "/" in address or address.endswith(".sock"):
|
||||||
from .uds import UDSTransportClient
|
from .uds import UDSTransportClient
|
||||||
|
|
||||||
|
|||||||
135
src/plugin_runtime/transport/named_pipe.py
Normal file
135
src/plugin_runtime/transport/named_pipe.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""Windows Named Pipe 传输实现。
|
||||||
|
|
||||||
|
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Optional, Protocol, cast
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||||
|
|
||||||
|
_PIPE_PREFIX = "\\\\.\\pipe\\"
|
||||||
|
_DEFAULT_PIPE_PREFIX = "maibot-plugin"
|
||||||
|
|
||||||
|
|
||||||
|
class _NamedPipeEventLoop(Protocol):
|
||||||
|
async def start_serving_pipe(self, protocol_factory: Any, address: str) -> list[Any]: ...
|
||||||
|
|
||||||
|
async def create_pipe_connection(self, protocol_factory: Any, address: str) -> tuple[Any, Any]: ...
|
||||||
|
|
||||||
|
def call_exception_handler(self, context: dict[str, Any]) -> None: ...
|
||||||
|
|
||||||
|
def create_task(self, coro: Any) -> asyncio.Task[None]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
|
||||||
|
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
|
||||||
|
return pipe_name
|
||||||
|
|
||||||
|
if pipe_name:
|
||||||
|
sanitized_name = re.sub(r"[^0-9A-Za-z._-]+", "-", pipe_name).strip("-.")
|
||||||
|
else:
|
||||||
|
sanitized_name = f"{_DEFAULT_PIPE_PREFIX}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
if not sanitized_name:
|
||||||
|
sanitized_name = f"{_DEFAULT_PIPE_PREFIX}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
return f"{_PIPE_PREFIX}{sanitized_name}"
|
||||||
|
|
||||||
|
|
||||||
|
class NamedPipeConnection(Connection):
|
||||||
|
"""基于 Windows Named Pipe 的连接。"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
|
||||||
|
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
self._reader = asyncio.StreamReader()
|
||||||
|
super().__init__(self._reader)
|
||||||
|
self._handler = handler
|
||||||
|
self._loop = loop
|
||||||
|
self._handler_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
|
super().connection_made(transport)
|
||||||
|
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
|
||||||
|
connection = NamedPipeConnection(self._reader, writer)
|
||||||
|
self._handler_task = self._loop.create_task(self._run_handler(connection))
|
||||||
|
self._handler_task.add_done_callback(self._on_handler_done)
|
||||||
|
|
||||||
|
async def _run_handler(self, connection: NamedPipeConnection) -> None:
|
||||||
|
try:
|
||||||
|
await self._handler(connection)
|
||||||
|
finally:
|
||||||
|
await connection.close()
|
||||||
|
|
||||||
|
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
if exc := task.exception():
|
||||||
|
self._loop.call_exception_handler(
|
||||||
|
{
|
||||||
|
"message": "Named pipe 连接处理失败",
|
||||||
|
"exception": exc,
|
||||||
|
"protocol": self,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NamedPipeTransportServer(TransportServer):
|
||||||
|
"""Windows Named Pipe 传输服务端。"""
|
||||||
|
|
||||||
|
def __init__(self, pipe_name: Optional[str] = None) -> None:
|
||||||
|
self._address = _normalize_pipe_address(pipe_name)
|
||||||
|
self._servers: list[Any] = []
|
||||||
|
|
||||||
|
async def start(self, handler: ConnectionHandler) -> None:
|
||||||
|
if sys.platform != "win32":
|
||||||
|
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if not hasattr(loop, "start_serving_pipe"):
|
||||||
|
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
||||||
|
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||||
|
|
||||||
|
self._servers = await pipe_loop.start_serving_pipe(
|
||||||
|
lambda: _NamedPipeServerProtocol(handler, loop),
|
||||||
|
self._address,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
for server in self._servers:
|
||||||
|
server.close()
|
||||||
|
self._servers.clear()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def get_address(self) -> str:
|
||||||
|
return self._address
|
||||||
|
|
||||||
|
|
||||||
|
class NamedPipeTransportClient(TransportClient):
|
||||||
|
"""Windows Named Pipe 传输客户端。"""
|
||||||
|
|
||||||
|
def __init__(self, address: str) -> None:
|
||||||
|
self._address = _normalize_pipe_address(address)
|
||||||
|
|
||||||
|
async def connect(self) -> Connection:
|
||||||
|
if sys.platform != "win32":
|
||||||
|
raise RuntimeError("Named pipe 仅支持 Windows")
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if not hasattr(loop, "create_pipe_connection"):
|
||||||
|
raise RuntimeError("当前事件循环不支持 Windows named pipe")
|
||||||
|
pipe_loop = cast(_NamedPipeEventLoop, loop)
|
||||||
|
|
||||||
|
reader = asyncio.StreamReader()
|
||||||
|
protocol = asyncio.StreamReaderProtocol(reader)
|
||||||
|
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
|
||||||
|
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
|
||||||
|
return NamedPipeConnection(reader, writer)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""TCP 传输实现(回退方案)
|
"""TCP 传输实现。
|
||||||
|
|
||||||
仅当 UDS / Named Pipe 不可用时启用。
|
用于显式 TCP 地址场景或调试场景。
|
||||||
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
|
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ class TCPConnection(Connection):
|
|||||||
|
|
||||||
|
|
||||||
class TCPTransportServer(TransportServer):
|
class TCPTransportServer(TransportServer):
|
||||||
"""TCP 传输服务端(回退方案)"""
|
"""TCP 传输服务端。"""
|
||||||
|
|
||||||
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
|
def __init__(self, host: str = "127.0.0.1", port: int = 0) -> None:
|
||||||
self._host = host
|
self._host = host
|
||||||
|
|||||||
Reference in New Issue
Block a user