Refactor protocol and transport modules to use type hints for improved clarity and consistency

- Updated Codec class to use abstract methods for encoding and decoding envelopes.
- Changed Envelope class to use Dict and Optional for payload and error fields.
- Refined error handling in RPCError class with Optional type hints for details.
- Enhanced manifest validation logic with type hints for better type safety.
- Improved plugin loading mechanism with consistent type annotations.
- Updated RPCClient to utilize Optional for codec and connection attributes.
- Refactored transport classes to use Optional for server attributes and socket paths.
This commit is contained in:
DrSmoothl
2026-03-11 00:07:13 +08:00
parent 7f1e79ea28
commit 69219e36f7
19 changed files with 273 additions and 253 deletions

View File

@@ -4,22 +4,21 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
每个能力方法被注册到 RPC Server接收 Runner 转发的请求并执行实际操作。 每个能力方法被注册到 RPC Server接收 Runner 转发的请求并执行实际操作。
""" """
from typing import Any, Callable, Awaitable from typing import Any, Awaitable, Callable, Dict, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.protocol.envelope import ( from src.plugin_runtime.protocol.envelope import (
CapabilityRequestPayload, CapabilityRequestPayload,
CapabilityResponsePayload, CapabilityResponsePayload,
Envelope, Envelope,
) )
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
from src.plugin_runtime.host.policy_engine import PolicyEngine
logger = get_logger("plugin_runtime.host.capability_service") logger = get_logger("plugin_runtime.host.capability_service")
# 能力实现函数类型: (plugin_id, capability, args) -> result # 能力实现函数类型: (plugin_id, capability, args) -> result
CapabilityImpl = Callable[[str, str, dict[str, Any]], Awaitable[Any]] CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
class CapabilityService: class CapabilityService:
@@ -35,7 +34,7 @@ class CapabilityService:
def __init__(self, policy_engine: PolicyEngine): def __init__(self, policy_engine: PolicyEngine):
self._policy = policy_engine self._policy = policy_engine
# capability_name -> implementation # capability_name -> implementation
self._implementations: dict[str, CapabilityImpl] = {} self._implementations: Dict[str, CapabilityImpl] = {}
def register_capability(self, name: str, impl: CapabilityImpl) -> None: def register_capability(self, name: str, impl: CapabilityImpl) -> None:
"""注册一个能力实现 """注册一个能力实现
@@ -95,6 +94,6 @@ class CapabilityService:
str(e), str(e),
) )
def list_capabilities(self) -> list[str]: def list_capabilities(self) -> List[str]:
"""列出所有已注册的能力""" """列出所有已注册的能力"""
return list(self._implementations.keys()) return list(self._implementations.keys())

View File

@@ -9,9 +9,10 @@
- 注册统计 - 注册统计
""" """
from typing import Any from typing import Any, Dict, List, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
import re import re
logger = get_logger("plugin_runtime.host.component_registry") logger = get_logger("plugin_runtime.host.component_registry")
@@ -30,7 +31,7 @@ class RegisteredComponent:
name: str, name: str,
component_type: str, component_type: str,
plugin_id: str, plugin_id: str,
metadata: dict[str, Any], metadata: Dict[str, Any],
): ):
self.name = name self.name = name
self.full_name = f"{plugin_id}.{name}" self.full_name = f"{plugin_id}.{name}"
@@ -40,7 +41,7 @@ class RegisteredComponent:
self.enabled = metadata.get("enabled", True) self.enabled = metadata.get("enabled", True)
# 预编译命令正则(仅 command 类型) # 预编译命令正则(仅 command 类型)
self._compiled_pattern: re.Pattern | None = None self._compiled_pattern: Optional[re.Pattern] = None
if component_type == "command": if component_type == "command":
if pattern := metadata.get("command_pattern", ""): if pattern := metadata.get("command_pattern", ""):
try: try:
@@ -58,10 +59,10 @@ class ComponentRegistry:
def __init__(self): def __init__(self):
# 全量索引 # 全量索引
self._components: dict[str, RegisteredComponent] = {} # full_name -> comp self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
# 按类型索引 # 按类型索引
self._by_type: dict[str, dict[str, RegisteredComponent]] = { self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
"action": {}, "action": {},
"command": {}, "command": {},
"tool": {}, "tool": {},
@@ -70,7 +71,7 @@ class ComponentRegistry:
} }
# 按插件索引 # 按插件索引
self._by_plugin: dict[str, list[RegisteredComponent]] = {} self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
# ──── 注册 / 注销 ───────────────────────────────────────── # ──── 注册 / 注销 ─────────────────────────────────────────
@@ -79,7 +80,7 @@ class ComponentRegistry:
name: str, name: str,
component_type: str, component_type: str,
plugin_id: str, plugin_id: str,
metadata: dict[str, Any], metadata: Dict[str, Any],
) -> bool: ) -> bool:
"""注册单个组件。""" """注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata) comp = RegisteredComponent(name, component_type, plugin_id, metadata)
@@ -99,7 +100,7 @@ class ComponentRegistry:
def register_plugin_components( def register_plugin_components(
self, self,
plugin_id: str, plugin_id: str,
components: list[dict[str, Any]], components: List[Dict[str, Any]],
) -> int: ) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。""" """批量注册一个插件的所有组件,返回成功注册数。"""
count = 0 count = 0
@@ -142,13 +143,13 @@ class ComponentRegistry:
# ──── 查询方法 ───────────────────────────────────────────── # ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> RegisteredComponent | None: def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
"""按全名查询。""" """按全名查询。"""
return self._components.get(full_name) return self._components.get(full_name)
def get_components_by_type( def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True self, component_type: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]: ) -> List[RegisteredComponent]:
"""按类型查询。""" """按类型查询。"""
type_dict = self._by_type.get(component_type, {}) type_dict = self._by_type.get(component_type, {})
if enabled_only: if enabled_only:
@@ -157,12 +158,12 @@ class ComponentRegistry:
def get_components_by_plugin( def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True self, plugin_id: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]: ) -> List[RegisteredComponent]:
"""按插件查询。""" """按插件查询。"""
comps = self._by_plugin.get(plugin_id, []) comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps) return [c for c in comps if c.enabled] if enabled_only else list(comps)
def find_command_by_text(self, text: str) -> RegisteredComponent | None: def find_command_by_text(self, text: str) -> Optional[RegisteredComponent]:
"""通过文本匹配命令正则,返回第一个匹配的 command 组件。""" """通过文本匹配命令正则,返回第一个匹配的 command 组件。"""
for comp in self._by_type.get("command", {}).values(): for comp in self._by_type.get("command", {}).values():
if not comp.enabled: if not comp.enabled:
@@ -178,7 +179,7 @@ class ComponentRegistry:
def get_event_handlers( def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True self, event_type: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]: ) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。""" """获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = [] handlers = []
for comp in self._by_type.get("event_handler", {}).values(): for comp in self._by_type.get("event_handler", {}).values():
@@ -191,7 +192,7 @@ class ComponentRegistry:
def get_workflow_steps( def get_workflow_steps(
self, stage: str, *, enabled_only: bool = True self, stage: str, *, enabled_only: bool = True
) -> list[RegisteredComponent]: ) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。""" """获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = [] steps = []
for comp in self._by_type.get("workflow_step", {}).values(): for comp in self._by_type.get("workflow_step", {}).values():
@@ -202,11 +203,11 @@ class ComponentRegistry:
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True) steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps return steps
def get_tools_for_llm(self, *, enabled_only: bool = True) -> list[dict[str, Any]]: def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。""" """获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result = [] result: List[Dict[str, Any]] = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only): for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: dict[str, Any] = { tool_def: Dict[str, Any] = {
"name": comp.full_name, "name": comp.full_name,
"description": comp.metadata.get("description", ""), "description": comp.metadata.get("description", ""),
} }
@@ -222,9 +223,9 @@ class ComponentRegistry:
# ──── 统计 ───────────────────────────────────────────────── # ──── 统计 ─────────────────────────────────────────────────
def get_stats(self) -> dict[str, int]: def get_stats(self) -> Dict[str, int]:
"""获取注册统计。""" """获取注册统计。"""
stats: dict[str, int] = {"total": len(self._components)} stats: Dict[str, int] = {"total": len(self._components)}
for comp_type, type_dict in self._by_type.items(): for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict) stats[comp_type] = len(type_dict)
stats["plugins"] = len(self._by_plugin) stats["plugins"] = len(self._by_plugin)

View File

