feat(plugin-runtime): add plugin isolation IPC infrastructure

- Protocol layer: Envelope model with Pydantic schema, MsgPack/JSON codecs, unified error codes
- Transport layer: cross-platform IPC abstraction with 4-byte length-prefixed framing (UDS + TCP fallback)
- Host: RPC server, policy engine, circuit breaker, capability service, supervisor with hot-reload
- Runner: RPC client, plugin loader, process entry point
- Tests: 16 passing tests covering protocol, transport, host, and E2E handshake
This commit is contained in:
DrSmoothl
2026-03-06 02:01:30 +08:00
parent 10d5c81268
commit 61dc15a513
22 changed files with 2695 additions and 1 deletions

3
.gitignore vendored
View File

@@ -354,4 +354,5 @@ MaiBot.code-workspace
*.lock
actionlint
.sisyphus/
dist-electron/
dist-electron/
packages/

View File

@@ -0,0 +1,427 @@
"""插件运行时框架基础测试
验证协议层、传输层、RPC 通信链路的正确性。
"""
import asyncio
import sys
import os
import pytest
# 确保项目根目录在 sys.path 中
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
# SDK 包路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
# ─── 协议层测试 ───────────────────────────────────────────
class TestProtocol:
"""协议层测试"""
def test_envelope_create_and_serialize(self):
"""Envelope 创建与序列化"""
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
env = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="plugin.invoke_command",
plugin_id="test_plugin",
payload={"component_name": "greet", "args": {}},
)
assert env.request_id == 1
assert env.is_request()
assert env.method == "plugin.invoke_command"
# 测试 make_response
resp = env.make_response(payload={"success": True})
assert resp.is_response()
assert resp.request_id == 1
assert resp.payload["success"] is True
def test_envelope_make_error_response(self):
"""错误响应生成"""
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
env = Envelope(
request_id=42,
message_type=MessageType.REQUEST,
method="cap.request",
)
err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限")
assert err_resp.error is not None
assert err_resp.error["code"] == "E_UNAUTHORIZED"
assert err_resp.error["message"] == "没有权限"
def test_msgpack_codec(self):
"""MsgPack 编解码"""
from src.plugin_runtime.protocol.codec import MsgPackCodec
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
codec = MsgPackCodec()
env = Envelope(
request_id=100,
message_type=MessageType.REQUEST,
method="test.method",
payload={"key": "value", "number": 42},
)
# 编码
data = codec.encode_envelope(env)
assert isinstance(data, bytes)
# 解码
decoded = codec.decode_envelope(data)
assert decoded.request_id == 100
assert decoded.method == "test.method"
assert decoded.payload["key"] == "value"
assert decoded.payload["number"] == 42
def test_json_codec(self):
"""JSON 编解码"""
from src.plugin_runtime.protocol.codec import JsonCodec
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
codec = JsonCodec()
env = Envelope(
request_id=200,
message_type=MessageType.EVENT,
method="plugin.config_updated",
payload={"config_version": "2.0"},
)
data = codec.encode_envelope(env)
assert isinstance(data, bytes)
decoded = codec.decode_envelope(data)
assert decoded.request_id == 200
assert decoded.is_event()
def test_request_id_generator(self):
"""请求 ID 生成器单调递增"""
from src.plugin_runtime.protocol.envelope import RequestIdGenerator
gen = RequestIdGenerator()
ids = [gen.next() for _ in range(100)]
assert ids == list(range(1, 101))
def test_error_codes(self):
"""错误码枚举"""
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
err = RPCError(ErrorCode.E_TIMEOUT, "请求超时")
assert err.code == ErrorCode.E_TIMEOUT
assert "E_TIMEOUT" in str(err)
# 序列化/反序列化
d = err.to_dict()
err2 = RPCError.from_dict(d)
assert err2.code == ErrorCode.E_TIMEOUT
# ─── 传输层测试 ───────────────────────────────────────────
class TestTransport:
"""传输层测试"""
@pytest.mark.asyncio
async def test_uds_connection_framing(self):
"""UDS 分帧协议测试"""
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
server = UDSTransportServer()
received = asyncio.Event()
received_data = []
async def handler(conn):
data = await conn.recv_frame()
received_data.append(data)
await conn.send_frame(b"pong")
received.set()
await server.start(handler)
address = server.get_address()
client = UDSTransportClient(address)
conn = await client.connect()
await conn.send_frame(b"ping")
# 等待服务端处理
await asyncio.wait_for(received.wait(), timeout=5.0)
assert received_data[0] == b"ping"
# 接收服务端回复
resp = await conn.recv_frame()
assert resp == b"pong"
await conn.close()
await server.stop()
@pytest.mark.asyncio
async def test_tcp_connection_framing(self):
"""TCP 分帧协议测试"""
from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient
server = TCPTransportServer()
received = asyncio.Event()
received_data = []
async def handler(conn):
data = await conn.recv_frame()
received_data.append(data)
await conn.send_frame(b"tcp_pong")
received.set()
await server.start(handler)
address = server.get_address()
host, port = address.split(":")
client = TCPTransportClient(host, int(port))
conn = await client.connect()
await conn.send_frame(b"tcp_ping")
await asyncio.wait_for(received.wait(), timeout=5.0)
assert received_data[0] == b"tcp_ping"
resp = await conn.recv_frame()
assert resp == b"tcp_pong"
await conn.close()
await server.stop()
@pytest.mark.asyncio
async def test_transport_factory(self):
"""传输工厂测试"""
from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client
server = create_transport_server()
assert server is not None
# UDS 路径
client = create_transport_client("/tmp/test.sock")
assert client is not None
# TCP 地址
client = create_transport_client("127.0.0.1:9999")
assert client is not None
# ─── Host 层测试 ──────────────────────────────────────────
class TestHost:
"""Host 端基础设施测试"""
def test_policy_engine(self):
"""策略引擎测试"""
from src.plugin_runtime.host.policy_engine import PolicyEngine
engine = PolicyEngine()
# 注册插件
token = engine.register_plugin(
plugin_id="test_plugin",
generation=1,
capabilities=["send.text", "db.query"],
limits={"qps": 10, "burst": 20},
)
assert token.plugin_id == "test_plugin"
assert "send.text" in token.capabilities
# 能力检查
ok, _ = engine.check_capability("test_plugin", "send.text")
assert ok
ok, reason = engine.check_capability("test_plugin", "llm.generate")
assert not ok
assert "未获授权" in reason
# 未注册插件
ok, reason = engine.check_capability("unknown", "send.text")
assert not ok
def test_circuit_breaker(self):
"""熔断器测试"""
from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState
breaker = CircuitBreaker(failure_threshold=3)
# 初始状态:关闭
assert breaker.state == CircuitState.CLOSED
assert breaker.allow_request()
# 连续失败
breaker.record_failure()
breaker.record_failure()
assert breaker.allow_request() # 还没到阈值
breaker.record_failure() # 第3次触发熔断
assert breaker.state == CircuitState.OPEN
assert not breaker.allow_request()
# 重置
breaker.reset()
assert breaker.state == CircuitState.CLOSED
def test_circuit_breaker_registry(self):
"""熔断器注册表测试"""
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
registry = CircuitBreakerRegistry(failure_threshold=2)
b1 = registry.get("plugin_a")
b2 = registry.get("plugin_b")
assert b1 is not b2
assert registry.get("plugin_a") is b1 # 同一个
# ─── SDK 测试 ─────────────────────────────────────────────
class TestSDK:
"""SDK 框架测试"""
def test_component_decorators(self):
"""组件装饰器测试"""
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
from maibot_sdk.types import ActivationType, EventType
class TestPlugin(MaiBotPlugin):
@Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"])
async def handle_greet(self, **kwargs):
return True, "ok"
@Command("echo", pattern=r"^/echo")
async def handle_echo(self, **kwargs):
return True, "echoed", 2
@Tool("search", parameters={"query": {"type": "string"}})
async def handle_search(self, **kwargs):
return {"result": "found"}
@EventHandler("on_start", event_type=EventType.ON_START)
async def handle_start(self, **kwargs):
return True, False, "started"
plugin = TestPlugin()
components = plugin.get_components()
assert len(components) == 4
names = {c["name"] for c in components}
assert "greet" in names
assert "echo" in names
assert "search" in names
assert "on_start" in names
types = {c["type"] for c in components}
assert "action" in types
assert "command" in types
assert "tool" in types
assert "event_handler" in types
def test_plugin_context_not_initialized(self):
"""未初始化上下文时应报错"""
from maibot_sdk import MaiBotPlugin
plugin = MaiBotPlugin()
with pytest.raises(RuntimeError, match="尚未初始化"):
_ = plugin.ctx
def test_plugin_context_injection(self):
"""上下文注入测试"""
from maibot_sdk import MaiBotPlugin
from maibot_sdk.context import PluginContext
plugin = MaiBotPlugin()
ctx = PluginContext(plugin_id="test")
plugin._set_context(ctx)
assert plugin.ctx.plugin_id == "test"
assert plugin.ctx.send is not None
assert plugin.ctx.db is not None
assert plugin.ctx.llm is not None
assert plugin.ctx.config is not None
# ─── 端到端集成测试 ────────────────────────────────────────
class TestE2E:
"""端到端集成测试Host + Runner 通信)"""
@pytest.mark.asyncio
async def test_handshake(self):
"""Host-Runner 握手流程测试"""
from src.plugin_runtime.protocol.codec import create_codec
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
import secrets
import tempfile
import os
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
session_token = secrets.token_hex(16)
codec = create_codec()
handshake_done = asyncio.Event()
server_result = {}
async def server_handler(conn):
# 接收握手
data = await conn.recv_frame()
env = codec.decode_envelope(data)
assert env.method == "runner.hello"
hello = HelloPayload.model_validate(env.payload)
assert hello.session_token == session_token
# 发送响应
resp_payload = HelloResponsePayload(
accepted=True,
host_version="1.0",
assigned_generation=1,
)
resp = env.make_response(payload=resp_payload.model_dump())
await conn.send_frame(codec.encode_envelope(resp))
server_result["runner_id"] = hello.runner_id
handshake_done.set()
# 保持连接一会儿
await asyncio.sleep(1.0)
server = UDSTransportServer(socket_path=socket_path)
await server.start(server_handler)
# 客户端握手
client = UDSTransportClient(socket_path)
conn = await client.connect()
hello = HelloPayload(
runner_id="test-runner",
sdk_version="1.0.0",
session_token=session_token,
)
env = Envelope(
request_id=1,
message_type=MessageType.REQUEST,
method="runner.hello",
payload=hello.model_dump(),
)
await conn.send_frame(codec.encode_envelope(env))
resp_data = await conn.recv_frame()
resp = codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(resp.payload)
assert resp_payload.accepted
assert resp_payload.assigned_generation == 1
await asyncio.wait_for(handshake_done.wait(), timeout=5.0)
assert server_result["runner_id"] == "test-runner"
await conn.close()
await server.stop()

