feat(plugin-runtime): add plugin isolation IPC infrastructure
- Protocol layer: Envelope model with Pydantic schema, MsgPack/JSON codecs, unified error codes - Transport layer: cross-platform IPC abstraction with 4-byte length-prefixed framing (UDS + TCP fallback) - Host: RPC server, policy engine, circuit breaker, capability service, supervisor with hot-reload - Runner: RPC client, plugin loader, process entry point - Tests: 16 passing tests covering protocol, transport, host, and E2E handshake
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -354,4 +354,5 @@ MaiBot.code-workspace
|
||||
*.lock
|
||||
actionlint
|
||||
.sisyphus/
|
||||
dist-electron/
|
||||
dist-electron/
|
||||
packages/
|
||||
427
pytests/test_plugin_runtime.py
Normal file
427
pytests/test_plugin_runtime.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""插件运行时框架基础测试
|
||||
|
||||
验证协议层、传输层、RPC 通信链路的正确性。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
# 确保项目根目录在 sys.path 中
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
# SDK 包路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "packages", "maibot-plugin-sdk"))
|
||||
|
||||
|
||||
# ─── 协议层测试 ───────────────────────────────────────────
|
||||
|
||||
class TestProtocol:
|
||||
"""协议层测试"""
|
||||
|
||||
def test_envelope_create_and_serialize(self):
|
||||
"""Envelope 创建与序列化"""
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
env = Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="plugin.invoke_command",
|
||||
plugin_id="test_plugin",
|
||||
payload={"component_name": "greet", "args": {}},
|
||||
)
|
||||
|
||||
assert env.request_id == 1
|
||||
assert env.is_request()
|
||||
assert env.method == "plugin.invoke_command"
|
||||
|
||||
# 测试 make_response
|
||||
resp = env.make_response(payload={"success": True})
|
||||
assert resp.is_response()
|
||||
assert resp.request_id == 1
|
||||
assert resp.payload["success"] is True
|
||||
|
||||
def test_envelope_make_error_response(self):
|
||||
"""错误响应生成"""
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
env = Envelope(
|
||||
request_id=42,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="cap.request",
|
||||
)
|
||||
|
||||
err_resp = env.make_error_response("E_UNAUTHORIZED", "没有权限")
|
||||
assert err_resp.error is not None
|
||||
assert err_resp.error["code"] == "E_UNAUTHORIZED"
|
||||
assert err_resp.error["message"] == "没有权限"
|
||||
|
||||
def test_msgpack_codec(self):
|
||||
"""MsgPack 编解码"""
|
||||
from src.plugin_runtime.protocol.codec import MsgPackCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
codec = MsgPackCodec()
|
||||
env = Envelope(
|
||||
request_id=100,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="test.method",
|
||||
payload={"key": "value", "number": 42},
|
||||
)
|
||||
|
||||
# 编码
|
||||
data = codec.encode_envelope(env)
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
# 解码
|
||||
decoded = codec.decode_envelope(data)
|
||||
assert decoded.request_id == 100
|
||||
assert decoded.method == "test.method"
|
||||
assert decoded.payload["key"] == "value"
|
||||
assert decoded.payload["number"] == 42
|
||||
|
||||
def test_json_codec(self):
|
||||
"""JSON 编解码"""
|
||||
from src.plugin_runtime.protocol.codec import JsonCodec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, MessageType
|
||||
|
||||
codec = JsonCodec()
|
||||
env = Envelope(
|
||||
request_id=200,
|
||||
message_type=MessageType.EVENT,
|
||||
method="plugin.config_updated",
|
||||
payload={"config_version": "2.0"},
|
||||
)
|
||||
|
||||
data = codec.encode_envelope(env)
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
decoded = codec.decode_envelope(data)
|
||||
assert decoded.request_id == 200
|
||||
assert decoded.is_event()
|
||||
|
||||
def test_request_id_generator(self):
|
||||
"""请求 ID 生成器单调递增"""
|
||||
from src.plugin_runtime.protocol.envelope import RequestIdGenerator
|
||||
|
||||
gen = RequestIdGenerator()
|
||||
ids = [gen.next() for _ in range(100)]
|
||||
assert ids == list(range(1, 101))
|
||||
|
||||
def test_error_codes(self):
|
||||
"""错误码枚举"""
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
|
||||
err = RPCError(ErrorCode.E_TIMEOUT, "请求超时")
|
||||
assert err.code == ErrorCode.E_TIMEOUT
|
||||
assert "E_TIMEOUT" in str(err)
|
||||
|
||||
# 序列化/反序列化
|
||||
d = err.to_dict()
|
||||
err2 = RPCError.from_dict(d)
|
||||
assert err2.code == ErrorCode.E_TIMEOUT
|
||||
|
||||
|
||||
# ─── 传输层测试 ───────────────────────────────────────────
|
||||
|
||||
class TestTransport:
|
||||
"""传输层测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uds_connection_framing(self):
|
||||
"""UDS 分帧协议测试"""
|
||||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||||
|
||||
server = UDSTransportServer()
|
||||
received = asyncio.Event()
|
||||
received_data = []
|
||||
|
||||
async def handler(conn):
|
||||
data = await conn.recv_frame()
|
||||
received_data.append(data)
|
||||
await conn.send_frame(b"pong")
|
||||
received.set()
|
||||
|
||||
await server.start(handler)
|
||||
address = server.get_address()
|
||||
|
||||
client = UDSTransportClient(address)
|
||||
conn = await client.connect()
|
||||
await conn.send_frame(b"ping")
|
||||
|
||||
# 等待服务端处理
|
||||
await asyncio.wait_for(received.wait(), timeout=5.0)
|
||||
assert received_data[0] == b"ping"
|
||||
|
||||
# 接收服务端回复
|
||||
resp = await conn.recv_frame()
|
||||
assert resp == b"pong"
|
||||
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_connection_framing(self):
|
||||
"""TCP 分帧协议测试"""
|
||||
from src.plugin_runtime.transport.tcp import TCPTransportServer, TCPTransportClient
|
||||
|
||||
server = TCPTransportServer()
|
||||
received = asyncio.Event()
|
||||
received_data = []
|
||||
|
||||
async def handler(conn):
|
||||
data = await conn.recv_frame()
|
||||
received_data.append(data)
|
||||
await conn.send_frame(b"tcp_pong")
|
||||
received.set()
|
||||
|
||||
await server.start(handler)
|
||||
address = server.get_address()
|
||||
host, port = address.split(":")
|
||||
|
||||
client = TCPTransportClient(host, int(port))
|
||||
conn = await client.connect()
|
||||
await conn.send_frame(b"tcp_ping")
|
||||
|
||||
await asyncio.wait_for(received.wait(), timeout=5.0)
|
||||
assert received_data[0] == b"tcp_ping"
|
||||
|
||||
resp = await conn.recv_frame()
|
||||
assert resp == b"tcp_pong"
|
||||
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transport_factory(self):
|
||||
"""传输工厂测试"""
|
||||
from src.plugin_runtime.transport.factory import create_transport_server, create_transport_client
|
||||
|
||||
server = create_transport_server()
|
||||
assert server is not None
|
||||
|
||||
# UDS 路径
|
||||
client = create_transport_client("/tmp/test.sock")
|
||||
assert client is not None
|
||||
|
||||
# TCP 地址
|
||||
client = create_transport_client("127.0.0.1:9999")
|
||||
assert client is not None
|
||||
|
||||
|
||||
# ─── Host 层测试 ──────────────────────────────────────────
|
||||
|
||||
class TestHost:
|
||||
"""Host 端基础设施测试"""
|
||||
|
||||
def test_policy_engine(self):
|
||||
"""策略引擎测试"""
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
|
||||
engine = PolicyEngine()
|
||||
|
||||
# 注册插件
|
||||
token = engine.register_plugin(
|
||||
plugin_id="test_plugin",
|
||||
generation=1,
|
||||
capabilities=["send.text", "db.query"],
|
||||
limits={"qps": 10, "burst": 20},
|
||||
)
|
||||
|
||||
assert token.plugin_id == "test_plugin"
|
||||
assert "send.text" in token.capabilities
|
||||
|
||||
# 能力检查
|
||||
ok, _ = engine.check_capability("test_plugin", "send.text")
|
||||
assert ok
|
||||
|
||||
ok, reason = engine.check_capability("test_plugin", "llm.generate")
|
||||
assert not ok
|
||||
assert "未获授权" in reason
|
||||
|
||||
# 未注册插件
|
||||
ok, reason = engine.check_capability("unknown", "send.text")
|
||||
assert not ok
|
||||
|
||||
def test_circuit_breaker(self):
|
||||
"""熔断器测试"""
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreaker, CircuitState
|
||||
|
||||
breaker = CircuitBreaker(failure_threshold=3)
|
||||
|
||||
# 初始状态:关闭
|
||||
assert breaker.state == CircuitState.CLOSED
|
||||
assert breaker.allow_request()
|
||||
|
||||
# 连续失败
|
||||
breaker.record_failure()
|
||||
breaker.record_failure()
|
||||
assert breaker.allow_request() # 还没到阈值
|
||||
|
||||
breaker.record_failure() # 第3次,触发熔断
|
||||
assert breaker.state == CircuitState.OPEN
|
||||
assert not breaker.allow_request()
|
||||
|
||||
# 重置
|
||||
breaker.reset()
|
||||
assert breaker.state == CircuitState.CLOSED
|
||||
|
||||
def test_circuit_breaker_registry(self):
|
||||
"""熔断器注册表测试"""
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
|
||||
|
||||
registry = CircuitBreakerRegistry(failure_threshold=2)
|
||||
|
||||
b1 = registry.get("plugin_a")
|
||||
b2 = registry.get("plugin_b")
|
||||
assert b1 is not b2
|
||||
assert registry.get("plugin_a") is b1 # 同一个
|
||||
|
||||
|
||||
# ─── SDK 测试 ─────────────────────────────────────────────
|
||||
|
||||
class TestSDK:
|
||||
"""SDK 框架测试"""
|
||||
|
||||
def test_component_decorators(self):
|
||||
"""组件装饰器测试"""
|
||||
from maibot_sdk import MaiBotPlugin, Action, Command, Tool, EventHandler
|
||||
from maibot_sdk.types import ActivationType, EventType
|
||||
|
||||
class TestPlugin(MaiBotPlugin):
|
||||
@Action("greet", activation_type=ActivationType.KEYWORD, activation_keywords=["hi"])
|
||||
async def handle_greet(self, **kwargs):
|
||||
return True, "ok"
|
||||
|
||||
@Command("echo", pattern=r"^/echo")
|
||||
async def handle_echo(self, **kwargs):
|
||||
return True, "echoed", 2
|
||||
|
||||
@Tool("search", parameters={"query": {"type": "string"}})
|
||||
async def handle_search(self, **kwargs):
|
||||
return {"result": "found"}
|
||||
|
||||
@EventHandler("on_start", event_type=EventType.ON_START)
|
||||
async def handle_start(self, **kwargs):
|
||||
return True, False, "started"
|
||||
|
||||
plugin = TestPlugin()
|
||||
components = plugin.get_components()
|
||||
|
||||
assert len(components) == 4
|
||||
|
||||
names = {c["name"] for c in components}
|
||||
assert "greet" in names
|
||||
assert "echo" in names
|
||||
assert "search" in names
|
||||
assert "on_start" in names
|
||||
|
||||
types = {c["type"] for c in components}
|
||||
assert "action" in types
|
||||
assert "command" in types
|
||||
assert "tool" in types
|
||||
assert "event_handler" in types
|
||||
|
||||
def test_plugin_context_not_initialized(self):
|
||||
"""未初始化上下文时应报错"""
|
||||
from maibot_sdk import MaiBotPlugin
|
||||
|
||||
plugin = MaiBotPlugin()
|
||||
with pytest.raises(RuntimeError, match="尚未初始化"):
|
||||
_ = plugin.ctx
|
||||
|
||||
def test_plugin_context_injection(self):
|
||||
"""上下文注入测试"""
|
||||
from maibot_sdk import MaiBotPlugin
|
||||
from maibot_sdk.context import PluginContext
|
||||
|
||||
plugin = MaiBotPlugin()
|
||||
ctx = PluginContext(plugin_id="test")
|
||||
plugin._set_context(ctx)
|
||||
|
||||
assert plugin.ctx.plugin_id == "test"
|
||||
assert plugin.ctx.send is not None
|
||||
assert plugin.ctx.db is not None
|
||||
assert plugin.ctx.llm is not None
|
||||
assert plugin.ctx.config is not None
|
||||
|
||||
|
||||
# ─── 端到端集成测试 ────────────────────────────────────────
|
||||
|
||||
class TestE2E:
|
||||
"""端到端集成测试(Host + Runner 通信)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handshake(self):
|
||||
"""Host-Runner 握手流程测试"""
|
||||
from src.plugin_runtime.protocol.codec import create_codec
|
||||
from src.plugin_runtime.protocol.envelope import Envelope, HelloPayload, HelloResponsePayload, MessageType
|
||||
from src.plugin_runtime.transport.uds import UDSTransportServer, UDSTransportClient
|
||||
|
||||
import secrets
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-test-{os.getpid()}.sock")
|
||||
session_token = secrets.token_hex(16)
|
||||
codec = create_codec()
|
||||
handshake_done = asyncio.Event()
|
||||
server_result = {}
|
||||
|
||||
async def server_handler(conn):
|
||||
# 接收握手
|
||||
data = await conn.recv_frame()
|
||||
env = codec.decode_envelope(data)
|
||||
assert env.method == "runner.hello"
|
||||
|
||||
hello = HelloPayload.model_validate(env.payload)
|
||||
assert hello.session_token == session_token
|
||||
|
||||
# 发送响应
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version="1.0",
|
||||
assigned_generation=1,
|
||||
)
|
||||
resp = env.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(codec.encode_envelope(resp))
|
||||
|
||||
server_result["runner_id"] = hello.runner_id
|
||||
handshake_done.set()
|
||||
|
||||
# 保持连接一会儿
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
server = UDSTransportServer(socket_path=socket_path)
|
||||
await server.start(server_handler)
|
||||
|
||||
# 客户端握手
|
||||
client = UDSTransportClient(socket_path)
|
||||
conn = await client.connect()
|
||||
|
||||
hello = HelloPayload(
|
||||
runner_id="test-runner",
|
||||
sdk_version="1.0.0",
|
||||
session_token=session_token,
|
||||
)
|
||||
env = Envelope(
|
||||
request_id=1,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="runner.hello",
|
||||
payload=hello.model_dump(),
|
||||
)
|
||||
await conn.send_frame(codec.encode_envelope(env))
|
||||
|
||||
resp_data = await conn.recv_frame()
|
||||
resp = codec.decode_envelope(resp_data)
|
||||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||||
|
||||
assert resp_payload.accepted
|
||||
assert resp_payload.assigned_generation == 1
|
||||
|
||||
await asyncio.wait_for(handshake_done.wait(), timeout=5.0)
|
||||
assert server_result["runner_id"] == "test-runner"
|
||||
|
||||
await conn.close()
|
||||
await server.stop()
|
||||
2
src/plugin_runtime/__init__.py
Normal file
2
src/plugin_runtime/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# MaiBot Plugin Runtime - 插件隔离运行时基础设施
|
||||
# 本模块实现 Host-Runner 进程分离架构,提供 IPC 通信、策略引擎与生命周期管理
|
||||
1
src/plugin_runtime/host/__init__.py
Normal file
1
src/plugin_runtime/host/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Host 端 - Supervisor、RPC Server、策略引擎、路由
|
||||
108
src/plugin_runtime/host/capability_service.py
Normal file
108
src/plugin_runtime/host/capability_service.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""能力服务层
|
||||
|
||||
Host 端实现的能力服务,处理来自插件的 cap.* 请求。
|
||||
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
import logging
|
||||
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
CapabilityRequestPayload,
|
||||
CapabilityResponsePayload,
|
||||
Envelope,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.host.capability_service")
|
||||
|
||||
# 能力实现函数类型: (plugin_id, capability, args) -> result
|
||||
CapabilityImpl = Callable[[str, str, dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
class CapabilityService:
|
||||
"""能力服务
|
||||
|
||||
负责:
|
||||
1. 注册能力实现
|
||||
2. 接收插件的能力调用请求
|
||||
3. 通过策略引擎校验权限和限流
|
||||
4. 执行实际操作并返回结果
|
||||
"""
|
||||
|
||||
def __init__(self, policy_engine: PolicyEngine):
|
||||
self._policy = policy_engine
|
||||
# capability_name -> implementation
|
||||
self._implementations: dict[str, CapabilityImpl] = {}
|
||||
|
||||
def register_capability(self, name: str, impl: CapabilityImpl) -> None:
|
||||
"""注册一个能力实现
|
||||
|
||||
Args:
|
||||
name: 能力名称,如 "send.text", "db.query", "llm.generate"
|
||||
impl: 实现函数
|
||||
"""
|
||||
self._implementations[name] = impl
|
||||
logger.debug(f"注册能力实现: {name}")
|
||||
|
||||
async def handle_capability_request(self, envelope: Envelope) -> Envelope:
|
||||
"""处理能力调用请求(作为 RPC Server 的 method handler)
|
||||
|
||||
从 envelope 中提取 capability 名称和参数,
|
||||
校验权限后调用对应实现。
|
||||
"""
|
||||
plugin_id = envelope.plugin_id
|
||||
|
||||
try:
|
||||
req = CapabilityRequestPayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BAD_PAYLOAD.value,
|
||||
f"能力调用 payload 格式错误: {e}",
|
||||
)
|
||||
|
||||
capability = req.capability
|
||||
|
||||
# 1. 权限校验
|
||||
allowed, reason = self._policy.check_capability(plugin_id, capability)
|
||||
if not allowed:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_CAPABILITY_DENIED.value,
|
||||
reason,
|
||||
)
|
||||
|
||||
# 2. 限流校验
|
||||
allowed, reason = self._policy.check_rate_limit(plugin_id)
|
||||
if not allowed:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_BACKPRESSURE.value,
|
||||
reason,
|
||||
)
|
||||
|
||||
# 3. 查找实现
|
||||
impl = self._implementations.get(capability)
|
||||
if impl is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的能力: {capability}",
|
||||
)
|
||||
|
||||
# 4. 执行
|
||||
try:
|
||||
result = await impl(plugin_id, capability, req.args)
|
||||
resp_payload = CapabilityResponsePayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except RPCError as e:
|
||||
return envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
except Exception as e:
|
||||
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_CAPABILITY_FAILED.value,
|
||||
str(e),
|
||||
)
|
||||
|
||||
def list_capabilities(self) -> list[str]:
|
||||
"""列出所有已注册的能力"""
|
||||
return list(self._implementations.keys())
|
||||
105
src/plugin_runtime/host/circuit_breaker.py
Normal file
105
src/plugin_runtime/host/circuit_breaker.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""熔断器
|
||||
|
||||
为每个插件提供熔断保护,连续失败超过阈值后临时禁用。
|
||||
支持指数退避恢复。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class CircuitState(str, Enum):
|
||||
CLOSED = "closed" # 正常工作
|
||||
OPEN = "open" # 熔断(拒绝所有调用)
|
||||
HALF_OPEN = "half_open" # 探测恢复
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""单个插件的熔断器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
failure_threshold: int = 5,
|
||||
recovery_timeout_sec: float = 30.0,
|
||||
max_recovery_timeout_sec: float = 300.0,
|
||||
):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.base_recovery_timeout = recovery_timeout_sec
|
||||
self.max_recovery_timeout = max_recovery_timeout_sec
|
||||
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._last_failure_time = 0.0
|
||||
self._consecutive_opens = 0 # 用于指数退避
|
||||
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
if self._state == CircuitState.OPEN:
|
||||
# 检查是否可以进入半开状态
|
||||
elapsed = time.monotonic() - self._last_failure_time
|
||||
recovery_timeout = min(
|
||||
self.base_recovery_timeout * (2 ** self._consecutive_opens),
|
||||
self.max_recovery_timeout,
|
||||
)
|
||||
if elapsed >= recovery_timeout:
|
||||
self._state = CircuitState.HALF_OPEN
|
||||
return self._state
|
||||
|
||||
def allow_request(self) -> bool:
|
||||
"""是否允许通过请求"""
|
||||
state = self.state
|
||||
if state == CircuitState.CLOSED:
|
||||
return True
|
||||
if state == CircuitState.HALF_OPEN:
|
||||
return True # 允许一次试探
|
||||
return False # OPEN 状态拒绝
|
||||
|
||||
def record_success(self) -> None:
|
||||
"""记录一次成功调用"""
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# 半开状态成功 -> 关闭熔断
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._consecutive_opens = 0
|
||||
elif self._state == CircuitState.CLOSED:
|
||||
self._failure_count = 0
|
||||
|
||||
def record_failure(self) -> None:
|
||||
"""记录一次失败调用"""
|
||||
self._failure_count += 1
|
||||
self._last_failure_time = time.monotonic()
|
||||
|
||||
if self._state == CircuitState.HALF_OPEN:
|
||||
# 半开状态失败 -> 重新开启熔断
|
||||
self._state = CircuitState.OPEN
|
||||
self._consecutive_opens += 1
|
||||
elif self._failure_count >= self.failure_threshold:
|
||||
self._state = CircuitState.OPEN
|
||||
self._consecutive_opens += 1
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置熔断器"""
|
||||
self._state = CircuitState.CLOSED
|
||||
self._failure_count = 0
|
||||
self._consecutive_opens = 0
|
||||
|
||||
|
||||
class CircuitBreakerRegistry:
|
||||
"""熔断器注册表,为每个插件维护独立的熔断器"""
|
||||
|
||||
def __init__(self, **default_kwargs):
|
||||
self._breakers: dict[str, CircuitBreaker] = {}
|
||||
self._default_kwargs = default_kwargs
|
||||
|
||||
def get(self, plugin_id: str) -> CircuitBreaker:
|
||||
if plugin_id not in self._breakers:
|
||||
self._breakers[plugin_id] = CircuitBreaker(**self._default_kwargs)
|
||||
return self._breakers[plugin_id]
|
||||
|
||||
def remove(self, plugin_id: str) -> None:
|
||||
self._breakers.pop(plugin_id, None)
|
||||
|
||||
def reset_all(self) -> None:
|
||||
for breaker in self._breakers.values():
|
||||
breaker.reset()
|
||||
125
src/plugin_runtime/host/policy_engine.py
Normal file
125
src/plugin_runtime/host/policy_engine.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""策略引擎
|
||||
|
||||
负责能力授权校验、限流、配额管理。
|
||||
每个插件在 manifest 中声明能力需求,Host 启动时签发能力令牌。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import time
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityToken:
|
||||
"""能力令牌
|
||||
|
||||
描述某个插件在当前会话中被授予的能力和资源限制。
|
||||
"""
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: set[str] = field(default_factory=set)
|
||||
qps_limit: int = 20
|
||||
burst_limit: int = 50
|
||||
daily_token_limit: int = 200000
|
||||
max_payload_kb: int = 256
|
||||
|
||||
# 运行时统计
|
||||
_call_count: int = field(default=0, init=False, repr=False)
|
||||
_window_start: float = field(default_factory=time.monotonic, init=False, repr=False)
|
||||
_window_calls: int = field(default=0, init=False, repr=False)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
"""策略引擎
|
||||
|
||||
管理所有插件的能力令牌,提供授权校验与限流决策。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# plugin_id -> CapabilityToken
|
||||
self._tokens: dict[str, CapabilityToken] = {}
|
||||
|
||||
def register_plugin(
|
||||
self,
|
||||
plugin_id: str,
|
||||
generation: int,
|
||||
capabilities: list[str],
|
||||
limits: dict | None = None,
|
||||
) -> CapabilityToken:
|
||||
"""为插件签发能力令牌"""
|
||||
limits = limits or {}
|
||||
token = CapabilityToken(
|
||||
plugin_id=plugin_id,
|
||||
generation=generation,
|
||||
capabilities=set(capabilities),
|
||||
qps_limit=limits.get("qps", 20),
|
||||
burst_limit=limits.get("burst", 50),
|
||||
daily_token_limit=limits.get("daily_tokens", 200000),
|
||||
max_payload_kb=limits.get("max_payload_kb", 256),
|
||||
)
|
||||
self._tokens[plugin_id] = token
|
||||
return token
|
||||
|
||||
def revoke_plugin(self, plugin_id: str) -> None:
|
||||
"""撤销插件的能力令牌"""
|
||||
self._tokens.pop(plugin_id, None)
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> tuple[bool, str]:
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
(allowed, reason)
|
||||
"""
|
||||
token = self._tokens.get(plugin_id)
|
||||
if token is None:
|
||||
return False, f"插件 {plugin_id} 未注册能力令牌"
|
||||
|
||||
if capability not in token.capabilities:
|
||||
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def check_rate_limit(self, plugin_id: str) -> tuple[bool, str]:
|
||||
"""检查插件是否超过调用频率限制(滑动窗口)
|
||||
|
||||
Returns:
|
||||
(allowed, reason)
|
||||
"""
|
||||
token = self._tokens.get(plugin_id)
|
||||
if token is None:
|
||||
return False, f"插件 {plugin_id} 未注册"
|
||||
|
||||
now = time.monotonic()
|
||||
elapsed = now - token._window_start
|
||||
|
||||
# 每秒重置窗口
|
||||
if elapsed >= 1.0:
|
||||
token._window_start = now
|
||||
token._window_calls = 0
|
||||
|
||||
token._window_calls += 1
|
||||
|
||||
if token._window_calls > token.burst_limit:
|
||||
return False, f"插件 {plugin_id} 超过突发限制 ({token.burst_limit}/s)"
|
||||
|
||||
return True, ""
|
||||
|
||||
def check_payload_size(self, plugin_id: str, payload_size_bytes: int) -> tuple[bool, str]:
|
||||
"""检查 payload 大小是否在限制内"""
|
||||
token = self._tokens.get(plugin_id)
|
||||
if token is None:
|
||||
return False, f"插件 {plugin_id} 未注册"
|
||||
|
||||
max_bytes = token.max_payload_kb * 1024
|
||||
if payload_size_bytes > max_bytes:
|
||||
return False, f"payload 大小 {payload_size_bytes} 超过限制 {max_bytes}"
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> CapabilityToken | None:
|
||||
"""获取插件的能力令牌"""
|
||||
return self._tokens.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._tokens.keys())
|
||||
357
src/plugin_runtime/host/rpc_server.py
Normal file
357
src/plugin_runtime/host/rpc_server.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Host 端 RPC Server
|
||||
|
||||
负责:
|
||||
1. 监听 Runner 连接
|
||||
2. 处理握手(runner.hello)
|
||||
3. 分发调用请求给 Runner / 处理 Runner 的能力调用
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from src.plugin_runtime.protocol.codec import Codec, create_codec
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
PROTOCOL_VERSION,
|
||||
MIN_SDK_VERSION,
|
||||
MAX_SDK_VERSION,
|
||||
Envelope,
|
||||
HelloPayload,
|
||||
HelloResponsePayload,
|
||||
MessageType,
|
||||
RequestIdGenerator,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.transport.base import Connection, TransportServer
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.host.rpc_server")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
|
||||
|
||||
class RPCServer:
|
||||
"""Host 端 RPC 服务器
|
||||
|
||||
管理与 Runner 的 IPC 连接,处理双向 RPC 调用。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: TransportServer,
|
||||
session_token: str | None = None,
|
||||
codec: Codec | None = None,
|
||||
send_queue_size: int = 128,
|
||||
):
|
||||
self._transport = transport
|
||||
self._session_token = session_token or secrets.token_hex(32)
|
||||
self._codec = codec or create_codec()
|
||||
self._send_queue_size = send_queue_size
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Connection | None = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: str | None = None
|
||||
self._runner_generation: int = 0
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: asyncio.Queue | None = None
|
||||
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._tasks: list[asyncio.Task] = []
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
return self._session_token
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册 RPC 方法处理器"""
|
||||
self._method_handlers[method] = handler
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 RPC 服务器"""
|
||||
self._running = True
|
||||
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
|
||||
await self._transport.start(self._handle_connection)
|
||||
logger.info(f"RPC Server 已启动,监听地址: {self._transport.get_address()}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 RPC 服务器"""
|
||||
self._running = False
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for req_id, future in self._pending_requests.items():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 取消后台任务
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
self._tasks.clear()
|
||||
|
||||
# 关闭连接
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
await self._transport.stop()
|
||||
logger.info("RPC Server 已停止")
|
||||
|
||||
async def send_request(
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str = "",
|
||||
payload: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
|
||||
Args:
|
||||
method: RPC 方法名
|
||||
plugin_id: 目标插件 ID
|
||||
payload: 请求数据
|
||||
timeout_ms: 超时时间(ms)
|
||||
|
||||
Returns:
|
||||
响应 Envelope
|
||||
|
||||
Raises:
|
||||
RPCError: 调用失败
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
# 背压检查
|
||||
if self._send_queue and self._send_queue.full():
|
||||
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满")
|
||||
|
||||
# 注册 pending future
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
|
||||
# 等待响应
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
response = await asyncio.wait_for(future, timeout=timeout_sec)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)")
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
if isinstance(e, RPCError):
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e))
|
||||
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
|
||||
"""向 Runner 发送单向事件(不等待响应)"""
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.EVENT,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._runner_generation,
|
||||
payload=payload or {},
|
||||
)
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
async def _handle_connection(self, conn: Connection) -> None:
|
||||
"""处理新的 Runner 连接"""
|
||||
logger.info("收到 Runner 连接")
|
||||
|
||||
# 第一条消息必须是 runner.hello 握手
|
||||
try:
|
||||
handshake_ok = await self._handle_handshake(conn)
|
||||
if not handshake_ok:
|
||||
await conn.close()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"握手失败: {e}")
|
||||
await conn.close()
|
||||
return
|
||||
|
||||
# 握手成功,保存连接
|
||||
self._connection = conn
|
||||
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
|
||||
|
||||
# 启动消息接收循环
|
||||
try:
|
||||
await self._recv_loop(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"连接异常断开: {e}")
|
||||
finally:
|
||||
self._connection = None
|
||||
self._runner_id = None
|
||||
|
||||
async def _handle_handshake(self, conn: Connection) -> bool:
|
||||
"""处理 runner.hello 握手"""
|
||||
# 接收握手请求
|
||||
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
|
||||
if envelope.method != "runner.hello":
|
||||
logger.error(f"期望 runner.hello,收到 {envelope.method}")
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_PROTOCOL_MISMATCH.value,
|
||||
"首条消息必须为 runner.hello",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return False
|
||||
|
||||
# 解析握手 payload
|
||||
hello = HelloPayload.model_validate(envelope.payload)
|
||||
|
||||
# 校验会话令牌
|
||||
if hello.session_token != self._session_token:
|
||||
logger.error("会话令牌不匹配")
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason="会话令牌无效",
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
|
||||
# 校验 SDK 版本
|
||||
if not self._check_sdk_version(hello.sdk_version):
|
||||
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=False,
|
||||
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
return False
|
||||
|
||||
# 握手成功
|
||||
self._runner_id = hello.runner_id
|
||||
self._runner_generation += 1
|
||||
|
||||
resp_payload = HelloResponsePayload(
|
||||
accepted=True,
|
||||
host_version=PROTOCOL_VERSION,
|
||||
assigned_generation=self._runner_generation,
|
||||
)
|
||||
resp = envelope.make_response(payload=resp_payload.model_dump())
|
||||
await conn.send_frame(self._codec.encode_envelope(resp))
|
||||
|
||||
return True
|
||||
|
||||
async def _recv_loop(self, conn: Connection) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and not conn.is_closed:
|
||||
try:
|
||||
data = await conn.recv_frame()
|
||||
except (asyncio.IncompleteReadError, ConnectionError):
|
||||
logger.info("Runner 连接已断开")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"接收帧失败: {e}")
|
||||
break
|
||||
|
||||
try:
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
except Exception as e:
|
||||
logger.error(f"解码消息失败: {e}")
|
||||
continue
|
||||
|
||||
# 分发消息
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
# 异步处理请求(Runner 发来的能力调用)
|
||||
task = asyncio.create_task(self._handle_request(envelope, conn))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
elif envelope.is_event():
|
||||
task = asyncio.create_task(self._handle_event(envelope))
|
||||
self._tasks.append(task)
|
||||
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的响应"""
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
|
||||
"""处理来自 Runner 的请求(通常是能力调用 cap.*)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的方法: {envelope.method}",
|
||||
)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return
|
||||
|
||||
try:
|
||||
response = await handler(envelope)
|
||||
await conn.send_frame(self._codec.encode_envelope(response))
|
||||
except RPCError as e:
|
||||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
await conn.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Runner 的事件"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler:
|
||||
try:
|
||||
await handler(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _check_sdk_version(sdk_version: str) -> bool:
|
||||
"""检查 SDK 版本是否在支持范围内"""
|
||||
try:
|
||||
sdk_parts = [int(x) for x in sdk_version.split(".")]
|
||||
min_parts = [int(x) for x in MIN_SDK_VERSION.split(".")]
|
||||
max_parts = [int(x) for x in MAX_SDK_VERSION.split(".")]
|
||||
return min_parts <= sdk_parts <= max_parts
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
315
src/plugin_runtime/host/supervisor.py
Normal file
315
src/plugin_runtime/host/supervisor.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Supervisor - 插件生命周期管理
|
||||
|
||||
负责:
|
||||
1. 拉起 Runner 子进程
|
||||
2. 健康检查
|
||||
3. 熔断与恢复
|
||||
4. 代码热重载(generation 切换)
|
||||
5. 优雅关停
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from src.plugin_runtime.host.capability_service import CapabilityService
|
||||
from src.plugin_runtime.host.circuit_breaker import CircuitBreakerRegistry
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
from src.plugin_runtime.host.rpc_server import RPCServer
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
RegisterComponentsPayload,
|
||||
ShutdownPayload,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.transport.factory import create_transport_server
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.host.supervisor")
|
||||
|
||||
|
||||
class PluginSupervisor:
|
||||
"""插件 Supervisor
|
||||
|
||||
Host 端的核心管理器,负责整个插件 Runner 进程的生命周期。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugin_dirs: list[str] | None = None,
|
||||
socket_path: str | None = None,
|
||||
health_check_interval_sec: float = 30.0,
|
||||
use_json_codec: bool = False,
|
||||
):
|
||||
self._plugin_dirs = plugin_dirs or []
|
||||
self._health_interval = health_check_interval_sec
|
||||
|
||||
# 基础设施
|
||||
self._transport = create_transport_server(socket_path=socket_path)
|
||||
self._policy = PolicyEngine()
|
||||
self._breakers = CircuitBreakerRegistry()
|
||||
self._capability_service = CapabilityService(self._policy)
|
||||
|
||||
# 编解码
|
||||
from src.plugin_runtime.protocol.codec import create_codec
|
||||
codec = create_codec(use_json=use_json_codec)
|
||||
|
||||
self._rpc_server = RPCServer(
|
||||
transport=self._transport,
|
||||
codec=codec,
|
||||
)
|
||||
|
||||
# Runner 子进程
|
||||
self._runner_process: asyncio.subprocess.Process | None = None
|
||||
self._runner_generation: int = 0
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
|
||||
|
||||
# 后台任务
|
||||
self._health_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
# 注册内部 RPC 方法
|
||||
self._register_internal_methods()
|
||||
|
||||
@property
|
||||
def policy_engine(self) -> PolicyEngine:
|
||||
return self._policy
|
||||
|
||||
@property
|
||||
def capability_service(self) -> CapabilityService:
|
||||
return self._capability_service
|
||||
|
||||
@property
|
||||
def rpc_server(self) -> RPCServer:
|
||||
return self._rpc_server
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Supervisor
|
||||
|
||||
1. 启动 RPC Server
|
||||
2. 拉起 Runner 子进程
|
||||
3. 启动健康检查
|
||||
"""
|
||||
self._running = True
|
||||
|
||||
# 启动 RPC Server
|
||||
await self._rpc_server.start()
|
||||
|
||||
# 拉起 Runner 进程
|
||||
await self._spawn_runner()
|
||||
|
||||
# 启动健康检查
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
logger.info("PluginSupervisor 已启动")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Supervisor"""
|
||||
self._running = False
|
||||
|
||||
# 停止健康检查
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
self._health_task = None
|
||||
|
||||
# 优雅关停 Runner
|
||||
await self._shutdown_runner()
|
||||
|
||||
# 停止 RPC Server
|
||||
await self._rpc_server.stop()
|
||||
|
||||
logger.info("PluginSupervisor 已停止")
|
||||
|
||||
async def invoke_plugin(
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""调用插件组件
|
||||
|
||||
由主进程业务逻辑调用,通过 RPC 转发给 Runner。
|
||||
"""
|
||||
# 熔断检查
|
||||
breaker = self._breakers.get(plugin_id)
|
||||
if not breaker.allow_request():
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, f"插件 {plugin_id} 已被熔断")
|
||||
|
||||
try:
|
||||
response = await self._rpc_server.send_request(
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
payload={
|
||||
"component_name": component_name,
|
||||
"args": args or {},
|
||||
},
|
||||
timeout_ms=timeout_ms,
|
||||
)
|
||||
breaker.record_success()
|
||||
return response
|
||||
except RPCError:
|
||||
breaker.record_failure()
|
||||
raise
|
||||
|
||||
async def reload_plugins(self, reason: str = "manual") -> None:
|
||||
"""热重载所有插件(进程级 generation 切换)
|
||||
|
||||
1. 拉起新 Runner
|
||||
2. 等待新 Runner 完成注册和健康检查
|
||||
3. 关停旧 Runner
|
||||
"""
|
||||
logger.info(f"开始热重载插件,原因: {reason}")
|
||||
|
||||
# 保存旧进程引用
|
||||
old_process = self._runner_process
|
||||
|
||||
# 拉起新 Runner
|
||||
await self._spawn_runner()
|
||||
|
||||
# 等待新 Runner 连接并完成握手
|
||||
for _ in range(30): # 最多等待 30 秒
|
||||
if self._rpc_server.is_connected:
|
||||
break
|
||||
await asyncio.sleep(1.0)
|
||||
else:
|
||||
logger.error("新 Runner 连接超时,回滚")
|
||||
# 回滚:终止新进程
|
||||
if self._runner_process and self._runner_process != old_process:
|
||||
self._runner_process.terminate()
|
||||
self._runner_process = old_process
|
||||
return
|
||||
|
||||
# 健康检查
|
||||
try:
|
||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "新 Runner 健康检查失败")
|
||||
except Exception as e:
|
||||
logger.error(f"新 Runner 健康检查失败: {e},回滚")
|
||||
if self._runner_process and self._runner_process != old_process:
|
||||
self._runner_process.terminate()
|
||||
self._runner_process = old_process
|
||||
return
|
||||
|
||||
# 关停旧 Runner
|
||||
if old_process and old_process.returncode is None:
|
||||
try:
|
||||
old_process.terminate()
|
||||
await asyncio.wait_for(old_process.wait(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
old_process.kill()
|
||||
|
||||
logger.info("热重载完成")
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
def _register_internal_methods(self) -> None:
|
||||
"""注册 Host 端的 RPC 方法处理器"""
|
||||
# Runner -> Host 的能力调用统一走 capability_service
|
||||
self._rpc_server.register_method("cap.request", self._capability_service.handle_capability_request)
|
||||
# 插件注册
|
||||
self._rpc_server.register_method("plugin.register_components", self._handle_register_components)
|
||||
|
||||
async def _handle_register_components(self, envelope: Envelope) -> Envelope:
|
||||
"""处理插件组件注册请求"""
|
||||
try:
|
||||
reg = RegisterComponentsPayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
# 记录注册信息
|
||||
self._registered_plugins[reg.plugin_id] = reg
|
||||
|
||||
# 在策略引擎中注册插件
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=envelope.generation,
|
||||
capabilities=reg.capabilities_required,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"插件 {reg.plugin_id} v{reg.plugin_version} 注册成功,"
|
||||
f"组件数: {len(reg.components)}, 能力需求: {reg.capabilities_required}"
|
||||
)
|
||||
|
||||
return envelope.make_response(payload={"accepted": True})
|
||||
|
||||
async def _spawn_runner(self) -> None:
|
||||
"""拉起 Runner 子进程"""
|
||||
runner_module = "src.plugin_runtime.runner.runner_main"
|
||||
address = self._transport.get_address()
|
||||
token = self._rpc_server.session_token
|
||||
|
||||
env = os.environ.copy()
|
||||
env["MAIBOT_IPC_ADDRESS"] = address
|
||||
env["MAIBOT_SESSION_TOKEN"] = token
|
||||
env["MAIBOT_PLUGIN_DIRS"] = os.pathsep.join(self._plugin_dirs)
|
||||
|
||||
self._runner_process = await asyncio.create_subprocess_exec(
|
||||
sys.executable, "-m", runner_module,
|
||||
env=env,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._runner_generation += 1
|
||||
logger.info(f"Runner 子进程已启动: pid={self._runner_process.pid}, generation={self._runner_generation}")
|
||||
|
||||
async def _shutdown_runner(self) -> None:
|
||||
"""优雅关停 Runner"""
|
||||
if not self._runner_process or self._runner_process.returncode is not None:
|
||||
return
|
||||
|
||||
# 发送 prepare_shutdown
|
||||
try:
|
||||
if self._rpc_server.is_connected:
|
||||
shutdown_payload = ShutdownPayload(reason="host_shutdown", drain_timeout_ms=5000)
|
||||
await self._rpc_server.send_request(
|
||||
"plugin.prepare_shutdown",
|
||||
payload=shutdown_payload.model_dump(),
|
||||
timeout_ms=5000,
|
||||
)
|
||||
await self._rpc_server.send_request(
|
||||
"plugin.shutdown",
|
||||
payload=shutdown_payload.model_dump(),
|
||||
timeout_ms=5000,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"发送关停命令失败: {e}")
|
||||
|
||||
# 等待进程退出
|
||||
try:
|
||||
await asyncio.wait_for(self._runner_process.wait(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Runner 未在超时内退出,强制终止")
|
||||
self._runner_process.kill()
|
||||
await self._runner_process.wait()
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""周期性健康检查"""
|
||||
while self._running:
|
||||
await asyncio.sleep(self._health_interval)
|
||||
|
||||
if not self._rpc_server.is_connected:
|
||||
logger.warning("Runner 未连接,跳过健康检查")
|
||||
continue
|
||||
|
||||
try:
|
||||
resp = await self._rpc_server.send_request("plugin.health", timeout_ms=5000)
|
||||
health = HealthPayload.model_validate(resp.payload)
|
||||
if not health.healthy:
|
||||
logger.warning(f"Runner 健康检查异常: {health}")
|
||||
except RPCError as e:
|
||||
logger.error(f"健康检查失败: {e}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查异常: {e}")
|
||||
1
src/plugin_runtime/protocol/__init__.py
Normal file
1
src/plugin_runtime/protocol/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Protocol 层 - RPC 消息模型、编解码、错误码
|
||||
80
src/plugin_runtime/protocol/codec.py
Normal file
80
src/plugin_runtime/protocol/codec.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""MsgPack / JSON 编解码器
|
||||
|
||||
提供统一的消息编解码接口,生产环境默认使用 MsgPack,
|
||||
开发调试模式可切换为 JSON(仅编解码切换,传输层不变)。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
|
||||
import msgpack
|
||||
|
||||
from .envelope import Envelope
|
||||
|
||||
|
||||
class Codec:
|
||||
"""消息编解码器基类"""
|
||||
|
||||
def encode_envelope(self, envelope: Envelope) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_envelope(self, data: bytes) -> Envelope:
|
||||
raise NotImplementedError
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, data: bytes) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MsgPackCodec(Codec):
|
||||
"""MsgPack 编解码器(生产默认)"""
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
def decode(self, data: bytes) -> dict[str, Any]:
|
||||
result = msgpack.unpackb(data, raw=False)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"期望解码为 dict,实际为 {type(result)}")
|
||||
return result
|
||||
|
||||
def encode_envelope(self, envelope: Envelope) -> bytes:
|
||||
return self.encode(envelope.model_dump())
|
||||
|
||||
def decode_envelope(self, data: bytes) -> Envelope:
|
||||
raw = self.decode(data)
|
||||
return Envelope.model_validate(raw)
|
||||
|
||||
|
||||
class JsonCodec(Codec):
|
||||
"""JSON 编解码器(开发调试用)"""
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
return json.dumps(obj, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
def decode(self, data: bytes) -> dict[str, Any]:
|
||||
result = json.loads(data.decode("utf-8"))
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"期望解码为 dict,实际为 {type(result)}")
|
||||
return result
|
||||
|
||||
def encode_envelope(self, envelope: Envelope) -> bytes:
|
||||
return self.encode(envelope.model_dump())
|
||||
|
||||
def decode_envelope(self, data: bytes) -> Envelope:
|
||||
raw = self.decode(data)
|
||||
return Envelope.model_validate(raw)
|
||||
|
||||
|
||||
def create_codec(use_json: bool = False) -> Codec:
|
||||
"""创建编解码器实例
|
||||
|
||||
Args:
|
||||
use_json: 是否使用 JSON(开发模式)。默认使用 MsgPack。
|
||||
"""
|
||||
if use_json:
|
||||
return JsonCodec()
|
||||
return MsgPackCodec()
|
||||
187
src/plugin_runtime/protocol/envelope.py
Normal file
187
src/plugin_runtime/protocol/envelope.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""RPC Envelope 消息模型
|
||||
|
||||
定义 Host 与 Runner 之间所有 RPC 消息的统一信封格式。
|
||||
使用 Pydantic 进行 schema 定义与校验。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import time
|
||||
|
||||
|
||||
# ─── 协议常量 ──────────────────────────────────────────────────────
|
||||
|
||||
PROTOCOL_VERSION = "1.0"
|
||||
|
||||
# 支持的 SDK 版本范围(Host 在握手时校验)
|
||||
MIN_SDK_VERSION = "1.0.0"
|
||||
MAX_SDK_VERSION = "1.99.99"
|
||||
|
||||
|
||||
# ─── 消息类型 ──────────────────────────────────────────────────────
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""RPC 消息类型"""
|
||||
REQUEST = "request"
|
||||
RESPONSE = "response"
|
||||
EVENT = "event"
|
||||
|
||||
|
||||
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
|
||||
|
||||
class RequestIdGenerator:
|
||||
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio)"""
|
||||
|
||||
def __init__(self, start: int = 1):
|
||||
self._counter = start
|
||||
|
||||
def next(self) -> int:
|
||||
current = self._counter
|
||||
self._counter += 1
|
||||
return current
|
||||
|
||||
|
||||
# ─── Envelope 模型 ─────────────────────────────────────────────────
|
||||
|
||||
class Envelope(BaseModel):
|
||||
"""RPC 统一信封
|
||||
|
||||
所有 Host <-> Runner 消息均封装为此格式。
|
||||
序列化流程:Envelope -> .model_dump() -> MsgPack encode
|
||||
反序列化流程:MsgPack decode -> Envelope.model_validate(data)
|
||||
"""
|
||||
|
||||
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
|
||||
request_id: int = Field(description="单调递增请求 ID")
|
||||
message_type: MessageType = Field(description="消息类型")
|
||||
method: str = Field(default="", description="RPC 方法名")
|
||||
plugin_id: str = Field(default="", description="目标插件 ID")
|
||||
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
|
||||
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
|
||||
generation: int = Field(default=0, description="Runner generation 编号")
|
||||
payload: dict[str, Any] = Field(default_factory=dict, description="业务数据")
|
||||
error: dict[str, Any] | None = Field(default=None, description="错误信息(仅 response)")
|
||||
|
||||
def is_request(self) -> bool:
|
||||
return self.message_type == MessageType.REQUEST
|
||||
|
||||
def is_response(self) -> bool:
|
||||
return self.message_type == MessageType.RESPONSE
|
||||
|
||||
def is_event(self) -> bool:
|
||||
return self.message_type == MessageType.EVENT
|
||||
|
||||
def make_response(self, payload: dict[str, Any] | None = None, error: dict[str, Any] | None = None) -> "Envelope":
|
||||
"""基于当前请求创建对应的响应信封"""
|
||||
return Envelope(
|
||||
protocol_version=self.protocol_version,
|
||||
request_id=self.request_id,
|
||||
message_type=MessageType.RESPONSE,
|
||||
method=self.method,
|
||||
plugin_id=self.plugin_id,
|
||||
generation=self.generation,
|
||||
payload=payload or {},
|
||||
error=error,
|
||||
)
|
||||
|
||||
def make_error_response(self, code: str, message: str = "", details: dict | None = None) -> "Envelope":
|
||||
"""基于当前请求创建错误响应"""
|
||||
return self.make_response(
|
||||
error={
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details or {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ─── 握手消息 ──────────────────────────────────────────────────────
|
||||
|
||||
class HelloPayload(BaseModel):
|
||||
"""runner.hello 握手请求 payload"""
|
||||
runner_id: str = Field(description="Runner 进程唯一标识")
|
||||
sdk_version: str = Field(description="SDK 版本号")
|
||||
session_token: str = Field(description="一次性会话令牌")
|
||||
|
||||
|
||||
class HelloResponsePayload(BaseModel):
|
||||
"""runner.hello 握手响应 payload"""
|
||||
accepted: bool = Field(description="是否接受连接")
|
||||
host_version: str = Field(default="", description="Host 版本号")
|
||||
assigned_generation: int = Field(default=0, description="分配的 generation 编号")
|
||||
reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
|
||||
|
||||
|
||||
# ─── 组件注册消息 ──────────────────────────────────────────────────
|
||||
|
||||
class ComponentDeclaration(BaseModel):
|
||||
"""单个组件声明"""
|
||||
name: str = Field(description="组件名称")
|
||||
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
|
||||
plugin_id: str = Field(description="所属插件 ID")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
||||
|
||||
|
||||
class RegisterComponentsPayload(BaseModel):
|
||||
"""plugin.register_components 请求 payload"""
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||
components: list[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
capabilities_required: list[str] = Field(default_factory=list, description="所需能力列表")
|
||||
|
||||
|
||||
# ─── 调用消息 ──────────────────────────────────────────────────────
|
||||
|
||||
class InvokePayload(BaseModel):
|
||||
"""plugin.invoke_* 请求 payload"""
|
||||
component_name: str = Field(description="要调用的组件名称")
|
||||
args: dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
|
||||
|
||||
class InvokeResultPayload(BaseModel):
|
||||
"""plugin.invoke_* 响应 payload"""
|
||||
success: bool = Field(description="是否成功")
|
||||
result: Any = Field(default=None, description="返回值")
|
||||
|
||||
|
||||
# ─── 能力调用消息 ──────────────────────────────────────────────────
|
||||
|
||||
class CapabilityRequestPayload(BaseModel):
|
||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||
capability: str = Field(description="能力名称,如 send.text, db.query")
|
||||
args: dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
|
||||
|
||||
class CapabilityResponsePayload(BaseModel):
|
||||
"""cap.* 响应 payload"""
|
||||
success: bool = Field(description="是否成功")
|
||||
result: Any = Field(default=None, description="返回值")
|
||||
|
||||
|
||||
# ─── 健康检查 ──────────────────────────────────────────────────────
|
||||
|
||||
class HealthPayload(BaseModel):
|
||||
"""plugin.health 响应 payload"""
|
||||
healthy: bool = Field(description="是否健康")
|
||||
loaded_plugins: list[str] = Field(default_factory=list, description="已加载的插件列表")
|
||||
uptime_ms: int = Field(default=0, description="运行时长(ms)")
|
||||
|
||||
|
||||
# ─── 配置更新 ──────────────────────────────────────────────────────
|
||||
|
||||
class ConfigUpdatedPayload(BaseModel):
|
||||
"""plugin.config_updated 事件 payload"""
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
config_version: str = Field(description="新配置版本")
|
||||
config_data: dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||
|
||||
|
||||
# ─── 关停 ──────────────────────────────────────────────────────────
|
||||
|
||||
class ShutdownPayload(BaseModel):
|
||||
"""plugin.shutdown / plugin.prepare_shutdown payload"""
|
||||
reason: str = Field(default="normal", description="关停原因")
|
||||
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
|
||||
61
src/plugin_runtime/protocol/errors.py
Normal file
61
src/plugin_runtime/protocol/errors.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""RPC 错误码定义
|
||||
|
||||
所有 Host 与 Runner 之间的 RPC 通信使用统一的错误码体系。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""RPC 错误码枚举"""
|
||||
|
||||
# 通用
|
||||
OK = "OK"
|
||||
E_UNKNOWN = "E_UNKNOWN"
|
||||
|
||||
# 协议层
|
||||
E_TIMEOUT = "E_TIMEOUT"
|
||||
E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
|
||||
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
|
||||
|
||||
# 权限与策略
|
||||
E_UNAUTHORIZED = "E_UNAUTHORIZED"
|
||||
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
|
||||
E_BACKPRESSURE = "E_BACKPRESSURE"
|
||||
E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
|
||||
|
||||
# 插件生命周期
|
||||
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
|
||||
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
|
||||
E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
|
||||
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
|
||||
|
||||
# 能力调用
|
||||
E_CAPABILITY_DENIED = "E_CAPABILITY_DENIED"
|
||||
E_CAPABILITY_FAILED = "E_CAPABILITY_FAILED"
|
||||
|
||||
|
||||
class RPCError(Exception):
|
||||
"""RPC 调用异常"""
|
||||
|
||||
def __init__(self, code: ErrorCode, message: str = "", details: dict | None = None):
|
||||
self.code = code
|
||||
self.message = message or code.value
|
||||
self.details = details or {}
|
||||
super().__init__(f"[{code.value}] {self.message}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"code": self.code.value,
|
||||
"message": self.message,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "RPCError":
|
||||
code = ErrorCode(data.get("code", "E_UNKNOWN"))
|
||||
return cls(
|
||||
code=code,
|
||||
message=data.get("message", ""),
|
||||
details=data.get("details", {}),
|
||||
)
|
||||
1
src/plugin_runtime/runner/__init__.py
Normal file
1
src/plugin_runtime/runner/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Runner 端 - 插件加载与执行进程
|
||||
127
src/plugin_runtime/runner/plugin_loader.py
Normal file
127
src/plugin_runtime/runner/plugin_loader.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""插件加载器
|
||||
|
||||
在 Runner 进程中负责发现和加载插件。
|
||||
插件通过 SDK 编写,不再 import src.*。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.runner.plugin_loader")
|
||||
|
||||
|
||||
class PluginMeta:
|
||||
"""加载后的插件元数据"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugin_id: str,
|
||||
plugin_dir: str,
|
||||
plugin_instance: Any,
|
||||
manifest: dict[str, Any],
|
||||
):
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_dir = plugin_dir
|
||||
self.instance = plugin_instance
|
||||
self.manifest = manifest
|
||||
self.version = manifest.get("version", "1.0.0")
|
||||
self.capabilities_required = manifest.get("capabilities", [])
|
||||
|
||||
|
||||
class PluginLoader:
|
||||
"""插件加载器
|
||||
|
||||
扫描插件目录,加载符合 SDK 规范的插件。
|
||||
每个插件目录须包含:
|
||||
- _manifest.json: 插件元数据
|
||||
- plugin.py: 插件入口模块(导出 create_plugin 工厂函数)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._loaded_plugins: dict[str, PluginMeta] = {}
|
||||
|
||||
def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]:
|
||||
"""扫描多个目录并加载所有插件
|
||||
|
||||
Args:
|
||||
plugin_dirs: 插件目录列表
|
||||
|
||||
Returns:
|
||||
成功加载的插件元数据列表
|
||||
"""
|
||||
results = []
|
||||
for base_dir in plugin_dirs:
|
||||
if not os.path.isdir(base_dir):
|
||||
logger.warning(f"插件目录不存在: {base_dir}")
|
||||
continue
|
||||
|
||||
for entry in os.listdir(base_dir):
|
||||
plugin_dir = os.path.join(base_dir, entry)
|
||||
if not os.path.isdir(plugin_dir):
|
||||
continue
|
||||
|
||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
||||
plugin_path = os.path.join(plugin_dir, "plugin.py")
|
||||
|
||||
if not os.path.exists(manifest_path) or not os.path.exists(plugin_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
meta = self._load_single_plugin(plugin_dir, manifest_path, plugin_path)
|
||||
if meta:
|
||||
self._loaded_plugins[meta.plugin_id] = meta
|
||||
results.append(meta)
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件失败 [{plugin_dir}]: {e}", exc_info=True)
|
||||
|
||||
return results
|
||||
|
||||
def get_plugin(self, plugin_id: str) -> PluginMeta | None:
|
||||
"""获取已加载的插件"""
|
||||
return self._loaded_plugins.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[str]:
|
||||
"""列出所有已加载的插件 ID"""
|
||||
return list(self._loaded_plugins.keys())
|
||||
|
||||
def _load_single_plugin(self, plugin_dir: str, manifest_path: str, plugin_path: str) -> PluginMeta | None:
|
||||
"""加载单个插件"""
|
||||
# 1. 读取 manifest
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
|
||||
plugin_id = os.path.basename(plugin_dir)
|
||||
|
||||
# 2. 动态导入插件模块
|
||||
module_name = f"_maibot_plugin_{plugin_id}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"无法创建模块 spec: {plugin_path}")
|
||||
return None
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# 3. 调用工厂函数创建插件实例
|
||||
create_plugin = getattr(module, "create_plugin", None)
|
||||
if create_plugin is None:
|
||||
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数")
|
||||
return None
|
||||
|
||||
instance = create_plugin()
|
||||
|
||||
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
|
||||
|
||||
return PluginMeta(
|
||||
plugin_id=plugin_id,
|
||||
plugin_dir=plugin_dir,
|
||||
plugin_instance=instance,
|
||||
manifest=manifest,
|
||||
)
|
||||
257
src/plugin_runtime/runner/rpc_client.py
Normal file
257
src/plugin_runtime/runner/rpc_client.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Runner 端 RPC Client
|
||||
|
||||
负责:
|
||||
1. 连接 Host RPC Server
|
||||
2. 发送握手(runner.hello)
|
||||
3. 发送组件注册请求
|
||||
4. 接收并分发 Host 的调用请求
|
||||
5. 发送能力调用请求到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from src.plugin_runtime.protocol.codec import Codec, create_codec
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
PROTOCOL_VERSION,
|
||||
Envelope,
|
||||
HelloPayload,
|
||||
HelloResponsePayload,
|
||||
MessageType,
|
||||
RequestIdGenerator,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.transport.base import Connection
|
||||
from src.plugin_runtime.transport.factory import create_transport_client
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.runner.rpc_client")
|
||||
|
||||
# RPC 方法处理器类型
|
||||
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
|
||||
|
||||
SDK_VERSION = "1.0.0"
|
||||
|
||||
|
||||
class RPCClient:
|
||||
"""Runner 端 RPC 客户端
|
||||
|
||||
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
codec: Codec | None = None,
|
||||
):
|
||||
self._host_address = host_address
|
||||
self._session_token = session_token
|
||||
self._codec = codec or create_codec()
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Connection | None = None
|
||||
self._runner_id = str(uuid.uuid4())
|
||||
self._generation: int = 0
|
||||
|
||||
# 方法处理器注册表(Host 发来的调用)
|
||||
self._method_handlers: dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._recv_task: asyncio.Task | None = None
|
||||
|
||||
@property
|
||||
def generation(self) -> int:
|
||||
return self._generation
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connection is not None and not self._connection.is_closed
|
||||
|
||||
def register_method(self, method: str, handler: MethodHandler) -> None:
|
||||
"""注册方法处理器(处理 Host 发来的请求)"""
|
||||
self._method_handlers[method] = handler
|
||||
|
||||
async def connect_and_handshake(self) -> bool:
|
||||
"""连接 Host 并完成握手
|
||||
|
||||
Returns:
|
||||
是否握手成功
|
||||
"""
|
||||
client = create_transport_client(self._host_address)
|
||||
self._connection = await client.connect()
|
||||
|
||||
# 发送 runner.hello
|
||||
hello = HelloPayload(
|
||||
runner_id=self._runner_id,
|
||||
sdk_version=SDK_VERSION,
|
||||
session_token=self._session_token,
|
||||
)
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method="runner.hello",
|
||||
payload=hello.model_dump(),
|
||||
)
|
||||
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
|
||||
# 接收握手响应
|
||||
resp_data = await asyncio.wait_for(self._connection.recv_frame(), timeout=10.0)
|
||||
resp = self._codec.decode_envelope(resp_data)
|
||||
|
||||
resp_payload = HelloResponsePayload.model_validate(resp.payload)
|
||||
if not resp_payload.accepted:
|
||||
logger.error(f"握手被拒绝: {resp_payload.reason}")
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
return False
|
||||
|
||||
self._generation = resp_payload.assigned_generation
|
||||
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
|
||||
|
||||
# 启动消息接收循环
|
||||
self._running = True
|
||||
self._recv_task = asyncio.create_task(self._recv_loop())
|
||||
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""断开连接"""
|
||||
self._running = False
|
||||
if self._recv_task:
|
||||
self._recv_task.cancel()
|
||||
try:
|
||||
await self._recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._recv_task = None
|
||||
|
||||
# 取消所有 pending 请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
|
||||
self._pending_requests.clear()
|
||||
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
async def send_request(
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str = "",
|
||||
payload: dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Host 发送 RPC 请求并等待响应"""
|
||||
if not self.is_connected:
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
|
||||
|
||||
request_id = self._id_gen.next()
|
||||
envelope = Envelope(
|
||||
request_id=request_id,
|
||||
message_type=MessageType.REQUEST,
|
||||
method=method,
|
||||
plugin_id=plugin_id,
|
||||
generation=self._generation,
|
||||
timeout_ms=timeout_ms,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[Envelope] = loop.create_future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
data = self._codec.encode_envelope(envelope)
|
||||
await self._connection.send_frame(data)
|
||||
|
||||
timeout_sec = timeout_ms / 1000.0
|
||||
response = await asyncio.wait_for(future, timeout=timeout_sec)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)")
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
if isinstance(e, RPCError):
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e))
|
||||
|
||||
# ─── 内部方法 ──────────────────────────────────────────────
|
||||
|
||||
async def _recv_loop(self) -> None:
|
||||
"""消息接收主循环"""
|
||||
while self._running and self._connection and not self._connection.is_closed:
|
||||
try:
|
||||
data = await self._connection.recv_frame()
|
||||
except (asyncio.IncompleteReadError, ConnectionError):
|
||||
logger.info("Host 连接已断开")
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"接收帧失败: {e}")
|
||||
break
|
||||
|
||||
try:
|
||||
envelope = self._codec.decode_envelope(data)
|
||||
except Exception as e:
|
||||
logger.error(f"解码消息失败: {e}")
|
||||
continue
|
||||
|
||||
if envelope.is_response():
|
||||
self._handle_response(envelope)
|
||||
elif envelope.is_request():
|
||||
asyncio.create_task(self._handle_request(envelope))
|
||||
elif envelope.is_event():
|
||||
asyncio.create_task(self._handle_event(envelope))
|
||||
|
||||
def _handle_response(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的响应"""
|
||||
future = self._pending_requests.pop(envelope.request_id, None)
|
||||
if future and not future.done():
|
||||
if envelope.error:
|
||||
future.set_exception(RPCError.from_dict(envelope.error))
|
||||
else:
|
||||
future.set_result(envelope)
|
||||
|
||||
async def _handle_request(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的请求(调用插件组件)"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler is None:
|
||||
error_resp = envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"未注册的方法: {envelope.method}",
|
||||
)
|
||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
return
|
||||
|
||||
try:
|
||||
response = await handler(envelope)
|
||||
await self._connection.send_frame(self._codec.encode_envelope(response))
|
||||
except RPCError as e:
|
||||
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
|
||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
except Exception as e:
|
||||
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
|
||||
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
await self._connection.send_frame(self._codec.encode_envelope(error_resp))
|
||||
|
||||
async def _handle_event(self, envelope: Envelope) -> None:
|
||||
"""处理来自 Host 的事件"""
|
||||
handler = self._method_handlers.get(envelope.method)
|
||||
if handler:
|
||||
try:
|
||||
await handler(envelope)
|
||||
except Exception as e:
|
||||
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
|
||||
246
src/plugin_runtime/runner/runner_main.py
Normal file
246
src/plugin_runtime/runner/runner_main.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""Runner 主循环
|
||||
|
||||
作为独立子进程运行,负责:
|
||||
1. 从环境变量读取 IPC 地址和会话令牌
|
||||
2. 连接 Host 并完成握手
|
||||
3. 加载所有插件
|
||||
4. 注册组件到 Host
|
||||
5. 处理 Host 的调用请求
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
ComponentDeclaration,
|
||||
Envelope,
|
||||
HealthPayload,
|
||||
InvokePayload,
|
||||
InvokeResultPayload,
|
||||
RegisterComponentsPayload,
|
||||
ShutdownPayload,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode
|
||||
from src.plugin_runtime.runner.plugin_loader import PluginLoader
|
||||
from src.plugin_runtime.runner.rpc_client import RPCClient
|
||||
|
||||
logger = logging.getLogger("plugin_runtime.runner.main")
|
||||
|
||||
|
||||
class PluginRunner:
|
||||
"""插件 Runner
|
||||
|
||||
运行在独立子进程中,管理所有插件的执行。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
plugin_dirs: list[str],
|
||||
):
|
||||
self._host_address = host_address
|
||||
self._session_token = session_token
|
||||
self._plugin_dirs = plugin_dirs
|
||||
|
||||
self._rpc_client = RPCClient(host_address, session_token)
|
||||
self._loader = PluginLoader()
|
||||
self._start_time = time.monotonic()
|
||||
self._shutting_down = False
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Runner 主入口"""
|
||||
# 1. 连接 Host
|
||||
logger.info(f"Runner 启动,连接 Host: {self._host_address}")
|
||||
ok = await self._rpc_client.connect_and_handshake()
|
||||
if not ok:
|
||||
logger.error("握手失败,退出")
|
||||
return
|
||||
|
||||
# 2. 注册方法处理器
|
||||
self._register_handlers()
|
||||
|
||||
# 3. 加载插件
|
||||
plugins = self._loader.discover_and_load(self._plugin_dirs)
|
||||
logger.info(f"已加载 {len(plugins)} 个插件")
|
||||
|
||||
# 4. 向 Host 注册所有插件的组件
|
||||
for meta in plugins:
|
||||
await self._register_plugin(meta)
|
||||
|
||||
# 5. 等待直到收到关停信号
|
||||
try:
|
||||
while not self._shutting_down:
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 6. 断开连接
|
||||
await self._rpc_client.disconnect()
|
||||
logger.info("Runner 已退出")
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""注册方法处理器"""
|
||||
self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.emit_event", self._handle_invoke)
|
||||
self._rpc_client.register_method("plugin.health", self._handle_health)
|
||||
self._rpc_client.register_method("plugin.prepare_shutdown", self._handle_prepare_shutdown)
|
||||
self._rpc_client.register_method("plugin.shutdown", self._handle_shutdown)
|
||||
self._rpc_client.register_method("plugin.config_updated", self._handle_config_updated)
|
||||
|
||||
async def _register_plugin(self, meta) -> None:
|
||||
"""向 Host 注册单个插件"""
|
||||
# 收集插件组件声明
|
||||
components = []
|
||||
instance = meta.instance
|
||||
|
||||
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
|
||||
if hasattr(instance, "get_components"):
|
||||
for comp_info in instance.get_components():
|
||||
components.append(ComponentDeclaration(
|
||||
name=comp_info.get("name", ""),
|
||||
component_type=comp_info.get("type", ""),
|
||||
plugin_id=meta.plugin_id,
|
||||
metadata=comp_info.get("metadata", {}),
|
||||
))
|
||||
|
||||
reg_payload = RegisterComponentsPayload(
|
||||
plugin_id=meta.plugin_id,
|
||||
plugin_version=meta.version,
|
||||
components=components,
|
||||
capabilities_required=meta.capabilities_required,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await self._rpc_client.send_request(
|
||||
"plugin.register_components",
|
||||
plugin_id=meta.plugin_id,
|
||||
payload=reg_payload.model_dump(),
|
||||
timeout_ms=10000,
|
||||
)
|
||||
logger.info(f"插件 {meta.plugin_id} 注册完成")
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {meta.plugin_id} 注册失败: {e}")
|
||||
|
||||
async def _handle_invoke(self, envelope: Envelope) -> Envelope:
|
||||
"""处理组件调用请求"""
|
||||
try:
|
||||
invoke = InvokePayload.model_validate(envelope.payload)
|
||||
except Exception as e:
|
||||
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(e))
|
||||
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta is None:
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||
f"插件 {plugin_id} 未加载",
|
||||
)
|
||||
|
||||
# 调用插件实例的组件方法
|
||||
instance = meta.instance
|
||||
component_name = invoke.component_name
|
||||
|
||||
handler_method = getattr(instance, f"handle_{component_name}", None)
|
||||
if handler_method is None:
|
||||
handler_method = getattr(instance, component_name, None)
|
||||
|
||||
if handler_method is None or not callable(handler_method):
|
||||
return envelope.make_error_response(
|
||||
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||
f"插件 {plugin_id} 无组件: {component_name}",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler_method(**invoke.args) if asyncio.iscoroutinefunction(handler_method) else handler_method(**invoke.args)
|
||||
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} 组件 {component_name} 执行异常: {e}", exc_info=True)
|
||||
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||
return envelope.make_response(payload=resp_payload.model_dump())
|
||||
|
||||
async def _handle_health(self, envelope: Envelope) -> Envelope:
|
||||
"""处理健康检查"""
|
||||
uptime_ms = int((time.monotonic() - self._start_time) * 1000)
|
||||
health = HealthPayload(
|
||||
healthy=True,
|
||||
loaded_plugins=self._loader.list_plugins(),
|
||||
uptime_ms=uptime_ms,
|
||||
)
|
||||
return envelope.make_response(payload=health.model_dump())
|
||||
|
||||
async def _handle_prepare_shutdown(self, envelope: Envelope) -> Envelope:
|
||||
"""处理准备关停"""
|
||||
logger.info("收到 prepare_shutdown 信号")
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
async def _handle_shutdown(self, envelope: Envelope) -> Envelope:
|
||||
"""处理关停"""
|
||||
logger.info("收到 shutdown 信号,准备退出")
|
||||
self._shutting_down = True
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
async def _handle_config_updated(self, envelope: Envelope) -> Envelope:
|
||||
"""处理配置更新事件"""
|
||||
plugin_id = envelope.plugin_id
|
||||
meta = self._loader.get_plugin(plugin_id)
|
||||
if meta and hasattr(meta.instance, "on_config_update"):
|
||||
try:
|
||||
config_data = envelope.payload.get("config_data", {})
|
||||
config_version = envelope.payload.get("config_version", "")
|
||||
await meta.instance.on_config_update(config_data, config_version)
|
||||
except Exception as e:
|
||||
logger.error(f"插件 {plugin_id} 配置更新失败: {e}")
|
||||
return envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
|
||||
return envelope.make_response(payload={"acknowledged": True})
|
||||
|
||||
def request_capability(self) -> RPCClient:
|
||||
"""获取 RPC 客户端(供 SDK 使用,发起能力调用)"""
|
||||
return self._rpc_client
|
||||
|
||||
|
||||
# ─── 进程入口 ──────────────────────────────────────────────
|
||||
|
||||
async def _async_main() -> None:
|
||||
"""异步主入口"""
|
||||
host_address = os.environ.get("MAIBOT_IPC_ADDRESS", "")
|
||||
session_token = os.environ.get("MAIBOT_SESSION_TOKEN", "")
|
||||
plugin_dirs_str = os.environ.get("MAIBOT_PLUGIN_DIRS", "")
|
||||
|
||||
if not host_address or not session_token:
|
||||
logger.error("缺少必要的环境变量: MAIBOT_IPC_ADDRESS, MAIBOT_SESSION_TOKEN")
|
||||
sys.exit(1)
|
||||
|
||||
plugin_dirs = [d for d in plugin_dirs_str.split(os.pathsep) if d]
|
||||
|
||||
runner = PluginRunner(host_address, session_token, plugin_dirs)
|
||||
|
||||
# 注册信号处理
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, lambda: setattr(runner, "_shutting_down", True))
|
||||
|
||||
await runner.run()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""进程入口(python -m src.plugin_runtime.runner.runner_main)"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
)
|
||||
asyncio.run(_async_main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
src/plugin_runtime/transport/__init__.py
Normal file
1
src/plugin_runtime/transport/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Transport 层 - 跨平台本地 IPC 传输抽象
|
||||
116
src/plugin_runtime/transport/base.py
Normal file
116
src/plugin_runtime/transport/base.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""传输层抽象基类
|
||||
|
||||
定义 TransportServer 和 TransportClient 的统一接口。
|
||||
所有传输后端(UDS、Named Pipe、TCP 回退)必须实现此接口。
|
||||
业务层仅依赖此抽象,禁止直接使用具体传输实现的细节。
|
||||
|
||||
分帧协议:4-byte big-endian length prefix + payload
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, Callable, Awaitable
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
|
||||
# 分帧常量
|
||||
FRAME_HEADER_SIZE = 4 # 4 字节长度前缀
|
||||
MAX_FRAME_SIZE = 16 * 1024 * 1024 # 16 MB 最大帧大小
|
||||
|
||||
|
||||
class ConnectionClosed(Exception):
|
||||
"""连接已关闭"""
|
||||
pass
|
||||
|
||||
|
||||
class Connection(ABC):
|
||||
"""单个连接的抽象
|
||||
|
||||
封装了底层 StreamReader/StreamWriter,提供分帧读写能力。
|
||||
"""
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
self._reader = reader
|
||||
self._writer = writer
|
||||
self._closed = False
|
||||
|
||||
async def send_frame(self, data: bytes) -> None:
|
||||
"""发送一帧数据(4-byte length prefix + payload)"""
|
||||
if self._closed:
|
||||
raise ConnectionClosed("连接已关闭")
|
||||
length = len(data)
|
||||
if length > MAX_FRAME_SIZE:
|
||||
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
||||
header = struct.pack(">I", length)
|
||||
self._writer.write(header + data)
|
||||
await self._writer.drain()
|
||||
|
||||
async def recv_frame(self) -> bytes:
|
||||
"""接收一帧数据"""
|
||||
if self._closed:
|
||||
raise ConnectionClosed("连接已关闭")
|
||||
# 读取 4 字节长度头
|
||||
header = await self._reader.readexactly(FRAME_HEADER_SIZE)
|
||||
(length,) = struct.unpack(">I", header)
|
||||
if length > MAX_FRAME_SIZE:
|
||||
raise ValueError(f"帧大小 {length} 超过最大限制 {MAX_FRAME_SIZE}")
|
||||
# 读取 payload
|
||||
payload = await self._reader.readexactly(length)
|
||||
return payload
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接"""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
|
||||
# 连接回调类型:收到新连接时调用
|
||||
ConnectionHandler = Callable[[Connection], Awaitable[None]]
|
||||
|
||||
|
||||
class TransportServer(ABC):
|
||||
"""传输服务端抽象
|
||||
|
||||
Host 端使用,监听来自 Runner 的连接。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
"""启动服务端,开始监听连接
|
||||
|
||||
Args:
|
||||
handler: 新连接到来时的回调函数
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""停止服务端"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_address(self) -> str:
|
||||
"""获取监听地址(供 Runner 连接用)"""
|
||||
...
|
||||
|
||||
|
||||
class TransportClient(ABC):
|
||||
"""传输客户端抽象
|
||||
|
||||
Runner 端使用,主动连接 Host。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> Connection:
|
||||
"""建立到 Host 的连接"""
|
||||
...
|
||||
46
src/plugin_runtime/transport/factory.py
Normal file
46
src/plugin_runtime/transport/factory.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""传输层工厂
|
||||
|
||||
根据运行平台自动选择最优传输实现。
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from .base import TransportClient, TransportServer
|
||||
|
||||
|
||||
def create_transport_server(socket_path: str | None = None) -> TransportServer:
|
||||
"""创建传输服务端
|
||||
|
||||
Linux/macOS 使用 UDS,Windows 使用 TCP 回退。
|
||||
|
||||
Args:
|
||||
socket_path: UDS socket 路径(仅 Linux/macOS 有效)
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
from .uds import UDSTransportServer
|
||||
return UDSTransportServer(socket_path=socket_path)
|
||||
else:
|
||||
# Windows 回退到 TCP(后续可改为 Named Pipe)
|
||||
from .tcp import TCPTransportServer
|
||||
return TCPTransportServer()
|
||||
|
||||
|
||||
def create_transport_client(address: str) -> TransportClient:
|
||||
"""创建传输客户端
|
||||
|
||||
根据地址格式自动判断传输类型:
|
||||
- 包含 '/' 或 '.sock' -> UDS
|
||||
- 包含 ':' -> TCP
|
||||
|
||||
Args:
|
||||
address: Host 端监听地址
|
||||
"""
|
||||
if "/" in address or address.endswith(".sock"):
|
||||
from .uds import UDSTransportClient
|
||||
return UDSTransportClient(socket_path=address)
|
||||
elif ":" in address:
|
||||
from .tcp import TCPTransportClient
|
||||
host, port_str = address.rsplit(":", 1)
|
||||
return TCPTransportClient(host=host, port=int(port_str))
|
||||
else:
|
||||
raise ValueError(f"无法识别的传输地址格式: {address}")
|
||||
59
src/plugin_runtime/transport/tcp.py
Normal file
59
src/plugin_runtime/transport/tcp.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""TCP 传输实现(回退方案)
|
||||
|
||||
仅当 UDS / Named Pipe 不可用时启用。
|
||||
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||
|
||||
|
||||
class TCPConnection(Connection):
|
||||
"""基于 TCP 的连接"""
|
||||
pass
|
||||
|
||||
|
||||
class TCPTransportServer(TransportServer):
|
||||
"""TCP 传输服务端(回退方案)"""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 0):
|
||||
self._host = host
|
||||
self._port = port # 0 表示自动分配
|
||||
self._server: asyncio.AbstractServer | None = None
|
||||
self._actual_port: int = 0
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
conn = TCPConnection(reader, writer)
|
||||
try:
|
||||
await handler(conn)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
self._server = await asyncio.start_server(_on_connect, self._host, self._port)
|
||||
|
||||
# 获取实际分配的端口
|
||||
addr = self._server.sockets[0].getsockname()
|
||||
self._actual_port = addr[1]
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
||||
def get_address(self) -> str:
|
||||
return f"{self._host}:{self._actual_port}"
|
||||
|
||||
|
||||
class TCPTransportClient(TransportClient):
|
||||
"""TCP 传输客户端"""
|
||||
|
||||
def __init__(self, host: str, port: int):
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
reader, writer = await asyncio.open_connection(self._host, self._port)
|
||||
return TCPConnection(reader, writer)
|
||||
71
src/plugin_runtime/transport/uds.py
Normal file
71
src/plugin_runtime/transport/uds.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Unix Domain Socket 传输实现
|
||||
|
||||
适用于 Linux / macOS 平台。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||
|
||||
|
||||
class UDSConnection(Connection):
|
||||
"""基于 UDS 的连接"""
|
||||
pass # 直接复用 Connection 基类的分帧读写
|
||||
|
||||
|
||||
class UDSTransportServer(TransportServer):
|
||||
"""UDS 传输服务端"""
|
||||
|
||||
def __init__(self, socket_path: str | None = None):
|
||||
if socket_path is None:
|
||||
# 默认放在临时目录
|
||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}.sock")
|
||||
self._socket_path = socket_path
|
||||
self._server: asyncio.AbstractServer | None = None
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
# 清理残留 socket 文件
|
||||
if os.path.exists(self._socket_path):
|
||||
os.unlink(self._socket_path)
|
||||
|
||||
# 确保父目录存在
|
||||
Path(self._socket_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def _on_connect(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
conn = UDSConnection(reader, writer)
|
||||
try:
|
||||
await handler(conn)
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
self._server = await asyncio.start_unix_server(_on_connect, path=self._socket_path)
|
||||
|
||||
# 设置文件权限为仅当前用户可访问
|
||||
os.chmod(self._socket_path, 0o600)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
# 清理 socket 文件
|
||||
if os.path.exists(self._socket_path):
|
||||
os.unlink(self._socket_path)
|
||||
|
||||
def get_address(self) -> str:
|
||||
return self._socket_path
|
||||
|
||||
|
||||
class UDSTransportClient(TransportClient):
|
||||
"""UDS 传输客户端"""
|
||||
|
||||
def __init__(self, socket_path: str):
|
||||
self._socket_path = socket_path
|
||||
|
||||
async def connect(self) -> Connection:
|
||||
reader, writer = await asyncio.open_unix_connection(self._socket_path)
|
||||
return UDSConnection(reader, writer)
|
||||
Reference in New Issue
Block a user