@@ -7,7 +7,7 @@
4. 事件结果历史记录 4. 事件结果历史记录
""" """
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import asyncio import asyncio
@@ -17,7 +17,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
logger = get_logger("plugin_runtime.host.event_dispatcher") logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict # invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]] InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class EventResult: class EventResult:
@@ -29,7 +29,7 @@ class EventResult:
handler_name: str, handler_name: str,
success: bool = True, success: bool = True,
continue_processing: bool = True, continue_processing: bool = True,
modified_message: dict[str, Any] | None = None, modified_message: Optional[Dict[str, Any]] = None,
custom_result: Any = None, custom_result: Any = None,
): ):
self.handler_name = handler_name self.handler_name = handler_name
@@ -49,14 +49,14 @@ class EventDispatcher:
def __init__(self, registry: ComponentRegistry) -> None: def __init__(self, registry: ComponentRegistry) -> None:
self._registry: ComponentRegistry = registry self._registry: ComponentRegistry = registry
self._result_history: dict[str, list[EventResult]] = {} self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: set[str] = set() self._history_enabled: Set[str] = set()
def enable_history(self, event_type: str) -> None: def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type) self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, []) self._result_history.setdefault(event_type, [])
def get_history(self, event_type: str) -> list[EventResult]: def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, []) return self._result_history.get(event_type, [])
def clear_history(self, event_type: str) -> None: def clear_history(self, event_type: str) -> None:
@@ -67,9 +67,9 @@ class EventDispatcher:
self, self,
event_type: str, event_type: str,
invoke_fn: InvokeFn, invoke_fn: InvokeFn,
message: dict[str, Any] | None = None, message: Optional[Dict[str, Any]] = None,
extra_args: dict[str, Any] | None = None, extra_args: Optional[Dict[str, Any]] = None,
) -> tuple[bool, dict[str, Any] | None]: ) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler。 """分发事件到所有对应 handler。
Args: Args:
@@ -86,8 +86,8 @@ class EventDispatcher:
return True, None return True, None
should_continue = True should_continue = True
modified_message: dict[str, Any] | None = None modified_message: Optional[Dict[str, Any]] = None
fire_and_forget_tasks: list[asyncio.Task] = [] fire_and_forget_tasks: List[asyncio.Task] = []
for handler in handlers: for handler in handlers:
intercept = handler.metadata.get("intercept_message", False) intercept = handler.metadata.get("intercept_message", False)
@@ -122,9 +122,9 @@ class EventDispatcher:
self, self,
invoke_fn: InvokeFn, invoke_fn: InvokeFn,
handler: RegisteredComponent, handler: RegisteredComponent,
args: dict[str, Any], args: Dict[str, Any],
event_type: str, event_type: str,
) -> EventResult | None: ) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。""" """调用单个 handler 并收集结果。"""
try: try:
resp = await invoke_fn(handler.plugin_id, handler.name, args) resp = await invoke_fn(handler.plugin_id, handler.name, args)

View File