View File

@@ -0,0 +1,2 @@
# MaiBot Plugin Runtime - 插件隔离运行时基础设施
# 本模块实现 Host-Runner 进程分离架构,提供 IPC 通信、策略引擎与生命周期管理

View File

@@ -0,0 +1 @@
# Host 端 - Supervisor、RPC Server、策略引擎、路由

View File

@@ -0,0 +1,108 @@
"""能力服务层
Host 端实现的能力服务,处理来自插件的 cap.* 请求。
每个能力方法被注册到 RPC Server接收 Runner 转发的请求并执行实际操作。
"""
from typing import Any, Callable, Awaitable
import logging
from src.plugin_runtime.protocol.envelope import (
CapabilityRequestPayload,
CapabilityResponsePayload,
Envelope,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.host.policy_engine import PolicyEngine
logger = logging.getLogger("plugin_runtime.host.capability_service")
# 能力实现函数类型: (plugin_id, capability, args) -> result
CapabilityImpl = Callable[[str, str, dict[str, Any]], Awaitable[Any]]
class CapabilityService:
"""能力服务
负责:
1. 注册能力实现
2. 接收插件的能力调用请求
3. 通过策略引擎校验权限和限流
4. 执行实际操作并返回结果
"""
def __init__(self, policy_engine: PolicyEngine):
self._policy = policy_engine
# capability_name -> implementation
self._implementations: dict[str, CapabilityImpl] = {}
def register_capability(self, name: str, impl: CapabilityImpl) -> None:
"""注册一个能力实现
Args:
name: 能力名称,如 "send.text", "db.query", "llm.generate"
impl: 实现函数
"""
self._implementations[name] = impl
logger.debug(f"注册能力实现: {name}")
async def handle_capability_request(self, envelope: Envelope) -> Envelope:
"""处理能力调用请求(作为 RPC Server 的 method handler
从 envelope 中提取 capability 名称和参数,
校验权限后调用对应实现。
"""
plugin_id = envelope.plugin_id
try:
req = CapabilityRequestPayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
f"能力调用 payload 格式错误: {e}",
)
capability = req.capability
# 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability)
if not allowed:
return envelope.make_error_response(
ErrorCode.E_CAPABILITY_DENIED.value,
reason,
)
# 2. 限流校验
allowed, reason = self._policy.check_rate_limit(plugin_id)
if not allowed:
return envelope.make_error_response(
ErrorCode.E_BACKPRESSURE.value,
reason,
)
# 3. 查找实现
impl = self._implementations.get(capability)
if impl is None:
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的能力: {capability}",
)
# 4. 执行
try:
result = await impl(plugin_id, capability, req.args)
resp_payload = CapabilityResponsePayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except RPCError as e:
return envelope.make_error_response(e.code.value, e.message, e.details)
except Exception as e:
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
return envelope.make_error_response(
ErrorCode.E_CAPABILITY_FAILED.value,
str(e),
)
def list_capabilities(self) -> list[str]:
"""列出所有已注册的能力"""
return list(self._implementations.keys())

View File

