Refactor protocol and transport modules to use type hints for improved clarity and consistency
- Updated Codec class to use abstract methods for encoding and decoding envelopes. - Changed Envelope class to use Dict and Optional for payload and error fields. - Refined error handling in RPCError class with Optional type hints for details. - Enhanced manifest validation logic with type hints for better type safety. - Improved plugin loading mechanism with consistent type annotations. - Updated RPCClient to utilize Optional for codec and connection attributes. - Refactored transport classes to use Optional for server attributes and socket paths.
This commit is contained in:
@@ -4,22 +4,21 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
|
||||
每个能力方法被注册到 RPC Server,接收 Runner 转发的请求并执行实际操作。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
from src.plugin_runtime.protocol.envelope import (
|
||||
CapabilityRequestPayload,
|
||||
CapabilityResponsePayload,
|
||||
Envelope,
|
||||
)
|
||||
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
|
||||
from src.plugin_runtime.host.policy_engine import PolicyEngine
|
||||
|
||||
logger = get_logger("plugin_runtime.host.capability_service")
|
||||
|
||||
# 能力实现函数类型: (plugin_id, capability, args) -> result
|
||||
CapabilityImpl = Callable[[str, str, dict[str, Any]], Awaitable[Any]]
|
||||
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
|
||||
|
||||
|
||||
class CapabilityService:
|
||||
@@ -35,7 +34,7 @@ class CapabilityService:
|
||||
def __init__(self, policy_engine: PolicyEngine):
|
||||
self._policy = policy_engine
|
||||
# capability_name -> implementation
|
||||
self._implementations: dict[str, CapabilityImpl] = {}
|
||||
self._implementations: Dict[str, CapabilityImpl] = {}
|
||||
|
||||
def register_capability(self, name: str, impl: CapabilityImpl) -> None:
|
||||
"""注册一个能力实现
|
||||
@@ -95,6 +94,6 @@ class CapabilityService:
|
||||
str(e),
|
||||
)
|
||||
|
||||
def list_capabilities(self) -> list[str]:
|
||||
def list_capabilities(self) -> List[str]:
|
||||
"""列出所有已注册的能力"""
|
||||
return list(self._implementations.keys())
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
- 注册统计
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
import re
|
||||
|
||||
logger = get_logger("plugin_runtime.host.component_registry")
|
||||
@@ -30,7 +31,7 @@ class RegisteredComponent:
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
):
|
||||
self.name = name
|
||||
self.full_name = f"{plugin_id}.{name}"
|
||||
@@ -40,7 +41,7 @@ class RegisteredComponent:
|
||||
self.enabled = metadata.get("enabled", True)
|
||||
|
||||
# 预编译命令正则(仅 command 类型)
|
||||
self._compiled_pattern: re.Pattern | None = None
|
||||
self._compiled_pattern: Optional[re.Pattern] = None
|
||||
if component_type == "command":
|
||||
if pattern := metadata.get("command_pattern", ""):
|
||||
try:
|
||||
@@ -58,10 +59,10 @@ class ComponentRegistry:
|
||||
|
||||
def __init__(self):
|
||||
# 全量索引
|
||||
self._components: dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
|
||||
|
||||
# 按类型索引
|
||||
self._by_type: dict[str, dict[str, RegisteredComponent]] = {
|
||||
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
|
||||
"action": {},
|
||||
"command": {},
|
||||
"tool": {},
|
||||
@@ -70,7 +71,7 @@ class ComponentRegistry:
|
||||
}
|
||||
|
||||
# 按插件索引
|
||||
self._by_plugin: dict[str, list[RegisteredComponent]] = {}
|
||||
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
|
||||
|
||||
# ──── 注册 / 注销 ─────────────────────────────────────────
|
||||
|
||||
@@ -79,7 +80,7 @@ class ComponentRegistry:
|
||||
name: str,
|
||||
component_type: str,
|
||||
plugin_id: str,
|
||||
metadata: dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""注册单个组件。"""
|
||||
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
|
||||
@@ -99,7 +100,7 @@ class ComponentRegistry:
|
||||
def register_plugin_components(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: list[dict[str, Any]],
|
||||
components: List[Dict[str, Any]],
|
||||
) -> int:
|
||||
"""批量注册一个插件的所有组件,返回成功注册数。"""
|
||||
count = 0
|
||||
@@ -142,13 +143,13 @@ class ComponentRegistry:
|
||||
|
||||
# ──── 查询方法 ─────────────────────────────────────────────
|
||||
|
||||
def get_component(self, full_name: str) -> RegisteredComponent | None:
|
||||
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
|
||||
"""按全名查询。"""
|
||||
return self._components.get(full_name)
|
||||
|
||||
def get_components_by_type(
|
||||
self, component_type: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""按类型查询。"""
|
||||
type_dict = self._by_type.get(component_type, {})
|
||||
if enabled_only:
|
||||
@@ -157,12 +158,12 @@ class ComponentRegistry:
|
||||
|
||||
def get_components_by_plugin(
|
||||
self, plugin_id: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""按插件查询。"""
|
||||
comps = self._by_plugin.get(plugin_id, [])
|
||||
return [c for c in comps if c.enabled] if enabled_only else list(comps)
|
||||
|
||||
def find_command_by_text(self, text: str) -> RegisteredComponent | None:
|
||||
def find_command_by_text(self, text: str) -> Optional[RegisteredComponent]:
|
||||
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
|
||||
for comp in self._by_type.get("command", {}).values():
|
||||
if not comp.enabled:
|
||||
@@ -178,7 +179,7 @@ class ComponentRegistry:
|
||||
|
||||
def get_event_handlers(
|
||||
self, event_type: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""获取特定事件类型的所有 event_handler,按 weight 降序排列。"""
|
||||
handlers = []
|
||||
for comp in self._by_type.get("event_handler", {}).values():
|
||||
@@ -191,7 +192,7 @@ class ComponentRegistry:
|
||||
|
||||
def get_workflow_steps(
|
||||
self, stage: str, *, enabled_only: bool = True
|
||||
) -> list[RegisteredComponent]:
|
||||
) -> List[RegisteredComponent]:
|
||||
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
|
||||
steps = []
|
||||
for comp in self._by_type.get("workflow_step", {}).values():
|
||||
@@ -202,11 +203,11 @@ class ComponentRegistry:
|
||||
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
|
||||
return steps
|
||||
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]:
|
||||
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
|
||||
"""获取可供 LLM 使用的工具列表(openai function-calling 格式预览)。"""
|
||||
result = []
|
||||
result: List[Dict[str, Any]] = []
|
||||
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
|
||||
tool_def: dict[str, Any] = {
|
||||
tool_def: Dict[str, Any] = {
|
||||
"name": comp.full_name,
|
||||
"description": comp.metadata.get("description", ""),
|
||||
}
|
||||
@@ -222,9 +223,9 @@ class ComponentRegistry:
|
||||
|
||||
# ──── 统计 ─────────────────────────────────────────────────
|
||||
|
||||
def get_stats(self) -> dict[str, int]:
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""获取注册统计。"""
|
||||
stats: dict[str, int] = {"total": len(self._components)}
|
||||
stats: Dict[str, int] = {"total": len(self._components)}
|
||||
for comp_type, type_dict in self._by_type.items():
|
||||
stats[comp_type] = len(type_dict)
|
||||
stats["plugins"] = len(self._by_plugin)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 事件结果历史记录
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -17,7 +17,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
|
||||
logger = get_logger("plugin_runtime.host.event_dispatcher")
|
||||
|
||||
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
|
||||
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class EventResult:
|
||||
@@ -29,7 +29,7 @@ class EventResult:
|
||||
handler_name: str,
|
||||
success: bool = True,
|
||||
continue_processing: bool = True,
|
||||
modified_message: dict[str, Any] | None = None,
|
||||
modified_message: Optional[Dict[str, Any]] = None,
|
||||
custom_result: Any = None,
|
||||
):
|
||||
self.handler_name = handler_name
|
||||
@@ -49,14 +49,14 @@ class EventDispatcher:
|
||||
|
||||
def __init__(self, registry: ComponentRegistry) -> None:
|
||||
self._registry: ComponentRegistry = registry
|
||||
self._result_history: dict[str, list[EventResult]] = {}
|
||||
self._history_enabled: set[str] = set()
|
||||
self._result_history: Dict[str, List[EventResult]] = {}
|
||||
self._history_enabled: Set[str] = set()
|
||||
|
||||
def enable_history(self, event_type: str) -> None:
|
||||
self._history_enabled.add(event_type)
|
||||
self._result_history.setdefault(event_type, [])
|
||||
|
||||
def get_history(self, event_type: str) -> list[EventResult]:
|
||||
def get_history(self, event_type: str) -> List[EventResult]:
|
||||
return self._result_history.get(event_type, [])
|
||||
|
||||
def clear_history(self, event_type: str) -> None:
|
||||
@@ -67,9 +67,9 @@ class EventDispatcher:
|
||||
self,
|
||||
event_type: str,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler。
|
||||
|
||||
Args:
|
||||
@@ -86,8 +86,8 @@ class EventDispatcher:
|
||||
return True, None
|
||||
|
||||
should_continue = True
|
||||
modified_message: dict[str, Any] | None = None
|
||||
fire_and_forget_tasks: list[asyncio.Task] = []
|
||||
modified_message: Optional[Dict[str, Any]] = None
|
||||
fire_and_forget_tasks: List[asyncio.Task] = []
|
||||
|
||||
for handler in handlers:
|
||||
intercept = handler.metadata.get("intercept_message", False)
|
||||
@@ -122,9 +122,9 @@ class EventDispatcher:
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
handler: RegisteredComponent,
|
||||
args: dict[str, Any],
|
||||
args: Dict[str, Any],
|
||||
event_type: str,
|
||||
) -> EventResult | None:
|
||||
) -> Optional[EventResult]:
|
||||
"""调用单个 handler 并收集结果。"""
|
||||
try:
|
||||
resp = await invoke_fn(handler.plugin_id, handler.name, args)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,7 +13,7 @@ class CapabilityToken:
|
||||
"""能力令牌"""
|
||||
plugin_id: str
|
||||
generation: int
|
||||
capabilities: set[str] = field(default_factory=set)
|
||||
capabilities: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class PolicyEngine:
|
||||
@@ -22,13 +23,13 @@ class PolicyEngine:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tokens: dict[str, CapabilityToken] = {}
|
||||
self._tokens: Dict[str, CapabilityToken] = {}
|
||||
|
||||
def register_plugin(
|
||||
self,
|
||||
plugin_id: str,
|
||||
generation: int,
|
||||
capabilities: list[str],
|
||||
capabilities: List[str],
|
||||
) -> CapabilityToken:
|
||||
"""为插件签发能力令牌"""
|
||||
token = CapabilityToken(
|
||||
@@ -43,7 +44,7 @@ class PolicyEngine:
|
||||
"""撤销插件的能力令牌"""
|
||||
self._tokens.pop(plugin_id, None)
|
||||
|
||||
def check_capability(self, plugin_id: str, capability: str) -> tuple[bool, str]:
|
||||
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
|
||||
"""检查插件是否有权调用某项能力
|
||||
|
||||
Returns:
|
||||
@@ -58,10 +59,10 @@ class PolicyEngine:
|
||||
|
||||
return True, ""
|
||||
|
||||
def get_token(self, plugin_id: str) -> CapabilityToken | None:
|
||||
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
|
||||
"""获取插件的能力令牌"""
|
||||
return self._tokens.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[str]:
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已注册的插件"""
|
||||
return list(self._tokens.keys())
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 请求-响应关联与超时管理
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import secrets
|
||||
@@ -42,8 +42,8 @@ class RPCServer:
|
||||
def __init__(
|
||||
self,
|
||||
transport: TransportServer,
|
||||
session_token: str | None = None,
|
||||
codec: Codec | None = None,
|
||||
session_token: Optional[str] = None,
|
||||
codec: Optional[Codec] = None,
|
||||
send_queue_size: int = 128,
|
||||
):
|
||||
self._transport = transport
|
||||
@@ -52,22 +52,22 @@ class RPCServer:
|
||||
self._send_queue_size = send_queue_size
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Connection | None = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: str | None = None
|
||||
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
|
||||
self._runner_id: Optional[str] = None
|
||||
self._runner_generation: int = 0
|
||||
|
||||
# 方法处理器注册表
|
||||
self._method_handlers: dict[str, MethodHandler] = {}
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
|
||||
# 发送队列(背压控制)
|
||||
self._send_queue: asyncio.Queue[bytes] | None = None
|
||||
self._send_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
|
||||
# 运行状态
|
||||
self._running: bool = False
|
||||
self._tasks: list[asyncio.Task] = []
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
|
||||
@property
|
||||
def session_token(self) -> str:
|
||||
@@ -115,7 +115,7 @@ class RPCServer:
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str = "",
|
||||
payload: dict[str, Any] | None = None,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Runner 发送 RPC 请求并等待响应
|
||||
@@ -172,7 +172,7 @@ class RPCServer:
|
||||
raise
|
||||
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
|
||||
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None:
|
||||
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""向 Runner 发送单向事件(不等待响应)"""
|
||||
if not self.is_connected:
|
||||
return
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
4. 优雅关停
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
@@ -40,8 +40,8 @@ class PluginSupervisor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
plugin_dirs: list[str] | None = None,
|
||||
socket_path: str | None = None,
|
||||
plugin_dirs: Optional[List[str]] = None,
|
||||
socket_path: Optional[str] = None,
|
||||
health_check_interval_sec: float = 30.0,
|
||||
):
|
||||
self._plugin_dirs = plugin_dirs or []
|
||||
@@ -65,16 +65,16 @@ class PluginSupervisor:
|
||||
)
|
||||
|
||||
# Runner 子进程
|
||||
self._runner_process: asyncio.subprocess.Process | None = None
|
||||
self._runner_process: Optional[asyncio.subprocess.Process] = None
|
||||
self._runner_generation: int = 0
|
||||
self._max_restart_attempts: int = 3
|
||||
self._restart_count: int = 0
|
||||
|
||||
# 已注册的插件组件信息
|
||||
self._registered_plugins: dict[str, RegisterComponentsPayload] = {}
|
||||
self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
|
||||
|
||||
# 后台任务
|
||||
self._health_task: asyncio.Task | None = None
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
# 注册内部 RPC 方法
|
||||
@@ -107,11 +107,11 @@ class PluginSupervisor:
|
||||
async def dispatch_event(
|
||||
self,
|
||||
event_type: str,
|
||||
message: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""分发事件到所有对应 handler 的快捷方法。"""
|
||||
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
|
||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.emit_event",
|
||||
plugin_id=plugin_id,
|
||||
@@ -129,12 +129,12 @@ class PluginSupervisor:
|
||||
|
||||
async def execute_workflow(
|
||||
self,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 Workflow Pipeline 的快捷方法。"""
|
||||
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]:
|
||||
async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resp = await self.invoke_plugin(
|
||||
method="plugin.invoke_workflow_step",
|
||||
plugin_id=plugin_id,
|
||||
@@ -196,7 +196,7 @@ class PluginSupervisor:
|
||||
method: str,
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: dict[str, Any] | None = None,
|
||||
args: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""调用插件组件
|
||||
@@ -225,6 +225,12 @@ class PluginSupervisor:
|
||||
# 保存旧进程引用
|
||||
old_process = self._runner_process
|
||||
|
||||
# 清理旧的组件注册,防止幽灵组件残留
|
||||
for plugin_id in list(self._registered_plugins.keys()):
|
||||
self._component_registry.remove_components_by_plugin(plugin_id)
|
||||
self._policy.revoke_plugin(plugin_id)
|
||||
self._registered_plugins.clear()
|
||||
|
||||
# 拉起新 Runner
|
||||
await self._spawn_runner()
|
||||
|
||||
@@ -286,7 +292,7 @@ class PluginSupervisor:
|
||||
# 在策略引擎中注册插件
|
||||
self._policy.register_plugin(
|
||||
plugin_id=reg.plugin_id,
|
||||
generation=envelope.generation,
|
||||
generation=self._runner_generation,
|
||||
capabilities=reg.capabilities_required or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
- modification_log: 消息修改审计
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
@@ -29,7 +29,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
|
||||
logger = get_logger("plugin_runtime.host.workflow_executor")
|
||||
|
||||
# 阶段顺序
|
||||
STAGE_SEQUENCE: list[str] = [
|
||||
STAGE_SEQUENCE: List[str] = [
|
||||
"ingress",
|
||||
"pre_process",
|
||||
"plan",
|
||||
@@ -48,7 +48,7 @@ class ModificationRecord:
|
||||
"""消息修改记录"""
|
||||
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
|
||||
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]):
|
||||
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
|
||||
self.stage = stage
|
||||
self.hook_name = hook_name
|
||||
self.timestamp = time.perf_counter()
|
||||
@@ -58,17 +58,17 @@ class ModificationRecord:
|
||||
class WorkflowContext:
|
||||
"""Workflow 执行上下文"""
|
||||
|
||||
def __init__(self, trace_id: str | None = None, stream_id: str | None = None):
|
||||
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
|
||||
self.trace_id = trace_id or uuid.uuid4().hex
|
||||
self.stream_id = stream_id
|
||||
self.timings: dict[str, float] = {}
|
||||
self.errors: list[str] = []
|
||||
self.timings: Dict[str, float] = {}
|
||||
self.errors: List[str] = []
|
||||
# 阶段间数据传递(按 stage 命名空间隔离)
|
||||
self.stage_outputs: dict[str, dict[str, Any]] = {}
|
||||
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
|
||||
# 消息修改审计日志
|
||||
self.modification_log: list[ModificationRecord] = []
|
||||
self.modification_log: List[ModificationRecord] = []
|
||||
# PLAN 阶段命令匹配结果
|
||||
self.matched_command: str | None = None
|
||||
self.matched_command: Optional[str] = None
|
||||
|
||||
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
|
||||
self.stage_outputs.setdefault(stage, {})[key] = value
|
||||
@@ -85,7 +85,7 @@ class WorkflowResult:
|
||||
status: str = "completed", # completed / aborted / failed
|
||||
return_message: str = "",
|
||||
stopped_at: str = "",
|
||||
diagnostics: dict[str, Any] | None = None,
|
||||
diagnostics: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.status = status
|
||||
self.return_message = return_message
|
||||
@@ -94,7 +94,7 @@ class WorkflowResult:
|
||||
|
||||
|
||||
# invoke_fn 签名
|
||||
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]]
|
||||
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
@@ -109,10 +109,10 @@ class WorkflowExecutor:
|
||||
async def execute(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any] | None = None,
|
||||
stream_id: str | None = None,
|
||||
context: WorkflowContext | None = None,
|
||||
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]:
|
||||
message: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
context: Optional[WorkflowContext] = None,
|
||||
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
|
||||
"""执行 workflow pipeline。
|
||||
|
||||
Returns:
|
||||
@@ -120,6 +120,8 @@ class WorkflowExecutor:
|
||||
"""
|
||||
ctx = context or WorkflowContext(stream_id=stream_id)
|
||||
current_message = dict(message) if message else None
|
||||
# 保持非阻塞任务引用,防止被 GC 回收
|
||||
background_tasks: List[asyncio.Task] = []
|
||||
|
||||
for stage in STAGE_SEQUENCE:
|
||||
stage_start = time.perf_counter()
|
||||
@@ -207,14 +209,15 @@ class WorkflowExecutor:
|
||||
# 4. 并发执行 non-blocking hook(只读,忽略返回值中的 modified_message)
|
||||
if nonblocking_steps and not skip_stage:
|
||||
nb_tasks = [
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
asyncio.create_task(
|
||||
self._invoke_step_fire_and_forget(
|
||||
invoke_fn, step, stage, ctx, current_message
|
||||
)
|
||||
)
|
||||
for step in nonblocking_steps
|
||||
]
|
||||
# 并发执行但不阻塞 pipeline
|
||||
for task in [asyncio.create_task(t) for t in nb_tasks]:
|
||||
task.add_done_callback(lambda _: None)
|
||||
# 保持任务引用以防止被 GC 回收
|
||||
background_tasks.extend(nb_tasks)
|
||||
|
||||
ctx.timings[stage] = time.perf_counter() - stage_start
|
||||
|
||||
@@ -247,9 +250,9 @@ class WorkflowExecutor:
|
||||
|
||||
def _pre_filter(
|
||||
self,
|
||||
steps: list[RegisteredComponent],
|
||||
message: dict[str, Any] | None,
|
||||
) -> list[RegisteredComponent]:
|
||||
steps: List[RegisteredComponent],
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> List[RegisteredComponent]:
|
||||
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
|
||||
if not message:
|
||||
return steps
|
||||
@@ -265,7 +268,7 @@ class WorkflowExecutor:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool:
|
||||
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
|
||||
"""简单 key-value 匹配过滤。
|
||||
|
||||
filter 中的每个 key 必须在 message 中存在且值相等,
|
||||
@@ -285,8 +288,8 @@ class WorkflowExecutor:
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, Any] | None, str | None]:
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""调用单个 blocking hook。
|
||||
|
||||
Returns:
|
||||
@@ -331,7 +334,7 @@ class WorkflowExecutor:
|
||||
step: RegisteredComponent,
|
||||
stage: str,
|
||||
ctx: WorkflowContext,
|
||||
message: dict[str, Any] | None,
|
||||
message: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Non-blocking hook 调用,只读,忽略结果。"""
|
||||
timeout_ms = step.metadata.get("timeout_ms", 0)
|
||||
@@ -354,9 +357,9 @@ class WorkflowExecutor:
|
||||
async def _route_command(
|
||||
self,
|
||||
invoke_fn: InvokeFn,
|
||||
message: dict[str, Any],
|
||||
message: Dict[str, Any],
|
||||
ctx: WorkflowContext,
|
||||
) -> dict[str, Any] | None:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""PLAN 阶段内置 Command 路由。
|
||||
|
||||
在 registry 中查找匹配的 command 组件,
|
||||
@@ -386,6 +389,6 @@ class WorkflowExecutor:
|
||||
return None
|
||||
|
||||
|
||||
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]:
|
||||
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
|
||||
"""返回 new 中与 old 不同的 key 列表。"""
|
||||
return [k for k, v in new.items() if k not in old or old[k] != v]
|
||||
|
||||
@@ -7,9 +7,7 @@
|
||||
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
@@ -19,7 +17,7 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("plugin_runtime.integration")
|
||||
|
||||
# 旧系统 EventType -> 新系统 event_type 字符串映射
|
||||
_EVENT_TYPE_MAP: dict[str, str] = {
|
||||
_EVENT_TYPE_MAP: Dict[str, str] = {
|
||||
"on_start": "on_start",
|
||||
"on_stop": "on_stop",
|
||||
"on_message_pre_process": "on_message_pre_process",
|
||||
@@ -42,20 +40,20 @@ class PluginRuntimeManager:
|
||||
def __init__(self) -> None:
|
||||
from src.plugin_runtime.host.supervisor import PluginSupervisor
|
||||
|
||||
self._builtin_supervisor: PluginSupervisor | None = None
|
||||
self._thirdparty_supervisor: PluginSupervisor | None = None
|
||||
self._builtin_supervisor: Optional[PluginSupervisor] = None
|
||||
self._thirdparty_supervisor: Optional[PluginSupervisor] = None
|
||||
self._started: bool = False
|
||||
|
||||
# ─── 插件目录 ─────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _get_builtin_plugin_dirs() -> list[str]:
|
||||
def _get_builtin_plugin_dirs() -> List[str]:
|
||||
"""内置插件目录: src/plugins/built_in/"""
|
||||
candidate = os.path.abspath(os.path.join("src", "plugins", "built_in"))
|
||||
return [candidate] if os.path.isdir(candidate) else []
|
||||
|
||||
@staticmethod
|
||||
def _get_thirdparty_plugin_dirs() -> list[str]:
|
||||
def _get_thirdparty_plugin_dirs() -> List[str]:
|
||||
"""第三方插件目录: plugins/"""
|
||||
candidate = os.path.abspath("plugins")
|
||||
return [candidate] if os.path.isdir(candidate) else []
|
||||
@@ -136,7 +134,7 @@ class PluginRuntimeManager:
|
||||
return self._started
|
||||
|
||||
@property
|
||||
def supervisors(self) -> list[Any]:
|
||||
def supervisors(self) -> List[Any]:
|
||||
"""获取所有活跃的 Supervisor"""
|
||||
return [s for s in (self._builtin_supervisor, self._thirdparty_supervisor) if s is not None]
|
||||
|
||||
@@ -145,9 +143,9 @@ class PluginRuntimeManager:
|
||||
async def bridge_event(
|
||||
self,
|
||||
event_type_value: str,
|
||||
message_dict: dict[str, Any] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
message_dict: Optional[Dict[str, Any]] = None,
|
||||
extra_args: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
"""将事件分发到所有 Supervisor
|
||||
|
||||
Returns:
|
||||
@@ -157,7 +155,7 @@ class PluginRuntimeManager:
|
||||
return True, None
|
||||
|
||||
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
|
||||
modified: dict[str, Any] | None = None
|
||||
modified: Optional[Dict[str, Any]] = None
|
||||
|
||||
for sv in self.supervisors:
|
||||
try:
|
||||
@@ -177,7 +175,7 @@ class PluginRuntimeManager:
|
||||
|
||||
# ─── 命令查找 ──────────────────────────────────────────────
|
||||
|
||||
def find_command_by_text(self, text: str) -> dict[str, Any] | None:
|
||||
def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]:
|
||||
"""在所有 Supervisor 的 ComponentRegistry 中查找命令"""
|
||||
if not self._started:
|
||||
return None
|
||||
@@ -281,7 +279,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_text(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_text(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送文本消息
|
||||
|
||||
args: text, stream_id, typing?, set_reply?, storage_message?
|
||||
@@ -307,7 +305,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_emoji(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_emoji(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送表情
|
||||
|
||||
args: emoji_base64, stream_id, storage_message?
|
||||
@@ -331,7 +329,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_image(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_image(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送图片
|
||||
|
||||
args: image_base64, stream_id, storage_message?
|
||||
@@ -355,7 +353,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_command(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_command(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送命令
|
||||
|
||||
args: command, stream_id, storage_message?, display_message?
|
||||
@@ -380,7 +378,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_custom(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_custom(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送自定义类型消息
|
||||
|
||||
args: message_type, content, stream_id, display_message?, typing?, storage_message?
|
||||
@@ -408,7 +406,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_forward(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_forward(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送转发消息
|
||||
|
||||
args: messages, stream_id
|
||||
@@ -431,7 +429,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_send_hybrid(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_send_hybrid(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""发送混合消息(图文混合)
|
||||
|
||||
args: segments, stream_id
|
||||
@@ -458,7 +456,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_llm_generate(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_llm_generate(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""LLM 生成
|
||||
|
||||
args: prompt, model_name?, temperature?, max_tokens?
|
||||
@@ -501,7 +499,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_llm_generate_with_tools(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_llm_generate_with_tools(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""LLM 带工具生成
|
||||
|
||||
args: prompt, model_name?, tool_options?, temperature?, max_tokens?
|
||||
@@ -554,7 +552,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_llm_get_available_models(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_llm_get_available_models(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取可用模型列表"""
|
||||
from src.services import llm_service as llm_api
|
||||
|
||||
@@ -570,7 +568,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_config_get(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_config_get(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""读取全局配置
|
||||
|
||||
args: key, default?
|
||||
@@ -590,7 +588,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "value": None, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_config_get_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_config_get_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""读取插件配置
|
||||
|
||||
args: key, default?, plugin_name?
|
||||
@@ -617,7 +615,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "value": default, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_config_get_all(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_config_get_all(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取当前插件的全部配置"""
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
|
||||
@@ -636,7 +634,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_database_query(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_database_query(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""数据库查询
|
||||
|
||||
args: model_name, query_type?, filters?, limit?, order_by?, data?, single_result?
|
||||
@@ -670,7 +668,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_database_save(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_database_save(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""数据库保存
|
||||
|
||||
args: model_name, data, key_field?, key_value?
|
||||
@@ -678,7 +676,7 @@ class PluginRuntimeManager:
|
||||
from src.services import database_service as database_api
|
||||
|
||||
model_name: str = args.get("model_name", "")
|
||||
data: dict[str, Any] | None = args.get("data")
|
||||
data: Optional[Dict[str, Any]] = args.get("data")
|
||||
if not model_name or not data:
|
||||
return {"success": False, "error": "缺少必要参数 model_name 或 data"}
|
||||
|
||||
@@ -701,7 +699,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_database_get(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_database_get(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""数据库简单查询
|
||||
|
||||
args: model_name, filters?, limit?, order_by?, single_result?
|
||||
@@ -732,7 +730,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_database_delete(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_database_delete(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""数据库删除
|
||||
|
||||
args: model_name, filters
|
||||
@@ -763,7 +761,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_database_count(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_database_count(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""数据库计数
|
||||
|
||||
args: model_name, filters?
|
||||
@@ -795,7 +793,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _serialize_stream(stream: Any) -> dict[str, Any]:
|
||||
def _serialize_stream(stream: Any) -> Dict[str, Any]:
|
||||
"""将 BotChatSession 序列化为可通过 RPC 传输的字典"""
|
||||
return {
|
||||
"session_id": getattr(stream, "session_id", ""),
|
||||
@@ -806,7 +804,7 @@ class PluginRuntimeManager:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_chat_get_all_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_chat_get_all_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取所有聊天流
|
||||
|
||||
args: platform?
|
||||
@@ -825,7 +823,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_chat_get_group_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_chat_get_group_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取所有群聊流
|
||||
|
||||
args: platform?
|
||||
@@ -844,7 +842,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_chat_get_private_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_chat_get_private_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取所有私聊流
|
||||
|
||||
args: platform?
|
||||
@@ -863,7 +861,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_chat_get_stream_by_group_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_chat_get_stream_by_group_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""按群 ID 查找聊天流
|
||||
|
||||
args: group_id, platform?
|
||||
@@ -885,7 +883,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_chat_get_stream_by_user_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_chat_get_stream_by_user_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""按用户 ID 查找私聊流
|
||||
|
||||
args: user_id, platform?
|
||||
@@ -911,9 +909,9 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
def _serialize_messages(messages: list) -> list[dict[str, Any]]:
|
||||
def _serialize_messages(messages: list) -> List[Dict[str, Any]]:
|
||||
"""将 DatabaseMessages 列表序列化为 dict 列表"""
|
||||
result: list[dict[str, Any]] = []
|
||||
result: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, "model_dump"):
|
||||
result.append(msg.model_dump())
|
||||
@@ -924,7 +922,7 @@ class PluginRuntimeManager:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def _cap_message_get_by_time(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_message_get_by_time(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""按时间范围查询消息
|
||||
|
||||
args: start_time, end_time, limit?, filter_mai?
|
||||
@@ -948,7 +946,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_message_get_by_time_in_chat(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_message_get_by_time_in_chat(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""按时间范围查询指定聊天消息
|
||||
|
||||
args: chat_id, start_time, end_time, limit?, filter_mai?, filter_command?
|
||||
@@ -975,7 +973,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_message_get_recent(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_message_get_recent(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取最近的消息
|
||||
|
||||
args: chat_id, hours?, limit?, filter_mai?
|
||||
@@ -1000,7 +998,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_message_count_new(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_message_count_new(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""统计新消息数量
|
||||
|
||||
args: chat_id, start_time?, end_time?
|
||||
@@ -1023,7 +1021,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_message_build_readable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_message_build_readable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""将消息列表构建成可读字符串
|
||||
|
||||
args: chat_id, start_time, end_time, limit?, replace_bot_name?, timestamp_mode?
|
||||
@@ -1057,7 +1055,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_person_get_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_person_get_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取 person_id
|
||||
|
||||
args: platform, user_id
|
||||
@@ -1077,7 +1075,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_person_get_value(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_person_get_value(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取用户字段值
|
||||
|
||||
args: person_id, field_name, default?
|
||||
@@ -1101,7 +1099,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_person_get_id_by_name(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_person_get_id_by_name(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""根据用户名获取 person_id
|
||||
|
||||
args: person_name
|
||||
@@ -1124,7 +1122,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_get_by_description(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_get_by_description(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""根据描述获取表情包
|
||||
|
||||
args: description
|
||||
@@ -1153,7 +1151,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_get_random(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_get_random(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""随机获取表情包
|
||||
|
||||
args: count?
|
||||
@@ -1172,7 +1170,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_get_count(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_get_count(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取表情包数量"""
|
||||
from src.services import emoji_service as emoji_api
|
||||
|
||||
@@ -1183,7 +1181,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_get_emotions(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_get_emotions(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取所有情绪标签"""
|
||||
from src.services import emoji_service as emoji_api
|
||||
|
||||
@@ -1194,7 +1192,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_get_all(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_get_all(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取所有表情包"""
|
||||
from src.services import emoji_service as emoji_api
|
||||
|
||||
@@ -1209,7 +1207,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_get_info(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_get_info(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取表情包统计信息"""
|
||||
from src.services import emoji_service as emoji_api
|
||||
|
||||
@@ -1220,7 +1218,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_register(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_register(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""注册表情包
|
||||
|
||||
args: emoji_base64
|
||||
@@ -1239,7 +1237,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_emoji_delete(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_emoji_delete(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""删除表情包
|
||||
|
||||
args: emoji_hash
|
||||
@@ -1262,7 +1260,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_frequency_get_current_talk_value(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_frequency_get_current_talk_value(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取当前说话频率值
|
||||
|
||||
args: chat_id
|
||||
@@ -1281,7 +1279,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_frequency_set_adjust(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_frequency_set_adjust(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""设置说话频率调整值
|
||||
|
||||
args: chat_id, value
|
||||
@@ -1301,7 +1299,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_frequency_get_adjust(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_frequency_get_adjust(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取说话频率调整值
|
||||
|
||||
args: chat_id
|
||||
@@ -1324,7 +1322,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_tool_get_definitions(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_tool_get_definitions(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取 LLM 可用的工具定义列表"""
|
||||
from src.core.component_registry import component_registry as core_registry
|
||||
|
||||
@@ -1343,10 +1341,10 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_get_all_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_get_all_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取所有插件信息(汇总所有 Supervisor 的注册信息)"""
|
||||
mgr = get_plugin_runtime_manager()
|
||||
result: dict[str, Any] = {}
|
||||
result: Dict[str, Any] = {}
|
||||
for sv in mgr.supervisors:
|
||||
for pid, reg in sv._registered_plugins.items():
|
||||
result[pid] = {
|
||||
@@ -1359,7 +1357,7 @@ class PluginRuntimeManager:
|
||||
return {"success": True, "plugins": result}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_get_plugin_info(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_get_plugin_info(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""获取指定插件信息
|
||||
|
||||
args: plugin_name
|
||||
@@ -1382,25 +1380,25 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": f"未找到插件: {plugin_name}"}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_list_loaded_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_list_loaded_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""列出已加载的插件"""
|
||||
mgr = get_plugin_runtime_manager()
|
||||
plugins: list[str] = []
|
||||
plugins: List[str] = []
|
||||
for sv in mgr.supervisors:
|
||||
plugins.extend(sv._registered_plugins.keys())
|
||||
return {"success": True, "plugins": plugins}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_list_registered_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_list_registered_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""列出已注册的插件(同 list_loaded)"""
|
||||
mgr = get_plugin_runtime_manager()
|
||||
plugins: list[str] = []
|
||||
plugins: List[str] = []
|
||||
for sv in mgr.supervisors:
|
||||
plugins.extend(sv._registered_plugins.keys())
|
||||
return {"success": True, "plugins": plugins}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_enable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_enable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""启用组件
|
||||
|
||||
args: name, component_type
|
||||
@@ -1419,7 +1417,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": f"未找到组件: {name}"}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_disable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_disable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""禁用组件
|
||||
|
||||
args: name, component_type
|
||||
@@ -1438,7 +1436,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": f"未找到组件: {name}"}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_load_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_load_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""加载插件(在新运行时中通过热重载实现)
|
||||
|
||||
args: plugin_name
|
||||
@@ -1457,7 +1455,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": f"无法加载插件: {plugin_name}"}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_unload_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_unload_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""卸载插件(在新运行时中不支持单独卸载)
|
||||
|
||||
args: plugin_name
|
||||
@@ -1465,7 +1463,7 @@ class PluginRuntimeManager:
|
||||
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
|
||||
|
||||
@staticmethod
|
||||
async def _cap_component_reload_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_component_reload_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""重新加载插件(触发对应 Supervisor 的热重载)
|
||||
|
||||
args: plugin_name
|
||||
@@ -1490,7 +1488,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_knowledge_search(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_knowledge_search(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""从 LPMM 知识库搜索知识
|
||||
|
||||
args: query, limit?
|
||||
@@ -1526,7 +1524,7 @@ class PluginRuntimeManager:
|
||||
# ═════════════════════════════════════════════════════════
|
||||
|
||||
@staticmethod
|
||||
async def _cap_logging_log(plugin_id: str, capability: str, args: dict[str, Any]) -> Any:
|
||||
async def _cap_logging_log(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
"""插件日志记录
|
||||
|
||||
args: level?, message
|
||||
@@ -1544,7 +1542,7 @@ class PluginRuntimeManager:
|
||||
|
||||
# ─── 单例 ──────────────────────────────────────────────────
|
||||
|
||||
_manager: PluginRuntimeManager | None = None
|
||||
_manager: Optional[PluginRuntimeManager] = None
|
||||
|
||||
|
||||
def get_plugin_runtime_manager() -> PluginRuntimeManager:
|
||||
|
||||
@@ -1,35 +1,40 @@
|
||||
"""MsgPack 编解码器"""
|
||||
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
import msgpack
|
||||
|
||||
from .envelope import Envelope
|
||||
|
||||
|
||||
class Codec:
|
||||
class Codec(ABC):
|
||||
"""消息编解码器基类"""
|
||||
|
||||
@abstractmethod
|
||||
def encode_envelope(self, envelope: Envelope) -> bytes:
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode_envelope(self, data: bytes) -> Envelope:
|
||||
raise NotImplementedError
|
||||
...
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def encode(self, obj: Dict[str, Any]) -> bytes:
|
||||
...
|
||||
|
||||
def decode(self, data: bytes) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def decode(self, data: bytes) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
|
||||
class MsgPackCodec(Codec):
|
||||
"""MsgPack 编解码器"""
|
||||
|
||||
def encode(self, obj: dict[str, Any]) -> bytes:
|
||||
def encode(self, obj: Dict[str, Any]) -> bytes:
|
||||
return msgpack.packb(obj, use_bin_type=True)
|
||||
|
||||
def decode(self, data: bytes) -> dict[str, Any]:
|
||||
def decode(self, data: bytes) -> Dict[str, Any]:
|
||||
result = msgpack.unpackb(data, raw=False)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(f"期望解码为 dict,实际为 {type(result)}")
|
||||
|
||||
@@ -5,9 +5,8 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import time
|
||||
|
||||
@@ -62,8 +61,8 @@ class Envelope(BaseModel):
|
||||
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
|
||||
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
|
||||
generation: int = Field(default=0, description="Runner generation 编号")
|
||||
payload: dict[str, Any] = Field(default_factory=dict, description="业务数据")
|
||||
error: dict[str, Any] | None = Field(default=None, description="错误信息(仅 response)")
|
||||
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
|
||||
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
|
||||
|
||||
def is_request(self) -> bool:
|
||||
return self.message_type == MessageType.REQUEST
|
||||
@@ -74,7 +73,7 @@ class Envelope(BaseModel):
|
||||
def is_event(self) -> bool:
|
||||
return self.message_type == MessageType.EVENT
|
||||
|
||||
def make_response(self, payload: dict[str, Any] | None = None, error: dict[str, Any] | None = None) -> "Envelope":
|
||||
def make_response(self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None) -> "Envelope":
|
||||
"""基于当前请求创建对应的响应信封"""
|
||||
return Envelope(
|
||||
protocol_version=self.protocol_version,
|
||||
@@ -87,7 +86,7 @@ class Envelope(BaseModel):
|
||||
error=error,
|
||||
)
|
||||
|
||||
def make_error_response(self, code: str, message: str = "", details: dict | None = None) -> "Envelope":
|
||||
def make_error_response(self, code: str, message: str = "", details: Optional[Dict[str, Any]] = None) -> "Envelope":
|
||||
"""基于当前请求创建错误响应"""
|
||||
return self.make_response(
|
||||
error={
|
||||
@@ -122,15 +121,15 @@ class ComponentDeclaration(BaseModel):
|
||||
name: str = Field(description="组件名称")
|
||||
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
|
||||
plugin_id: str = Field(description="所属插件 ID")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
|
||||
|
||||
|
||||
class RegisterComponentsPayload(BaseModel):
|
||||
"""plugin.register_components 请求 payload"""
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
plugin_version: str = Field(default="1.0.0", description="插件版本")
|
||||
components: list[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
capabilities_required: list[str] = Field(default_factory=list, description="所需能力列表")
|
||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||
|
||||
|
||||
# ─── 调用消息 ──────────────────────────────────────────────────────
|
||||
@@ -138,7 +137,7 @@ class RegisterComponentsPayload(BaseModel):
|
||||
class InvokePayload(BaseModel):
|
||||
"""plugin.invoke_* 请求 payload"""
|
||||
component_name: str = Field(description="要调用的组件名称")
|
||||
args: dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
|
||||
|
||||
class InvokeResultPayload(BaseModel):
|
||||
@@ -152,7 +151,7 @@ class InvokeResultPayload(BaseModel):
|
||||
class CapabilityRequestPayload(BaseModel):
|
||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||
capability: str = Field(description="能力名称,如 send.text, db.query")
|
||||
args: dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
|
||||
|
||||
|
||||
class CapabilityResponsePayload(BaseModel):
|
||||
@@ -166,7 +165,7 @@ class CapabilityResponsePayload(BaseModel):
|
||||
class HealthPayload(BaseModel):
|
||||
"""plugin.health 响应 payload"""
|
||||
healthy: bool = Field(description="是否健康")
|
||||
loaded_plugins: list[str] = Field(default_factory=list, description="已加载的插件列表")
|
||||
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
|
||||
uptime_ms: int = Field(default=0, description="运行时长(ms)")
|
||||
|
||||
|
||||
@@ -176,7 +175,7 @@ class ConfigUpdatedPayload(BaseModel):
|
||||
"""plugin.config_updated 事件 payload"""
|
||||
plugin_id: str = Field(description="插件 ID")
|
||||
config_version: str = Field(description="新配置版本")
|
||||
config_data: dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
|
||||
|
||||
|
||||
# ─── 关停 ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
@@ -38,13 +39,13 @@ class ErrorCode(str, Enum):
|
||||
class RPCError(Exception):
|
||||
"""RPC 调用异常"""
|
||||
|
||||
def __init__(self, code: ErrorCode, message: str = "", details: dict | None = None):
|
||||
def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None):
|
||||
self.code = code
|
||||
self.message = message or code.value
|
||||
self.details = details or {}
|
||||
super().__init__(f"[{code.value}] {self.message}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"code": self.code.value,
|
||||
"message": self.message,
|
||||
@@ -52,7 +53,7 @@ class RPCError(Exception):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "RPCError":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RPCError":
|
||||
code = ErrorCode(data.get("code", "E_UNKNOWN"))
|
||||
return cls(
|
||||
code=code,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
适配新 plugin_runtime 的 _manifest.json 格式。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import re
|
||||
|
||||
@@ -29,7 +29,7 @@ class VersionComparator:
|
||||
return ".".join(parts[:3])
|
||||
|
||||
@staticmethod
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
def parse_version(version: str) -> Tuple[int, int, int]:
|
||||
normalized = VersionComparator.normalize_version(version)
|
||||
try:
|
||||
parts = normalized.split(".")
|
||||
@@ -48,7 +48,7 @@ class VersionComparator:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]:
|
||||
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
|
||||
if not min_version and not max_version:
|
||||
return True, ""
|
||||
vn = VersionComparator.normalize_version(version)
|
||||
@@ -72,10 +72,10 @@ class ManifestValidator:
|
||||
|
||||
def __init__(self, host_version: str = ""):
|
||||
self._host_version = host_version
|
||||
self.errors: list[str] = []
|
||||
self.warnings: list[str] = []
|
||||
self.errors: List[str] = []
|
||||
self.warnings: List[str] = []
|
||||
|
||||
def validate(self, manifest: dict[str, Any]) -> bool:
|
||||
def validate(self, manifest: Dict[str, Any]) -> bool:
|
||||
"""校验 manifest 数据,返回是否通过(errors 为空即通过)。"""
|
||||
self.errors.clear()
|
||||
self.warnings.clear()
|
||||
@@ -95,21 +95,21 @@ class ManifestValidator:
|
||||
|
||||
return len(self.errors) == 0
|
||||
|
||||
def _check_required_fields(self, manifest: dict[str, Any]) -> None:
|
||||
def _check_required_fields(self, manifest: Dict[str, Any]) -> None:
|
||||
for field in self.REQUIRED_FIELDS:
|
||||
if field not in manifest:
|
||||
self.errors.append(f"缺少必需字段: {field}")
|
||||
elif not manifest[field]:
|
||||
self.errors.append(f"必需字段不能为空: {field}")
|
||||
|
||||
def _check_manifest_version(self, manifest: dict[str, Any]) -> None:
|
||||
def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
|
||||
mv = manifest.get("manifest_version")
|
||||
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
|
||||
self.errors.append(
|
||||
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
|
||||
)
|
||||
|
||||
def _check_author(self, manifest: dict[str, Any]) -> None:
|
||||
def _check_author(self, manifest: Dict[str, Any]) -> None:
|
||||
author = manifest.get("author")
|
||||
if author is None:
|
||||
return
|
||||
@@ -122,7 +122,7 @@ class ManifestValidator:
|
||||
else:
|
||||
self.errors.append("author 应为字符串或 {name, url} 对象")
|
||||
|
||||
def _check_host_compatibility(self, manifest: dict[str, Any]) -> None:
|
||||
def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None:
|
||||
host_app = manifest.get("host_application")
|
||||
if not isinstance(host_app, dict) or not self._host_version:
|
||||
return
|
||||
@@ -132,7 +132,7 @@ class ManifestValidator:
|
||||
if not ok:
|
||||
self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
|
||||
|
||||
def _check_recommended(self, manifest: dict[str, Any]) -> None:
|
||||
def _check_recommended(self, manifest: Dict[str, Any]) -> None:
|
||||
for field in self.RECOMMENDED_FIELDS:
|
||||
if field not in manifest or not manifest[field]:
|
||||
self.warnings.append(f"建议填写字段: {field}")
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
@@ -29,7 +29,7 @@ class PluginMeta:
|
||||
plugin_id: str,
|
||||
plugin_dir: str,
|
||||
plugin_instance: Any,
|
||||
manifest: dict[str, Any],
|
||||
manifest: Dict[str, Any],
|
||||
):
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_dir = plugin_dir
|
||||
@@ -37,12 +37,12 @@ class PluginMeta:
|
||||
self.manifest = manifest
|
||||
self.version = manifest.get("version", "1.0.0")
|
||||
self.capabilities_required = manifest.get("capabilities", [])
|
||||
self.dependencies: list[str] = self._extract_dependencies(manifest)
|
||||
self.dependencies: List[str] = self._extract_dependencies(manifest)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dependencies(manifest: dict[str, Any]) -> list[str]:
|
||||
def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
|
||||
raw = manifest.get("dependencies", [])
|
||||
result: list[str] = []
|
||||
result: List[str] = []
|
||||
for dep in raw:
|
||||
if isinstance(dep, str):
|
||||
result.append(dep.strip())
|
||||
@@ -62,12 +62,12 @@ class PluginLoader:
|
||||
"""
|
||||
|
||||
def __init__(self, host_version: str = ""):
|
||||
self._loaded_plugins: dict[str, PluginMeta] = {}
|
||||
self._failed_plugins: dict[str, str] = {}
|
||||
self._loaded_plugins: Dict[str, PluginMeta] = {}
|
||||
self._failed_plugins: Dict[str, str] = {}
|
||||
self._manifest_validator = ManifestValidator(host_version=host_version)
|
||||
self._compat_hook_installed = False
|
||||
|
||||
def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]:
|
||||
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
|
||||
"""扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
|
||||
|
||||
Args:
|
||||
@@ -77,7 +77,7 @@ class PluginLoader:
|
||||
成功加载的插件元数据列表(按依赖顺序)
|
||||
"""
|
||||
# 第一阶段:发现并校验 manifest
|
||||
candidates: dict[str, tuple[str, dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path)
|
||||
candidates: Dict[str, Tuple[str, Dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path)
|
||||
for base_dir in plugin_dirs:
|
||||
if not os.path.isdir(base_dir):
|
||||
logger.warning(f"插件目录不存在: {base_dir}")
|
||||
@@ -131,33 +131,33 @@ class PluginLoader:
|
||||
|
||||
return results
|
||||
|
||||
def get_plugin(self, plugin_id: str) -> PluginMeta | None:
|
||||
def get_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
|
||||
"""获取已加载的插件"""
|
||||
return self._loaded_plugins.get(plugin_id)
|
||||
|
||||
def list_plugins(self) -> list[str]:
|
||||
def list_plugins(self) -> List[str]:
|
||||
"""列出所有已加载的插件 ID"""
|
||||
return list(self._loaded_plugins.keys())
|
||||
|
||||
@property
|
||||
def failed_plugins(self) -> dict[str, str]:
|
||||
def failed_plugins(self) -> Dict[str, str]:
|
||||
return dict(self._failed_plugins)
|
||||
|
||||
# ──── 依赖解析 ────────────────────────────────────────────
|
||||
|
||||
def _resolve_dependencies(
|
||||
self,
|
||||
candidates: dict[str, tuple[str, dict[str, Any], str]],
|
||||
) -> tuple[list[str], dict[str, str]]:
|
||||
candidates: Dict[str, Tuple[str, Dict[str, Any], str]],
|
||||
) -> Tuple[List[str], Dict[str, str]]:
|
||||
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
|
||||
available = set(candidates.keys())
|
||||
dep_graph: dict[str, set[str]] = {}
|
||||
failed: dict[str, str] = {}
|
||||
dep_graph: Dict[str, Set[str]] = {}
|
||||
failed: Dict[str, str] = {}
|
||||
|
||||
for pid, (_, manifest, _) in candidates.items():
|
||||
raw_deps = manifest.get("dependencies", [])
|
||||
resolved: set[str] = set()
|
||||
missing: list[str] = []
|
||||
resolved: Set[str] = set()
|
||||
missing: List[str] = []
|
||||
for dep in raw_deps:
|
||||
dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
|
||||
dep_name = dep_name.strip()
|
||||
@@ -177,14 +177,14 @@ class PluginLoader:
|
||||
|
||||
# Kahn 拓扑排序
|
||||
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
|
||||
reverse: dict[str, set[str]] = {pid: set() for pid in dep_graph}
|
||||
reverse: Dict[str, Set[str]] = {pid: set() for pid in dep_graph}
|
||||
for pid, deps in dep_graph.items():
|
||||
for d in deps:
|
||||
if d in reverse:
|
||||
reverse[d].add(pid)
|
||||
|
||||
queue = deque(sorted(pid for pid, deg in indegree.items() if deg == 0))
|
||||
sorted_order: list[str] = []
|
||||
sorted_order: List[str] = []
|
||||
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
@@ -206,9 +206,9 @@ class PluginLoader:
|
||||
self,
|
||||
plugin_id: str,
|
||||
plugin_dir: str,
|
||||
manifest: dict[str, Any],
|
||||
manifest: Dict[str, Any],
|
||||
plugin_path: str,
|
||||
) -> PluginMeta | None:
|
||||
) -> Optional[PluginMeta]:
|
||||
"""加载单个插件"""
|
||||
# 确保兼容层导入钩子已安装(旧版插件可能 import src.plugin_system)
|
||||
self._ensure_compat_hook()
|
||||
@@ -267,7 +267,7 @@ class PluginLoader:
|
||||
logger.debug("maibot_sdk.compat 不可用,跳过导入钩子安装")
|
||||
|
||||
@staticmethod
|
||||
def _try_load_legacy_plugin(module: Any, plugin_id: str) -> Any | None:
|
||||
def _try_load_legacy_plugin(module: Any, plugin_id: str) -> Optional[Any]:
|
||||
"""尝试从模块中发现旧版 BasePlugin 子类并包装为 LegacyPluginAdapter"""
|
||||
# 方式 1: @register_plugin 装饰器设置的标记
|
||||
legacy_cls = getattr(module, "_legacy_plugin_class", None)
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
5. 发送能力调用请求到 Host
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Awaitable
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
@@ -45,26 +45,26 @@ class RPCClient:
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
codec: Codec | None = None,
|
||||
codec: Optional[Codec] = None,
|
||||
):
|
||||
self._host_address = host_address
|
||||
self._session_token = session_token
|
||||
self._codec = codec or MsgPackCodec()
|
||||
|
||||
self._id_gen = RequestIdGenerator()
|
||||
self._connection: Connection | None = None
|
||||
self._connection: Optional[Connection] = None
|
||||
self._runner_id = str(uuid.uuid4())
|
||||
self._generation: int = 0
|
||||
|
||||
# 方法处理器注册表(Host 发来的调用)
|
||||
self._method_handlers: dict[str, MethodHandler] = {}
|
||||
self._method_handlers: Dict[str, MethodHandler] = {}
|
||||
|
||||
# 等待响应的 pending 请求: request_id -> Future
|
||||
self._pending_requests: dict[int, asyncio.Future] = {}
|
||||
self._pending_requests: Dict[int, asyncio.Future] = {}
|
||||
|
||||
# 运行状态
|
||||
self._running = False
|
||||
self._recv_task: asyncio.Task | None = None
|
||||
self._recv_task: Optional[asyncio.Task] = None
|
||||
|
||||
@property
|
||||
def generation(self) -> int:
|
||||
@@ -147,7 +147,7 @@ class RPCClient:
|
||||
self,
|
||||
method: str,
|
||||
plugin_id: str = "",
|
||||
payload: dict[str, Any] | None = None,
|
||||
payload: Optional[Dict[str, Any]] = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Envelope:
|
||||
"""向 Host 发送 RPC 请求并等待响应"""
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
6. 转发插件的能力调用到 Host
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import inspect
|
||||
@@ -43,7 +45,7 @@ class PluginRunner:
|
||||
self,
|
||||
host_address: str,
|
||||
session_token: str,
|
||||
plugin_dirs: list[str],
|
||||
plugin_dirs: List[str],
|
||||
) -> None:
|
||||
self._host_address: str = host_address
|
||||
self._session_token: str = session_token
|
||||
@@ -114,7 +116,7 @@ class PluginRunner:
|
||||
async def _register_plugin(self, meta: PluginMeta) -> None:
|
||||
"""向 Host 注册单个插件"""
|
||||
# 收集插件组件声明
|
||||
components: list[ComponentDeclaration] = []
|
||||
components: List[ComponentDeclaration] = []
|
||||
instance = meta.instance
|
||||
|
||||
# 从插件实例获取组件声明(SDK 插件须实现 get_components 方法)
|
||||
@@ -284,7 +286,7 @@ class PluginRunner:
|
||||
|
||||
# ─── sys.path 隔离 ────────────────────────────────────────
|
||||
|
||||
def _isolate_sys_path(plugin_dirs: list[str]) -> None:
|
||||
def _isolate_sys_path(plugin_dirs: List[str]) -> None:
|
||||
"""清理 sys.path,限制 Runner 子进程只能访问标准库、SDK 和插件目录。
|
||||
|
||||
防止插件代码 import 主程序模块读取运行时数据。
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterator, Callable, Awaitable
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
根据运行平台自动选择最优传输实现。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import sys
|
||||
|
||||
from .base import TransportClient, TransportServer
|
||||
|
||||
|
||||
def create_transport_server(socket_path: str | None = None) -> TransportServer:
|
||||
def create_transport_server(socket_path: Optional[str] = None) -> TransportServer:
|
||||
"""创建传输服务端
|
||||
|
||||
Linux/macOS 使用 UDS,Windows 使用 TCP 回退。
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
|
||||
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
|
||||
@@ -20,7 +22,7 @@ class TCPTransportServer(TransportServer):
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 0):
|
||||
self._host = host
|
||||
self._port = port # 0 表示自动分配
|
||||
self._server: asyncio.AbstractServer | None = None
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
self._actual_port: int = 0
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
@@ -20,13 +21,13 @@ class UDSConnection(Connection):
|
||||
class UDSTransportServer(TransportServer):
|
||||
"""UDS 传输服务端"""
|
||||
|
||||
def __init__(self, socket_path: str | None = None):
|
||||
def __init__(self, socket_path: Optional[str] = None):
|
||||
if socket_path is None:
|
||||
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
|
||||
import uuid
|
||||
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
|
||||
self._socket_path = socket_path
|
||||
self._server: asyncio.AbstractServer | None = None
|
||||
self._server: Optional[asyncio.AbstractServer] = None
|
||||
|
||||
async def start(self, handler: ConnectionHandler) -> None:
|
||||
# 清理残留 socket 文件
|
||||
|
||||
Reference in New Issue
Block a user