@@ -5,6 +5,7 @@
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
@dataclass @dataclass
@@ -12,7 +13,7 @@ class CapabilityToken:
"""能力令牌""" """能力令牌"""
plugin_id: str plugin_id: str
generation: int generation: int
capabilities: set[str] = field(default_factory=set) capabilities: Set[str] = field(default_factory=set)
class PolicyEngine: class PolicyEngine:
@@ -22,13 +23,13 @@ class PolicyEngine:
""" """
def __init__(self): def __init__(self):
self._tokens: dict[str, CapabilityToken] = {} self._tokens: Dict[str, CapabilityToken] = {}
def register_plugin( def register_plugin(
self, self,
plugin_id: str, plugin_id: str,
generation: int, generation: int,
capabilities: list[str], capabilities: List[str],
) -> CapabilityToken: ) -> CapabilityToken:
"""为插件签发能力令牌""" """为插件签发能力令牌"""
token = CapabilityToken( token = CapabilityToken(
@@ -43,7 +44,7 @@ class PolicyEngine:
"""撤销插件的能力令牌""" """撤销插件的能力令牌"""
self._tokens.pop(plugin_id, None) self._tokens.pop(plugin_id, None)
def check_capability(self, plugin_id: str, capability: str) -> tuple[bool, str]: def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
"""检查插件是否有权调用某项能力 """检查插件是否有权调用某项能力
Returns: Returns:
@@ -58,10 +59,10 @@ class PolicyEngine:
return True, "" return True, ""
def get_token(self, plugin_id: str) -> CapabilityToken | None: def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
"""获取插件的能力令牌""" """获取插件的能力令牌"""
return self._tokens.get(plugin_id) return self._tokens.get(plugin_id)
def list_plugins(self) -> list[str]: def list_plugins(self) -> List[str]:
"""列出所有已注册的插件""" """列出所有已注册的插件"""
return list(self._tokens.keys()) return list(self._tokens.keys())

View File

@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理 4. 请求-响应关联与超时管理
""" """
from typing import Any, Callable, Awaitable from typing import Any, Awaitable, Callable, Dict, List, Optional
import asyncio import asyncio
import secrets import secrets
@@ -42,8 +42,8 @@ class RPCServer:
def __init__( def __init__(
self, self,
transport: TransportServer, transport: TransportServer,
session_token: str | None = None, session_token: Optional[str] = None,
codec: Codec | None = None, codec: Optional[Codec] = None,
send_queue_size: int = 128, send_queue_size: int = 128,
): ):
self._transport = transport self._transport = transport
@@ -52,22 +52,22 @@ class RPCServer:
self._send_queue_size = send_queue_size self._send_queue_size = send_queue_size
self._id_gen = RequestIdGenerator() self._id_gen = RequestIdGenerator()
self._connection: Connection | None = None # 当前活跃的 Runner 连接 self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
self._runner_id: str | None = None self._runner_id: Optional[str] = None
self._runner_generation: int = 0 self._runner_generation: int = 0
# 方法处理器注册表 # 方法处理器注册表
self._method_handlers: dict[str, MethodHandler] = {} self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future # 等待响应的 pending 请求: request_id -> Future
self._pending_requests: dict[int, asyncio.Future] = {} self._pending_requests: Dict[int, asyncio.Future] = {}
# 发送队列(背压控制) # 发送队列(背压控制)
self._send_queue: asyncio.Queue[bytes] | None = None self._send_queue: Optional[asyncio.Queue[bytes]] = None
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
self._tasks: list[asyncio.Task] = [] self._tasks: List[asyncio.Task] = []
@property @property
def session_token(self) -> str: def session_token(self) -> str:
@@ -115,7 +115,7 @@ class RPCServer:
self, self,
method: str, method: str,
plugin_id: str = "", plugin_id: str = "",
payload: dict[str, Any] | None = None, payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000, timeout_ms: int = 30000,
) -> Envelope: ) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应 """向 Runner 发送 RPC 请求并等待响应
@@ -172,7 +172,7 @@ class RPCServer:
raise raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
async def send_event(self, method: str, plugin_id: str = "", payload: dict[str, Any] | None = None) -> None: async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
"""向 Runner 发送单向事件(不等待响应)""" """向 Runner 发送单向事件(不等待响应)"""
if not self.is_connected: if not self.is_connected:
return return

View File

@@ -7,7 +7,7 @@
4. 优雅关停 4. 优雅关停
""" """
from typing import Any from typing import Any, Dict, List, Optional, Tuple
import asyncio import asyncio
import os import os
@@ -40,8 +40,8 @@ class PluginSupervisor:
def __init__( def __init__(
self, self,
plugin_dirs: list[str] | None = None, plugin_dirs: Optional[List[str]] = None,
socket_path: str | None = None, socket_path: Optional[str] = None,
health_check_interval_sec: float = 30.0, health_check_interval_sec: float = 30.0,
): ):
self._plugin_dirs = plugin_dirs or [] self._plugin_dirs = plugin_dirs or []
@@ -65,16 +65,16 @@ class PluginSupervisor:
) )
# Runner 子进程 # Runner 子进程
self._runner_process: asyncio.subprocess.Process | None = None self._runner_process: Optional[asyncio.subprocess.Process] = None
self._runner_generation: int = 0 self._runner_generation: int = 0
self._max_restart_attempts: int = 3 self._max_restart_attempts: int = 3
self._restart_count: int = 0 self._restart_count: int = 0
# 已注册的插件组件信息 # 已注册的插件组件信息
self._registered_plugins: dict[str, RegisterComponentsPayload] = {} self._registered_plugins: Dict[str, RegisterComponentsPayload] = {}
# 后台任务 # 后台任务
self._health_task: asyncio.Task | None = None self._health_task: Optional[asyncio.Task] = None
self._running = False self._running = False
# 注册内部 RPC 方法 # 注册内部 RPC 方法
@@ -107,11 +107,11 @@ class PluginSupervisor:
async def dispatch_event( async def dispatch_event(
self, self,
event_type: str, event_type: str,
message: dict[str, Any] | None = None, message: Optional[Dict[str, Any]] = None,
extra_args: dict[str, Any] | None = None, extra_args: Optional[Dict[str, Any]] = None,
) -> tuple[bool, dict[str, Any] | None]: ) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler 的快捷方法。""" """分发事件到所有对应 handler 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]: async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
resp = await self.invoke_plugin( resp = await self.invoke_plugin(
method="plugin.emit_event", method="plugin.emit_event",
plugin_id=plugin_id, plugin_id=plugin_id,
@@ -129,12 +129,12 @@ class PluginSupervisor:
async def execute_workflow( async def execute_workflow(
self, self,
message: dict[str, Any] | None = None, message: Optional[Dict[str, Any]] = None,
stream_id: str | None = None, stream_id: Optional[str] = None,
context: WorkflowContext | None = None, context: Optional[WorkflowContext] = None,
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]: ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 Workflow Pipeline 的快捷方法。""" """执行 Workflow Pipeline 的快捷方法。"""
async def _invoke(plugin_id: str, component_name: str, args: dict[str, Any]) -> dict[str, Any]: async def _invoke(plugin_id: str, component_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
resp = await self.invoke_plugin( resp = await self.invoke_plugin(
method="plugin.invoke_workflow_step", method="plugin.invoke_workflow_step",
plugin_id=plugin_id, plugin_id=plugin_id,
@@ -196,7 +196,7 @@ class PluginSupervisor:
method: str, method: str,
plugin_id: str, plugin_id: str,
component_name: str, component_name: str,
args: dict[str, Any] | None = None, args: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000, timeout_ms: int = 30000,
) -> Envelope: ) -> Envelope:
"""调用插件组件 """调用插件组件
@@ -225,6 +225,12 @@ class PluginSupervisor:
# 保存旧进程引用 # 保存旧进程引用
old_process = self._runner_process old_process = self._runner_process
# 清理旧的组件注册,防止幽灵组件残留
for plugin_id in list(self._registered_plugins.keys()):
self._component_registry.remove_components_by_plugin(plugin_id)
self._policy.revoke_plugin(plugin_id)
self._registered_plugins.clear()
# 拉起新 Runner # 拉起新 Runner
await self._spawn_runner() await self._spawn_runner()
@@ -286,7 +292,7 @@ class PluginSupervisor:
# 在策略引擎中注册插件 # 在策略引擎中注册插件
self._policy.register_plugin( self._policy.register_plugin(
plugin_id=reg.plugin_id, plugin_id=reg.plugin_id,
generation=envelope.generation, generation=self._runner_generation,
capabilities=reg.capabilities_required or [], capabilities=reg.capabilities_required or [],
) )

View File

@@ -17,7 +17,7 @@
- modification_log: 消息修改审计 - modification_log: 消息修改审计
""" """
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
import asyncio import asyncio
import time import time
@@ -29,7 +29,7 @@ from src.plugin_runtime.host.component_registry import ComponentRegistry, Regist
logger = get_logger("plugin_runtime.host.workflow_executor") logger = get_logger("plugin_runtime.host.workflow_executor")
# 阶段顺序 # 阶段顺序
STAGE_SEQUENCE: list[str] = [ STAGE_SEQUENCE: List[str] = [
"ingress", "ingress",
"pre_process", "pre_process",
"plan", "plan",
@@ -48,7 +48,7 @@ class ModificationRecord:
"""消息修改记录""" """消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed") __slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: list[str]): def __init__(self, stage: str, hook_name: str, fields_changed: List[str]):
self.stage = stage self.stage = stage
self.hook_name = hook_name self.hook_name = hook_name
self.timestamp = time.perf_counter() self.timestamp = time.perf_counter()
@@ -58,17 +58,17 @@ class ModificationRecord:
class WorkflowContext: class WorkflowContext:
"""Workflow 执行上下文""" """Workflow 执行上下文"""
def __init__(self, trace_id: str | None = None, stream_id: str | None = None): def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None):
self.trace_id = trace_id or uuid.uuid4().hex self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id self.stream_id = stream_id
self.timings: dict[str, float] = {} self.timings: Dict[str, float] = {}
self.errors: list[str] = [] self.errors: List[str] = []
# 阶段间数据传递(按 stage 命名空间隔离) # 阶段间数据传递(按 stage 命名空间隔离)
self.stage_outputs: dict[str, dict[str, Any]] = {} self.stage_outputs: Dict[str, Dict[str, Any]] = {}
# 消息修改审计日志 # 消息修改审计日志
self.modification_log: list[ModificationRecord] = [] self.modification_log: List[ModificationRecord] = []
# PLAN 阶段命令匹配结果 # PLAN 阶段命令匹配结果
self.matched_command: str | None = None self.matched_command: Optional[str] = None
def set_stage_output(self, stage: str, key: str, value: Any) -> None: def set_stage_output(self, stage: str, key: str, value: Any) -> None:
self.stage_outputs.setdefault(stage, {})[key] = value self.stage_outputs.setdefault(stage, {})[key] = value
@@ -85,7 +85,7 @@ class WorkflowResult:
status: str = "completed", # completed / aborted / failed status: str = "completed", # completed / aborted / failed
return_message: str = "", return_message: str = "",
stopped_at: str = "", stopped_at: str = "",
diagnostics: dict[str, Any] | None = None, diagnostics: Optional[Dict[str, Any]] = None,
): ):
self.status = status self.status = status
self.return_message = return_message self.return_message = return_message
@@ -94,7 +94,7 @@ class WorkflowResult:
# invoke_fn 签名 # invoke_fn 签名
InvokeFn = Callable[[str, str, dict[str, Any]], Awaitable[dict[str, Any]]] InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class WorkflowExecutor: class WorkflowExecutor:
@@ -109,10 +109,10 @@ class WorkflowExecutor:
async def execute( async def execute(
self, self,
invoke_fn: InvokeFn, invoke_fn: InvokeFn,
message: dict[str, Any] | None = None, message: Optional[Dict[str, Any]] = None,
stream_id: str | None = None, stream_id: Optional[str] = None,
context: WorkflowContext | None = None, context: Optional[WorkflowContext] = None,
) -> tuple[WorkflowResult, dict[str, Any] | None, WorkflowContext]: ) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 workflow pipeline。 """执行 workflow pipeline。
Returns: Returns:
@@ -120,6 +120,8 @@ class WorkflowExecutor:
""" """
ctx = context or WorkflowContext(stream_id=stream_id) ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None current_message = dict(message) if message else None
# 保持非阻塞任务引用,防止被 GC 回收
background_tasks: List[asyncio.Task] = []
for stage in STAGE_SEQUENCE: for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter() stage_start = time.perf_counter()
@@ -207,14 +209,15 @@ class WorkflowExecutor:
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message # 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage: if nonblocking_steps and not skip_stage:
nb_tasks = [ nb_tasks = [
self._invoke_step_fire_and_forget( asyncio.create_task(
invoke_fn, step, stage, ctx, current_message self._invoke_step_fire_and_forget(
invoke_fn, step, stage, ctx, current_message
)
) )
for step in nonblocking_steps for step in nonblocking_steps
] ]
# 并发执行但不阻塞 pipeline # 保持任务引用以防止被 GC 回收
for task in [asyncio.create_task(t) for t in nb_tasks]: background_tasks.extend(nb_tasks)
task.add_done_callback(lambda _: None)
ctx.timings[stage] = time.perf_counter() - stage_start ctx.timings[stage] = time.perf_counter() - stage_start
@@ -247,9 +250,9 @@ class WorkflowExecutor:
def _pre_filter( def _pre_filter(
self, self,
steps: list[RegisteredComponent], steps: List[RegisteredComponent],
message: dict[str, Any] | None, message: Optional[Dict[str, Any]],
) -> list[RegisteredComponent]: ) -> List[RegisteredComponent]:
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。""" """根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
if not message: if not message:
return steps return steps
@@ -265,7 +268,7 @@ class WorkflowExecutor:
return result return result
@staticmethod @staticmethod
def _match_filter(filter_cond: dict[str, Any], message: dict[str, Any]) -> bool: def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
"""简单 key-value 匹配过滤。 """简单 key-value 匹配过滤。
filter 中的每个 key 必须在 message 中存在且值相等, filter 中的每个 key 必须在 message 中存在且值相等,
@@ -285,8 +288,8 @@ class WorkflowExecutor:
step: RegisteredComponent, step: RegisteredComponent,
stage: str, stage: str,
ctx: WorkflowContext, ctx: WorkflowContext,
message: dict[str, Any] | None, message: Optional[Dict[str, Any]],
) -> tuple[str, dict[str, Any] | None, str | None]: ) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
"""调用单个 blocking hook。 """调用单个 blocking hook。
Returns: Returns:
@@ -331,7 +334,7 @@ class WorkflowExecutor:
step: RegisteredComponent, step: RegisteredComponent,
stage: str, stage: str,
ctx: WorkflowContext, ctx: WorkflowContext,
message: dict[str, Any] | None, message: Optional[Dict[str, Any]],
) -> None: ) -> None:
"""Non-blocking hook 调用,只读,忽略结果。""" """Non-blocking hook 调用,只读,忽略结果。"""
timeout_ms = step.metadata.get("timeout_ms", 0) timeout_ms = step.metadata.get("timeout_ms", 0)
@@ -354,9 +357,9 @@ class WorkflowExecutor:
async def _route_command( async def _route_command(
self, self,
invoke_fn: InvokeFn, invoke_fn: InvokeFn,
message: dict[str, Any], message: Dict[str, Any],
ctx: WorkflowContext, ctx: WorkflowContext,
) -> dict[str, Any] | None: ) -> Optional[Dict[str, Any]]:
"""PLAN 阶段内置 Command 路由。 """PLAN 阶段内置 Command 路由。
在 registry 中查找匹配的 command 组件, 在 registry 中查找匹配的 command 组件,
@@ -386,6 +389,6 @@ class WorkflowExecutor:
return None return None
def _diff_keys(old: dict[str, Any], new: dict[str, Any]) -> list[str]: def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
"""返回 new 中与 old 不同的 key 列表。""" """返回 new 中与 old 不同的 key 列表。"""
return [k for k, v in new.items() if k not in old or old[k] != v] return [k for k, v in new.items() if k not in old or old[k] != v]