@@ -0,0 +1,105 @@
"""熔断器
为每个插件提供熔断保护,连续失败超过阈值后临时禁用。
支持指数退避恢复。
"""
from enum import Enum
import time
class CircuitState(str, Enum):
CLOSED = "closed" # 正常工作
OPEN = "open" # 熔断(拒绝所有调用)
HALF_OPEN = "half_open" # 探测恢复
class CircuitBreaker:
"""单个插件的熔断器"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout_sec: float = 30.0,
max_recovery_timeout_sec: float = 300.0,
):
self.failure_threshold = failure_threshold
self.base_recovery_timeout = recovery_timeout_sec
self.max_recovery_timeout = max_recovery_timeout_sec
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = 0.0
self._consecutive_opens = 0 # 用于指数退避
@property
def state(self) -> CircuitState:
if self._state == CircuitState.OPEN:
# 检查是否可以进入半开状态
elapsed = time.monotonic() - self._last_failure_time
recovery_timeout = min(
self.base_recovery_timeout * (2 ** self._consecutive_opens),
self.max_recovery_timeout,
)
if elapsed >= recovery_timeout:
self._state = CircuitState.HALF_OPEN
return self._state
def allow_request(self) -> bool:
"""是否允许通过请求"""
state = self.state
if state == CircuitState.CLOSED:
return True
if state == CircuitState.HALF_OPEN:
return True # 允许一次试探
return False # OPEN 状态拒绝
def record_success(self) -> None:
"""记录一次成功调用"""
if self._state == CircuitState.HALF_OPEN:
# 半开状态成功 -> 关闭熔断
self._state = CircuitState.CLOSED
self._failure_count = 0
self._consecutive_opens = 0
elif self._state == CircuitState.CLOSED:
self._failure_count = 0
def record_failure(self) -> None:
"""记录一次失败调用"""
self._failure_count += 1
self._last_failure_time = time.monotonic()
if self._state == CircuitState.HALF_OPEN:
# 半开状态失败 -> 重新开启熔断
self._state = CircuitState.OPEN
self._consecutive_opens += 1
elif self._failure_count >= self.failure_threshold:
self._state = CircuitState.OPEN
self._consecutive_opens += 1
def reset(self) -> None:
"""重置熔断器"""
self._state = CircuitState.CLOSED
self._failure_count = 0
self._consecutive_opens = 0
class CircuitBreakerRegistry:
"""熔断器注册表,为每个插件维护独立的熔断器"""
def __init__(self, **default_kwargs):
self._breakers: dict[str, CircuitBreaker] = {}
self._default_kwargs = default_kwargs
def get(self, plugin_id: str) -> CircuitBreaker:
if plugin_id not in self._breakers:
self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs)
return self._breakers[plugin_id]
def remove(self, plugin_id: str) -> None:
self._breakers.pop(plugin_id, None)
def reset_all(self) -> None:
for breaker in self._breakers.values():
breaker.reset()

View File

@@ -0,0 +1,125 @@
"""策略引擎
负责能力授权校验、限流、配额管理。
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
"""
from dataclasses import dataclass, field
import time
@dataclass
class CapabilityToken:
"""能力令牌
描述某个插件在当前会话中被授予的能力和资源限制。
"""
plugin_id: str
generation: int
capabilities: set[str] = field(default_factory=set)
qps_limit: int = 20
burst_limit: int = 50
daily_token_limit: int = 200000
max_payload_kb: int = 256
# 运行时统计
_call_count: int = field(default=0, init=False, repr=False)
_window_start: float = field(default_factory=time.monotonic, init=False, repr=False)
_window_calls: int = field(default=0, init=False, repr=False)
class PolicyEngine:
"""策略引擎
管理所有插件的能力令牌,提供授权校验与限流决策。
"""
def __init__(self):
# plugin_id -> CapabilityToken
self._tokens: dict[str, CapabilityToken] = {}
def register_plugin(
self,
plugin_id: str,
generation: int,
capabilities: list[str],
limits: dict | None = None,
) -> CapabilityToken:
"""为插件签发能力令牌"""
limits = limits or {}
token = CapabilityToken(
plugin_id=plugin_id,
generation=generation,
capabilities=set(capabilities),
qps_limit=limits.get("qps", 20),
burst_limit=limits.get("burst", 50),
daily_token_limit=limits.get("daily_tokens", 200000),
max_payload_kb=limits.get("max_payload_kb", 256),
)
self._tokens[plugin_id] = token
return token
def revoke_plugin(self, plugin_id: str) -> None:
"""撤销插件的能力令牌"""
self._tokens.pop(plugin_id, None)
def check_capability(self, plugin_id: str, capability: str) -> tuple[bool, str]:
"""检查插件是否有权调用某项能力
Returns:
(allowed, reason)
"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册能力令牌"
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
return True, ""
def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]:
"""检查插件是否超过调用频率限制(滑动窗口)
Returns:
(allowed, reason)
"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册"
now = time.monotonic()
elapsed = now - token._window_start
# 每秒重置窗口
if elapsed >= 1.0:
token._window_start = now
token._window_calls = 0
token._window_calls += 1
if token._window_calls > token.burst_limit:
return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)"
return True, ""
def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]:
"""检查 payload 大小是否在限制内"""
token = self._tokens.get(plugin_id)
if token is None:
return False, f"插件 {plugin_id} 未注册"
max_bytes = token.max_payload_kb * 1024
if payload_size_bytes > max_bytes:
return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}"
return True, ""
def get_token(self, plugin_id: str) -> CapabilityToken | None:
"""获取插件的能力令牌"""
return self._tokens.get(plugin_id)
def list_plugins(self) -> list[str]:
"""列出所有已注册的插件"""
return list(self._tokens.keys())

View File

@@ -0,0 +1,357 @@
"""Host 端 RPC Server
负责:
1. 监听 Runner 连接
2. 处理握手runner.hello
3. 分发调用请求给 Runner / 处理 Runner 的能力调用
4. 请求-响应关联与超时管理
"""
from typing import Any, Callable, Awaitable
import asyncio
import logging
import secrets
from src.plugin_runtime.protocol.codec import Codec, create_codec
from src.plugin_runtime.protocol.envelope import (
PROTOCOL_VERSION,
MIN_SDK_VERSION,
MAX_SDK_VERSION,
Envelope,
HelloPayload,
HelloResponsePayload,
MessageType,
RequestIdGenerator,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.transport.base import Connection, TransportServer
logger = logging.getLogger("plugin_runtime.host.rpc_server")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
class RPCServer:
"""Host 端 RPC 服务器
管理与 Runner 的 IPC 连接,处理双向 RPC 调用。
"""
def __init__(
self,
transport: TransportServer,
session_token: str | None = None,
codec: Codec | None = None,
send_queue_size: int = 128,
):
self._transport = transport
self._session_token = session_token or secrets.token_hex(32)
self._codec = codec or create_codec()
self._send_queue_size = send_queue_size
self._id_gen = RequestIdGenerator()
self._connection: Connection | None = None # 当前活跃的 Runner 连接
self._runner_id: str | None = None
self._runner_generation: int = 0
# 方法处理器注册表
self._method_handlers: dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: dict[int, asyncio.Future] = {}
# 发送队列(背压控制)
self._send_queue: asyncio.Queue | None = None
# 运行状态
self._running = False
self._tasks: list[asyncio.Task] = []
@property
def session_token(self) -> str:
return self._session_token
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
self._method_handlers[method] = handler
async def start(self) -> None:
"""启动 RPC 服务器"""
self._running = True
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
await self._transport.start(self._handle_connection)
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
async def stop(self) -> None:
"""停止 RPC 服务器"""
self._running = False
# 取消所有 pending 请求
for req_id, future in self._pending_requests.items():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
# 取消后台任务
for task in self._tasks:
task.cancel()
self._tasks.clear()
# 关闭连接
if self._connection:
await self._connection.close()
self._connection = None
await self._transport.stop()
logger.info("RPC Server 已停止")
async def send_request(
self,
method: str,
plugin_id: str = "",
payload: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
Args:
method: RPC 方法名
plugin_id: 目标插件 ID
payload: 请求数据
timeout_ms: 超时时间(ms)
Returns:
响应 Envelope
Raises:
RPCError: 调用失败
"""
if not self.is_connected:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=self._runner_generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
# 背压检查
if self._send_queue and self._send_queue.full():
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
# 注册 pending future
loop = asyncio.get_event_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = future
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
response = await asyncio.wait_for(future, timeout=timeout_sec)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)")
except Exception as e:
self._pending_requests.pop(request_id, None)
if isinstance(e, RPCError):
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e))
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
"""向 Runner 发送单向事件(不等待响应)"""
if not self.is_connected:
return
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.EVENT,
method=method,
plugin_id=plugin_id,
generation=self._runner_generation,
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
# ─── 内部方法 ──────────────────────────────────────────────
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
# 第一条消息必须是 runner.hello 握手
try:
handshake_ok = await self._handle_handshake(conn)
if not handshake_ok:
await conn.close()
return
except Exception as e:
logger.error(f"握手失败: {e}")
await conn.close()
return
# 握手成功,保存连接
self._connection = conn
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
# 启动消息接收循环
try:
await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
self._connection = None
self._runner_id = None
async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
envelope = self._codec.decode_envelope(data)
if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello收到 {envelope.method}")
error_resp = envelope.make_error_response(
ErrorCode.E_PROTOCOL_MISMATCH.value,
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
return False
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
# 校验会话令牌
if hello.session_token != self._session_token:
logger.error("会话令牌不匹配")
resp_payload = HelloResponsePayload(
accepted=False,
reason="会话令牌无效",
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
resp_payload = HelloResponsePayload(
accepted=False,
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
# 握手成功
self._runner_id = hello.runner_id
self._runner_generation += 1
resp_payload = HelloResponsePayload(
accepted=True,
host_version=PROTOCOL_VERSION,
assigned_generation=self._runner_generation,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return True
async def _recv_loop(self, conn: Connection) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
data = await conn.recv_frame()
except (asyncio.IncompleteReadError, ConnectionError):
logger.info("Runner 连接已断开")
break
except Exception as e:
logger.error(f"接收帧失败: {e}")
break
try:
envelope = self._codec.decode_envelope(data)
except Exception as e:
logger.error(f"解码消息失败: {e}")
continue
# 分发消息
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
# 异步处理请求Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
task = asyncio.create_task(self._handle_event(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)
if handler is None:
error_resp = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
return
try:
response = await handler(envelope)
await conn.send_frame(self._codec.encode_envelope(response))
except RPCError as e:
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
await conn.send_frame(self._codec.encode_envelope(error_resp))
except Exception as e:
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await conn.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Runner 的事件"""
handler = self._method_handlers.get(envelope.method)
if handler:
try:
await handler(envelope)
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
@staticmethod
def _check_sdk_version(sdk_version: str) -> bool:
"""检查 SDK 版本是否在支持范围内"""
try:
sdk_parts = [int(x) for x in sdk_version.split(".")]
min_parts = [int(x) for x in MIN_SDK_VERSION.split(".")]
max_parts = [int(x) for x in MAX_SDK_VERSION.split(".")]
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False

View File

@@ -0,0 +1,315 @@
"""Supervisor - 插件生命周期管理
负责:
1. 拉起 Runner 子进程
2. 健康检查
3. 熔断与恢复
4. 代码热重载generation 切换)
5. 优雅关停
"""
from typing import Any
import asyncio
import logging
import os
import sys
from src.plugin_runtime.host.capability_service import CapabilityService
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.host.rpc_server import RPCServer
from src.plugin_runtime.protocol.envelope import (
Envelope,
HealthPayload,
RegisterComponentsPayload,
ShutdownPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.transport.factory import create_transport_server
logger = logging.getLogger("plugin_runtime.host.supervisor")
class PluginSupervisor:
"""插件 Supervisor
Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。
"""
def __init__(
self,
plugin_dirs: list[str] | None = None,
socket_path: str | None = None,
health_check_interval_sec: float = 30.0,
use_json_codec: bool = False,
):
self._plugin_dirs = plugin_dirs or []
self._health_interval = health_check_interval_sec
# 基础设施
self._transport = create_transport_server(socket_path=socket_path)
self._policy = PolicyEngine()
self._breakers = CircuitBreakerRegistry()
self._capability_service = CapabilityService(self._policy)
# 编解码
from src.plugin_runtime.protocol.codec import create_codec
codec = create_codec(use_json=use_json_codec)
self._rpc_server = RPCServer(
transport=self._transport,
codec=codec,
)
# Runner 子进程
self._runner_process: asyncio.subprocess.Process | None = None
self._runner_generation: int = 0
# 已注册的插件组件信息
self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
# 后台任务
self._health_task: asyncio.Task | None = None
self._running = False
# 注册内部 RPC 方法
self._register_internal_methods()
@property
def policy_engine(self) -> PolicyEngine:
return self._policy
@property
def capability_service(self) -> CapabilityService:
return self._capability_service
@property
def rpc_server(self) -> RPCServer:
return self._rpc_server
async def start(self) -> None:
"""启动 Supervisor
1. 启动 RPC Server
2. 拉起 Runner 子进程
3. 启动健康检查
"""
self._running = True
# 启动 RPC Server
await self._rpc_server.start()
# 拉起 Runner 进程
await self._spawn_runner()
# 启动健康检查
self._health_task = asyncio.create_task(self._health_check_loop())
logger.info("PluginSupervisor 已启动")
async def stop(self) -> None:
"""停止 Supervisor"""
self._running = False
# 停止健康检查
if self._health_task:
self._health_task.cancel()
self._health_task = None
# 优雅关停 Runner
await self._shutdown_runner()
# 停止 RPC Server
await self._rpc_server.stop()
logger.info("PluginSupervisor 已停止")
async def invoke_plugin(
self,
method: str,
plugin_id: str,
component_name: str,
args: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Envelope:
"""调用插件组件
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
"""
# 熔断检查
breaker = self._breakers.get(plugin_id)
if not breaker.allow_request():
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断")
try:
response = await self._rpc_server.send_request(
method=method,
plugin_id=plugin_id,
payload={
"component_name": component_name,
"args": args or {},
},
timeout_ms=timeout_ms,
)
breaker.record_success()
return response
except RPCError:
breaker.record_failure()
raise
async def reload_plugins(self, reason: str = "manual") -> None:
"""热重载所有插件(进程级 generation 切换)
1. 拉起新 Runner
2. 等待新 Runner 完成注册和健康检查
3. 关停旧 Runner
"""
logger.info(f"开始热重载插件,原因: {reason}")
# 保存旧进程引用
old_process = self._runner_process
# 拉起新 Runner
await self._spawn_runner()
# 等待新 Runner 连接并完成握手
for _ in range(30): # 最多等待 30 秒
if self._rpc_server.is_connected:
break
await asyncio.sleep(1.0)
else:
logger.error("新 Runner 连接超时,回滚")
# 回滚:终止新进程
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
self._runner_process = old_process
return
# 健康检查
try:
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
except Exception as e:
logger.error(f"新 Runner 健康检查失败: {e},回滚")
if self._runner_process and self._runner_process != old_process:
self._runner_process.terminate()
self._runner_process = old_process
return
# 关停旧 Runner
if old_process and old_process.returncode is None:
try:
old_process.terminate()
await asyncio.wait_for(old_process.wait(), timeout=10.0)
except asyncio.TimeoutError:
old_process.kill()
logger.info("热重载完成")
# ─── 内部方法 ──────────────────────────────────────────────
def _register_internal_methods(self) -> None:
"""注册 Host 端的 RPC 方法处理器"""
# Runner -> Host 的能力调用统一走 capability_service
self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request)
# 插件注册
self._rpc_server.register_method("plugin.register_components", self._handle_register_components)
async def _handle_register_components(self, envelope: Envelope) -> Envelope:
"""处理插件组件注册请求"""
try:
reg = RegisterComponentsPayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
# 记录注册信息
self._registered_plugins[reg.plugin_id] = reg
# 在策略引擎中注册插件
self._policy.register_plugin(
plugin_id=reg.plugin_id,
generation=envelope.generation,
capabilities=reg.capabilities_required,
)
logger.info(
f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
)
return envelope.make_response(payload={"accepted": True})
async def _spawn_runner(self) -> None:
"""拉起 Runner 子进程"""
runner_module = "src.plugin_runtime.runner.runner_main"
address = self._transport.get_address()
token = self._rpc_server.session_token
env = os.environ.copy()
env["MAIBOT_IPC_ADDRESS"] = address
env["MAIBOT_SESSION_TOKEN"] = token
env["MAIBOT_PLUGIN_DIRS"] = os.pathsep.join(self._plugin_dirs)
self._runner_process = await asyncio.create_subprocess_exec(
sys.executable, "-m", runner_module,
env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._runner_generation += 1
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
async def _shutdown_runner(self) -> None:
"""优雅关停 Runner"""
if not self._runner_process or self._runner_process.returncode is not None:
return
# 发送 prepare_shutdown
try:
if self._rpc_server.is_connected:
shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000)
await self._rpc_server.send_request(
"plugin.prepare_shutdown",
payload=shutdown_payload.model_dump(),
timeout_ms=5000,
)
await self._rpc_server.send_request(
"plugin.shutdown",
payload=shutdown_payload.model_dump(),
timeout_ms=5000,
)
except Exception as e:
logger.warning(f"发送关停命令失败: {e}")
# 等待进程退出
try:
await asyncio.wait_for(self._runner_process.wait(), timeout=10.0)
except asyncio.TimeoutError:
logger.warning("Runner 未在超时内退出,强制终止")
self._runner_process.kill()
await self._runner_process.wait()
async def _health_check_loop(self) -> None:
"""周期性健康检查"""
while self._running:
await asyncio.sleep(self._health_interval)
if not self._rpc_server.is_connected:
logger.warning("Runner 未连接,跳过健康检查")
continue
try:
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
health = HealthPayload.model_validate(resp.payload)
if not health.healthy:
logger.warning(f"Runner 健康检查异常: {health}")
except RPCError as e:
logger.error(f"健康检查失败: {e}")
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"健康检查异常: {e}")

View File

@@ -0,0 +1 @@
# Protocol 层 - RPC 消息模型、编解码、错误码

View File

@@ -0,0 +1,80 @@
"""MsgPack / JSON 编解码器
提供统一的消息编解码接口,生产环境默认使用 MsgPack
开发调试模式可切换为 JSON仅编解码切换传输层不变
"""
from typing import Any
import json
import msgpack
from .envelope import Envelope
class Codec:
"""消息编解码器基类"""
def encode_envelope(self, envelope: Envelope) -> bytes:
raise NotImplementedError
def decode_envelope(self, data: bytes) -> Envelope:
raise NotImplementedError
def encode(self, obj: dict[str, Any]) -> bytes:
raise NotImplementedError
def decode(self, data: bytes) -> dict[str, Any]:
raise NotImplementedError
class MsgPackCodec(Codec):
"""MsgPack 编解码器(生产默认)"""
def encode(self, obj: dict[str, Any]) -> bytes:
return msgpack.packb(obj, use_bin_type=True)
def decode(self, data: bytes) -> dict[str, Any]:
result = msgpack.unpackb(data, raw=False)
if not isinstance(result, dict):
raise ValueError(f"期望解码为 dict实际为 {type(result)}")
return result
def encode_envelope(self, envelope: Envelope) -> bytes:
return self.encode(envelope.model_dump())
def decode_envelope(self, data: bytes) -> Envelope:
raw = self.decode(data)
return Envelope.model_validate(raw)
class JsonCodec(Codec):
"""JSON 编解码器(开发调试用)"""
def encode(self, obj: dict[str, Any]) -> bytes:
return json.dumps(obj, ensure_ascii=False).encode("utf-8")
def decode(self, data: bytes) -> dict[str, Any]:
result = json.loads(data.decode("utf-8"))
if not isinstance(result, dict):
raise ValueError(f"期望解码为 dict实际为 {type(result)}")
return result
def encode_envelope(self, envelope: Envelope) -> bytes:
return self.encode(envelope.model_dump())
def decode_envelope(self, data: bytes) -> Envelope:
raw = self.decode(data)
return Envelope.model_validate(raw)
def create_codec(use_json: bool = False) -> Codec:
"""创建编解码器实例
Args:
use_json: 是否使用 JSON开发模式。默认使用 MsgPack。
"""
if use_json:
return JsonCodec()
return MsgPackCodec()

View File

@@ -0,0 +1,187 @@
"""RPC Envelope 消息模型
定义 Host 与 Runner 之间所有 RPC 消息的统一信封格式。
使用 Pydantic 进行 schema 定义与校验。
"""
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
import time
# ─── 协议常量 ──────────────────────────────────────────────────────
PROTOCOL_VERSION = "1.0"
# 支持的 SDK 版本范围Host 在握手时校验)
MIN_SDK_VERSION = "1.0.0"
MAX_SDK_VERSION = "1.99.99"
# ─── 消息类型 ──────────────────────────────────────────────────────
class MessageType(str, Enum):
"""RPC 消息类型"""
REQUEST = "request"
RESPONSE = "response"
EVENT = "event"
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
def __init__(self, start: int = 1):
self._counter = start
def next(self) -> int:
current = self._counter
self._counter += 1
return current
# ─── Envelope 模型 ─────────────────────────────────────────────────
class Envelope(BaseModel):
"""RPC 统一信封
所有 Host <-> Runner 消息均封装为此格式。
序列化流程Envelope -> .model_dump() -> MsgPack encode
反序列化流程MsgPack decode -> Envelope.model_validate(data)
"""
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
request_id: int = Field(description="单调递增请求 ID")
message_type: MessageType = Field(description="消息类型")
method: str = Field(default="", description="RPC 方法名")
plugin_id: str = Field(default="", description="目标插件 ID")
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
generation: int = Field(default=0, description="Runner generation 编号")
payload: dict[str, Any] = Field(default_factory=dict, description="业务数据")
error: dict[str, Any] | None = Field(default=None, description="错误信息(仅 response)")
def is_request(self) -> bool:
return self.message_type == MessageType.REQUEST
def is_response(self) -> bool:
return self.message_type == MessageType.RESPONSE
def is_event(self) -> bool:
return self.message_type == MessageType.EVENT
def make_response(self, payload: dict[str, Any] | None = None, error: dict[str, Any] | None = None) -> "Envelope":
"""基于当前请求创建对应的响应信封"""
return Envelope(
protocol_version=self.protocol_version,
request_id=self.request_id,
message_type=MessageType.RESPONSE,
method=self.method,
plugin_id=self.plugin_id,
generation=self.generation,
payload=payload or {},
error=error,
)
def make_error_response(self, code: str, message: str = "", details: dict | None = None) -> "Envelope":
"""基于当前请求创建错误响应"""
return self.make_response(
error={
"code": code,
"message": message,
"details": details or {},
}
)
# ─── 握手消息 ──────────────────────────────────────────────────────
class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识")
sdk_version: str = Field(description="SDK 版本号")
session_token: str = Field(description="一次性会话令牌")
class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接")
host_version: str = Field(default="", description="Host 版本号")
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
# ─── 组件注册消息 ──────────────────────────────────────────────────
class ComponentDeclaration(BaseModel):
"""单个组件声明"""
name: str = Field(description="组件名称")
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
plugin_id: str = Field(description="所属插件 ID")
metadata: dict[str, Any] = Field(default_factory=dict, description="组件元数据")
class RegisterComponentsPayload(BaseModel):
"""plugin.register_components 请求 payload"""
plugin_id: str = Field(description="插件 ID")
plugin_version: str = Field(default="1.0.0", description="插件版本")
components: list[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
capabilities_required: list[str] = Field(default_factory=list, description="所需能力列表")
# ─── 调用消息 ──────────────────────────────────────────────────────
class InvokePayload(BaseModel):
"""plugin.invoke_* 请求 payload"""
component_name: str = Field(description="要调用的组件名称")
args: dict[str, Any] = Field(default_factory=dict, description="调用参数")
class InvokeResultPayload(BaseModel):
"""plugin.invoke_* 响应 payload"""
success: bool = Field(description="是否成功")
result: Any = Field(default=None, description="返回值")
# ─── 能力调用消息 ──────────────────────────────────────────────────
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query")
args: dict[str, Any] = Field(default_factory=dict, description="调用参数")
class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload"""
success: bool = Field(description="是否成功")
result: Any = Field(default=None, description="返回值")
# ─── 健康检查 ──────────────────────────────────────────────────────
class HealthPayload(BaseModel):
"""plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康")
loaded_plugins: list[str] = Field(default_factory=list, description="已加载的插件列表")
uptime_ms: int = Field(default=0, description="运行时长(ms)")
# ─── 配置更新 ──────────────────────────────────────────────────────
class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID")
config_version: str = Field(description="新配置版本")
config_data: dict[str, Any] = Field(default_factory=dict, description="配置内容")
# ─── 关停 ──────────────────────────────────────────────────────────
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因")
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")

View File

@@ -0,0 +1,61 @@
"""RPC 错误码定义
所有 Host 与 Runner 之间的 RPC 通信使用统一的错误码体系。
"""
from enum import Enum
class ErrorCode(str, Enum):
"""RPC 错误码枚举"""
# 通用
OK = "OK"
E_UNKNOWN = "E_UNKNOWN"
# 协议层
E_TIMEOUT = "E_TIMEOUT"
E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
# 权限与策略
E_UNAUTHORIZED = "E_UNAUTHORIZED"
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
E_BACKPRESSURE = "E_BACKPRESSURE"
E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
# 插件生命周期
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
# 能力调用
E_CAPABILITY_DENIED = "E_CAPABILITY_DENIED"
E_CAPABILITY_FAILED = "E_CAPABILITY_FAILED"
class RPCError(Exception):
"""RPC 调用异常"""
def __init__(self, code: ErrorCode, message: str = "", details: dict | None = None):
self.code = code
self.message = message or code.value
self.details = details or {}
super().__init__(f"[{code.value}] {self.message}")
def to_dict(self) -> dict:
return {
"code": self.code.value,
"message": self.message,
"details": self.details,
}
@classmethod
def from_dict(cls, data: dict) -> "RPCError":
code = ErrorCode(data.get("code", "E_UNKNOWN"))
return cls(
code=code,
message=data.get("message", ""),
details=data.get("details", {}),
)

View File

@@ -0,0 +1 @@
# Runner 端 - 插件加载与执行进程

View File

@@ -0,0 +1,127 @@
"""插件加载器
在 Runner 进程中负责发现和加载插件。
插件通过 SDK 编写,不再 import src.*。
"""
from typing import Any
import importlib
import importlib.util
import json
import logging
import os
import sys
logger = logging.getLogger("plugin_runtime.runner.plugin_loader")
class PluginMeta:
"""加载后的插件元数据"""
def __init__(
self,
plugin_id: str,
plugin_dir: str,
plugin_instance: Any,
manifest: dict[str, Any],
):
self.plugin_id = plugin_id
self.plugin_dir = plugin_dir
self.instance = plugin_instance
self.manifest = manifest
self.version = manifest.get("version", "1.0.0")
self.capabilities_required = manifest.get("capabilities", [])
class PluginLoader:
"""插件加载器
扫描插件目录,加载符合 SDK 规范的插件。
每个插件目录须包含:
- _manifest.json: 插件元数据
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
"""
def __init__(self):
self._loaded_plugins: dict[str, PluginMeta] = {}
def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]:
"""扫描多个目录并加载所有插件
Args:
plugin_dirs: 插件目录列表
Returns:
成功加载的插件元数据列表
"""
results = []
for base_dir in plugin_dirs:
if not os.path.isdir(base_dir):
logger.warning(f"插件目录不存在: {base_dir}")
continue
for entry in os.listdir(base_dir):
plugin_dir = os.path.join(base_dir, entry)
if not os.path.isdir(plugin_dir):
continue
manifest_path = os.path.join(plugin_dir, "_manifest.json")
plugin_path = os.path.join(plugin_dir, "plugin.py")
if not os.path.exists(manifest_path) or not os.path.exists(plugin_path):
continue
try:
meta = self._load_single_plugin(plugin_dir, manifest_path, plugin_path)
if meta:
self._loaded_plugins[meta.plugin_id] = meta
results.append(meta)
except Exception as e:
logger.error(f"加载插件失败 [{plugin_dir}]: {e}", exc_info=True)
return results
def get_plugin(self, plugin_id: str) -> PluginMeta | None:
"""获取已加载的插件"""
return self._loaded_plugins.get(plugin_id)
def list_plugins(self) -> list[str]:
"""列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys())
def _load_single_plugin(self, plugin_dir: str, manifest_path: str, plugin_path: str) -> PluginMeta | None:
"""加载单个插件"""
# 1. 读取 manifest
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
plugin_id = os.path.basename(plugin_dir)
# 2. 动态导入插件模块
module_name = f"_maibot_plugin_{plugin_id}"
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块 spec: {plugin_path}")
return None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
# 3. 调用工厂函数创建插件实例
create_plugin = getattr(module, "create_plugin", None)
if create_plugin is None:
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数")
return None
instance = create_plugin()
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
return PluginMeta(
plugin_id=plugin_id,
plugin_dir=plugin_dir,
plugin_instance=instance,
manifest=manifest,
)

View File

@@ -0,0 +1,257 @@
"""Runner 端 RPC Client
负责:
1. 连接 Host RPC Server
2. 发送握手runner.hello
3. 发送组件注册请求
4. 接收并分发 Host 的调用请求
5. 发送能力调用请求到 Host
"""
from typing import Any, Callable, Awaitable
import asyncio
import logging
import uuid
from src.plugin_runtime.protocol.codec import Codec, create_codec
from src.plugin_runtime.protocol.envelope import (
PROTOCOL_VERSION,
Envelope,
HelloPayload,
HelloResponsePayload,
MessageType,
RequestIdGenerator,
)
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.transport.base import Connection
from src.plugin_runtime.transport.factory import create_transport_client
logger = logging.getLogger("plugin_runtime.runner.rpc_client")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
SDK_VERSION = "1.0.0"
class RPCClient:
"""Runner 端 RPC 客户端
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
"""
def __init__(
self,
host_address: str,
session_token: str,
codec: Codec | None = None,
):
self._host_address = host_address
self._session_token = session_token
self._codec = codec or create_codec()
self._id_gen = RequestIdGenerator()
self._connection: Connection | None = None
self._runner_id = str(uuid.uuid4())
self._generation: int = 0
# 方法处理器注册表Host 发来的调用)
self._method_handlers: dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: dict[int, asyncio.Future] = {}
# 运行状态
self._running = False
self._recv_task: asyncio.Task | None = None
@property
def generation(self) -> int:
return self._generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册方法处理器(处理 Host 发来的请求)"""
self._method_handlers[method] = handler
async def connect_and_handshake(self) -> bool:
"""连接 Host 并完成握手
Returns:
是否握手成功
"""
client = create_transport_client(self._host_address)
self._connection = await client.connect()
# 发送 runner.hello
hello = HelloPayload(
runner_id=self._runner_id,
sdk_version=SDK_VERSION,
session_token=self._session_token,
)
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method="runner.hello",
payload=hello.model_dump(),
)
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
# 接收握手响应
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
resp = self._codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(resp.payload)
if not resp_payload.accepted:
logger.error(f"握手被拒绝: {resp_payload.reason}")
await self._connection.close()
self._connection = None
return False
self._generation = resp_payload.assigned_generation
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
# 启动消息接收循环
self._running = True
self._recv_task = asyncio.create_task(self._recv_loop())
return True
async def disconnect(self) -> None:
"""断开连接"""
self._running = False
if self._recv_task:
self._recv_task.cancel()
try:
await self._recv_task
except asyncio.CancelledError:
pass
self._recv_task = None
# 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
self._pending_requests.clear()
if self._connection:
await self._connection.close()
self._connection = None
async def send_request(
self,
method: str,
plugin_id: str = "",
payload: dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Envelope:
"""向 Host 发送 RPC 请求并等待响应"""
if not self.is_connected:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=self._generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
loop = asyncio.get_event_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = future
try:
data = self._codec.encode_envelope(envelope)
await self._connection.send_frame(data)
timeout_sec = timeout_ms / 1000.0
response = await asyncio.wait_for(future, timeout=timeout_sec)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)")
except Exception as e:
self._pending_requests.pop(request_id, None)
if isinstance(e, RPCError):
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e))
# ─── 内部方法 ──────────────────────────────────────────────
async def _recv_loop(self) -> None:
"""消息接收主循环"""
while self._running and self._connection and not self._connection.is_closed:
try:
data = await self._connection.recv_frame()
except (asyncio.IncompleteReadError, ConnectionError):
logger.info("Host 连接已断开")
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"接收帧失败: {e}")
break
try:
envelope = self._codec.decode_envelope(data)
except Exception as e:
logger.error(f"解码消息失败: {e}")
continue
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
asyncio.create_task(self._handle_request(envelope))
elif envelope.is_event():
asyncio.create_task(self._handle_event(envelope))
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Host 的响应"""
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
async def _handle_request(self, envelope: Envelope) -> None:
"""处理来自 Host 的请求(调用插件组件)"""
handler = self._method_handlers.get(envelope.method)
if handler is None:
error_resp = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
return
try:
response = await handler(envelope)
await self._connection.send_frame(self._codec.encode_envelope(response))
except RPCError as e:
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
except Exception as e:
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Host 的事件"""
handler = self._method_handlers.get(envelope.method)
if handler:
try:
await handler(envelope)
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)

View File

@@ -0,0 +1,246 @@
"""Runner 主循环
作为独立子进程运行,负责:
1. 从环境变量读取 IPC 地址和会话令牌
2. 连接 Host 并完成握手
3. 加载所有插件
4. 注册组件到 Host
5. 处理 Host 的调用请求
6. 转发插件的能力调用到 Host
"""
from typing import Any
import asyncio
import logging
import os
import signal
import sys
import time
from src.plugin_runtime.protocol.envelope import (
ComponentDeclaration,
Envelope,
HealthPayload,
InvokePayload,
InvokeResultPayload,
RegisterComponentsPayload,
ShutdownPayload,
)
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.runner.plugin_loader import PluginLoader
from src.plugin_runtime.runner.rpc_client import RPCClient
logger = logging.getLogger("plugin_runtime.runner.main")
class PluginRunner:
"""插件 Runner
运行在独立子进程中,管理所有插件的执行。
"""
def __init__(
self,
host_address: str,
session_token: str,
plugin_dirs: list[str],
):
self._host_address = host_address
self._session_token = session_token
self._plugin_dirs = plugin_dirs
self._rpc_client = RPCClient(host_address, session_token)
self._loader = PluginLoader()
self._start_time = time.monotonic()
self._shutting_down = False
async def run(self) -> None:
"""Runner 主入口"""
# 1. 连接 Host
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
ok = await self._rpc_client.connect_and_handshake()
if not ok:
logger.error("握手失败,退出")
return
# 2. 注册方法处理器
self._register_handlers()
# 3. 加载插件
plugins = self._loader.discover_and_load(self._plugin_dirs)
logger.info(f"已加载 {len(plugins)} 个插件")
# 4. 向 Host 注册所有插件的组件
for meta in plugins:
await self._register_plugin(meta)
# 5. 等待直到收到关停信号
try:
while not self._shutting_down:
await asyncio.sleep(1.0)
except asyncio.CancelledError:
pass
# 6. 断开连接
await self._rpc_client.disconnect()
logger.info("Runner 已退出")
def _register_handlers(self) -> None:
"""注册方法处理器"""
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_invoke)
self._rpc_client.register_method("plugin.health", self._handle_health)
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
async def _register_plugin(self, meta) -> None:
"""向 Host 注册单个插件"""
# 收集插件组件声明
components = []
instance = meta.instance
# 从插件实例获取组件声明SDK 插件须实现 get_components 方法)
if hasattr(instance, "get_components"):
for comp_info in instance.get_components():
components.append(ComponentDeclaration(
name=comp_info.get("name", ""),
component_type=comp_info.get("type", ""),
plugin_id=meta.plugin_id,
metadata=comp_info.get("metadata", {}),
))
reg_payload = RegisterComponentsPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
capabilities_required=meta.capabilities_required,
)
try:
resp = await self._rpc_client.send_request(
"plugin.register_components",
plugin_id=meta.plugin_id,
payload=reg_payload.model_dump(),
timeout_ms=10000,
)
logger.info(f"插件 {meta.plugin_id} 注册完成")
except Exception as e:
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
async def _handle_invoke(self, envelope: Envelope) -> Envelope:
"""处理组件调用请求"""
try:
invoke = InvokePayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
# 调用插件实例的组件方法
instance = meta.instance
component_name = invoke.component_name
handler_method = getattr(instance, f"handle_{component_name}", None)
if handler_method is None:
handler_method = getattr(instance, component_name, None)
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 无组件: {component_name}",
)
try:
result = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as e:
logger.error(f"插件 {plugin_id} 组件 {component_name} 执行异常: {e}", exc_info=True)
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_health(self, envelope: Envelope) -> Envelope:
"""处理健康检查"""
uptime_ms = int((time.monotonic() - self._start_time) * 1000)
health = HealthPayload(
healthy=True,
loaded_plugins=self._loader.list_plugins(),
uptime_ms=uptime_ms,
)
return envelope.make_response(payload=health.model_dump())
async def _handle_prepare_shutdown(self, envelope: Envelope) -> Envelope:
"""处理准备关停"""
logger.info("收到 prepare_shutdown 信号")
return envelope.make_response(payload={"acknowledged": True})
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
"""处理关停"""
logger.info("收到 shutdown 信号,准备退出")
self._shutting_down = True
return envelope.make_response(payload={"acknowledged": True})
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
"""处理配置更新事件"""
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta and hasattr(meta.instance, "on_config_update"):
try:
config_data = envelope.payload.get("config_data", {})
config_version = envelope.payload.get("config_version", "")
await meta.instance.on_config_update(config_data, config_version)
except Exception as e:
logger.error(f"插件 {plugin_id} 配置更新失败: {e}")
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
return envelope.make_response(payload={"acknowledged": True})
def request_capability(self) -> RPCClient:
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
return self._rpc_client
# ─── 进程入口 ──────────────────────────────────────────────
async def _async_main() -> None:
"""异步主入口"""
host_address = os.environ.get("MAIBOT_IPC_ADDRESS", "")
session_token = os.environ.get("MAIBOT_SESSION_TOKEN", "")
plugin_dirs_str = os.environ.get("MAIBOT_PLUGIN_DIRS", "")
if not host_address or not session_token:
logger.error("缺少必要的环境变量: MAIBOT_IPC_ADDRESS, MAIBOT_SESSION_TOKEN")
sys.exit(1)
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
runner = PluginRunner(host_address, session_token, plugin_dirs)
# 注册信号处理
loop = asyncio.get_event_loop()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True))
await runner.run()
def main() -> None:
"""进程入口python -m src.plugin_runtime.runner.runner_main"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
asyncio.run(_async_main())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@
# Transport 层 - 跨平台本地 IPC 传输抽象

View File

@@ -0,0 +1,116 @@
"""传输层抽象基类
定义 TransportServer 和 TransportClient 的统一接口。
所有传输后端UDS、Named Pipe、TCP 回退)必须实现此接口。
业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。
分帧协议4-byte big-endian length prefix + payload
"""
from abc import ABC, abstractmethod
from typing import AsyncIterator, Callable, Awaitable
import asyncio
import struct
# 分帧常量
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
class ConnectionClosed(Exception):
"""连接已关闭"""
pass
class Connection(ABC):
"""单个连接的抽象
封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self._reader = reader
self._writer = writer
self._closed = False
async def send_frame(self, data: bytes) -> None:
"""发送一帧数据4-byte length prefix + payload"""
if self._closed:
raise ConnectionClosed("连接已关闭")
length = len(data)
if length > MAX_FRAME_SIZE:
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
header = struct.pack(">I", length)
self._writer.write(header + data)
await self._writer.drain()
async def recv_frame(self) -> bytes:
"""接收一帧数据"""
if self._closed:
raise ConnectionClosed("连接已关闭")
# 读取 4 字节长度头
header = await self._reader.readexactly(FRAME_HEADER_SIZE)
(length,) = struct.unpack(">I", header)
if length > MAX_FRAME_SIZE:
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
# 读取 payload
payload = await self._reader.readexactly(length)
return payload
async def close(self) -> None:
"""关闭连接"""
if self._closed:
return
self._closed = True
try:
self._writer.close()
await self._writer.wait_closed()
except Exception:
pass
@property
def is_closed(self) -> bool:
return self._closed
# 连接回调类型:收到新连接时调用
ConnectionHandler = Callable[[Connection], Awaitable[None]]
class TransportServer(ABC):
"""传输服务端抽象
Host 端使用,监听来自 Runner 的连接。
"""
@abstractmethod
async def start(self, handler: ConnectionHandler) -> None:
"""启动服务端,开始监听连接
Args:
handler: 新连接到来时的回调函数
"""
...
@abstractmethod
async def stop(self) -> None:
"""停止服务端"""
...
@abstractmethod
def get_address(self) -> str:
"""获取监听地址(供 Runner 连接用)"""
...
class TransportClient(ABC):
"""传输客户端抽象
Runner 端使用,主动连接 Host。
"""
@abstractmethod
async def connect(self) -> Connection:
"""建立到 Host 的连接"""
...

View File

@@ -0,0 +1,46 @@
"""传输层工厂
根据运行平台自动选择最优传输实现。
"""
import sys
from .base import TransportClient, TransportServer
def create_transport_server(socket_path: str | None = None) -> TransportServer:
"""创建传输服务端
Linux/macOS 使用 UDSWindows 使用 TCP 回退。
Args:
socket_path: UDS socket 路径(仅 Linux/macOS 有效)
"""
if sys.platform != "win32":
from .uds import UDSTransportServer
return UDSTransportServer(socket_path=socket_path)
else:
# Windows 回退到 TCP后续可改为 Named Pipe
from .tcp import TCPTransportServer
return TCPTransportServer()
def create_transport_client(address: str) -> TransportClient:
"""创建传输客户端
根据地址格式自动判断传输类型:
- 包含 '/''.sock' -> UDS
- 包含 ':' -> TCP
Args:
address: Host 端监听地址
"""
if "/" in address or address.endswith(".sock"):
from .uds import UDSTransportClient
return UDSTransportClient(socket_path=address)
elif ":" in address:
from .tcp import TCPTransportClient
host, port_str = address.rsplit(":", 1)
return TCPTransportClient(host=host, port=int(port_str))
else:
raise ValueError(f"无法识别的传输地址格式: {address}")

View File

@@ -0,0 +1,59 @@
"""TCP 传输实现(回退方案)
仅当 UDS / Named Pipe 不可用时启用。
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
"""
import asyncio
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class TCPConnection(Connection):
"""基于 TCP 的连接"""
pass
class TCPTransportServer(TransportServer):
"""TCP 传输服务端(回退方案)"""
def __init__(self, host: str = "127.0.0.1", port: int = 0):
self._host = host
self._port = port # 0 表示自动分配
self._server: asyncio.AbstractServer | None = None
self._actual_port: int = 0
async def start(self, handler: ConnectionHandler) -> None:
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
conn = TCPConnection(reader, writer)
try:
await handler(conn)
finally:
await conn.close()
self._server = await asyncio.start_server(_on_connect, self._host, self._port)
# 获取实际分配的端口
addr = self._server.sockets[0].getsockname()
self._actual_port = addr[1]
async def stop(self) -> None:
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
def get_address(self) -> str:
return f"{self._host}:{self._actual_port}"
class TCPTransportClient(TransportClient):
"""TCP 传输客户端"""
def __init__(self, host: str, port: int):
self._host = host
self._port = port
async def connect(self) -> Connection:
reader, writer = await asyncio.open_connection(self._host, self._port)
return TCPConnection(reader, writer)

View File

@@ -0,0 +1,71 @@
"""Unix Domain Socket 传输实现
适用于 Linux / macOS 平台。
"""
from pathlib import Path
import asyncio
import os
import tempfile
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class UDSConnection(Connection):
"""基于 UDS 的连接"""
pass # 直接复用 Connection 基类的分帧读写
class UDSTransportServer(TransportServer):
"""UDS 传输服务端"""
def __init__(self, socket_path: str | None = None):
if socket_path is None:
# 默认放在临时目录
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock")
self._socket_path = socket_path
self._server: asyncio.AbstractServer | None = None
async def start(self, handler: ConnectionHandler) -> None:
# 清理残留 socket 文件
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
# 确保父目录存在
Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True)
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
conn = UDSConnection(reader, writer)
try:
await handler(conn)
finally:
await conn.close()
self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path)
# 设置文件权限为仅当前用户可访问
os.chmod(self._socket_path, 0o600)
async def stop(self) -> None:
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
# 清理 socket 文件
if os.path.exists(self._socket_path):
os.unlink(self._socket_path)
def get_address(self) -> str:
return self._socket_path
class UDSTransportClient(TransportClient):
"""UDS 传输客户端"""
def __init__(self, socket_path: str):
self._socket_path = socket_path
async def connect(self) -> Connection:
reader, writer = await asyncio.open_unix_connection(self._socket_path)
return UDSConnection(reader, writer)