diff --git a/.gitignore b/.gitignore index b4b922ed..9267bfaf 100644 --- a/.gitignore +++ b/.gitignore @@ -354,4 +354,5 @@ MaiBot.code-workspace *.lock actionlint .sisyphus/ -dist-electron/ \ No newline at end of file +dist-electron/ +packages/ \ No newline at end of file diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py new file mode 100644 index 00000000..bae8f6e4 --- /dev/null +++ b/pytests/test_plugin_runtime.py @@ -0,0 +1,427 @@ +"""插件运行时框架基础测试 + +验证协议层、传输层、RPC 通信链路的正确性。 +""" + +import asyncio +import sys +import os + +import pytest + +# 确保项目根目录在 sys.path 中 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +# SDK 包路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk")) + + +# ─── 协议层测试 ─────────────────────────────────────────── + +class TestProtocol: + """协议层测试""" + + def test_envelope_create_and_serialize(self): + """Envelope 创建与序列化""" + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + env = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="plugin.invoke_command", + plugin_id="test_plugin", + payload={"component_name": "greet", "args": {}}, + ) + + assert env.request_id == 1 + assert env.is_request() + assert env.method == "plugin.invoke_command" + + # 测试 make_response + resp = env.make_response(payload={"success": True}) + assert resp.is_response() + assert resp.request_id == 1 + assert resp.payload["success"] is True + + def test_envelope_make_error_response(self): + """错误响应生成""" + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + env = Envelope( + request_id=42, + message_type=MessageType.REQUEST, + method="cap.request", + ) + + err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限") + assert err_resp.error is not None + assert err_resp.error["code"] == "E_UNAUTHORIZED" + assert err_resp.error["message"] == "没有权限" + + def test_msgpack_codec(self): + """MsgPack 编解码""" + from src.plugin_runtime.protocol.codec import MsgPackCodec + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + codec = MsgPackCodec() + env = Envelope( + request_id=100, + message_type=MessageType.REQUEST, + method="test.method", + payload={"key": "value", "number": 42}, + ) + + # 编码 + data = codec.encode_envelope(env) + assert isinstance(data, bytes) + + # 解码 + decoded = codec.decode_envelope(data) + assert decoded.request_id == 100 + assert decoded.method == "test.method" + assert decoded.payload["key"] == "value" + assert decoded.payload["number"] == 42 + + def test_json_codec(self): + """JSON 编解码""" + from src.plugin_runtime.protocol.codec import JsonCodec + from src.plugin_runtime.protocol.envelope import Envelope, MessageType + + codec = JsonCodec() + env = Envelope( + request_id=200, + message_type=MessageType.EVENT, + method="plugin.config_updated", + payload={"config_version": "2.0"}, + ) + + data = codec.encode_envelope(env) + assert isinstance(data, bytes) + + decoded = codec.decode_envelope(data) + assert decoded.request_id == 200 + assert decoded.is_event() + + def test_request_id_generator(self): + """请求 ID 生成器单调递增""" + from src.plugin_runtime.protocol.envelope import RequestIdGenerator + + gen = RequestIdGenerator() + ids = [gen.next() for _ in range(100)] + assert ids == list(range(1, 101)) + + def test_error_codes(self): + """错误码枚举""" + from src.plugin_runtime.protocol.errors import ErrorCode, RPCError + + err = RPCError(ErrorCode.E_TIMEOUT, "请求超时") + assert err.code == ErrorCode.E_TIMEOUT + assert "E_TIMEOUT" in str(err) + + # 序列化/反序列化 + d = err.to_dict() + err2 = RPCError.from_dict(d) + assert err2.code == ErrorCode.E_TIMEOUT + + +# ─── 传输层测试 ─────────────────────────────────────────── + +class TestTransport: + """传输层测试""" + + @pytest.mark.asyncio + async def test_uds_connection_framing(self): + """UDS 分帧协议测试""" + from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient + + server = UDSTransportServer() + received = asyncio.Event() + received_data = [] + + async def handler(conn): + data = await conn.recv_frame() + received_data.append(data) + await conn.send_frame(b"pong") + received.set() + + await server.start(handler) + address = server.get_address() + + client = UDSTransportClient(address) + conn = await client.connect() + await conn.send_frame(b"ping") + + # 等待服务端处理 + await asyncio.wait_for(received.wait(), timeout=5.0) + assert received_data[0] == b"ping" + + # 接收服务端回复 + resp = await conn.recv_frame() + assert resp == b"pong" + + await conn.close() + await server.stop() + + @pytest.mark.asyncio + async def test_tcp_connection_framing(self): + """TCP 分帧协议测试""" + from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient + + server = TCPTransportServer() + received = asyncio.Event() + received_data = [] + + async def handler(conn): + data = await conn.recv_frame() + received_data.append(data) + await conn.send_frame(b"tcp_pong") + received.set() + + await server.start(handler) + address = server.get_address() + host, port = address.split(":") + + client = TCPTransportClient(host, int(port)) + conn = await client.connect() + await conn.send_frame(b"tcp_ping") + + await asyncio.wait_for(received.wait(), timeout=5.0) + assert received_data[0] == b"tcp_ping" + + resp = await conn.recv_frame() + assert resp == b"tcp_pong" + + await conn.close() + await server.stop() + + @pytest.mark.asyncio + async def test_transport_factory(self): + """传输工厂测试""" + from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client + + server = create_transport_server() + assert server is not None + + # UDS 路径 + client = create_transport_client("/tmp/test.sock") + assert client is not None + + # TCP 地址 + client = create_transport_client("127.0.0.1:9999") + assert client is not None + + +# ─── Host 层测试 ────────────────────────────────────────── + +class TestHost: + """Host 端基础设施测试""" + + def test_policy_engine(self): + """策略引擎测试""" + from src.plugin_runtime.host.policy_engine import PolicyEngine + + engine = PolicyEngine() + + # 注册插件 + token = engine.register_plugin( + plugin_id="test_plugin", + generation=1, + capabilities=["send.text", "db.query"], + limits={"qps": 10, "burst": 20}, + ) + + assert token.plugin_id == "test_plugin" + assert "send.text" in token.capabilities + + # 能力检查 + ok, _ = engine.check_capability("test_plugin", "send.text") + assert ok + + ok, reason = engine.check_capability("test_plugin", "llm.generate") + assert not ok + assert "未获授权" in reason + + # 未注册插件 + ok, reason = engine.check_capability("unknown", "send.text") + assert not ok + + def test_circuit_breaker(self): + """熔断器测试""" + from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState + + breaker = CircuitBreaker(failure_threshold=3) + + # 初始状态:关闭 + assert breaker.state == CircuitState.CLOSED + assert breaker.allow_request() + + # 连续失败 + breaker.record_failure() + breaker.record_failure() + assert breaker.allow_request() # 还没到阈值 + + breaker.record_failure() # 第3次,触发熔断 + assert breaker.state == CircuitState.OPEN + assert not breaker.allow_request() + + # 重置 + breaker.reset() + assert breaker.state == CircuitState.CLOSED + + def test_circuit_breaker_registry(self): + """熔断器注册表测试""" + from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry + + registry = CircuitBreakerRegistry(failure_threshold=2) + + b1 = registry.get("plugin_a") + b2 = registry.get("plugin_b") + assert b1 is not b2 + assert registry.get("plugin_a") is b1 # 同一个 + + +# ─── SDK 测试 ───────────────────────────────────────────── + +class TestSDK: + """SDK 框架测试""" + + def test_component_decorators(self): + """组件装饰器测试""" + from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler + from maibot_sdk.types import ActivationType, EventType + + class TestPlugin(MaiBotPlugin): + @Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"]) + async def handle_greet(self, **kwargs): + return True, "ok" + + @Command("echo", pattern=r"^/echo") + async def handle_echo(self, **kwargs): + return True, "echoed", 2 + + @Tool("search", parameters={"query": {"type": "string"}}) + async def handle_search(self, **kwargs): + return {"result": "found"} + + @EventHandler("on_start", event_type=EventType.ON_START) + async def handle_start(self, **kwargs): + return True, False, "started" + + plugin = TestPlugin() + components = plugin.get_components() + + assert len(components) == 4 + + names = {c["name"] for c in components} + assert "greet" in names + assert "echo" in names + assert "search" in names + assert "on_start" in names + + types = {c["type"] for c in components} + assert "action" in types + assert "command" in types + assert "tool" in types + assert "event_handler" in types + + def test_plugin_context_not_initialized(self): + """未初始化上下文时应报错""" + from maibot_sdk import MaiBotPlugin + + plugin = MaiBotPlugin() + with pytest.raises(RuntimeError, match="尚未初始化"): + _ = plugin.ctx + + def test_plugin_context_injection(self): + """上下文注入测试""" + from maibot_sdk import MaiBotPlugin + from maibot_sdk.context import PluginContext + + plugin = MaiBotPlugin() + ctx = PluginContext(plugin_id="test") + plugin._set_context(ctx) + + assert plugin.ctx.plugin_id == "test" + assert plugin.ctx.send is not None + assert plugin.ctx.db is not None + assert plugin.ctx.llm is not None + assert plugin.ctx.config is not None + + +# ─── 端到端集成测试 ──────────────────────────────────────── + +class TestE2E: + """端到端集成测试(Host + Runner 通信)""" + + @pytest.mark.asyncio + async def test_handshake(self): + """Host-Runner 握手流程测试""" + from src.plugin_runtime.protocol.codec import create_codec + from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType + from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient + + import secrets + import tempfile + import os + + socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock") + session_token = secrets.token_hex(16) + codec = create_codec() + handshake_done = asyncio.Event() + server_result = {} + + async def server_handler(conn): + # 接收握手 + data = await conn.recv_frame() + env = codec.decode_envelope(data) + assert env.method == "runner.hello" + + hello = HelloPayload.model_validate(env.payload) + assert hello.session_token == session_token + + # 发送响应 + resp_payload = HelloResponsePayload( + accepted=True, + host_version="1.0", + assigned_generation=1, + ) + resp = env.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(codec.encode_envelope(resp)) + + server_result["runner_id"] = hello.runner_id + handshake_done.set() + + # 保持连接一会儿 + await asyncio.sleep(1.0) + + server = UDSTransportServer(socket_path=socket_path) + await server.start(server_handler) + + # 客户端握手 + client = UDSTransportClient(socket_path) + conn = await client.connect() + + hello = HelloPayload( + runner_id="test-runner", + sdk_version="1.0.0", + session_token=session_token, + ) + env = Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="runner.hello", + payload=hello.model_dump(), + ) + await conn.send_frame(codec.encode_envelope(env)) + + resp_data = await conn.recv_frame() + resp = codec.decode_envelope(resp_data) + resp_payload = HelloResponsePayload.model_validate(resp.payload) + + assert resp_payload.accepted + assert resp_payload.assigned_generation == 1 + + await asyncio.wait_for(handshake_done.wait(), timeout=5.0) + assert server_result["runner_id"] == "test-runner" + + await conn.close() + await server.stop() diff --git a/src/plugin_runtime/__init__.py b/src/plugin_runtime/__init__.py new file mode 100644 index 00000000..59925016 --- /dev/null +++ b/src/plugin_runtime/__init__.py @@ -0,0 +1,2 @@ +# MaiBot Plugin Runtime - 插件隔离运行时基础设施 +# 本模块实现 Host-Runner 进程分离架构,提供 IPC 通信、策略引擎与生命周期管理 diff --git a/src/plugin_runtime/host/__init__.py b/src/plugin_runtime/host/__init__.py new file mode 100644 index 00000000..8b983d9d --- /dev/null +++ b/src/plugin_runtime/host/__init__.py @@ -0,0 +1 @@ +# Host 端 - Supervisor、RPC Server、策略引擎、路由 diff --git a/src/plugin_runtime/host/capability_service.py b/src/plugin_runtime/host/capability_service.py new file mode 100644 index 00000000..7f36f5f6 --- /dev/null +++ b/src/plugin_runtime/host/capability_service.py @@ -0,0 +1,108 @@ +"""能力服务层 + +Host 端实现的能力服务,处理来自插件的 cap.* 请求。 +每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。 +""" + +from typing import Any, Callable, Awaitable + +import logging + +from src.plugin_runtime.protocol.envelope import ( + CapabilityRequestPayload, + CapabilityResponsePayload, + Envelope, +) +from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +from src.plugin_runtime.host.policy_engine import PolicyEngine + +logger = logging.getLogger("plugin_runtime.host.capability_service") + +# 能力实现函数类型: (plugin_id, capability, args) -> result +CapabilityImpl = Callable[[str, str, dict[str, Any]], Awaitable[Any]] + + +class CapabilityService: + """能力服务 + + 负责: + 1. 注册能力实现 + 2. 接收插件的能力调用请求 + 3. 通过策略引擎校验权限和限流 + 4. 执行实际操作并返回结果 + """ + + def __init__(self, policy_engine: PolicyEngine): + self._policy = policy_engine + # capability_name -> implementation + self._implementations: dict[str, CapabilityImpl] = {} + + def register_capability(self, name: str, impl: CapabilityImpl) -> None: + """注册一个能力实现 + + Args: + name: 能力名称,如 "send.text", "db.query", "llm.generate" + impl: 实现函数 + """ + self._implementations[name] = impl + logger.debug(f"注册能力实现: {name}") + + async def handle_capability_request(self, envelope: Envelope) -> Envelope: + """处理能力调用请求(作为 RPC Server 的 method handler) + + 从 envelope 中提取 capability 名称和参数, + 校验权限后调用对应实现。 + """ + plugin_id = envelope.plugin_id + + try: + req = CapabilityRequestPayload.model_validate(envelope.payload) + except Exception as e: + return envelope.make_error_response( + ErrorCode.E_BAD_PAYLOAD.value, + f"能力调用 payload 格式错误: {e}", + ) + + capability = req.capability + + # 1. 权限校验 + allowed, reason = self._policy.check_capability(plugin_id, capability) + if not allowed: + return envelope.make_error_response( + ErrorCode.E_CAPABILITY_DENIED.value, + reason, + ) + + # 2. 限流校验 + allowed, reason = self._policy.check_rate_limit(plugin_id) + if not allowed: + return envelope.make_error_response( + ErrorCode.E_BACKPRESSURE.value, + reason, + ) + + # 3. 查找实现 + impl = self._implementations.get(capability) + if impl is None: + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"未注册的能力: {capability}", + ) + + # 4. 执行 + try: + result = await impl(plugin_id, capability, req.args) + resp_payload = CapabilityResponsePayload(success=True, result=result) + return envelope.make_response(payload=resp_payload.model_dump()) + except RPCError as e: + return envelope.make_error_response(e.code.value, e.message, e.details) + except Exception as e: + logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True) + return envelope.make_error_response( + ErrorCode.E_CAPABILITY_FAILED.value, + str(e), + ) + + def list_capabilities(self) -> list[str]: + """列出所有已注册的能力""" + return list(self._implementations.keys()) diff --git a/src/plugin_runtime/host/circuit_breaker.py b/src/plugin_runtime/host/circuit_breaker.py new file mode 100644 index 00000000..f598029d --- /dev/null +++ b/src/plugin_runtime/host/circuit_breaker.py @@ -0,0 +1,105 @@ +"""熔断器 + +为每个插件提供熔断保护,连续失败超过阈值后临时禁用。 +支持指数退避恢复。 +""" + +from enum import Enum + +import time + + +class CircuitState(str, Enum): + CLOSED = "closed" # 正常工作 + OPEN = "open" # 熔断(拒绝所有调用) + HALF_OPEN = "half_open" # 探测恢复 + + +class CircuitBreaker: + """单个插件的熔断器""" + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout_sec: float = 30.0, + max_recovery_timeout_sec: float = 300.0, + ): + self.failure_threshold = failure_threshold + self.base_recovery_timeout = recovery_timeout_sec + self.max_recovery_timeout = max_recovery_timeout_sec + + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._last_failure_time = 0.0 + self._consecutive_opens = 0 # 用于指数退避 + + @property + def state(self) -> CircuitState: + if self._state == CircuitState.OPEN: + # 检查是否可以进入半开状态 + elapsed = time.monotonic() - self._last_failure_time + recovery_timeout = min( + self.base_recovery_timeout * (2 ** self._consecutive_opens), + self.max_recovery_timeout, + ) + if elapsed >= recovery_timeout: + self._state = CircuitState.HALF_OPEN + return self._state + + def allow_request(self) -> bool: + """是否允许通过请求""" + state = self.state + if state == CircuitState.CLOSED: + return True + if state == CircuitState.HALF_OPEN: + return True # 允许一次试探 + return False # OPEN 状态拒绝 + + def record_success(self) -> None: + """记录一次成功调用""" + if self._state == CircuitState.HALF_OPEN: + # 半开状态成功 -> 关闭熔断 + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._consecutive_opens = 0 + elif self._state == CircuitState.CLOSED: + self._failure_count = 0 + + def record_failure(self) -> None: + """记录一次失败调用""" + self._failure_count += 1 + self._last_failure_time = time.monotonic() + + if self._state == CircuitState.HALF_OPEN: + # 半开状态失败 -> 重新开启熔断 + self._state = CircuitState.OPEN + self._consecutive_opens += 1 + elif self._failure_count >= self.failure_threshold: + self._state = CircuitState.OPEN + self._consecutive_opens += 1 + + def reset(self) -> None: + """重置熔断器""" + self._state = CircuitState.CLOSED + self._failure_count = 0 + self._consecutive_opens = 0 + + +class CircuitBreakerRegistry: + """熔断器注册表,为每个插件维护独立的熔断器""" + + def __init__(self, **default_kwargs): + self._breakers: dict[str, CircuitBreaker] = {} + self._default_kwargs = default_kwargs + + def get(self, plugin_id: str) -> CircuitBreaker: + if plugin_id not in self._breakers: + self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs) + return self._breakers[plugin_id] + + def remove(self, plugin_id: str) -> None: + self._breakers.pop(plugin_id, None) + + def reset_all(self) -> None: + for breaker in self._breakers.values(): + breaker.reset() diff --git a/src/plugin_runtime/host/policy_engine.py b/src/plugin_runtime/host/policy_engine.py new file mode 100644 index 00000000..7c889c2f --- /dev/null +++ b/src/plugin_runtime/host/policy_engine.py @@ -0,0 +1,125 @@ +"""策略引擎 + +负责能力授权校验、限流、配额管理。 +每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。 +""" + +from dataclasses import dataclass, field + +import time + + +@dataclass +class CapabilityToken: + """能力令牌 + + 描述某个插件在当前会话中被授予的能力和资源限制。 + """ + plugin_id: str + generation: int + capabilities: set[str] = field(default_factory=set) + qps_limit: int = 20 + burst_limit: int = 50 + daily_token_limit: int = 200000 + max_payload_kb: int = 256 + + # 运行时统计 + _call_count: int = field(default=0, init=False, repr=False) + _window_start: float = field(default_factory=time.monotonic, init=False, repr=False) + _window_calls: int = field(default=0, init=False, repr=False) + + +class PolicyEngine: + """策略引擎 + + 管理所有插件的能力令牌,提供授权校验与限流决策。 + """ + + def __init__(self): + # plugin_id -> CapabilityToken + self._tokens: dict[str, CapabilityToken] = {} + + def register_plugin( + self, + plugin_id: str, + generation: int, + capabilities: list[str], + limits: dict | None = None, + ) -> CapabilityToken: + """为插件签发能力令牌""" + limits = limits or {} + token = CapabilityToken( + plugin_id=plugin_id, + generation=generation, + capabilities=set(capabilities), + qps_limit=limits.get("qps", 20), + burst_limit=limits.get("burst", 50), + daily_token_limit=limits.get("daily_tokens", 200000), + max_payload_kb=limits.get("max_payload_kb", 256), + ) + self._tokens[plugin_id] = token + return token + + def revoke_plugin(self, plugin_id: str) -> None: + """撤销插件的能力令牌""" + self._tokens.pop(plugin_id, None) + + def check_capability(self, plugin_id: str, capability: str) -> tuple[bool, str]: + """检查插件是否有权调用某项能力 + + Returns: + (allowed, reason) + """ + token = self._tokens.get(plugin_id) + if token is None: + return False, f"插件 {plugin_id} 未注册能力令牌" + + if capability not in token.capabilities: + return False, f"插件 {plugin_id} 未获授权能力: {capability}" + + return True, "" + + def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]: + """检查插件是否超过调用频率限制(滑动窗口) + + Returns: + (allowed, reason) + """ + token = self._tokens.get(plugin_id) + if token is None: + return False, f"插件 {plugin_id} 未注册" + + now = time.monotonic() + elapsed = now - token._window_start + + # 每秒重置窗口 + if elapsed >= 1.0: + token._window_start = now + token._window_calls = 0 + + token._window_calls += 1 + + if token._window_calls > token.burst_limit: + return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)" + + return True, "" + + def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]: + """检查 payload 大小是否在限制内""" + token = self._tokens.get(plugin_id) + if token is None: + return False, f"插件 {plugin_id} 未注册" + + max_bytes = token.max_payload_kb * 1024 + if payload_size_bytes > max_bytes: + return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}" + + return True, "" + + def get_token(self, plugin_id: str) -> CapabilityToken | None: + """获取插件的能力令牌""" + return self._tokens.get(plugin_id) + + def list_plugins(self) -> list[str]: + """列出所有已注册的插件""" + return list(self._tokens.keys()) diff --git a/src/plugin_runtime/host/rpc_server.py b/src/plugin_runtime/host/rpc_server.py new file mode 100644 index 00000000..40575742 --- /dev/null +++ b/src/plugin_runtime/host/rpc_server.py @@ -0,0 +1,357 @@ +"""Host 端 RPC Server + +负责: +1. 监听 Runner 连接 +2. 处理握手(runner.hello) +3. 分发调用请求给 Runner / 处理 Runner 的能力调用 +4. 请求-响应关联与超时管理 +""" + +from typing import Any, Callable, Awaitable + +import asyncio +import logging +import secrets + +from src.plugin_runtime.protocol.codec import Codec, create_codec +from src.plugin_runtime.protocol.envelope import ( + PROTOCOL_VERSION, + MIN_SDK_VERSION, + MAX_SDK_VERSION, + Envelope, + HelloPayload, + HelloResponsePayload, + MessageType, + RequestIdGenerator, +) +from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +from src.plugin_runtime.transport.base import Connection, TransportServer + +logger = logging.getLogger("plugin_runtime.host.rpc_server") + +# RPC 方法处理器类型 +MethodHandler = Callable[[Envelope], Awaitable[Envelope]] + + +class RPCServer: + """Host 端 RPC 服务器 + + 管理与 Runner 的 IPC 连接,处理双向 RPC 调用。 + """ + + def __init__( + self, + transport: TransportServer, + session_token: str | None = None, + codec: Codec | None = None, + send_queue_size: int = 128, + ): + self._transport = transport + self._session_token = session_token or secrets.token_hex(32) + self._codec = codec or create_codec() + self._send_queue_size = send_queue_size + + self._id_gen = RequestIdGenerator() + self._connection: Connection | None = None # 当前活跃的 Runner 连接 + self._runner_id: str | None = None + self._runner_generation: int = 0 + + # 方法处理器注册表 + self._method_handlers: dict[str, MethodHandler] = {} + + # 等待响应的 pending 请求: request_id -> Future + self._pending_requests: dict[int, asyncio.Future] = {} + + # 发送队列(背压控制) + self._send_queue: asyncio.Queue | None = None + + # 运行状态 + self._running = False + self._tasks: list[asyncio.Task] = [] + + @property + def session_token(self) -> str: + return self._session_token + + @property + def is_connected(self) -> bool: + return self._connection is not None and not self._connection.is_closed + + def register_method(self, method: str, handler: MethodHandler) -> None: + """注册 RPC 方法处理器""" + self._method_handlers[method] = handler + + async def start(self) -> None: + """启动 RPC 服务器""" + self._running = True + self._send_queue = asyncio.Queue(maxsize=self._send_queue_size) + await self._transport.start(self._handle_connection) + logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}") + + async def stop(self) -> None: + """停止 RPC 服务器""" + self._running = False + + # 取消所有 pending 请求 + for req_id, future in self._pending_requests.items(): + if not future.done(): + future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭")) + self._pending_requests.clear() + + # 取消后台任务 + for task in self._tasks: + task.cancel() + self._tasks.clear() + + # 关闭连接 + if self._connection: + await self._connection.close() + self._connection = None + + await self._transport.stop() + logger.info("RPC Server 已停止") + + async def send_request( + self, + method: str, + plugin_id: str = "", + payload: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Envelope: + """向 Runner 发送 RPC 请求并等待响应 + + Args: + method: RPC 方法名 + plugin_id: 目标插件 ID + payload: 请求数据 + timeout_ms: 超时时间(ms) + + Returns: + 响应 Envelope + + Raises: + RPCError: 调用失败 + """ + if not self.is_connected: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接") + + request_id = self._id_gen.next() + envelope = Envelope( + request_id=request_id, + message_type=MessageType.REQUEST, + method=method, + plugin_id=plugin_id, + generation=self._runner_generation, + timeout_ms=timeout_ms, + payload=payload or {}, + ) + + # 背压检查 + if self._send_queue and self._send_queue.full(): + raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") + + # 注册 pending future + loop = asyncio.get_event_loop() + future: asyncio.Future[Envelope] = loop.create_future() + self._pending_requests[request_id] = future + + try: + # 发送请求 + data = self._codec.encode_envelope(envelope) + await self._connection.send_frame(data) + + # 等待响应 + timeout_sec = timeout_ms / 1000.0 + response = await asyncio.wait_for(future, timeout=timeout_sec) + return response + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") + except Exception as e: + self._pending_requests.pop(request_id, None) + if isinstance(e, RPCError): + raise + raise RPCError(ErrorCode.E_UNKNOWN, str(e)) + + async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None: + """向 Runner 发送单向事件(不等待响应)""" + if not self.is_connected: + return + + request_id = self._id_gen.next() + envelope = Envelope( + request_id=request_id, + message_type=MessageType.EVENT, + method=method, + plugin_id=plugin_id, + generation=self._runner_generation, + payload=payload or {}, + ) + data = self._codec.encode_envelope(envelope) + await self._connection.send_frame(data) + + # ─── 内部方法 ────────────────────────────────────────────── + + async def _handle_connection(self, conn: Connection) -> None: + """处理新的 Runner 连接""" + logger.info("收到 Runner 连接") + + # 第一条消息必须是 runner.hello 握手 + try: + handshake_ok = await self._handle_handshake(conn) + if not handshake_ok: + await conn.close() + return + except Exception as e: + logger.error(f"握手失败: {e}") + await conn.close() + return + + # 握手成功,保存连接 + self._connection = conn + logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}") + + # 启动消息接收循环 + try: + await self._recv_loop(conn) + except Exception as e: + logger.error(f"连接异常断开: {e}") + finally: + self._connection = None + self._runner_id = None + + async def _handle_handshake(self, conn: Connection) -> bool: + """处理 runner.hello 握手""" + # 接收握手请求 + data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0) + envelope = self._codec.decode_envelope(data) + + if envelope.method != "runner.hello": + logger.error(f"期望 runner.hello,收到 {envelope.method}") + error_resp = envelope.make_error_response( + ErrorCode.E_PROTOCOL_MISMATCH.value, + "首条消息必须为 runner.hello", + ) + await conn.send_frame(self._codec.encode_envelope(error_resp)) + return False + + # 解析握手 payload + hello = HelloPayload.model_validate(envelope.payload) + + # 校验会话令牌 + if hello.session_token != self._session_token: + logger.error("会话令牌不匹配") + resp_payload = HelloResponsePayload( + accepted=False, + reason="会话令牌无效", + ) + resp = envelope.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(self._codec.encode_envelope(resp)) + return False + + # 校验 SDK 版本 + if not self._check_sdk_version(hello.sdk_version): + logger.error(f"SDK 版本不兼容: {hello.sdk_version}") + resp_payload = HelloResponsePayload( + accepted=False, + reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]", + ) + resp = envelope.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(self._codec.encode_envelope(resp)) + return False + + # 握手成功 + self._runner_id = hello.runner_id + self._runner_generation += 1 + + resp_payload = HelloResponsePayload( + accepted=True, + host_version=PROTOCOL_VERSION, + assigned_generation=self._runner_generation, + ) + resp = envelope.make_response(payload=resp_payload.model_dump()) + await conn.send_frame(self._codec.encode_envelope(resp)) + + return True + + async def _recv_loop(self, conn: Connection) -> None: + """消息接收主循环""" + while self._running and not conn.is_closed: + try: + data = await conn.recv_frame() + except (asyncio.IncompleteReadError, ConnectionError): + logger.info("Runner 连接已断开") + break + except Exception as e: + logger.error(f"接收帧失败: {e}") + break + + try: + envelope = self._codec.decode_envelope(data) + except Exception as e: + logger.error(f"解码消息失败: {e}") + continue + + # 分发消息 + if envelope.is_response(): + self._handle_response(envelope) + elif envelope.is_request(): + # 异步处理请求(Runner 发来的能力调用) + task = asyncio.create_task(self._handle_request(envelope, conn)) + self._tasks.append(task) + task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) + elif envelope.is_event(): + task = asyncio.create_task(self._handle_event(envelope)) + self._tasks.append(task) + task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None) + + def _handle_response(self, envelope: Envelope) -> None: + """处理来自 Runner 的响应""" + future = self._pending_requests.pop(envelope.request_id, None) + if future and not future.done(): + if envelope.error: + future.set_exception(RPCError.from_dict(envelope.error)) + else: + future.set_result(envelope) + + async def _handle_request(self, envelope: Envelope, conn: Connection) -> None: + """处理来自 Runner 的请求(通常是能力调用 cap.*)""" + handler = self._method_handlers.get(envelope.method) + if handler is None: + error_resp = envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"未注册的方法: {envelope.method}", + ) + await conn.send_frame(self._codec.encode_envelope(error_resp)) + return + + try: + response = await handler(envelope) + await conn.send_frame(self._codec.encode_envelope(response)) + except RPCError as e: + error_resp = envelope.make_error_response(e.code.value, e.message, e.details) + await conn.send_frame(self._codec.encode_envelope(error_resp)) + except Exception as e: + logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) + error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) + await conn.send_frame(self._codec.encode_envelope(error_resp)) + + async def _handle_event(self, envelope: Envelope) -> None: + """处理来自 Runner 的事件""" + handler = self._method_handlers.get(envelope.method) + if handler: + try: + await handler(envelope) + except Exception as e: + logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) + + @staticmethod + def _check_sdk_version(sdk_version: str) -> bool: + """检查 SDK 版本是否在支持范围内""" + try: + sdk_parts = [int(x) for x in sdk_version.split(".")] + min_parts = [int(x) for x in MIN_SDK_VERSION.split(".")] + max_parts = [int(x) for x in MAX_SDK_VERSION.split(".")] + return min_parts <= sdk_parts <= max_parts + except (ValueError, AttributeError): + return False diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py new file mode 100644 index 00000000..9de664a7 --- /dev/null +++ b/src/plugin_runtime/host/supervisor.py @@ -0,0 +1,315 @@ +"""Supervisor - 插件生命周期管理 + +负责: +1. 拉起 Runner 子进程 +2. 健康检查 +3. 熔断与恢复 +4. 代码热重载(generation 切换) +5. 优雅关停 +""" + +from typing import Any + +import asyncio +import logging +import os +import sys + +from src.plugin_runtime.host.capability_service import CapabilityService +from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry +from src.plugin_runtime.host.policy_engine import PolicyEngine +from src.plugin_runtime.host.rpc_server import RPCServer +from src.plugin_runtime.protocol.envelope import ( + Envelope, + HealthPayload, + RegisterComponentsPayload, + ShutdownPayload, +) +from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +from src.plugin_runtime.transport.factory import create_transport_server + +logger = logging.getLogger("plugin_runtime.host.supervisor") + + +class PluginSupervisor: + """插件 Supervisor + + Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。 + """ + + def __init__( + self, + plugin_dirs: list[str] | None = None, + socket_path: str | None = None, + health_check_interval_sec: float = 30.0, + use_json_codec: bool = False, + ): + self._plugin_dirs = plugin_dirs or [] + self._health_interval = health_check_interval_sec + + # 基础设施 + self._transport = create_transport_server(socket_path=socket_path) + self._policy = PolicyEngine() + self._breakers = CircuitBreakerRegistry() + self._capability_service = CapabilityService(self._policy) + + # 编解码 + from src.plugin_runtime.protocol.codec import create_codec + codec = create_codec(use_json=use_json_codec) + + self._rpc_server = RPCServer( + transport=self._transport, + codec=codec, + ) + + # Runner 子进程 + self._runner_process: asyncio.subprocess.Process | None = None + self._runner_generation: int = 0 + + # 已注册的插件组件信息 + self._registered_plugins: dict[str, RegisterComponentsPayload] = {} + + # 后台任务 + self._health_task: asyncio.Task | None = None + self._running = False + + # 注册内部 RPC 方法 + self._register_internal_methods() + + @property + def policy_engine(self) -> PolicyEngine: + return self._policy + + @property + def capability_service(self) -> CapabilityService: + return self._capability_service + + @property + def rpc_server(self) -> RPCServer: + return self._rpc_server + + async def start(self) -> None: + """启动 Supervisor + + 1. 启动 RPC Server + 2. 拉起 Runner 子进程 + 3. 启动健康检查 + """ + self._running = True + + # 启动 RPC Server + await self._rpc_server.start() + + # 拉起 Runner 进程 + await self._spawn_runner() + + # 启动健康检查 + self._health_task = asyncio.create_task(self._health_check_loop()) + + logger.info("PluginSupervisor 已启动") + + async def stop(self) -> None: + """停止 Supervisor""" + self._running = False + + # 停止健康检查 + if self._health_task: + self._health_task.cancel() + self._health_task = None + + # 优雅关停 Runner + await self._shutdown_runner() + + # 停止 RPC Server + await self._rpc_server.stop() + + logger.info("PluginSupervisor 已停止") + + async def invoke_plugin( + self, + method: str, + plugin_id: str, + component_name: str, + args: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用插件组件 + + 由主进程业务逻辑调用,通过 RPC 转发给 Runner。 + """ + # 熔断检查 + breaker = self._breakers.get(plugin_id) + if not breaker.allow_request(): + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断") + + try: + response = await self._rpc_server.send_request( + method=method, + plugin_id=plugin_id, + payload={ + "component_name": component_name, + "args": args or {}, + }, + timeout_ms=timeout_ms, + ) + breaker.record_success() + return response + except RPCError: + breaker.record_failure() + raise + + async def reload_plugins(self, reason: str = "manual") -> None: + """热重载所有插件(进程级 generation 切换) + + 1. 拉起新 Runner + 2. 等待新 Runner 完成注册和健康检查 + 3. 关停旧 Runner + """ + logger.info(f"开始热重载插件,原因: {reason}") + + # 保存旧进程引用 + old_process = self._runner_process + + # 拉起新 Runner + await self._spawn_runner() + + # 等待新 Runner 连接并完成握手 + for _ in range(30): # 最多等待 30 秒 + if self._rpc_server.is_connected: + break + await asyncio.sleep(1.0) + else: + logger.error("新 Runner 连接超时,回滚") + # 回滚:终止新进程 + if self._runner_process and self._runner_process != old_process: + self._runner_process.terminate() + self._runner_process = old_process + return + + # 健康检查 + try: + resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000) + health = HealthPayload.model_validate(resp.payload) + if not health.healthy: + raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败") + except Exception as e: + logger.error(f"新 Runner 健康检查失败: {e},回滚") + if self._runner_process and self._runner_process != old_process: + self._runner_process.terminate() + self._runner_process = old_process + return + + # 关停旧 Runner + if old_process and old_process.returncode is None: + try: + old_process.terminate() + await asyncio.wait_for(old_process.wait(), timeout=10.0) + except asyncio.TimeoutError: + old_process.kill() + + logger.info("热重载完成") + + # ─── 内部方法 ────────────────────────────────────────────── + + def _register_internal_methods(self) -> None: + """注册 Host 端的 RPC 方法处理器""" + # Runner -> Host 的能力调用统一走 capability_service + self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request) + # 插件注册 + self._rpc_server.register_method("plugin.register_components", self._handle_register_components) + + async def _handle_register_components(self, envelope: Envelope) -> Envelope: + """处理插件组件注册请求""" + try: + reg = RegisterComponentsPayload.model_validate(envelope.payload) + except Exception as e: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + + # 记录注册信息 + self._registered_plugins[reg.plugin_id] = reg + + # 在策略引擎中注册插件 + self._policy.register_plugin( + plugin_id=reg.plugin_id, + generation=envelope.generation, + capabilities=reg.capabilities_required, + ) + + logger.info( + f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功," + f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}" + ) + + return envelope.make_response(payload={"accepted": True}) + + async def _spawn_runner(self) -> None: + """拉起 Runner 子进程""" + runner_module = "src.plugin_runtime.runner.runner_main" + address = self._transport.get_address() + token = self._rpc_server.session_token + + env = os.environ.copy() + env["MAIBOT_IPC_ADDRESS"] = address + env["MAIBOT_SESSION_TOKEN"] = token + env["MAIBOT_PLUGIN_DIRS"] = os.pathsep.join(self._plugin_dirs) + + self._runner_process = await asyncio.create_subprocess_exec( + sys.executable, "-m", runner_module, + env=env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + self._runner_generation += 1 + logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}") + + async def _shutdown_runner(self) -> None: + """优雅关停 Runner""" + if not self._runner_process or self._runner_process.returncode is not None: + return + + # 发送 prepare_shutdown + try: + if self._rpc_server.is_connected: + shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000) + await self._rpc_server.send_request( + "plugin.prepare_shutdown", + payload=shutdown_payload.model_dump(), + timeout_ms=5000, + ) + await self._rpc_server.send_request( + "plugin.shutdown", + payload=shutdown_payload.model_dump(), + timeout_ms=5000, + ) + except Exception as e: + logger.warning(f"发送关停命令失败: {e}") + + # 等待进程退出 + try: + await asyncio.wait_for(self._runner_process.wait(), timeout=10.0) + except asyncio.TimeoutError: + logger.warning("Runner 未在超时内退出,强制终止") + self._runner_process.kill() + await self._runner_process.wait() + + async def _health_check_loop(self) -> None: + """周期性健康检查""" + while self._running: + await asyncio.sleep(self._health_interval) + + if not self._rpc_server.is_connected: + logger.warning("Runner 未连接,跳过健康检查") + continue + + try: + resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000) + health = HealthPayload.model_validate(resp.payload) + if not health.healthy: + logger.warning(f"Runner 健康检查异常: {health}") + except RPCError as e: + logger.error(f"健康检查失败: {e}") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"健康检查异常: {e}") diff --git a/src/plugin_runtime/protocol/__init__.py b/src/plugin_runtime/protocol/__init__.py new file mode 100644 index 00000000..c39bce06 --- /dev/null +++ b/src/plugin_runtime/protocol/__init__.py @@ -0,0 +1 @@ +# Protocol 层 - RPC 消息模型、编解码、错误码 diff --git a/src/plugin_runtime/protocol/codec.py b/src/plugin_runtime/protocol/codec.py new file mode 100644 index 00000000..8e388511 --- /dev/null +++ b/src/plugin_runtime/protocol/codec.py @@ -0,0 +1,80 @@ +"""MsgPack / JSON 编解码器 + +提供统一的消息编解码接口,生产环境默认使用 MsgPack, +开发调试模式可切换为 JSON(仅编解码切换,传输层不变)。 +""" + +from typing import Any + +import json + +import msgpack + +from .envelope import Envelope + + +class Codec: + """消息编解码器基类""" + + def encode_envelope(self, envelope: Envelope) -> bytes: + raise NotImplementedError + + def decode_envelope(self, data: bytes) -> Envelope: + raise NotImplementedError + + def encode(self, obj: dict[str, Any]) -> bytes: + raise NotImplementedError + + def decode(self, data: bytes) -> dict[str, Any]: + raise NotImplementedError + + +class MsgPackCodec(Codec): + """MsgPack 编解码器(生产默认)""" + + def encode(self, obj: dict[str, Any]) -> bytes: + return msgpack.packb(obj, use_bin_type=True) + + def decode(self, data: bytes) -> dict[str, Any]: + result = msgpack.unpackb(data, raw=False) + if not isinstance(result, dict): + raise ValueError(f"期望解码为 dict,实际为 {type(result)}") + return result + + def encode_envelope(self, envelope: Envelope) -> bytes: + return self.encode(envelope.model_dump()) + + def decode_envelope(self, data: bytes) -> Envelope: + raw = self.decode(data) + return Envelope.model_validate(raw) + + +class JsonCodec(Codec): + """JSON 编解码器(开发调试用)""" + + def encode(self, obj: dict[str, Any]) -> bytes: + return json.dumps(obj, ensure_ascii=False).encode("utf-8") + + def decode(self, data: bytes) -> dict[str, Any]: + result = json.loads(data.decode("utf-8")) + if not isinstance(result, dict): + raise ValueError(f"期望解码为 dict,实际为 {type(result)}") + return result + + def encode_envelope(self, envelope: Envelope) -> bytes: + return self.encode(envelope.model_dump()) + + def decode_envelope(self, data: bytes) -> Envelope: + raw = self.decode(data) + return Envelope.model_validate(raw) + + +def create_codec(use_json: bool = False) -> Codec: + """创建编解码器实例 + + Args: + use_json: 是否使用 JSON(开发模式)。默认使用 MsgPack。 + """ + if use_json: + return JsonCodec() + return MsgPackCodec() diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py new file mode 100644 index 00000000..062a30ee --- /dev/null +++ b/src/plugin_runtime/protocol/envelope.py @@ -0,0 +1,187 @@ +"""RPC Envelope 消息模型 + +定义 Host 与 Runner 之间所有 RPC 消息的统一信封格式。 +使用 Pydantic 进行 schema 定义与校验。 +""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + +import time + + +# ─── 协议常量 ────────────────────────────────────────────────────── + +PROTOCOL_VERSION = "1.0" + +# 支持的 SDK 版本范围(Host 在握手时校验) +MIN_SDK_VERSION = "1.0.0" +MAX_SDK_VERSION = "1.99.99" + + +# ─── 消息类型 ────────────────────────────────────────────────────── + +class MessageType(str, Enum): + """RPC 消息类型""" + REQUEST = "request" + RESPONSE = "response" + EVENT = "event" + + +# ─── 请求 ID 生成器 ─────────────────────────────────────────────── + +class RequestIdGenerator: + """单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)""" + + def __init__(self, start: int = 1): + self._counter = start + + def next(self) -> int: + current = self._counter + self._counter += 1 + return current + + +# ─── Envelope 模型 ───────────────────────────────────────────────── + +class Envelope(BaseModel): + """RPC 统一信封 + + 所有 Host <-> Runner 消息均封装为此格式。 + 序列化流程:Envelope -> .model_dump() -> MsgPack encode + 反序列化流程:MsgPack decode -> Envelope.model_validate(data) + """ + + protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本") + request_id: int = Field(description="单调递增请求 ID") + message_type: MessageType = Field(description="消息类型") + method: str = Field(default="", description="RPC 方法名") + plugin_id: str = Field(default="", description="目标插件 ID") + timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)") + timeout_ms: int = Field(default=30000, description="相对超时(ms)") + generation: int = Field(default=0, description="Runner generation 编号") + payload: dict[str, Any] = Field(default_factory=dict, description="业务数据") + error: dict[str, Any] | None = Field(default=None, description="错误信息(仅 response)") + + def is_request(self) -> bool: + return self.message_type == MessageType.REQUEST + + def is_response(self) -> bool: + return self.message_type == MessageType.RESPONSE + + def is_event(self) -> bool: + return self.message_type == MessageType.EVENT + + def make_response(self, payload: dict[str, Any] | None = None, error: dict[str, Any] | None = None) -> "Envelope": + """基于当前请求创建对应的响应信封""" + return Envelope( + protocol_version=self.protocol_version, + request_id=self.request_id, + message_type=MessageType.RESPONSE, + method=self.method, + plugin_id=self.plugin_id, + generation=self.generation, + payload=payload or {}, + error=error, + ) + + def make_error_response(self, code: str, message: str = "", details: dict | None = None) -> "Envelope": + """基于当前请求创建错误响应""" + return self.make_response( + error={ + "code": code, + "message": message, + "details": details or {}, + } + ) + + +# ─── 握手消息 ────────────────────────────────────────────────────── + +class HelloPayload(BaseModel): + """runner.hello 握手请求 payload""" + runner_id: str = Field(description="Runner 进程唯一标识") + sdk_version: str = Field(description="SDK 版本号") + session_token: str = Field(description="一次性会话令牌") + + +class HelloResponsePayload(BaseModel): + """runner.hello 握手响应 payload""" + accepted: bool = Field(description="是否接受连接") + host_version: str = Field(default="", description="Host 版本号") + assigned_generation: int = Field(default=0, description="分配的 generation 编号") + reason: str = Field(default="", description="拒绝原因(若 accepted=False)") + + +# ─── 组件注册消息 ────────────────────────────────────────────────── + +class ComponentDeclaration(BaseModel): + """单个组件声明""" + name: str = Field(description="组件名称") + component_type: str = Field(description="组件类型: action/command/tool/event_handler") + plugin_id: str = Field(description="所属插件 ID") + metadata: dict[str, Any] = Field(default_factory=dict, description="组件元数据") + + +class RegisterComponentsPayload(BaseModel): + """plugin.register_components 请求 payload""" + plugin_id: str = Field(description="插件 ID") + plugin_version: str = Field(default="1.0.0", description="插件版本") + components: list[ComponentDeclaration] = Field(default_factory=list, description="组件列表") + capabilities_required: list[str] = Field(default_factory=list, description="所需能力列表") + + +# ─── 调用消息 ────────────────────────────────────────────────────── + +class InvokePayload(BaseModel): + """plugin.invoke_* 请求 payload""" + component_name: str = Field(description="要调用的组件名称") + args: dict[str, Any] = Field(default_factory=dict, description="调用参数") + + +class InvokeResultPayload(BaseModel): + """plugin.invoke_* 响应 payload""" + success: bool = Field(description="是否成功") + result: Any = Field(default=None, description="返回值") + + +# ─── 能力调用消息 ────────────────────────────────────────────────── + +class CapabilityRequestPayload(BaseModel): + """cap.* 请求 payload(插件 -> Host 能力调用)""" + capability: str = Field(description="能力名称,如 send.text, db.query") + args: dict[str, Any] = Field(default_factory=dict, description="调用参数") + + +class CapabilityResponsePayload(BaseModel): + """cap.* 响应 payload""" + success: bool = Field(description="是否成功") + result: Any = Field(default=None, description="返回值") + + +# ─── 健康检查 ────────────────────────────────────────────────────── + +class HealthPayload(BaseModel): + """plugin.health 响应 payload""" + healthy: bool = Field(description="是否健康") + loaded_plugins: list[str] = Field(default_factory=list, description="已加载的插件列表") + uptime_ms: int = Field(default=0, description="运行时长(ms)") + + +# ─── 配置更新 ────────────────────────────────────────────────────── + +class ConfigUpdatedPayload(BaseModel): + """plugin.config_updated 事件 payload""" + plugin_id: str = Field(description="插件 ID") + config_version: str = Field(description="新配置版本") + config_data: dict[str, Any] = Field(default_factory=dict, description="配置内容") + + +# ─── 关停 ────────────────────────────────────────────────────────── + +class ShutdownPayload(BaseModel): + """plugin.shutdown / plugin.prepare_shutdown payload""" + reason: str = Field(default="normal", description="关停原因") + drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)") diff --git a/src/plugin_runtime/protocol/errors.py b/src/plugin_runtime/protocol/errors.py new file mode 100644 index 00000000..53d437ef --- /dev/null +++ b/src/plugin_runtime/protocol/errors.py @@ -0,0 +1,61 @@ +"""RPC 错误码定义 + +所有 Host 与 Runner 之间的 RPC 通信使用统一的错误码体系。 +""" + +from enum import Enum + + +class ErrorCode(str, Enum): + """RPC 错误码枚举""" + + # 通用 + OK = "OK" + E_UNKNOWN = "E_UNKNOWN" + + # 协议层 + E_TIMEOUT = "E_TIMEOUT" + E_BAD_PAYLOAD = "E_BAD_PAYLOAD" + E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH" + + # 权限与策略 + E_UNAUTHORIZED = "E_UNAUTHORIZED" + E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED" + E_BACKPRESSURE = "E_BACKPRESSURE" + E_HOST_OVERLOADED = "E_HOST_OVERLOADED" + + # 插件生命周期 + E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED" + E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND" + E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH" + E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS" + + # 能力调用 + E_CAPABILITY_DENIED = "E_CAPABILITY_DENIED" + E_CAPABILITY_FAILED = "E_CAPABILITY_FAILED" + + +class RPCError(Exception): + """RPC 调用异常""" + + def __init__(self, code: ErrorCode, message: str = "", details: dict | None = None): + self.code = code + self.message = message or code.value + self.details = details or {} + super().__init__(f"[{code.value}] {self.message}") + + def to_dict(self) -> dict: + return { + "code": self.code.value, + "message": self.message, + "details": self.details, + } + + @classmethod + def from_dict(cls, data: dict) -> "RPCError": + code = ErrorCode(data.get("code", "E_UNKNOWN")) + return cls( + code=code, + message=data.get("message", ""), + details=data.get("details", {}), + ) diff --git a/src/plugin_runtime/runner/__init__.py b/src/plugin_runtime/runner/__init__.py new file mode 100644 index 00000000..44dde8af --- /dev/null +++ b/src/plugin_runtime/runner/__init__.py @@ -0,0 +1 @@ +# Runner 端 - 插件加载与执行进程 diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py new file mode 100644 index 00000000..60451b37 --- /dev/null +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -0,0 +1,127 @@ +"""插件加载器 + +在 Runner 进程中负责发现和加载插件。 +插件通过 SDK 编写,不再 import src.*。 +""" + +from typing import Any + +import importlib +import importlib.util +import json +import logging +import os +import sys + +logger = logging.getLogger("plugin_runtime.runner.plugin_loader") + + +class PluginMeta: + """加载后的插件元数据""" + + def __init__( + self, + plugin_id: str, + plugin_dir: str, + plugin_instance: Any, + manifest: dict[str, Any], + ): + self.plugin_id = plugin_id + self.plugin_dir = plugin_dir + self.instance = plugin_instance + self.manifest = manifest + self.version = manifest.get("version", "1.0.0") + self.capabilities_required = manifest.get("capabilities", []) + + +class PluginLoader: + """插件加载器 + + 扫描插件目录,加载符合 SDK 规范的插件。 + 每个插件目录须包含: + - _manifest.json: 插件元数据 + - plugin.py: 插件入口模块(导出 create_plugin 工厂函数) + """ + + def __init__(self): + self._loaded_plugins: dict[str, PluginMeta] = {} + + def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]: + """扫描多个目录并加载所有插件 + + Args: + plugin_dirs: 插件目录列表 + + Returns: + 成功加载的插件元数据列表 + """ + results = [] + for base_dir in plugin_dirs: + if not os.path.isdir(base_dir): + logger.warning(f"插件目录不存在: {base_dir}") + continue + + for entry in os.listdir(base_dir): + plugin_dir = os.path.join(base_dir, entry) + if not os.path.isdir(plugin_dir): + continue + + manifest_path = os.path.join(plugin_dir, "_manifest.json") + plugin_path = os.path.join(plugin_dir, "plugin.py") + + if not os.path.exists(manifest_path) or not os.path.exists(plugin_path): + continue + + try: + meta = self._load_single_plugin(plugin_dir, manifest_path, plugin_path) + if meta: + self._loaded_plugins[meta.plugin_id] = meta + results.append(meta) + except Exception as e: + logger.error(f"加载插件失败 [{plugin_dir}]: {e}", exc_info=True) + + return results + + def get_plugin(self, plugin_id: str) -> PluginMeta | None: + """获取已加载的插件""" + return self._loaded_plugins.get(plugin_id) + + def list_plugins(self) -> list[str]: + """列出所有已加载的插件 ID""" + return list(self._loaded_plugins.keys()) + + def _load_single_plugin(self, plugin_dir: str, manifest_path: str, plugin_path: str) -> PluginMeta | None: + """加载单个插件""" + # 1. 读取 manifest + with open(manifest_path, "r", encoding="utf-8") as f: + manifest = json.load(f) + + plugin_id = os.path.basename(plugin_dir) + + # 2. 动态导入插件模块 + module_name = f"_maibot_plugin_{plugin_id}" + spec = importlib.util.spec_from_file_location(module_name, plugin_path) + if spec is None or spec.loader is None: + logger.error(f"无法创建模块 spec: {plugin_path}") + return None + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # 3. 调用工厂函数创建插件实例 + create_plugin = getattr(module, "create_plugin", None) + if create_plugin is None: + logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数") + return None + + instance = create_plugin() + + logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功") + + return PluginMeta( + plugin_id=plugin_id, + plugin_dir=plugin_dir, + plugin_instance=instance, + manifest=manifest, + ) diff --git a/src/plugin_runtime/runner/rpc_client.py b/src/plugin_runtime/runner/rpc_client.py new file mode 100644 index 00000000..e900a5e8 --- /dev/null +++ b/src/plugin_runtime/runner/rpc_client.py @@ -0,0 +1,257 @@ +"""Runner 端 RPC Client + +负责: +1. 连接 Host RPC Server +2. 发送握手(runner.hello) +3. 发送组件注册请求 +4. 接收并分发 Host 的调用请求 +5. 发送能力调用请求到 Host +""" + +from typing import Any, Callable, Awaitable + +import asyncio +import logging +import uuid + +from src.plugin_runtime.protocol.codec import Codec, create_codec +from src.plugin_runtime.protocol.envelope import ( + PROTOCOL_VERSION, + Envelope, + HelloPayload, + HelloResponsePayload, + MessageType, + RequestIdGenerator, +) +from src.plugin_runtime.protocol.errors import ErrorCode, RPCError +from src.plugin_runtime.transport.base import Connection +from src.plugin_runtime.transport.factory import create_transport_client + +logger = logging.getLogger("plugin_runtime.runner.rpc_client") + +# RPC 方法处理器类型 +MethodHandler = Callable[[Envelope], Awaitable[Envelope]] + +SDK_VERSION = "1.0.0" + + +class RPCClient: + """Runner 端 RPC 客户端 + + 管理与 Host 的 IPC 连接,支持双向 RPC 调用。 + """ + + def __init__( + self, + host_address: str, + session_token: str, + codec: Codec | None = None, + ): + self._host_address = host_address + self._session_token = session_token + self._codec = codec or create_codec() + + self._id_gen = RequestIdGenerator() + self._connection: Connection | None = None + self._runner_id = str(uuid.uuid4()) + self._generation: int = 0 + + # 方法处理器注册表(Host 发来的调用) + self._method_handlers: dict[str, MethodHandler] = {} + + # 等待响应的 pending 请求: request_id -> Future + self._pending_requests: dict[int, asyncio.Future] = {} + + # 运行状态 + self._running = False + self._recv_task: asyncio.Task | None = None + + @property + def generation(self) -> int: + return self._generation + + @property + def is_connected(self) -> bool: + return self._connection is not None and not self._connection.is_closed + + def register_method(self, method: str, handler: MethodHandler) -> None: + """注册方法处理器(处理 Host 发来的请求)""" + self._method_handlers[method] = handler + + async def connect_and_handshake(self) -> bool: + """连接 Host 并完成握手 + + Returns: + 是否握手成功 + """ + client = create_transport_client(self._host_address) + self._connection = await client.connect() + + # 发送 runner.hello + hello = HelloPayload( + runner_id=self._runner_id, + sdk_version=SDK_VERSION, + session_token=self._session_token, + ) + request_id = self._id_gen.next() + envelope = Envelope( + request_id=request_id, + message_type=MessageType.REQUEST, + method="runner.hello", + payload=hello.model_dump(), + ) + + data = self._codec.encode_envelope(envelope) + await self._connection.send_frame(data) + + # 接收握手响应 + resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0) + resp = self._codec.decode_envelope(resp_data) + + resp_payload = HelloResponsePayload.model_validate(resp.payload) + if not resp_payload.accepted: + logger.error(f"握手被拒绝: {resp_payload.reason}") + await self._connection.close() + self._connection = None + return False + + self._generation = resp_payload.assigned_generation + logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}") + + # 启动消息接收循环 + self._running = True + self._recv_task = asyncio.create_task(self._recv_loop()) + + return True + + async def disconnect(self) -> None: + """断开连接""" + self._running = False + if self._recv_task: + self._recv_task.cancel() + try: + await self._recv_task + except asyncio.CancelledError: + pass + self._recv_task = None + + # 取消所有 pending 请求 + for future in self._pending_requests.values(): + if not future.done(): + future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭")) + self._pending_requests.clear() + + if self._connection: + await self._connection.close() + self._connection = None + + async def send_request( + self, + method: str, + plugin_id: str = "", + payload: dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Envelope: + """向 Host 发送 RPC 请求并等待响应""" + if not self.is_connected: + raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host") + + request_id = self._id_gen.next() + envelope = Envelope( + request_id=request_id, + message_type=MessageType.REQUEST, + method=method, + plugin_id=plugin_id, + generation=self._generation, + timeout_ms=timeout_ms, + payload=payload or {}, + ) + + loop = asyncio.get_event_loop() + future: asyncio.Future[Envelope] = loop.create_future() + self._pending_requests[request_id] = future + + try: + data = self._codec.encode_envelope(envelope) + await self._connection.send_frame(data) + + timeout_sec = timeout_ms / 1000.0 + response = await asyncio.wait_for(future, timeout=timeout_sec) + return response + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") + except Exception as e: + self._pending_requests.pop(request_id, None) + if isinstance(e, RPCError): + raise + raise RPCError(ErrorCode.E_UNKNOWN, str(e)) + + # ─── 内部方法 ────────────────────────────────────────────── + + async def _recv_loop(self) -> None: + """消息接收主循环""" + while self._running and self._connection and not self._connection.is_closed: + try: + data = await self._connection.recv_frame() + except (asyncio.IncompleteReadError, ConnectionError): + logger.info("Host 连接已断开") + break + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"接收帧失败: {e}") + break + + try: + envelope = self._codec.decode_envelope(data) + except Exception as e: + logger.error(f"解码消息失败: {e}") + continue + + if envelope.is_response(): + self._handle_response(envelope) + elif envelope.is_request(): + asyncio.create_task(self._handle_request(envelope)) + elif envelope.is_event(): + asyncio.create_task(self._handle_event(envelope)) + + def _handle_response(self, envelope: Envelope) -> None: + """处理来自 Host 的响应""" + future = self._pending_requests.pop(envelope.request_id, None) + if future and not future.done(): + if envelope.error: + future.set_exception(RPCError.from_dict(envelope.error)) + else: + future.set_result(envelope) + + async def _handle_request(self, envelope: Envelope) -> None: + """处理来自 Host 的请求(调用插件组件)""" + handler = self._method_handlers.get(envelope.method) + if handler is None: + error_resp = envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"未注册的方法: {envelope.method}", + ) + await self._connection.send_frame(self._codec.encode_envelope(error_resp)) + return + + try: + response = await handler(envelope) + await self._connection.send_frame(self._codec.encode_envelope(response)) + except RPCError as e: + error_resp = envelope.make_error_response(e.code.value, e.message, e.details) + await self._connection.send_frame(self._codec.encode_envelope(error_resp)) + except Exception as e: + logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True) + error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) + await self._connection.send_frame(self._codec.encode_envelope(error_resp)) + + async def _handle_event(self, envelope: Envelope) -> None: + """处理来自 Host 的事件""" + handler = self._method_handlers.get(envelope.method) + if handler: + try: + await handler(envelope) + except Exception as e: + logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True) diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py new file mode 100644 index 00000000..a658fe7d --- /dev/null +++ b/src/plugin_runtime/runner/runner_main.py @@ -0,0 +1,246 @@ +"""Runner 主循环 + +作为独立子进程运行,负责: +1. 从环境变量读取 IPC 地址和会话令牌 +2. 连接 Host 并完成握手 +3. 加载所有插件 +4. 注册组件到 Host +5. 处理 Host 的调用请求 +6. 转发插件的能力调用到 Host +""" + +from typing import Any + +import asyncio +import logging +import os +import signal +import sys +import time + +from src.plugin_runtime.protocol.envelope import ( + ComponentDeclaration, + Envelope, + HealthPayload, + InvokePayload, + InvokeResultPayload, + RegisterComponentsPayload, + ShutdownPayload, +) +from src.plugin_runtime.protocol.errors import ErrorCode +from src.plugin_runtime.runner.plugin_loader import PluginLoader +from src.plugin_runtime.runner.rpc_client import RPCClient + +logger = logging.getLogger("plugin_runtime.runner.main") + + +class PluginRunner: + """插件 Runner + + 运行在独立子进程中,管理所有插件的执行。 + """ + + def __init__( + self, + host_address: str, + session_token: str, + plugin_dirs: list[str], + ): + self._host_address = host_address + self._session_token = session_token + self._plugin_dirs = plugin_dirs + + self._rpc_client = RPCClient(host_address, session_token) + self._loader = PluginLoader() + self._start_time = time.monotonic() + self._shutting_down = False + + async def run(self) -> None: + """Runner 主入口""" + # 1. 连接 Host + logger.info(f"Runner 启动,连接 Host: {self._host_address}") + ok = await self._rpc_client.connect_and_handshake() + if not ok: + logger.error("握手失败,退出") + return + + # 2. 注册方法处理器 + self._register_handlers() + + # 3. 加载插件 + plugins = self._loader.discover_and_load(self._plugin_dirs) + logger.info(f"已加载 {len(plugins)} 个插件") + + # 4. 向 Host 注册所有插件的组件 + for meta in plugins: + await self._register_plugin(meta) + + # 5. 等待直到收到关停信号 + try: + while not self._shutting_down: + await asyncio.sleep(1.0) + except asyncio.CancelledError: + pass + + # 6. 断开连接 + await self._rpc_client.disconnect() + logger.info("Runner 已退出") + + def _register_handlers(self) -> None: + """注册方法处理器""" + self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) + self._rpc_client.register_method("plugin.emit_event", self._handle_invoke) + self._rpc_client.register_method("plugin.health", self._handle_health) + self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown) + self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown) + self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated) + + async def _register_plugin(self, meta) -> None: + """向 Host 注册单个插件""" + # 收集插件组件声明 + components = [] + instance = meta.instance + + # 从插件实例获取组件声明(SDK 插件须实现 get_components 方法) + if hasattr(instance, "get_components"): + for comp_info in instance.get_components(): + components.append(ComponentDeclaration( + name=comp_info.get("name", ""), + component_type=comp_info.get("type", ""), + plugin_id=meta.plugin_id, + metadata=comp_info.get("metadata", {}), + )) + + reg_payload = RegisterComponentsPayload( + plugin_id=meta.plugin_id, + plugin_version=meta.version, + components=components, + capabilities_required=meta.capabilities_required, + ) + + try: + resp = await self._rpc_client.send_request( + "plugin.register_components", + plugin_id=meta.plugin_id, + payload=reg_payload.model_dump(), + timeout_ms=10000, + ) + logger.info(f"插件 {meta.plugin_id} 注册完成") + except Exception as e: + logger.error(f"插件 {meta.plugin_id} 注册失败: {e}") + + async def _handle_invoke(self, envelope: Envelope) -> Envelope: + """处理组件调用请求""" + try: + invoke = InvokePayload.model_validate(envelope.payload) + except Exception as e: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e)) + + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta is None: + return envelope.make_error_response( + ErrorCode.E_PLUGIN_NOT_FOUND.value, + f"插件 {plugin_id} 未加载", + ) + + # 调用插件实例的组件方法 + instance = meta.instance + component_name = invoke.component_name + + handler_method = getattr(instance, f"handle_{component_name}", None) + if handler_method is None: + handler_method = getattr(instance, component_name, None) + + if handler_method is None or not callable(handler_method): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {plugin_id} 无组件: {component_name}", + ) + + try: + result = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args) + resp_payload = InvokeResultPayload(success=True, result=result) + return envelope.make_response(payload=resp_payload.model_dump()) + except Exception as e: + logger.error(f"插件 {plugin_id} 组件 {component_name} 执行异常: {e}", exc_info=True) + resp_payload = InvokeResultPayload(success=False, result=str(e)) + return envelope.make_response(payload=resp_payload.model_dump()) + + async def _handle_health(self, envelope: Envelope) -> Envelope: + """处理健康检查""" + uptime_ms = int((time.monotonic() - self._start_time) * 1000) + health = HealthPayload( + healthy=True, + loaded_plugins=self._loader.list_plugins(), + uptime_ms=uptime_ms, + ) + return envelope.make_response(payload=health.model_dump()) + + async def _handle_prepare_shutdown(self, envelope: Envelope) -> Envelope: + """处理准备关停""" + logger.info("收到 prepare_shutdown 信号") + return envelope.make_response(payload={"acknowledged": True}) + + async def _handle_shutdown(self, envelope: Envelope) -> Envelope: + """处理关停""" + logger.info("收到 shutdown 信号,准备退出") + self._shutting_down = True + return envelope.make_response(payload={"acknowledged": True}) + + async def _handle_config_updated(self, envelope: Envelope) -> Envelope: + """处理配置更新事件""" + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta and hasattr(meta.instance, "on_config_update"): + try: + config_data = envelope.payload.get("config_data", {}) + config_version = envelope.payload.get("config_version", "") + await meta.instance.on_config_update(config_data, config_version) + except Exception as e: + logger.error(f"插件 {plugin_id} 配置更新失败: {e}") + return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e)) + return envelope.make_response(payload={"acknowledged": True}) + + def request_capability(self) -> RPCClient: + """获取 RPC 客户端(供 SDK 使用,发起能力调用)""" + return self._rpc_client + + +# ─── 进程入口 ────────────────────────────────────────────── + +async def _async_main() -> None: + """异步主入口""" + host_address = os.environ.get("MAIBOT_IPC_ADDRESS", "") + session_token = os.environ.get("MAIBOT_SESSION_TOKEN", "") + plugin_dirs_str = os.environ.get("MAIBOT_PLUGIN_DIRS", "") + + if not host_address or not session_token: + logger.error("缺少必要的环境变量: MAIBOT_IPC_ADDRESS, MAIBOT_SESSION_TOKEN") + sys.exit(1) + + plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d] + + runner = PluginRunner(host_address, session_token, plugin_dirs) + + # 注册信号处理 + loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True)) + + await runner.run() + + +def main() -> None: + """进程入口(python -m src.plugin_runtime.runner.runner_main)""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", + ) + asyncio.run(_async_main()) + + +if __name__ == "__main__": + main() diff --git a/src/plugin_runtime/transport/__init__.py b/src/plugin_runtime/transport/__init__.py new file mode 100644 index 00000000..24759ac0 --- /dev/null +++ b/src/plugin_runtime/transport/__init__.py @@ -0,0 +1 @@ +# Transport 层 - 跨平台本地 IPC 传输抽象 diff --git a/src/plugin_runtime/transport/base.py b/src/plugin_runtime/transport/base.py new file mode 100644 index 00000000..82c6c985 --- /dev/null +++ b/src/plugin_runtime/transport/base.py @@ -0,0 +1,116 @@ +"""传输层抽象基类 + +定义 TransportServer 和 TransportClient 的统一接口。 +所有传输后端(UDS、Named Pipe、TCP 回退)必须实现此接口。 +业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。 + +分帧协议:4-byte big-endian length prefix + payload +""" + +from abc import ABC, abstractmethod +from typing import AsyncIterator, Callable, Awaitable + +import asyncio +import struct + +# 分帧常量 +FRAME_HEADER_SIZE = 4 # 4 字节长度前缀 +MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小 + + +class ConnectionClosed(Exception): + """连接已关闭""" + pass + + +class Connection(ABC): + """单个连接的抽象 + + 封装了底层 StreamReader/StreamWriter,提供分帧读写能力。 + """ + + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self._reader = reader + self._writer = writer + self._closed = False + + async def send_frame(self, data: bytes) -> None: + """发送一帧数据(4-byte length prefix + payload)""" + if self._closed: + raise ConnectionClosed("连接已关闭") + length = len(data) + if length > MAX_FRAME_SIZE: + raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}") + header = struct.pack(">I", length) + self._writer.write(header + data) + await self._writer.drain() + + async def recv_frame(self) -> bytes: + """接收一帧数据""" + if self._closed: + raise ConnectionClosed("连接已关闭") + # 读取 4 字节长度头 + header = await self._reader.readexactly(FRAME_HEADER_SIZE) + (length,) = struct.unpack(">I", header) + if length > MAX_FRAME_SIZE: + raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}") + # 读取 payload + payload = await self._reader.readexactly(length) + return payload + + async def close(self) -> None: + """关闭连接""" + if self._closed: + return + self._closed = True + try: + self._writer.close() + await self._writer.wait_closed() + except Exception: + pass + + @property + def is_closed(self) -> bool: + return self._closed + + +# 连接回调类型:收到新连接时调用 +ConnectionHandler = Callable[[Connection], Awaitable[None]] + + +class TransportServer(ABC): + """传输服务端抽象 + + Host 端使用,监听来自 Runner 的连接。 + """ + + @abstractmethod + async def start(self, handler: ConnectionHandler) -> None: + """启动服务端,开始监听连接 + + Args: + handler: 新连接到来时的回调函数 + """ + ... + + @abstractmethod + async def stop(self) -> None: + """停止服务端""" + ... + + @abstractmethod + def get_address(self) -> str: + """获取监听地址(供 Runner 连接用)""" + ... + + +class TransportClient(ABC): + """传输客户端抽象 + + Runner 端使用,主动连接 Host。 + """ + + @abstractmethod + async def connect(self) -> Connection: + """建立到 Host 的连接""" + ... diff --git a/src/plugin_runtime/transport/factory.py b/src/plugin_runtime/transport/factory.py new file mode 100644 index 00000000..b3b01852 --- /dev/null +++ b/src/plugin_runtime/transport/factory.py @@ -0,0 +1,46 @@ +"""传输层工厂 + +根据运行平台自动选择最优传输实现。 +""" + +import sys + +from .base import TransportClient, TransportServer + + +def create_transport_server(socket_path: str | None = None) -> TransportServer: + """创建传输服务端 + + Linux/macOS 使用 UDS,Windows 使用 TCP 回退。 + + Args: + socket_path: UDS socket 路径(仅 Linux/macOS 有效) + """ + if sys.platform != "win32": + from .uds import UDSTransportServer + return UDSTransportServer(socket_path=socket_path) + else: + # Windows 回退到 TCP(后续可改为 Named Pipe) + from .tcp import TCPTransportServer + return TCPTransportServer() + + +def create_transport_client(address: str) -> TransportClient: + """创建传输客户端 + + 根据地址格式自动判断传输类型: + - 包含 '/' 或 '.sock' -> UDS + - 包含 ':' -> TCP + + Args: + address: Host 端监听地址 + """ + if "/" in address or address.endswith(".sock"): + from .uds import UDSTransportClient + return UDSTransportClient(socket_path=address) + elif ":" in address: + from .tcp import TCPTransportClient + host, port_str = address.rsplit(":", 1) + return TCPTransportClient(host=host, port=int(port_str)) + else: + raise ValueError(f"无法识别的传输地址格式: {address}") diff --git a/src/plugin_runtime/transport/tcp.py b/src/plugin_runtime/transport/tcp.py new file mode 100644 index 00000000..002a949c --- /dev/null +++ b/src/plugin_runtime/transport/tcp.py @@ -0,0 +1,59 @@ +"""TCP 传输实现(回退方案) + +仅当 UDS / Named Pipe 不可用时启用。 +绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。 +""" + +import asyncio + +from .base import Connection, ConnectionHandler, TransportClient, TransportServer + + +class TCPConnection(Connection): + """基于 TCP 的连接""" + pass + + +class TCPTransportServer(TransportServer): + """TCP 传输服务端(回退方案)""" + + def __init__(self, host: str = "127.0.0.1", port: int = 0): + self._host = host + self._port = port # 0 表示自动分配 + self._server: asyncio.AbstractServer | None = None + self._actual_port: int = 0 + + async def start(self, handler: ConnectionHandler) -> None: + async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + conn = TCPConnection(reader, writer) + try: + await handler(conn) + finally: + await conn.close() + + self._server = await asyncio.start_server(_on_connect, self._host, self._port) + + # 获取实际分配的端口 + addr = self._server.sockets[0].getsockname() + self._actual_port = addr[1] + + async def stop(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + + def get_address(self) -> str: + return f"{self._host}:{self._actual_port}" + + +class TCPTransportClient(TransportClient): + """TCP 传输客户端""" + + def __init__(self, host: str, port: int): + self._host = host + self._port = port + + async def connect(self) -> Connection: + reader, writer = await asyncio.open_connection(self._host, self._port) + return TCPConnection(reader, writer) diff --git a/src/plugin_runtime/transport/uds.py b/src/plugin_runtime/transport/uds.py new file mode 100644 index 00000000..461f3197 --- /dev/null +++ b/src/plugin_runtime/transport/uds.py @@ -0,0 +1,71 @@ +"""Unix Domain Socket 传输实现 + +适用于 Linux / macOS 平台。 +""" + +from pathlib import Path + +import asyncio +import os +import tempfile + +from .base import Connection, ConnectionHandler, TransportClient, TransportServer + + +class UDSConnection(Connection): + """基于 UDS 的连接""" + pass # 直接复用 Connection 基类的分帧读写 + + +class UDSTransportServer(TransportServer): + """UDS 传输服务端""" + + def __init__(self, socket_path: str | None = None): + if socket_path is None: + # 默认放在临时目录 + socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock") + self._socket_path = socket_path + self._server: asyncio.AbstractServer | None = None + + async def start(self, handler: ConnectionHandler) -> None: + # 清理残留 socket 文件 + if os.path.exists(self._socket_path): + os.unlink(self._socket_path) + + # 确保父目录存在 + Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True) + + async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + conn = UDSConnection(reader, writer) + try: + await handler(conn) + finally: + await conn.close() + + self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path) + + # 设置文件权限为仅当前用户可访问 + os.chmod(self._socket_path, 0o600) + + async def stop(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() + self._server = None + # 清理 socket 文件 + if os.path.exists(self._socket_path): + os.unlink(self._socket_path) + + def get_address(self) -> str: + return self._socket_path + + +class UDSTransportClient(TransportClient): + """UDS 传输客户端""" + + def __init__(self, socket_path: str): + self._socket_path = socket_path + + async def connect(self) -> Connection: + reader, writer = await asyncio.open_unix_connection(self._socket_path) + return UDSConnection(reader, writer)