View File

@@ -7,9 +7,7 @@
4. 提供统一的能力实现注册接口,使插件可以调用主程序功能 4. 提供统一的能力实现注册接口,使插件可以调用主程序功能
""" """
from __future__ import annotations from typing import Any, Dict, List, Optional, Tuple
from typing import Any
import asyncio import asyncio
import os import os
@@ -19,7 +17,7 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration") logger = get_logger("plugin_runtime.integration")
# 旧系统 EventType -> 新系统 event_type 字符串映射 # 旧系统 EventType -> 新系统 event_type 字符串映射
_EVENT_TYPE_MAP: dict[str, str] = { _EVENT_TYPE_MAP: Dict[str, str] = {
"on_start": "on_start", "on_start": "on_start",
"on_stop": "on_stop", "on_stop": "on_stop",
"on_message_pre_process": "on_message_pre_process", "on_message_pre_process": "on_message_pre_process",
@@ -42,20 +40,20 @@ class PluginRuntimeManager:
def __init__(self) -> None: def __init__(self) -> None:
from src.plugin_runtime.host.supervisor import PluginSupervisor from src.plugin_runtime.host.supervisor import PluginSupervisor
self._builtin_supervisor: PluginSupervisor | None = None self._builtin_supervisor: Optional[PluginSupervisor] = None
self._thirdparty_supervisor: PluginSupervisor | None = None self._thirdparty_supervisor: Optional[PluginSupervisor] = None
self._started: bool = False self._started: bool = False
# ─── 插件目录 ───────────────────────────────────────────── # ─── 插件目录 ─────────────────────────────────────────────
@staticmethod @staticmethod
def _get_builtin_plugin_dirs() -> list[str]: def _get_builtin_plugin_dirs() -> List[str]:
"""内置插件目录: src/plugins/built_in/""" """内置插件目录: src/plugins/built_in/"""
candidate = os.path.abspath(os.path.join("src", "plugins", "built_in")) candidate = os.path.abspath(os.path.join("src", "plugins", "built_in"))
return [candidate] if os.path.isdir(candidate) else [] return [candidate] if os.path.isdir(candidate) else []
@staticmethod @staticmethod
def _get_thirdparty_plugin_dirs() -> list[str]: def _get_thirdparty_plugin_dirs() -> List[str]:
"""第三方插件目录: plugins/""" """第三方插件目录: plugins/"""
candidate = os.path.abspath("plugins") candidate = os.path.abspath("plugins")
return [candidate] if os.path.isdir(candidate) else [] return [candidate] if os.path.isdir(candidate) else []
@@ -136,7 +134,7 @@ class PluginRuntimeManager:
return self._started return self._started
@property @property
def supervisors(self) -> list[Any]: def supervisors(self) -> List[Any]:
"""获取所有活跃的 Supervisor""" """获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._thirdparty_supervisor) if s is not None] return [s for s in (self._builtin_supervisor, self._thirdparty_supervisor) if s is not None]
@@ -145,9 +143,9 @@ class PluginRuntimeManager:
async def bridge_event( async def bridge_event(
self, self,
event_type_value: str, event_type_value: str,
message_dict: dict[str, Any] | None = None, message_dict: Optional[Dict[str, Any]] = None,
extra_args: dict[str, Any] | None = None, extra_args: Optional[Dict[str, Any]] = None,
) -> tuple[bool, dict[str, Any] | None]: ) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""将事件分发到所有 Supervisor """将事件分发到所有 Supervisor
Returns: Returns:
@@ -157,7 +155,7 @@ class PluginRuntimeManager:
return True, None return True, None
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value) new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
modified: dict[str, Any] | None = None modified: Optional[Dict[str, Any]] = None
for sv in self.supervisors: for sv in self.supervisors:
try: try:
@@ -177,7 +175,7 @@ class PluginRuntimeManager:
# ─── 命令查找 ────────────────────────────────────────────── # ─── 命令查找 ──────────────────────────────────────────────
def find_command_by_text(self, text: str) -> dict[str, Any] | None: def find_command_by_text(self, text: str) -> Optional[Dict[str, Any]]:
"""在所有 Supervisor 的 ComponentRegistry 中查找命令""" """在所有 Supervisor 的 ComponentRegistry 中查找命令"""
if not self._started: if not self._started:
return None return None
@@ -281,7 +279,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_send_text(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_text(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送文本消息 """发送文本消息
args: text, stream_id, typing?, set_reply?, storage_message? args: text, stream_id, typing?, set_reply?, storage_message?
@@ -307,7 +305,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_send_emoji(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_emoji(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送表情 """发送表情
args: emoji_base64, stream_id, storage_message? args: emoji_base64, stream_id, storage_message?
@@ -331,7 +329,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_send_image(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_image(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送图片 """发送图片
args: image_base64, stream_id, storage_message? args: image_base64, stream_id, storage_message?
@@ -355,7 +353,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_send_command(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_command(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送命令 """发送命令
args: command, stream_id, storage_message?, display_message? args: command, stream_id, storage_message?, display_message?
@@ -380,7 +378,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_send_custom(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_custom(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送自定义类型消息 """发送自定义类型消息
args: message_type, content, stream_id, display_message?, typing?, storage_message? args: message_type, content, stream_id, display_message?, typing?, storage_message?
@@ -408,7 +406,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_send_forward(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_forward(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送转发消息 """发送转发消息
args: messages, stream_id args: messages, stream_id
@@ -431,7 +429,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_send_hybrid(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_send_hybrid(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""发送混合消息(图文混合) """发送混合消息(图文混合)
args: segments, stream_id args: segments, stream_id
@@ -458,7 +456,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_llm_generate(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_llm_generate(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""LLM 生成 """LLM 生成
args: prompt, model_name?, temperature?, max_tokens? args: prompt, model_name?, temperature?, max_tokens?
@@ -501,7 +499,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_llm_generate_with_tools(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_llm_generate_with_tools(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""LLM 带工具生成 """LLM 带工具生成
args: prompt, model_name?, tool_options?, temperature?, max_tokens? args: prompt, model_name?, tool_options?, temperature?, max_tokens?
@@ -554,7 +552,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_llm_get_available_models(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_llm_get_available_models(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取可用模型列表""" """获取可用模型列表"""
from src.services import llm_service as llm_api from src.services import llm_service as llm_api
@@ -570,7 +568,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_config_get(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_config_get(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""读取全局配置 """读取全局配置
args: key, default? args: key, default?
@@ -590,7 +588,7 @@ class PluginRuntimeManager:
return {"success": False, "value": None, "error": str(e)} return {"success": False, "value": None, "error": str(e)}
@staticmethod @staticmethod
async def _cap_config_get_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_config_get_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""读取插件配置 """读取插件配置
args: key, default?, plugin_name? args: key, default?, plugin_name?
@@ -617,7 +615,7 @@ class PluginRuntimeManager:
return {"success": False, "value": default, "error": str(e)} return {"success": False, "value": default, "error": str(e)}
@staticmethod @staticmethod
async def _cap_config_get_all(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_config_get_all(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取当前插件的全部配置""" """获取当前插件的全部配置"""
from src.core.component_registry import component_registry as core_registry from src.core.component_registry import component_registry as core_registry
@@ -636,7 +634,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_database_query(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_database_query(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""数据库查询 """数据库查询
args: model_name, query_type?, filters?, limit?, order_by?, data?, single_result? args: model_name, query_type?, filters?, limit?, order_by?, data?, single_result?
@@ -670,7 +668,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_database_save(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_database_save(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""数据库保存 """数据库保存
args: model_name, data, key_field?, key_value? args: model_name, data, key_field?, key_value?
@@ -678,7 +676,7 @@ class PluginRuntimeManager:
from src.services import database_service as database_api from src.services import database_service as database_api
model_name: str = args.get("model_name", "") model_name: str = args.get("model_name", "")
data: dict[str, Any] | None = args.get("data") data: Optional[Dict[str, Any]] = args.get("data")
if not model_name or not data: if not model_name or not data:
return {"success": False, "error": "缺少必要参数 model_name 或 data"} return {"success": False, "error": "缺少必要参数 model_name 或 data"}
@@ -701,7 +699,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_database_get(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_database_get(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""数据库简单查询 """数据库简单查询
args: model_name, filters?, limit?, order_by?, single_result? args: model_name, filters?, limit?, order_by?, single_result?
@@ -732,7 +730,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_database_delete(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_database_delete(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""数据库删除 """数据库删除
args: model_name, filters args: model_name, filters
@@ -763,7 +761,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_database_count(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_database_count(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""数据库计数 """数据库计数
args: model_name, filters? args: model_name, filters?
@@ -795,7 +793,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
def _serialize_stream(stream: Any) -> dict[str, Any]: def _serialize_stream(stream: Any) -> Dict[str, Any]:
"""将 BotChatSession 序列化为可通过 RPC 传输的字典""" """将 BotChatSession 序列化为可通过 RPC 传输的字典"""
return { return {
"session_id": getattr(stream, "session_id", ""), "session_id": getattr(stream, "session_id", ""),
@@ -806,7 +804,7 @@ class PluginRuntimeManager:
} }
@staticmethod @staticmethod
async def _cap_chat_get_all_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_chat_get_all_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取所有聊天流 """获取所有聊天流
args: platform? args: platform?
@@ -825,7 +823,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_chat_get_group_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_chat_get_group_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取所有群聊流 """获取所有群聊流
args: platform? args: platform?
@@ -844,7 +842,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_chat_get_private_streams(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_chat_get_private_streams(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取所有私聊流 """获取所有私聊流
args: platform? args: platform?
@@ -863,7 +861,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_chat_get_stream_by_group_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_chat_get_stream_by_group_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""按群 ID 查找聊天流 """按群 ID 查找聊天流
args: group_id, platform? args: group_id, platform?
@@ -885,7 +883,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_chat_get_stream_by_user_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_chat_get_stream_by_user_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""按用户 ID 查找私聊流 """按用户 ID 查找私聊流
args: user_id, platform? args: user_id, platform?
@@ -911,9 +909,9 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
def _serialize_messages(messages: list) -> list[dict[str, Any]]: def _serialize_messages(messages: list) -> List[Dict[str, Any]]:
"""将 DatabaseMessages 列表序列化为 dict 列表""" """将 DatabaseMessages 列表序列化为 dict 列表"""
result: list[dict[str, Any]] = [] result: List[Dict[str, Any]] = []
for msg in messages: for msg in messages:
if hasattr(msg, "model_dump"): if hasattr(msg, "model_dump"):
result.append(msg.model_dump()) result.append(msg.model_dump())
@@ -924,7 +922,7 @@ class PluginRuntimeManager:
return result return result
@staticmethod @staticmethod
async def _cap_message_get_by_time(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_message_get_by_time(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""按时间范围查询消息 """按时间范围查询消息
args: start_time, end_time, limit?, filter_mai? args: start_time, end_time, limit?, filter_mai?
@@ -948,7 +946,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_message_get_by_time_in_chat(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_message_get_by_time_in_chat(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""按时间范围查询指定聊天消息 """按时间范围查询指定聊天消息
args: chat_id, start_time, end_time, limit?, filter_mai?, filter_command? args: chat_id, start_time, end_time, limit?, filter_mai?, filter_command?
@@ -975,7 +973,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_message_get_recent(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_message_get_recent(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取最近的消息 """获取最近的消息
args: chat_id, hours?, limit?, filter_mai? args: chat_id, hours?, limit?, filter_mai?
@@ -1000,7 +998,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_message_count_new(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_message_count_new(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""统计新消息数量 """统计新消息数量
args: chat_id, start_time?, end_time? args: chat_id, start_time?, end_time?
@@ -1023,7 +1021,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_message_build_readable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_message_build_readable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""将消息列表构建成可读字符串 """将消息列表构建成可读字符串
args: chat_id, start_time, end_time, limit?, replace_bot_name?, timestamp_mode? args: chat_id, start_time, end_time, limit?, replace_bot_name?, timestamp_mode?
@@ -1057,7 +1055,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_person_get_id(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_person_get_id(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取 person_id """获取 person_id
args: platform, user_id args: platform, user_id
@@ -1077,7 +1075,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_person_get_value(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_person_get_value(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取用户字段值 """获取用户字段值
args: person_id, field_name, default? args: person_id, field_name, default?
@@ -1101,7 +1099,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_person_get_id_by_name(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_person_get_id_by_name(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""根据用户名获取 person_id """根据用户名获取 person_id
args: person_name args: person_name
@@ -1124,7 +1122,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_emoji_get_by_description(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_get_by_description(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""根据描述获取表情包 """根据描述获取表情包
args: description args: description
@@ -1153,7 +1151,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_get_random(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_get_random(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""随机获取表情包 """随机获取表情包
args: count? args: count?
@@ -1172,7 +1170,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_get_count(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_get_count(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取表情包数量""" """获取表情包数量"""
from src.services import emoji_service as emoji_api from src.services import emoji_service as emoji_api
@@ -1183,7 +1181,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_get_emotions(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_get_emotions(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取所有情绪标签""" """获取所有情绪标签"""
from src.services import emoji_service as emoji_api from src.services import emoji_service as emoji_api
@@ -1194,7 +1192,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_get_all(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_get_all(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取所有表情包""" """获取所有表情包"""
from src.services import emoji_service as emoji_api from src.services import emoji_service as emoji_api
@@ -1209,7 +1207,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_get_info(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_get_info(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取表情包统计信息""" """获取表情包统计信息"""
from src.services import emoji_service as emoji_api from src.services import emoji_service as emoji_api
@@ -1220,7 +1218,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_register(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_register(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""注册表情包 """注册表情包
args: emoji_base64 args: emoji_base64
@@ -1239,7 +1237,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_emoji_delete(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_emoji_delete(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""删除表情包 """删除表情包
args: emoji_hash args: emoji_hash
@@ -1262,7 +1260,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_frequency_get_current_talk_value(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_frequency_get_current_talk_value(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取当前说话频率值 """获取当前说话频率值
args: chat_id args: chat_id
@@ -1281,7 +1279,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_frequency_set_adjust(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_frequency_set_adjust(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""设置说话频率调整值 """设置说话频率调整值
args: chat_id, value args: chat_id, value
@@ -1301,7 +1299,7 @@ class PluginRuntimeManager:
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@staticmethod @staticmethod
async def _cap_frequency_get_adjust(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_frequency_get_adjust(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取说话频率调整值 """获取说话频率调整值
args: chat_id args: chat_id
@@ -1324,7 +1322,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_tool_get_definitions(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_tool_get_definitions(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取 LLM 可用的工具定义列表""" """获取 LLM 可用的工具定义列表"""
from src.core.component_registry import component_registry as core_registry from src.core.component_registry import component_registry as core_registry
@@ -1343,10 +1341,10 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_component_get_all_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_get_all_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取所有插件信息(汇总所有 Supervisor 的注册信息)""" """获取所有插件信息(汇总所有 Supervisor 的注册信息)"""
mgr = get_plugin_runtime_manager() mgr = get_plugin_runtime_manager()
result: dict[str, Any] = {} result: Dict[str, Any] = {}
for sv in mgr.supervisors: for sv in mgr.supervisors:
for pid, reg in sv._registered_plugins.items(): for pid, reg in sv._registered_plugins.items():
result[pid] = { result[pid] = {
@@ -1359,7 +1357,7 @@ class PluginRuntimeManager:
return {"success": True, "plugins": result} return {"success": True, "plugins": result}
@staticmethod @staticmethod
async def _cap_component_get_plugin_info(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_get_plugin_info(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""获取指定插件信息 """获取指定插件信息
args: plugin_name args: plugin_name
@@ -1382,25 +1380,25 @@ class PluginRuntimeManager:
return {"success": False, "error": f"未找到插件: {plugin_name}"} return {"success": False, "error": f"未找到插件: {plugin_name}"}
@staticmethod @staticmethod
async def _cap_component_list_loaded_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_list_loaded_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""列出已加载的插件""" """列出已加载的插件"""
mgr = get_plugin_runtime_manager() mgr = get_plugin_runtime_manager()
plugins: list[str] = [] plugins: List[str] = []
for sv in mgr.supervisors: for sv in mgr.supervisors:
plugins.extend(sv._registered_plugins.keys()) plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins} return {"success": True, "plugins": plugins}
@staticmethod @staticmethod
async def _cap_component_list_registered_plugins(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_list_registered_plugins(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""列出已注册的插件(同 list_loaded""" """列出已注册的插件(同 list_loaded"""
mgr = get_plugin_runtime_manager() mgr = get_plugin_runtime_manager()
plugins: list[str] = [] plugins: List[str] = []
for sv in mgr.supervisors: for sv in mgr.supervisors:
plugins.extend(sv._registered_plugins.keys()) plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins} return {"success": True, "plugins": plugins}
@staticmethod @staticmethod
async def _cap_component_enable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_enable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""启用组件 """启用组件
args: name, component_type args: name, component_type
@@ -1419,7 +1417,7 @@ class PluginRuntimeManager:
return {"success": False, "error": f"未找到组件: {name}"} return {"success": False, "error": f"未找到组件: {name}"}
@staticmethod @staticmethod
async def _cap_component_disable(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_disable(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""禁用组件 """禁用组件
args: name, component_type args: name, component_type
@@ -1438,7 +1436,7 @@ class PluginRuntimeManager:
return {"success": False, "error": f"未找到组件: {name}"} return {"success": False, "error": f"未找到组件: {name}"}
@staticmethod @staticmethod
async def _cap_component_load_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_load_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""加载插件(在新运行时中通过热重载实现) """加载插件(在新运行时中通过热重载实现)
args: plugin_name args: plugin_name
@@ -1457,7 +1455,7 @@ class PluginRuntimeManager:
return {"success": False, "error": f"无法加载插件: {plugin_name}"} return {"success": False, "error": f"无法加载插件: {plugin_name}"}
@staticmethod @staticmethod
async def _cap_component_unload_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_unload_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""卸载插件(在新运行时中不支持单独卸载) """卸载插件(在新运行时中不支持单独卸载)
args: plugin_name args: plugin_name
@@ -1465,7 +1463,7 @@ class PluginRuntimeManager:
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"} return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
@staticmethod @staticmethod
async def _cap_component_reload_plugin(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_component_reload_plugin(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""重新加载插件(触发对应 Supervisor 的热重载) """重新加载插件(触发对应 Supervisor 的热重载)
args: plugin_name args: plugin_name
@@ -1490,7 +1488,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_knowledge_search(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_knowledge_search(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""从 LPMM 知识库搜索知识 """从 LPMM 知识库搜索知识
args: query, limit? args: query, limit?
@@ -1526,7 +1524,7 @@ class PluginRuntimeManager:
# ═════════════════════════════════════════════════════════ # ═════════════════════════════════════════════════════════
@staticmethod @staticmethod
async def _cap_logging_log(plugin_id: str, capability: str, args: dict[str, Any]) -> Any: async def _cap_logging_log(plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
"""插件日志记录 """插件日志记录
args: level?, message args: level?, message
@@ -1544,7 +1542,7 @@ class PluginRuntimeManager:
# ─── 单例 ────────────────────────────────────────────────── # ─── 单例 ──────────────────────────────────────────────────
_manager: PluginRuntimeManager | None = None _manager: Optional[PluginRuntimeManager] = None
def get_plugin_runtime_manager() -> PluginRuntimeManager: def get_plugin_runtime_manager() -> PluginRuntimeManager:

View File

@@ -1,35 +1,40 @@
"""MsgPack 编解码器""" """MsgPack 编解码器"""
from typing import Any from abc import ABC, abstractmethod
from typing import Any, Dict
import msgpack import msgpack
from .envelope import Envelope from .envelope import Envelope
class Codec: class Codec(ABC):
"""消息编解码器基类""" """消息编解码器基类"""
@abstractmethod
def encode_envelope(self, envelope: Envelope) -> bytes: def encode_envelope(self, envelope: Envelope) -> bytes:
raise NotImplementedError ...
@abstractmethod
def decode_envelope(self, data: bytes) -> Envelope: def decode_envelope(self, data: bytes) -> Envelope:
raise NotImplementedError ...
def encode(self, obj: dict[str, Any]) -> bytes: @abstractmethod
raise NotImplementedError def encode(self, obj: Dict[str, Any]) -> bytes:
...
def decode(self, data: bytes) -> dict[str, Any]: @abstractmethod
raise NotImplementedError def decode(self, data: bytes) -> Dict[str, Any]:
...
class MsgPackCodec(Codec): class MsgPackCodec(Codec):
"""MsgPack 编解码器""" """MsgPack 编解码器"""
def encode(self, obj: dict[str, Any]) -> bytes: def encode(self, obj: Dict[str, Any]) -> bytes:
return msgpack.packb(obj, use_bin_type=True) return msgpack.packb(obj, use_bin_type=True)
def decode(self, data: bytes) -> dict[str, Any]: def decode(self, data: bytes) -> Dict[str, Any]:
result = msgpack.unpackb(data, raw=False) result = msgpack.unpackb(data, raw=False)
if not isinstance(result, dict): if not isinstance(result, dict):
raise ValueError(f"期望解码为 dict实际为 {type(result)}") raise ValueError(f"期望解码为 dict实际为 {type(result)}")

View File

@@ -5,9 +5,8 @@
""" """
from enum import Enum from enum import Enum
from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional
import time import time
@@ -62,8 +61,8 @@ class Envelope(BaseModel):
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)") timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
timeout_ms: int = Field(default=30000, description="相对超时(ms)") timeout_ms: int = Field(default=30000, description="相对超时(ms)")
generation: int = Field(default=0, description="Runner generation 编号") generation: int = Field(default=0, description="Runner generation 编号")
payload: dict[str, Any] = Field(default_factory=dict, description="业务数据") payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
error: dict[str, Any] | None = Field(default=None, description="错误信息(仅 response)") error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
def is_request(self) -> bool: def is_request(self) -> bool:
return self.message_type == MessageType.REQUEST return self.message_type == MessageType.REQUEST
@@ -74,7 +73,7 @@ class Envelope(BaseModel):
def is_event(self) -> bool: def is_event(self) -> bool:
return self.message_type == MessageType.EVENT return self.message_type == MessageType.EVENT
def make_response(self, payload: dict[str, Any] | None = None, error: dict[str, Any] | None = None) -> "Envelope": def make_response(self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None) -> "Envelope":
"""基于当前请求创建对应的响应信封""" """基于当前请求创建对应的响应信封"""
return Envelope( return Envelope(
protocol_version=self.protocol_version, protocol_version=self.protocol_version,
@@ -87,7 +86,7 @@ class Envelope(BaseModel):
error=error, error=error,
) )
def make_error_response(self, code: str, message: str = "", details: dict | None = None) -> "Envelope": def make_error_response(self, code: str, message: str = "", details: Optional[Dict[str, Any]] = None) -> "Envelope":
"""基于当前请求创建错误响应""" """基于当前请求创建错误响应"""
return self.make_response( return self.make_response(
error={ error={
@@ -122,15 +121,15 @@ class ComponentDeclaration(BaseModel):
name: str = Field(description="组件名称") name: str = Field(description="组件名称")
component_type: str = Field(description="组件类型: action/command/tool/event_handler") component_type: str = Field(description="组件类型: action/command/tool/event_handler")
plugin_id: str = Field(description="所属插件 ID") plugin_id: str = Field(description="所属插件 ID")
metadata: dict[str, Any] = Field(default_factory=dict, description="组件元数据") metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
class RegisterComponentsPayload(BaseModel): class RegisterComponentsPayload(BaseModel):
"""plugin.register_components 请求 payload""" """plugin.register_components 请求 payload"""
plugin_id: str = Field(description="插件 ID") plugin_id: str = Field(description="插件 ID")
plugin_version: str = Field(default="1.0.0", description="插件版本") plugin_version: str = Field(default="1.0.0", description="插件版本")
components: list[ComponentDeclaration] = Field(default_factory=list, description="组件列表") components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
capabilities_required: list[str] = Field(default_factory=list, description="所需能力列表") capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
# ─── 调用消息 ────────────────────────────────────────────────────── # ─── 调用消息 ──────────────────────────────────────────────────────
@@ -138,7 +137,7 @@ class RegisterComponentsPayload(BaseModel):
class InvokePayload(BaseModel): class InvokePayload(BaseModel):
"""plugin.invoke_* 请求 payload""" """plugin.invoke_* 请求 payload"""
component_name: str = Field(description="要调用的组件名称") component_name: str = Field(description="要调用的组件名称")
args: dict[str, Any] = Field(default_factory=dict, description="调用参数") args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
class InvokeResultPayload(BaseModel): class InvokeResultPayload(BaseModel):
@@ -152,7 +151,7 @@ class InvokeResultPayload(BaseModel):
class CapabilityRequestPayload(BaseModel): class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)""" """cap.* 请求 payload插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query") capability: str = Field(description="能力名称,如 send.text, db.query")
args: dict[str, Any] = Field(default_factory=dict, description="调用参数") args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
class CapabilityResponsePayload(BaseModel): class CapabilityResponsePayload(BaseModel):
@@ -166,7 +165,7 @@ class CapabilityResponsePayload(BaseModel):
class HealthPayload(BaseModel): class HealthPayload(BaseModel):
"""plugin.health 响应 payload""" """plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康") healthy: bool = Field(description="是否健康")
loaded_plugins: list[str] = Field(default_factory=list, description="已加载的插件列表") loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
uptime_ms: int = Field(default=0, description="运行时长(ms)") uptime_ms: int = Field(default=0, description="运行时长(ms)")
@@ -176,7 +175,7 @@ class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload""" """plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID") plugin_id: str = Field(description="插件 ID")
config_version: str = Field(description="新配置版本") config_version: str = Field(description="新配置版本")
config_data: dict[str, Any] = Field(default_factory=dict, description="配置内容") config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
# ─── 关停 ────────────────────────────────────────────────────────── # ─── 关停 ──────────────────────────────────────────────────────────

View File

@@ -4,6 +4,7 @@
""" """
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional
class ErrorCode(str, Enum): class ErrorCode(str, Enum):
@@ -38,13 +39,13 @@ class ErrorCode(str, Enum):
class RPCError(Exception): class RPCError(Exception):
"""RPC 调用异常""" """RPC 调用异常"""
def __init__(self, code: ErrorCode, message: str = "", details: dict | None = None): def __init__(self, code: ErrorCode, message: str = "", details: Optional[Dict[str, Any]] = None):
self.code = code self.code = code
self.message = message or code.value self.message = message or code.value
self.details = details or {} self.details = details or {}
super().__init__(f"[{code.value}] {self.message}") super().__init__(f"[{code.value}] {self.message}")
def to_dict(self) -> dict: def to_dict(self) -> Dict[str, Any]:
return { return {
"code": self.code.value, "code": self.code.value,
"message": self.message, "message": self.message,
@@ -52,7 +53,7 @@ class RPCError(Exception):
} }
@classmethod @classmethod
def from_dict(cls, data: dict) -> "RPCError": def from_dict(cls, data: Dict[str, Any]) -> "RPCError":
code = ErrorCode(data.get("code", "E_UNKNOWN")) code = ErrorCode(data.get("code", "E_UNKNOWN"))
return cls( return cls(
code=code, code=code,

View File

@@ -4,7 +4,7 @@
适配新 plugin_runtime 的 _manifest.json 格式。 适配新 plugin_runtime 的 _manifest.json 格式。
""" """
from typing import Any from typing import Any, Dict, List, Tuple
import re import re
@@ -29,7 +29,7 @@ class VersionComparator:
return ".".join(parts[:3]) return ".".join(parts[:3])
@staticmethod @staticmethod
def parse_version(version: str) -> tuple[int, int, int]: def parse_version(version: str) -> Tuple[int, int, int]:
normalized = VersionComparator.normalize_version(version) normalized = VersionComparator.normalize_version(version)
try: try:
parts = normalized.split(".") parts = normalized.split(".")
@@ -48,7 +48,7 @@ class VersionComparator:
return 0 return 0
@staticmethod @staticmethod
def is_in_range(version: str, min_version: str = "", max_version: str = "") -> tuple[bool, str]: def is_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
if not min_version and not max_version: if not min_version and not max_version:
return True, "" return True, ""
vn = VersionComparator.normalize_version(version) vn = VersionComparator.normalize_version(version)
@@ -72,10 +72,10 @@ class ManifestValidator:
def __init__(self, host_version: str = ""): def __init__(self, host_version: str = ""):
self._host_version = host_version self._host_version = host_version
self.errors: list[str] = [] self.errors: List[str] = []
self.warnings: list[str] = [] self.warnings: List[str] = []
def validate(self, manifest: dict[str, Any]) -> bool: def validate(self, manifest: Dict[str, Any]) -> bool:
"""校验 manifest 数据返回是否通过errors 为空即通过)。""" """校验 manifest 数据返回是否通过errors 为空即通过)。"""
self.errors.clear() self.errors.clear()
self.warnings.clear() self.warnings.clear()
@@ -95,21 +95,21 @@ class ManifestValidator:
return len(self.errors) == 0 return len(self.errors) == 0
def _check_required_fields(self, manifest: dict[str, Any]) -> None: def _check_required_fields(self, manifest: Dict[str, Any]) -> None:
for field in self.REQUIRED_FIELDS: for field in self.REQUIRED_FIELDS:
if field not in manifest: if field not in manifest:
self.errors.append(f"缺少必需字段: {field}") self.errors.append(f"缺少必需字段: {field}")
elif not manifest[field]: elif not manifest[field]:
self.errors.append(f"必需字段不能为空: {field}") self.errors.append(f"必需字段不能为空: {field}")
def _check_manifest_version(self, manifest: dict[str, Any]) -> None: def _check_manifest_version(self, manifest: Dict[str, Any]) -> None:
mv = manifest.get("manifest_version") mv = manifest.get("manifest_version")
if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS: if mv is not None and mv not in self.SUPPORTED_MANIFEST_VERSIONS:
self.errors.append( self.errors.append(
f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}" f"不支持的 manifest_version: {mv},支持: {self.SUPPORTED_MANIFEST_VERSIONS}"
) )
def _check_author(self, manifest: dict[str, Any]) -> None: def _check_author(self, manifest: Dict[str, Any]) -> None:
author = manifest.get("author") author = manifest.get("author")
if author is None: if author is None:
return return
@@ -122,7 +122,7 @@ class ManifestValidator:
else: else:
self.errors.append("author 应为字符串或 {name, url} 对象") self.errors.append("author 应为字符串或 {name, url} 对象")
def _check_host_compatibility(self, manifest: dict[str, Any]) -> None: def _check_host_compatibility(self, manifest: Dict[str, Any]) -> None:
host_app = manifest.get("host_application") host_app = manifest.get("host_application")
if not isinstance(host_app, dict) or not self._host_version: if not isinstance(host_app, dict) or not self._host_version:
return return
@@ -132,7 +132,7 @@ class ManifestValidator:
if not ok: if not ok:
self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})") self.errors.append(f"Host 版本不兼容: {msg} (当前 Host: {self._host_version})")
def _check_recommended(self, manifest: dict[str, Any]) -> None: def _check_recommended(self, manifest: Dict[str, Any]) -> None:
for field in self.RECOMMENDED_FIELDS: for field in self.RECOMMENDED_FIELDS:
if field not in manifest or not manifest[field]: if field not in manifest or not manifest[field]:
self.warnings.append(f"建议填写字段: {field}") self.warnings.append(f"建议填写字段: {field}")

View File

@@ -7,7 +7,7 @@
""" """
from collections import deque from collections import deque
from typing import Any from typing import Any, Dict, List, Optional, Set, Tuple
import importlib import importlib
import importlib.util import importlib.util
@@ -29,7 +29,7 @@ class PluginMeta:
plugin_id: str, plugin_id: str,
plugin_dir: str, plugin_dir: str,
plugin_instance: Any, plugin_instance: Any,
manifest: dict[str, Any], manifest: Dict[str, Any],
): ):
self.plugin_id = plugin_id self.plugin_id = plugin_id
self.plugin_dir = plugin_dir self.plugin_dir = plugin_dir
@@ -37,12 +37,12 @@ class PluginMeta:
self.manifest = manifest self.manifest = manifest
self.version = manifest.get("version", "1.0.0") self.version = manifest.get("version", "1.0.0")
self.capabilities_required = manifest.get("capabilities", []) self.capabilities_required = manifest.get("capabilities", [])
self.dependencies: list[str] = self._extract_dependencies(manifest) self.dependencies: List[str] = self._extract_dependencies(manifest)
@staticmethod @staticmethod
def _extract_dependencies(manifest: dict[str, Any]) -> list[str]: def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
raw = manifest.get("dependencies", []) raw = manifest.get("dependencies", [])
result: list[str] = [] result: List[str] = []
for dep in raw: for dep in raw:
if isinstance(dep, str): if isinstance(dep, str):
result.append(dep.strip()) result.append(dep.strip())
@@ -62,12 +62,12 @@ class PluginLoader:
""" """
def __init__(self, host_version: str = ""): def __init__(self, host_version: str = ""):
self._loaded_plugins: dict[str, PluginMeta] = {} self._loaded_plugins: Dict[str, PluginMeta] = {}
self._failed_plugins: dict[str, str] = {} self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version) self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False self._compat_hook_installed = False
def discover_and_load(self, plugin_dirs: list[str]) -> list[PluginMeta]: def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
"""扫描多个目录并加载所有插件(含依赖排序和 manifest 校验) """扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
Args: Args:
@@ -77,7 +77,7 @@ class PluginLoader:
成功加载的插件元数据列表(按依赖顺序) 成功加载的插件元数据列表(按依赖顺序)
""" """
# 第一阶段:发现并校验 manifest # 第一阶段:发现并校验 manifest
candidates: dict[str, tuple[str, dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path) candidates: Dict[str, Tuple[str, Dict[str, Any], str]] = {} # id -> (dir, manifest, plugin_path)
for base_dir in plugin_dirs: for base_dir in plugin_dirs:
if not os.path.isdir(base_dir): if not os.path.isdir(base_dir):
logger.warning(f"插件目录不存在: {base_dir}") logger.warning(f"插件目录不存在: {base_dir}")
@@ -131,33 +131,33 @@ class PluginLoader:
return results return results
def get_plugin(self, plugin_id: str) -> PluginMeta | None: def get_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
"""获取已加载的插件""" """获取已加载的插件"""
return self._loaded_plugins.get(plugin_id) return self._loaded_plugins.get(plugin_id)
def list_plugins(self) -> list[str]: def list_plugins(self) -> List[str]:
"""列出所有已加载的插件 ID""" """列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys()) return list(self._loaded_plugins.keys())
@property @property
def failed_plugins(self) -> dict[str, str]: def failed_plugins(self) -> Dict[str, str]:
return dict(self._failed_plugins) return dict(self._failed_plugins)
# ──── 依赖解析 ──────────────────────────────────────────── # ──── 依赖解析 ────────────────────────────────────────────
def _resolve_dependencies( def _resolve_dependencies(
self, self,
candidates: dict[str, tuple[str, dict[str, Any], str]], candidates: Dict[str, Tuple[str, Dict[str, Any], str]],
) -> tuple[list[str], dict[str, str]]: ) -> Tuple[List[str], Dict[str, str]]:
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。""" """拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
available = set(candidates.keys()) available = set(candidates.keys())
dep_graph: dict[str, set[str]] = {} dep_graph: Dict[str, Set[str]] = {}
failed: dict[str, str] = {} failed: Dict[str, str] = {}
for pid, (_, manifest, _) in candidates.items(): for pid, (_, manifest, _) in candidates.items():
raw_deps = manifest.get("dependencies", []) raw_deps = manifest.get("dependencies", [])
resolved: set[str] = set() resolved: Set[str] = set()
missing: list[str] = [] missing: List[str] = []
for dep in raw_deps: for dep in raw_deps:
dep_name = dep if isinstance(dep, str) else str(dep.get("name", "")) dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
dep_name = dep_name.strip() dep_name = dep_name.strip()
@@ -177,14 +177,14 @@ class PluginLoader:
# Kahn 拓扑排序 # Kahn 拓扑排序
indegree = {pid: len(deps) for pid, deps in dep_graph.items()} indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
reverse: dict[str, set[str]] = {pid: set() for pid in dep_graph} reverse: Dict[str, Set[str]] = {pid: set() for pid in dep_graph}
for pid, deps in dep_graph.items(): for pid, deps in dep_graph.items():
for d in deps: for d in deps:
if d in reverse: if d in reverse:
reverse[d].add(pid) reverse[d].add(pid)
queue = deque(sorted(pid for pid, deg in indegree.items() if deg == 0)) queue = deque(sorted(pid for pid, deg in indegree.items() if deg == 0))
sorted_order: list[str] = [] sorted_order: List[str] = []
while queue: while queue:
current = queue.popleft() current = queue.popleft()
@@ -206,9 +206,9 @@ class PluginLoader:
self, self,
plugin_id: str, plugin_id: str,
plugin_dir: str, plugin_dir: str,
manifest: dict[str, Any], manifest: Dict[str, Any],
plugin_path: str, plugin_path: str,
) -> PluginMeta | None: ) -> Optional[PluginMeta]:
"""加载单个插件""" """加载单个插件"""
# 确保兼容层导入钩子已安装(旧版插件可能 import src.plugin_system # 确保兼容层导入钩子已安装(旧版插件可能 import src.plugin_system
self._ensure_compat_hook() self._ensure_compat_hook()
@@ -267,7 +267,7 @@ class PluginLoader:
logger.debug("maibot_sdk.compat 不可用,跳过导入钩子安装") logger.debug("maibot_sdk.compat 不可用,跳过导入钩子安装")
@staticmethod @staticmethod
def _try_load_legacy_plugin(module: Any, plugin_id: str) -> Any | None: def _try_load_legacy_plugin(module: Any, plugin_id: str) -> Optional[Any]:
"""尝试从模块中发现旧版 BasePlugin 子类并包装为 LegacyPluginAdapter""" """尝试从模块中发现旧版 BasePlugin 子类并包装为 LegacyPluginAdapter"""
# 方式 1: @register_plugin 装饰器设置的标记 # 方式 1: @register_plugin 装饰器设置的标记
legacy_cls = getattr(module, "_legacy_plugin_class", None) legacy_cls = getattr(module, "_legacy_plugin_class", None)

View File

@@ -8,7 +8,7 @@
5. 发送能力调用请求到 Host 5. 发送能力调用请求到 Host
""" """
from typing import Any, Callable, Awaitable from typing import Any, Awaitable, Callable, Dict, Optional
import asyncio import asyncio
import contextlib import contextlib
@@ -45,26 +45,26 @@ class RPCClient:
self, self,
host_address: str, host_address: str,
session_token: str, session_token: str,
codec: Codec | None = None, codec: Optional[Codec] = None,
): ):
self._host_address = host_address self._host_address = host_address
self._session_token = session_token self._session_token = session_token
self._codec = codec or MsgPackCodec() self._codec = codec or MsgPackCodec()
self._id_gen = RequestIdGenerator() self._id_gen = RequestIdGenerator()
self._connection: Connection | None = None self._connection: Optional[Connection] = None
self._runner_id = str(uuid.uuid4()) self._runner_id = str(uuid.uuid4())
self._generation: int = 0 self._generation: int = 0
# 方法处理器注册表Host 发来的调用) # 方法处理器注册表Host 发来的调用)
self._method_handlers: dict[str, MethodHandler] = {} self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future # 等待响应的 pending 请求: request_id -> Future
self._pending_requests: dict[int, asyncio.Future] = {} self._pending_requests: Dict[int, asyncio.Future] = {}
# 运行状态 # 运行状态
self._running = False self._running = False
self._recv_task: asyncio.Task | None = None self._recv_task: Optional[asyncio.Task] = None
@property @property
def generation(self) -> int: def generation(self) -> int:
@@ -147,7 +147,7 @@ class RPCClient:
self, self,
method: str, method: str,
plugin_id: str = "", plugin_id: str = "",
payload: dict[str, Any] | None = None, payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000, timeout_ms: int = 30000,
) -> Envelope: ) -> Envelope:
"""向 Host 发送 RPC 请求并等待响应""" """向 Host 发送 RPC 请求并等待响应"""

View File

@@ -9,6 +9,8 @@
6. 转发插件的能力调用到 Host 6. 转发插件的能力调用到 Host
""" """
from typing import List
import asyncio import asyncio
import contextlib import contextlib
import inspect import inspect
@@ -43,7 +45,7 @@ class PluginRunner:
self, self,
host_address: str, host_address: str,
session_token: str, session_token: str,
plugin_dirs: list[str], plugin_dirs: List[str],
) -> None: ) -> None:
self._host_address: str = host_address self._host_address: str = host_address
self._session_token: str = session_token self._session_token: str = session_token
@@ -114,7 +116,7 @@ class PluginRunner:
async def _register_plugin(self, meta: PluginMeta) -> None: async def _register_plugin(self, meta: PluginMeta) -> None:
"""向 Host 注册单个插件""" """向 Host 注册单个插件"""
# 收集插件组件声明 # 收集插件组件声明
components: list[ComponentDeclaration] = [] components: List[ComponentDeclaration] = []
instance = meta.instance instance = meta.instance
# 从插件实例获取组件声明SDK 插件须实现 get_components 方法) # 从插件实例获取组件声明SDK 插件须实现 get_components 方法)
@@ -284,7 +286,7 @@ class PluginRunner:
# ─── sys.path 隔离 ──────────────────────────────────────── # ─── sys.path 隔离 ────────────────────────────────────────
def _isolate_sys_path(plugin_dirs: list[str]) -> None: def _isolate_sys_path(plugin_dirs: List[str]) -> None:
"""清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。 """清理 sys.path限制 Runner 子进程只能访问标准库、SDK 和插件目录。
防止插件代码 import 主程序模块读取运行时数据。 防止插件代码 import 主程序模块读取运行时数据。

View File

@@ -8,7 +8,7 @@
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncIterator, Callable, Awaitable from typing import Awaitable, Callable
import asyncio import asyncio
import struct import struct

View File

@@ -3,12 +3,14 @@
根据运行平台自动选择最优传输实现。 根据运行平台自动选择最优传输实现。
""" """
from typing import Optional
import sys import sys
from .base import TransportClient, TransportServer from .base import TransportClient, TransportServer
def create_transport_server(socket_path: str | None = None) -> TransportServer: def create_transport_server(socket_path: Optional[str] = None) -> TransportServer:
"""创建传输服务端 """创建传输服务端
Linux/macOS 使用 UDSWindows 使用 TCP 回退。 Linux/macOS 使用 UDSWindows 使用 TCP 回退。

View File

@@ -4,6 +4,8 @@
绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。 绑定到 127.0.0.1 避免远程访问,但仍需会话令牌做身份校验。
""" """
from typing import Optional
import asyncio import asyncio
from .base import Connection, ConnectionHandler, TransportClient, TransportServer from .base import Connection, ConnectionHandler, TransportClient, TransportServer
@@ -20,7 +22,7 @@ class TCPTransportServer(TransportServer):
def __init__(self, host: str = "127.0.0.1", port: int = 0): def __init__(self, host: str = "127.0.0.1", port: int = 0):
self._host = host self._host = host
self._port = port # 0 表示自动分配 self._port = port # 0 表示自动分配
self._server: asyncio.AbstractServer | None = None self._server: Optional[asyncio.AbstractServer] = None
self._actual_port: int = 0 self._actual_port: int = 0
async def start(self, handler: ConnectionHandler) -> None: async def start(self, handler: ConnectionHandler) -> None:

View File

@@ -4,6 +4,7 @@
""" """
from pathlib import Path from pathlib import Path
from typing import Optional
import asyncio import asyncio
import os import os
@@ -20,13 +21,13 @@ class UDSConnection(Connection):
class UDSTransportServer(TransportServer): class UDSTransportServer(TransportServer):
"""UDS 传输服务端""" """UDS 传输服务端"""
def __init__(self, socket_path: str | None = None): def __init__(self, socket_path: Optional[str] = None):
if socket_path is None: if socket_path is None:
# 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞 # 默认放在临时目录,使用 uuid 确保同一进程多实例不碰撞
import uuid import uuid
socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock") socket_path = os.path.join(tempfile.gettempdir(), f"maibot-plugin-{os.getpid()}-{uuid.uuid4().hex[:8]}.sock")
self._socket_path = socket_path self._socket_path = socket_path
self._server: asyncio.AbstractServer | None = None self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None: async def start(self, handler: ConnectionHandler) -> None:
# 清理残留 socket 文件 # 清理残留 socket 文件