merge: sync upstream/r-dev and resolve real conflicts

This commit is contained in:
A-Dawn
2026-03-24 15:36:26 +08:00
114 changed files with 15841 additions and 5236 deletions

View File

@@ -16,3 +16,9 @@ ENV_PLUGIN_DIRS = "MAIBOT_PLUGIN_DIRS"
ENV_HOST_VERSION = "MAIBOT_HOST_VERSION"
"""Runner 读取的 Host 应用版本号,用于 manifest 兼容性校验"""
ENV_EXTERNAL_PLUGIN_IDS = "MAIBOT_EXTERNAL_PLUGIN_IDS"
"""Runner 启动时可视为已满足的外部插件依赖版本映射JSON 对象)"""
ENV_GLOBAL_CONFIG_SNAPSHOT = "MAIBOT_GLOBAL_CONFIG_SNAPSHOT"
"""Runner 启动时注入的全局配置快照JSON 对象)"""

View File

@@ -1,12 +1,13 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol, Sequence
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import RegisteredComponent
from src.plugin_runtime.host.api_registry import APIEntry
from src.plugin_runtime.host.component_registry import ComponentEntry
from src.plugin_runtime.host.supervisor import PluginSupervisor
@@ -14,18 +15,311 @@ class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
def _normalize_component_type(self, component_type: str) -> str: ...
def _is_api_component_type(self, component_type: str) -> bool: ...
def _serialize_api_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
def _serialize_api_component_entry(self, entry: "APIEntry") -> Dict[str, Any]: ...
def _is_api_visible_to_plugin(self, entry: "APIEntry", caller_plugin_id: str) -> bool: ...
def _normalize_api_reference(self, api_name: str, version: str = "") -> tuple[str, str]: ...
def _build_api_unavailable_error(self, entry: "APIEntry") -> str: ...
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
def _resolve_api_target(
self,
caller_plugin_id: str,
api_name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
def _resolve_api_toggle_target(
self,
name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ...
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
) -> tuple[Optional["ComponentEntry"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ...
def _iter_plugin_dirs(self) -> Iterable[Path]: ...
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool: ...
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool: ...
class RuntimeComponentCapabilityMixin:
@staticmethod
def _normalize_component_type(component_type: str) -> str:
"""规范化组件类型名称。
Args:
component_type: 原始组件类型。
Returns:
str: 统一转为大写后的组件类型名。
"""
return str(component_type or "").strip().upper()
@classmethod
def _is_api_component_type(cls, component_type: str) -> bool:
"""判断组件类型是否为 API。
Args:
component_type: 原始组件类型。
Returns:
bool: 是否为 API 组件类型。
"""
return cls._normalize_component_type(component_type) == "API"
@staticmethod
def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]:
"""将 API 组件条目序列化为能力返回值。
Args:
entry: API 组件条目。
Returns:
Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。
"""
return {
"name": entry.name,
"full_name": entry.full_name,
"plugin_id": entry.plugin_id,
"description": entry.description,
"version": entry.version,
"public": entry.public,
"enabled": entry.enabled,
"dynamic": entry.dynamic,
"offline_reason": entry.offline_reason,
"metadata": dict(entry.metadata),
}
@classmethod
def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]:
"""将 API 条目序列化为通用组件视图。
Args:
entry: API 组件条目。
Returns:
Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。
"""
serialized_entry = cls._serialize_api_entry(entry)
return {
"name": serialized_entry["name"],
"full_name": serialized_entry["full_name"],
"type": "API",
"enabled": serialized_entry["enabled"],
"metadata": serialized_entry["metadata"],
}
@staticmethod
def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool:
"""判断某个 API 是否对调用方可见。
Args:
entry: 目标 API 组件条目。
caller_plugin_id: 调用方插件 ID。
Returns:
bool: 是否允许当前插件可见并调用。
"""
return entry.plugin_id == caller_plugin_id or entry.public
@staticmethod
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
"""规范化 API 名称与版本参数。
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
"""
normalized_api_name = str(api_name or "").strip()
normalized_version = str(version or "").strip()
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
candidate_name = candidate_name.strip()
candidate_version = candidate_version.strip()
if candidate_name and candidate_version:
normalized_api_name = candidate_name
normalized_version = candidate_version
return normalized_api_name, normalized_version
@staticmethod
def _build_api_unavailable_error(entry: "APIEntry") -> str:
"""构造 API 当前不可用时的错误信息。"""
if entry.offline_reason:
return entry.offline_reason
return f"API {entry.registry_key} 当前不可用"
def _resolve_api_target(
self: _RuntimeComponentManagerProtocol,
caller_plugin_id: str,
api_name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析 API 名称到唯一可调用的目标组件。
Args:
caller_plugin_id: 调用方插件 ID。
api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
version: 可选的 API 版本。
Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
if not normalized_api_name:
return None, None, "缺少必要参数 api_name"
if "." in normalized_api_name:
target_plugin_id, target_api_name = normalized_api_name.rsplit(".", 1)
try:
supervisor = self._get_supervisor_for_plugin(target_plugin_id)
except RuntimeError as exc:
return None, None, str(exc)
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
entries = supervisor.api_registry.get_apis(
plugin_id=target_plugin_id,
name=target_api_name,
version=normalized_version,
enabled_only=False,
)
visible_enabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
]
visible_disabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
]
if len(visible_enabled_entries) == 1:
return supervisor, visible_enabled_entries[0], None
if len(visible_enabled_entries) > 1:
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
if visible_disabled_entries:
if len(visible_disabled_entries) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
return None, None, f"未找到 API: {normalized_api_name}"
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
hidden_match_exists = False
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(
name=normalized_api_name,
version=normalized_version,
enabled_only=False,
):
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
if entry.enabled:
visible_enabled_matches.append((supervisor, entry))
else:
visible_disabled_matches.append((supervisor, entry))
else:
hidden_match_exists = True
if len(visible_enabled_matches) == 1:
return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
if len(visible_enabled_matches) > 1:
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
if visible_disabled_matches:
if len(visible_disabled_matches) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
if hidden_match_exists:
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
return None, None, f"未找到 API: {normalized_api_name}"
def _resolve_api_toggle_target(
self: _RuntimeComponentManagerProtocol,
name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析需要启用或禁用的 API 组件。
Args:
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
version: 可选的 API 版本。
Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
normalized_name, normalized_version = self._normalize_api_reference(name, version)
if not normalized_name:
return None, None, "缺少必要参数 name"
if "." in normalized_name:
plugin_id, api_name = normalized_name.rsplit(".", 1)
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return None, None, str(exc)
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {plugin_id}"
entries = supervisor.api_registry.get_apis(
plugin_id=plugin_id,
name=api_name,
version=normalized_version,
enabled_only=False,
)
if len(entries) == 1:
return supervisor, entries[0], None
if entries:
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
return None, None, f"未找到 API: {normalized_name}"
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors:
matches.extend(
(supervisor, entry)
for entry in supervisor.api_registry.get_apis(
name=normalized_name,
version=normalized_version,
enabled_only=False,
)
)
if len(matches) == 1:
return matches[0][0], matches[0][1], None
if len(matches) > 1:
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
return None, None, f"未找到 API: {normalized_name}"
async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
@@ -46,6 +340,10 @@ class RuntimeComponentCapabilityMixin:
}
for component in comps
]
components_list.extend(
self._serialize_api_component_entry(entry)
for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False)
)
result[pid] = {
"name": pid,
"version": reg.plugin_version,
@@ -96,30 +394,35 @@ class RuntimeComponentCapabilityMixin:
def _resolve_component_toggle_target(
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
short_name_matches: List["RegisteredComponent"] = []
) -> tuple[Optional["ComponentEntry"], Optional[str]]:
normalized_component_type = self._normalize_component_type(component_type)
short_name_matches: List["ComponentEntry"] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
if comp is not None and comp.component_type == component_type:
if comp is not None and comp.component_type == normalized_component_type:
return comp, None
short_name_matches.extend(
candidate
for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
for candidate in sv.component_registry.get_components_by_type(
normalized_component_type,
enabled_only=False,
)
if candidate.name == name
)
if len(short_name_matches) == 1:
return short_name_matches[0], None
if len(short_name_matches) > 1:
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({component_type})"
return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({normalized_component_type})"
async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -127,6 +430,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -139,6 +449,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -146,6 +457,13 @@ class RuntimeComponentCapabilityMixin:
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
@@ -168,33 +486,14 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
loaded = await self.load_plugin_globally(plugin_name, reason=f"load {plugin_name}")
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
if registered_supervisor is not None:
try:
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
for sv in self.supervisors:
for pdir in sv._plugin_dirs:
if (pdir / plugin_name).is_dir():
try:
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
if loaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
@@ -216,17 +515,204 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
sv = self._get_supervisor_for_plugin(plugin_name)
reloaded = await self.reload_plugins_globally([plugin_name], reason=f"reload {plugin_name}")
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败"}
async def _cap_api_call(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""调用其他插件公开的 API。
Args:
plugin_id: 当前调用方插件 ID。
capability: 能力名称。
args: 能力参数。
Returns:
Any: API 调用结果。
"""
del capability
api_name = str(args.get("api_name", "") or "").strip()
version = str(args.get("version", "") or "").strip()
api_args = args.get("args", {})
if not isinstance(api_args, dict):
return {"success": False, "error": "参数 args 必须为字典"}
supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version)
if supervisor is None or entry is None:
return {"success": False, "error": error or "API 解析失败"}
invoke_args = dict(api_args)
if entry.dynamic:
invoke_args.setdefault("__maibot_api_name__", entry.name)
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
invoke_args.setdefault("__maibot_api_version__", entry.version)
try:
response = await supervisor.invoke_api(
plugin_id=entry.plugin_id,
component_name=entry.handler_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True)
return {"success": False, "error": str(exc)}
if response.error:
return {"success": False, "error": response.error.get("message", "API 调用失败")}
payload = response.payload if isinstance(response.payload, dict) else {}
if not bool(payload.get("success", False)):
result = payload.get("result")
return {"success": False, "error": "" if result is None else str(result)}
return {"success": True, "result": payload.get("result")}
async def _cap_api_get(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""获取当前插件可见的单个 API 元信息。
Args:
plugin_id: 当前调用方插件 ID。
capability: 能力名称。
args: 能力参数。
Returns:
Any: API 元信息或 ``None``。
"""
del capability
api_name = str(args.get("api_name", "") or "").strip()
version = str(args.get("version", "") or "").strip()
if not api_name:
return {"success": False, "error": "缺少必要参数 api_name"}
supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version)
if supervisor is None or entry is None:
return {"success": True, "api": None}
return {"success": True, "api": self._serialize_api_entry(entry)}
async def _cap_api_list(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""列出当前插件可见的 API 列表。
Args:
plugin_id: 当前调用方插件 ID。
capability: 能力名称。
args: 能力参数。
Returns:
Any: API 元信息列表。
"""
del capability
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
api_name, version = self._normalize_api_reference(
str(args.get("api_name", args.get("name", "")) or ""),
str(args.get("version", "") or ""),
)
apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors:
apis.extend(
self._serialize_api_entry(entry)
for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None,
name=api_name,
version=version,
enabled_only=True,
)
if self._is_api_visible_to_plugin(entry, plugin_id)
)
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis}
async def _cap_api_replace_dynamic(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""替换插件自行维护的动态 API 列表。"""
del capability
raw_apis = args.get("apis", [])
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
if not isinstance(raw_apis, list):
return {"success": False, "error": "参数 apis 必须为列表"}
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is not None:
try:
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
if supervisor is None:
return {"success": False, "error": f"未找到插件: {plugin_id}"}
normalized_components: List[Dict[str, Any]] = []
seen_registry_keys: set[str] = set()
for index, raw_api in enumerate(raw_apis):
if not isinstance(raw_api, dict):
return {"success": False, "error": f"apis[{index}] 必须为字典"}
api_name = str(raw_api.get("name", "") or "").strip()
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
if not api_name:
return {"success": False, "error": f"apis[{index}] 缺少 name"}
if not self._is_api_component_type(component_type):
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
normalized_metadata = dict(metadata)
normalized_metadata["dynamic"] = True
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
if registry_key in seen_registry_keys:
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
seen_registry_keys.add(registry_key)
existing_entry = supervisor.api_registry.get_api(
plugin_id,
api_name,
version=version,
enabled_only=False,
)
if existing_entry is not None and not existing_entry.dynamic:
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
normalized_components.append(
{
"name": api_name,
"component_type": "API",
"metadata": normalized_metadata,
}
)
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
plugin_id,
normalized_components,
offline_reason=offline_reason,
)
return {
"success": True,
"count": registered_count,
"offlined": offlined_count,
}

View File

@@ -238,14 +238,14 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": None, "error": str(e)}
async def _cap_config_get_plugin(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
key: str = args.get("key", "")
default = args.get("default")
try:
config = core_registry.get_plugin_config(plugin_name)
config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": False, "value": default, "error": f"未找到插件 {plugin_name} 的配置"}
@@ -258,11 +258,11 @@ class RuntimeCoreCapabilityMixin:
return {"success": False, "value": default, "error": str(e)}
async def _cap_config_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
from src.plugin_runtime.component_query import component_query_service
plugin_name: str = args.get("plugin_name", plugin_id)
try:
config = core_registry.get_plugin_config(plugin_name)
config = component_query_service.get_plugin_config(plugin_name)
if config is None:
return {"success": True, "value": {}}
return {"success": True, "value": config}

View File

@@ -648,10 +648,10 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
async def _cap_tool_get_definitions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
from src.core.component_registry import component_registry as core_registry
from src.plugin_runtime.component_query import component_query_service
try:
tools = core_registry.get_llm_available_tools()
tools = component_query_service.get_llm_available_tools()
return {
"success": True,
"tools": [{"name": name, "definition": info.get_llm_definition()} for name, info in tools.items()],

View File

@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_runtime.host.capability_service import CapabilityImpl
from src.plugin_runtime.host.supervisor import PluginSupervisor
if TYPE_CHECKING:
@@ -13,66 +14,80 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
"""向指定 Supervisor 注册主程序提供的能力实现。"""
cap_service = supervisor.capability_service
cap_service.register_capability("send.text", manager._cap_send_text)
cap_service.register_capability("send.emoji", manager._cap_send_emoji)
cap_service.register_capability("send.image", manager._cap_send_image)
cap_service.register_capability("send.command", manager._cap_send_command)
cap_service.register_capability("send.custom", manager._cap_send_custom)
def _register(name: str, impl: CapabilityImpl) -> None:
"""注册单个能力实现。
cap_service.register_capability("llm.generate", manager._cap_llm_generate)
cap_service.register_capability("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
cap_service.register_capability("llm.get_available_models", manager._cap_llm_get_available_models)
Args:
name: 能力名称。
impl: 能力实现函数。
"""
cap_service.register_capability(name, impl)
cap_service.register_capability("config.get", manager._cap_config_get)
cap_service.register_capability("config.get_plugin", manager._cap_config_get_plugin)
cap_service.register_capability("config.get_all", manager._cap_config_get_all)
_register("send.text", manager._cap_send_text)
_register("send.emoji", manager._cap_send_emoji)
_register("send.image", manager._cap_send_image)
_register("send.command", manager._cap_send_command)
_register("send.custom", manager._cap_send_custom)
cap_service.register_capability("database.query", manager._cap_database_query)
cap_service.register_capability("database.save", manager._cap_database_save)
cap_service.register_capability("database.get", manager._cap_database_get)
cap_service.register_capability("database.delete", manager._cap_database_delete)
cap_service.register_capability("database.count", manager._cap_database_count)
_register("llm.generate", manager._cap_llm_generate)
_register("llm.generate_with_tools", manager._cap_llm_generate_with_tools)
_register("llm.get_available_models", manager._cap_llm_get_available_models)
cap_service.register_capability("chat.get_all_streams", manager._cap_chat_get_all_streams)
cap_service.register_capability("chat.get_group_streams", manager._cap_chat_get_group_streams)
cap_service.register_capability("chat.get_private_streams", manager._cap_chat_get_private_streams)
cap_service.register_capability("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
cap_service.register_capability("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
_register("config.get", manager._cap_config_get)
_register("config.get_plugin", manager._cap_config_get_plugin)
_register("config.get_all", manager._cap_config_get_all)
cap_service.register_capability("message.get_by_time", manager._cap_message_get_by_time)
cap_service.register_capability("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
cap_service.register_capability("message.get_recent", manager._cap_message_get_recent)
cap_service.register_capability("message.count_new", manager._cap_message_count_new)
cap_service.register_capability("message.build_readable", manager._cap_message_build_readable)
_register("database.query", manager._cap_database_query)
_register("database.save", manager._cap_database_save)
_register("database.get", manager._cap_database_get)
_register("database.delete", manager._cap_database_delete)
_register("database.count", manager._cap_database_count)
cap_service.register_capability("person.get_id", manager._cap_person_get_id)
cap_service.register_capability("person.get_value", manager._cap_person_get_value)
cap_service.register_capability("person.get_id_by_name", manager._cap_person_get_id_by_name)
_register("chat.get_all_streams", manager._cap_chat_get_all_streams)
_register("chat.get_group_streams", manager._cap_chat_get_group_streams)
_register("chat.get_private_streams", manager._cap_chat_get_private_streams)
_register("chat.get_stream_by_group_id", manager._cap_chat_get_stream_by_group_id)
_register("chat.get_stream_by_user_id", manager._cap_chat_get_stream_by_user_id)
cap_service.register_capability("emoji.get_by_description", manager._cap_emoji_get_by_description)
cap_service.register_capability("emoji.get_random", manager._cap_emoji_get_random)
cap_service.register_capability("emoji.get_count", manager._cap_emoji_get_count)
cap_service.register_capability("emoji.get_emotions", manager._cap_emoji_get_emotions)
cap_service.register_capability("emoji.get_all", manager._cap_emoji_get_all)
cap_service.register_capability("emoji.get_info", manager._cap_emoji_get_info)
cap_service.register_capability("emoji.register", manager._cap_emoji_register)
cap_service.register_capability("emoji.delete", manager._cap_emoji_delete)
_register("message.get_by_time", manager._cap_message_get_by_time)
_register("message.get_by_time_in_chat", manager._cap_message_get_by_time_in_chat)
_register("message.get_recent", manager._cap_message_get_recent)
_register("message.count_new", manager._cap_message_count_new)
_register("message.build_readable", manager._cap_message_build_readable)
cap_service.register_capability("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
cap_service.register_capability("frequency.set_adjust", manager._cap_frequency_set_adjust)
cap_service.register_capability("frequency.get_adjust", manager._cap_frequency_get_adjust)
_register("person.get_id", manager._cap_person_get_id)
_register("person.get_value", manager._cap_person_get_value)
_register("person.get_id_by_name", manager._cap_person_get_id_by_name)
cap_service.register_capability("tool.get_definitions", manager._cap_tool_get_definitions)
_register("emoji.get_by_description", manager._cap_emoji_get_by_description)
_register("emoji.get_random", manager._cap_emoji_get_random)
_register("emoji.get_count", manager._cap_emoji_get_count)
_register("emoji.get_emotions", manager._cap_emoji_get_emotions)
_register("emoji.get_all", manager._cap_emoji_get_all)
_register("emoji.get_info", manager._cap_emoji_get_info)
_register("emoji.register", manager._cap_emoji_register)
_register("emoji.delete", manager._cap_emoji_delete)
cap_service.register_capability("component.get_all_plugins", manager._cap_component_get_all_plugins)
cap_service.register_capability("component.get_plugin_info", manager._cap_component_get_plugin_info)
cap_service.register_capability("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
cap_service.register_capability("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
cap_service.register_capability("component.enable", manager._cap_component_enable)
cap_service.register_capability("component.disable", manager._cap_component_disable)
cap_service.register_capability("component.load_plugin", manager._cap_component_load_plugin)
cap_service.register_capability("component.unload_plugin", manager._cap_component_unload_plugin)
cap_service.register_capability("component.reload_plugin", manager._cap_component_reload_plugin)
_register("frequency.get_current_talk_value", manager._cap_frequency_get_current_talk_value)
_register("frequency.set_adjust", manager._cap_frequency_set_adjust)
_register("frequency.get_adjust", manager._cap_frequency_get_adjust)
cap_service.register_capability("knowledge.search", manager._cap_knowledge_search)
_register("tool.get_definitions", manager._cap_tool_get_definitions)
_register("api.call", manager._cap_api_call)
_register("api.get", manager._cap_api_get)
_register("api.list", manager._cap_api_list)
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
_register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins)
_register("component.list_registered_plugins", manager._cap_component_list_registered_plugins)
_register("component.enable", manager._cap_component_enable)
_register("component.disable", manager._cap_component_disable)
_register("component.load_plugin", manager._cap_component_load_plugin)
_register("component.unload_plugin", manager._cap_component_unload_plugin)
_register("component.reload_plugin", manager._cap_component_reload_plugin)
_register("knowledge.search", manager._cap_knowledge_search)
logger.debug("已注册全部主程序能力实现")

View File

@@ -0,0 +1,709 @@
"""插件运行时统一组件查询服务。
该模块统一从插件运行时的 Host ComponentRegistry 中聚合只读视图,
供 HFC/PFC、Planner、ToolExecutor 和运行时能力层查询与调用。
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple
from src.common.logger import get_logger
from src.core.types import ActionActivationType, ActionInfo, CommandInfo, ComponentInfo, ComponentType, ToolInfo
from src.llm_models.payload_content.tool_option import ToolParamType
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import ActionEntry, CommandEntry, ComponentEntry, ToolEntry
from src.plugin_runtime.host.supervisor import PluginSupervisor
from src.plugin_runtime.integration import PluginRuntimeManager
logger = get_logger("plugin_runtime.component_query")
ActionExecutor = Callable[..., Awaitable[Any]]
CommandExecutor = Callable[..., Awaitable[Tuple[bool, Optional[str], bool]]]
ToolExecutor = Callable[..., Awaitable[Any]]
_HOST_COMPONENT_TYPE_MAP: Dict[ComponentType, str] = {
ComponentType.ACTION: "ACTION",
ComponentType.COMMAND: "COMMAND",
ComponentType.TOOL: "TOOL",
}
_TOOL_PARAM_TYPE_MAP: Dict[str, ToolParamType] = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
class ComponentQueryService:
"""插件运行时统一组件查询服务。
该对象不维护独立状态,只读取插件系统中的注册结果。
所有注册、删除、配置写入等写操作都被显式禁用。
"""
@staticmethod
def _get_runtime_manager() -> "PluginRuntimeManager":
"""获取插件运行时管理器单例。
Returns:
PluginRuntimeManager: 当前全局插件运行时管理器。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _iter_supervisors(self) -> list["PluginSupervisor"]:
"""获取当前所有活跃的插件运行时监督器。
Returns:
list[PluginSupervisor]: 当前运行中的监督器列表。
"""
runtime_manager = self._get_runtime_manager()
return list(runtime_manager.supervisors)
def _iter_component_entries(
self,
component_type: ComponentType,
*,
enabled_only: bool = True,
) -> list[tuple["PluginSupervisor", "ComponentEntry"]]:
"""遍历指定类型的全部组件条目。
Args:
component_type: 目标组件类型。
enabled_only: 是否仅返回启用状态的组件。
Returns:
list[tuple[PluginSupervisor, ComponentEntry]]: ``(监督器, 组件条目)`` 列表。
"""
host_component_type = _HOST_COMPONENT_TYPE_MAP.get(component_type)
if host_component_type is None:
return []
collected_entries: list[tuple["PluginSupervisor", "ComponentEntry"]] = []
for supervisor in self._iter_supervisors():
for component in supervisor.component_registry.get_components_by_type(
host_component_type,
enabled_only=enabled_only,
):
collected_entries.append((supervisor, component))
return collected_entries
@staticmethod
def _coerce_action_activation_type(raw_value: Any) -> ActionActivationType:
"""规范化动作激活类型。
Args:
raw_value: 原始激活类型值。
Returns:
ActionActivationType: 规范化后的激活类型枚举。
"""
normalized_value = str(raw_value or "").strip().lower()
if normalized_value == ActionActivationType.NEVER.value:
return ActionActivationType.NEVER
if normalized_value == ActionActivationType.RANDOM.value:
return ActionActivationType.RANDOM
if normalized_value == ActionActivationType.KEYWORD.value:
return ActionActivationType.KEYWORD
return ActionActivationType.ALWAYS
@staticmethod
def _coerce_float(value: Any, default: float = 0.0) -> float:
"""将任意值安全转换为浮点数。
Args:
value: 待转换的输入值。
default: 转换失败时返回的默认值。
Returns:
float: 转换后的浮点结果。
"""
try:
return float(value)
except (TypeError, ValueError):
return default
@staticmethod
def _build_action_info(entry: "ActionEntry") -> ActionInfo:
"""将运行时 Action 条目转换为核心动作信息。
Args:
entry: 插件运行时中的 Action 条目。
Returns:
ActionInfo: 供核心 Planner 使用的动作信息。
"""
metadata = dict(entry.metadata)
raw_action_parameters = metadata.get("action_parameters")
action_parameters = (
{
str(param_name): str(param_description)
for param_name, param_description in raw_action_parameters.items()
}
if isinstance(raw_action_parameters, dict)
else {}
)
action_require = [
str(item)
for item in (metadata.get("action_require") or [])
if item is not None and str(item).strip()
]
associated_types = [
str(item)
for item in (metadata.get("associated_types") or [])
if item is not None and str(item).strip()
]
activation_keywords = [
str(item)
for item in (metadata.get("activation_keywords") or [])
if item is not None and str(item).strip()
]
return ActionInfo(
name=entry.name,
component_type=ComponentType.ACTION,
description=str(metadata.get("description", "") or ""),
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=metadata,
action_parameters=action_parameters,
action_require=action_require,
associated_types=associated_types,
activation_type=ComponentQueryService._coerce_action_activation_type(metadata.get("activation_type")),
random_activation_probability=ComponentQueryService._coerce_float(
metadata.get("activation_probability"),
0.0,
),
activation_keywords=activation_keywords,
parallel_action=bool(metadata.get("parallel_action", False)),
)
@staticmethod
def _build_command_info(entry: "CommandEntry") -> CommandInfo:
"""将运行时 Command 条目转换为核心命令信息。
Args:
entry: 插件运行时中的 Command 条目。
Returns:
CommandInfo: 供核心命令链使用的命令信息。
"""
metadata = dict(entry.metadata)
return CommandInfo(
name=entry.name,
component_type=ComponentType.COMMAND,
description=str(metadata.get("description", "") or ""),
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=metadata,
command_pattern=str(metadata.get("command_pattern", "") or ""),
)
@staticmethod
def _coerce_tool_param_type(raw_value: Any) -> ToolParamType:
"""规范化工具参数类型。
Args:
raw_value: 原始工具参数类型值。
Returns:
ToolParamType: 规范化后的工具参数类型。
"""
normalized_value = str(raw_value or "").strip().lower()
return _TOOL_PARAM_TYPE_MAP.get(normalized_value, ToolParamType.STRING)
@staticmethod
def _build_tool_parameters(entry: "ToolEntry") -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
"""将运行时工具参数元数据转换为核心 ToolInfo 参数列表。
Args:
entry: 插件运行时中的 Tool 条目。
Returns:
list[tuple[str, ToolParamType, str, bool, list[str] | None]]: 转换后的参数列表。
"""
structured_parameters = entry.parameters if isinstance(entry.parameters, list) else []
if not structured_parameters and isinstance(entry.parameters_raw, dict):
structured_parameters = [
{"name": key, **value}
for key, value in entry.parameters_raw.items()
if isinstance(value, dict)
]
normalized_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for parameter in structured_parameters:
if not isinstance(parameter, dict):
continue
parameter_name = str(parameter.get("name", "") or "").strip()
if not parameter_name:
continue
enum_values = parameter.get("enum")
normalized_enum_values = (
[str(item) for item in enum_values if item is not None]
if isinstance(enum_values, list)
else None
)
normalized_parameters.append(
(
parameter_name,
ComponentQueryService._coerce_tool_param_type(parameter.get("param_type") or parameter.get("type")),
str(parameter.get("description", "") or ""),
bool(parameter.get("required", True)),
normalized_enum_values,
)
)
return normalized_parameters
@staticmethod
def _build_tool_info(entry: "ToolEntry") -> ToolInfo:
"""将运行时 Tool 条目转换为核心工具信息。
Args:
entry: 插件运行时中的 Tool 条目。
Returns:
ToolInfo: 供 ToolExecutor 与能力层使用的工具信息。
"""
return ToolInfo(
name=entry.name,
component_type=ComponentType.TOOL,
description=entry.description,
enabled=bool(entry.enabled),
plugin_name=entry.plugin_id,
metadata=dict(entry.metadata),
tool_parameters=ComponentQueryService._build_tool_parameters(entry),
tool_description=entry.description,
)
@staticmethod
def _log_duplicate_component(component_type: ComponentType, component_name: str) -> None:
"""记录重复组件名称冲突。
Args:
component_type: 组件类型。
component_name: 发生冲突的组件名称。
"""
logger.warning(f"检测到重复{component_type.value}名称 {component_name},将只保留首个匹配项")
def _get_unique_component_entry(
self,
component_type: ComponentType,
name: str,
) -> Optional[tuple["PluginSupervisor", "ComponentEntry"]]:
"""按组件短名解析唯一条目。
Args:
component_type: 目标组件类型。
name: 组件短名。
Returns:
Optional[tuple[PluginSupervisor, ComponentEntry]]: 唯一命中的组件条目。
"""
matched_entries = [
(supervisor, entry)
for supervisor, entry in self._iter_component_entries(component_type)
if entry.name == name
]
if not matched_entries:
return None
if len(matched_entries) > 1:
self._log_duplicate_component(component_type, name)
return matched_entries[0]
def _collect_unique_component_infos(
self,
component_type: ComponentType,
) -> Dict[str, ComponentInfo]:
"""收集某类组件的唯一信息视图。
Args:
component_type: 目标组件类型。
Returns:
Dict[str, ComponentInfo]: 组件名到核心组件信息的映射。
"""
collected_components: Dict[str, ComponentInfo] = {}
for _supervisor, entry in self._iter_component_entries(component_type):
if entry.name in collected_components:
self._log_duplicate_component(component_type, entry.name)
continue
if component_type == ComponentType.ACTION:
collected_components[entry.name] = self._build_action_info(entry) # type: ignore[arg-type]
elif component_type == ComponentType.COMMAND:
collected_components[entry.name] = self._build_command_info(entry) # type: ignore[arg-type]
elif component_type == ComponentType.TOOL:
collected_components[entry.name] = self._build_tool_info(entry) # type: ignore[arg-type]
return collected_components
@staticmethod
def _extract_stream_id_from_action_kwargs(kwargs: Dict[str, Any]) -> str:
"""从旧 ActionManager 参数中提取聊天流 ID。
Args:
kwargs: 旧动作执行器收到的关键字参数。
Returns:
str: 提取出的 ``stream_id``。
"""
chat_stream = kwargs.get("chat_stream")
if chat_stream is not None:
try:
return str(chat_stream.session_id)
except AttributeError:
pass
return str(kwargs.get("stream_id", "") or "")
@staticmethod
def _build_action_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ActionExecutor:
"""构造动作执行 RPC 闭包。
Args:
supervisor: 负责该组件的监督器。
plugin_id: 插件 ID。
component_name: 组件名称。
Returns:
ActionExecutor: 兼容旧 Planner 的异步执行器。
"""
async def _executor(**kwargs: Any) -> tuple[bool, str]:
"""将核心动作调用桥接到插件运行时。
Args:
**kwargs: 旧 ActionManager 传入的上下文参数。
Returns:
tuple[bool, str]: ``(是否成功, 结果说明)``。
"""
invoke_args: Dict[str, Any] = {}
action_data = kwargs.get("action_data")
if isinstance(action_data, dict):
invoke_args.update(action_data)
stream_id = ComponentQueryService._extract_stream_id_from_action_kwargs(kwargs)
invoke_args["action_data"] = action_data if isinstance(action_data, dict) else {}
invoke_args["stream_id"] = stream_id
invoke_args["chat_id"] = stream_id
invoke_args["reasoning"] = str(kwargs.get("action_reasoning", "") or "")
if (thinking_id := kwargs.get("thinking_id")) is not None:
invoke_args["thinking_id"] = str(thinking_id)
if isinstance(kwargs.get("cycle_timers"), dict):
invoke_args["cycle_timers"] = kwargs["cycle_timers"]
if isinstance(kwargs.get("plugin_config"), dict):
invoke_args["plugin_config"] = kwargs["plugin_config"]
if isinstance(kwargs.get("log_prefix"), str):
invoke_args["log_prefix"] = kwargs["log_prefix"]
if isinstance(kwargs.get("shutting_down"), bool):
invoke_args["shutting_down"] = kwargs["shutting_down"]
try:
response = await supervisor.invoke_plugin(
method="plugin.invoke_action",
plugin_id=plugin_id,
component_name=component_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"运行时 Action {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
return False, str(exc)
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
if isinstance(result, (list, tuple)):
if len(result) >= 2:
return bool(result[0]), "" if result[1] is None else str(result[1])
if len(result) == 1:
return bool(result[0]), ""
if success:
return True, "" if result is None else str(result)
return False, "" if result is None else str(result)
return _executor
@staticmethod
def _build_command_executor(
supervisor: "PluginSupervisor",
plugin_id: str,
component_name: str,
metadata: Dict[str, Any],
) -> CommandExecutor:
"""构造命令执行 RPC 闭包。
Args:
supervisor: 负责该组件的监督器。
plugin_id: 插件 ID。
component_name: 组件名称。
metadata: 命令组件元数据。
Returns:
CommandExecutor: 兼容旧消息命令链的执行器。
"""
async def _executor(**kwargs: Any) -> tuple[bool, Optional[str], bool]:
"""将核心命令调用桥接到插件运行时。
Args:
**kwargs: 命令执行上下文参数。
Returns:
tuple[bool, Optional[str], bool]: ``(是否成功, 返回文本, 是否拦截后续消息)``。
"""
message = kwargs.get("message")
matched_groups = kwargs.get("matched_groups")
plugin_config = kwargs.get("plugin_config")
invoke_args: Dict[str, Any] = {
"text": str(getattr(message, "processed_plain_text", "") or ""),
"stream_id": str(getattr(message, "session_id", "") or ""),
"matched_groups": matched_groups if isinstance(matched_groups, dict) else {},
}
if isinstance(plugin_config, dict):
invoke_args["plugin_config"] = plugin_config
try:
response = await supervisor.invoke_plugin(
method="plugin.invoke_command",
plugin_id=plugin_id,
component_name=component_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"运行时 Command {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
return False, str(exc), True
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
intercept = bool(metadata.get("intercept_message_level", 0))
response_text: Optional[str]
if isinstance(result, (list, tuple)) and len(result) >= 3:
response_text = None if result[1] is None else str(result[1])
intercept = bool(result[2])
else:
response_text = None if result is None else str(result)
return success, response_text, intercept
return _executor
@staticmethod
def _build_tool_executor(supervisor: "PluginSupervisor", plugin_id: str, component_name: str) -> ToolExecutor:
"""构造工具执行 RPC 闭包。
Args:
supervisor: 负责该组件的监督器。
plugin_id: 插件 ID。
component_name: 组件名称。
Returns:
ToolExecutor: 兼容旧 ToolExecutor 的异步执行器。
"""
async def _executor(function_args: Dict[str, Any]) -> Any:
"""将核心工具调用桥接到插件运行时。
Args:
function_args: 工具调用参数。
Returns:
Any: 插件工具返回结果;若结果不是字典,则会包装为 ``{"content": ...}``。
"""
try:
response = await supervisor.invoke_plugin(
method="plugin.invoke_tool",
plugin_id=plugin_id,
component_name=component_name,
args=function_args,
timeout_ms=30000,
)
except Exception as exc:
logger.error(f"运行时 Tool {plugin_id}.{component_name} 执行失败: {exc}", exc_info=True)
return {"content": f"工具 {component_name} 执行失败: {exc}"}
payload = response.payload if isinstance(response.payload, dict) else {}
result = payload.get("result")
if isinstance(result, dict):
return result
return {"content": "" if result is None else str(result)}
return _executor
def get_action_info(self, name: str) -> Optional[ActionInfo]:
"""获取指定动作的信息。
Args:
name: 动作名称。
Returns:
Optional[ActionInfo]: 匹配到的动作信息。
"""
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
if matched_entry is None:
return None
_supervisor, entry = matched_entry
return self._build_action_info(entry) # type: ignore[arg-type]
def get_action_executor(self, name: str) -> Optional[ActionExecutor]:
"""获取指定动作的执行器。
Args:
name: 动作名称。
Returns:
Optional[ActionExecutor]: 运行时 RPC 执行闭包。
"""
matched_entry = self._get_unique_component_entry(ComponentType.ACTION, name)
if matched_entry is None:
return None
supervisor, entry = matched_entry
return self._build_action_executor(supervisor, entry.plugin_id, entry.name)
def get_default_actions(self) -> Dict[str, ActionInfo]:
"""获取当前默认启用的动作集合。
Returns:
Dict[str, ActionInfo]: 动作名到动作信息的映射。
"""
action_infos = self._collect_unique_component_infos(ComponentType.ACTION)
return {name: info for name, info in action_infos.items() if isinstance(info, ActionInfo) and info.enabled}
def find_command_by_text(self, text: str) -> Optional[Tuple[CommandExecutor, dict, CommandInfo]]:
"""根据文本查找匹配的命令。
Args:
text: 待匹配的文本内容。
Returns:
Optional[Tuple[CommandExecutor, dict, CommandInfo]]: 匹配结果。
"""
for supervisor in self._iter_supervisors():
match_result = supervisor.component_registry.find_command_by_text(text)
if match_result is None:
continue
entry, matched_groups = match_result
command_info = self._build_command_info(entry) # type: ignore[arg-type]
command_executor = self._build_command_executor(
supervisor,
entry.plugin_id,
entry.name,
dict(entry.metadata),
)
return command_executor, matched_groups, command_info
return None
def get_tool_info(self, name: str) -> Optional[ToolInfo]:
"""获取指定工具的信息。
Args:
name: 工具名称。
Returns:
Optional[ToolInfo]: 匹配到的工具信息。
"""
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
if matched_entry is None:
return None
_supervisor, entry = matched_entry
return self._build_tool_info(entry) # type: ignore[arg-type]
def get_tool_executor(self, name: str) -> Optional[ToolExecutor]:
"""获取指定工具的执行器。
Args:
name: 工具名称。
Returns:
Optional[ToolExecutor]: 运行时 RPC 执行闭包。
"""
matched_entry = self._get_unique_component_entry(ComponentType.TOOL, name)
if matched_entry is None:
return None
supervisor, entry = matched_entry
return self._build_tool_executor(supervisor, entry.plugin_id, entry.name)
def get_llm_available_tools(self) -> Dict[str, ToolInfo]:
"""获取当前可供 LLM 选择的工具集合。
Returns:
Dict[str, ToolInfo]: 工具名到工具信息的映射。
"""
tool_infos = self._collect_unique_component_infos(ComponentType.TOOL)
return {name: info for name, info in tool_infos.items() if isinstance(info, ToolInfo) and info.enabled}
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
"""获取某类组件的全部信息。
Args:
component_type: 组件类型。
Returns:
Dict[str, ComponentInfo]: 组件名到组件信息的映射。
"""
return self._collect_unique_component_infos(component_type)
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
"""读取指定插件的配置文件内容。
Args:
plugin_name: 插件名称。
Returns:
Optional[dict]: 读取成功时返回配置字典;未找到时返回 ``None``。
"""
runtime_manager = self._get_runtime_manager()
try:
supervisor = runtime_manager._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
logger.error(f"读取插件配置失败: {exc}")
return None
if supervisor is None:
return None
try:
return runtime_manager._load_plugin_config_for_supervisor(supervisor, plugin_name)
except Exception as exc:
logger.error(f"读取插件 {plugin_name} 配置失败: {exc}", exc_info=True)
return None
component_query_service = ComponentQueryService()

View File

@@ -0,0 +1,349 @@
"""Host 侧插件 API 动态注册表。"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.api_registry")
@dataclass(slots=True)
class APIEntry:
"""API 组件条目。"""
name: str
plugin_id: str
description: str = ""
version: str = "1"
public: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
enabled: bool = True
handler_name: str = ""
dynamic: bool = False
offline_reason: str = ""
disabled_session: Set[str] = field(default_factory=set)
full_name: str = field(init=False)
registry_key: str = field(init=False)
def __post_init__(self) -> None:
"""规范化 API 条目字段。"""
self.name = str(self.name or "").strip()
self.plugin_id = str(self.plugin_id or "").strip()
self.description = str(self.description or "").strip()
self.version = str(self.version or "1").strip() or "1"
self.handler_name = str(self.handler_name or self.name).strip() or self.name
self.offline_reason = str(self.offline_reason or "").strip()
self.full_name = f"{self.plugin_id}.{self.name}"
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
@classmethod
def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
"""根据 Runner 上报的元数据构造 API 条目。"""
safe_metadata = dict(metadata)
return cls(
name=name,
plugin_id=plugin_id,
description=str(safe_metadata.get("description", "") or ""),
version=str(safe_metadata.get("version", "1") or "1"),
public=bool(safe_metadata.get("public", False)),
metadata=safe_metadata,
enabled=bool(safe_metadata.get("enabled", True)),
handler_name=str(safe_metadata.get("handler_name", name) or name),
dynamic=bool(safe_metadata.get("dynamic", False)),
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
)
class APIRegistry:
"""Host 侧插件 API 动态注册表。
该注册表不直接面向 Runner而是复用插件组件注册/卸载事件,
维护面向 API 调用场景的专用索引。
"""
def __init__(self) -> None:
"""初始化 API 注册表。"""
self._apis: Dict[str, APIEntry] = {}
self._by_full_name: Dict[str, List[APIEntry]] = {}
self._by_plugin: Dict[str, List[APIEntry]] = {}
self._by_name: Dict[str, List[APIEntry]] = {}
def clear(self) -> None:
"""清空全部 API 注册状态。"""
self._apis.clear()
self._by_full_name.clear()
self._by_plugin.clear()
self._by_name.clear()
@staticmethod
def _is_api_component(component_type: Any) -> bool:
"""判断组件声明是否属于 API。"""
return str(component_type or "").strip().upper() == "API"
@staticmethod
def _normalize_query_version(version: Any) -> str:
"""规范化查询使用的版本字符串。"""
return str(version or "").strip()
@classmethod
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
"""解析可能带 ``@version`` 后缀的 API 引用。"""
normalized_reference = str(reference or "").strip()
normalized_version = cls._normalize_query_version(version)
if normalized_reference and not normalized_version and "@" in normalized_reference:
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
candidate_reference = candidate_reference.strip()
candidate_version = candidate_version.strip()
if candidate_reference and candidate_version:
normalized_reference = candidate_reference
normalized_version = candidate_version
return normalized_reference, normalized_version
@staticmethod
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
"""构造 API 注册表唯一键。"""
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
normalized_version = str(version or "1").strip() or "1"
return f"{normalized_full_name}@{normalized_version}"
@staticmethod
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
"""判断 API 条目当前是否处于启用状态。"""
if session_id and session_id in entry.disabled_session:
return False
return entry.enabled
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个 API 条目。"""
normalized_name = str(name or "").strip()
if not normalized_name:
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
return False
entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
existing_entry = self._apis.get(entry.registry_key)
if existing_entry is not None:
logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
self._remove_entry(existing_entry)
self._apis[entry.registry_key] = entry
self._by_full_name.setdefault(entry.full_name, []).append(entry)
self._by_plugin.setdefault(plugin_id, []).append(entry)
self._by_name.setdefault(entry.name, []).append(entry)
return True
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册某个插件声明的全部 API。"""
count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
if self.register_api(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
):
count += 1
return count
def replace_plugin_dynamic_apis(
self,
plugin_id: str,
components: List[Dict[str, Any]],
*,
offline_reason: str = "动态 API 已下线",
) -> Tuple[int, int]:
"""替换指定插件当前声明的动态 API 集合。"""
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
desired_registry_keys: Set[str] = set()
registered_count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
dynamic_metadata = dict(metadata)
dynamic_metadata["dynamic"] = True
dynamic_metadata.pop("offline_reason", None)
entry = APIEntry.from_metadata(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=dynamic_metadata,
)
desired_registry_keys.add(entry.registry_key)
if self.register_api(entry.name, plugin_id, dynamic_metadata):
registered_count += 1
offlined_count = 0
for entry in list(self._by_plugin.get(plugin_id, [])):
if not entry.dynamic or entry.registry_key in desired_registry_keys:
continue
entry.enabled = False
entry.offline_reason = normalized_offline_reason
entry.metadata["offline_reason"] = normalized_offline_reason
offlined_count += 1
return registered_count, offlined_count
def _remove_entry(self, entry: APIEntry) -> None:
"""从全部索引中移除单个 API 条目。"""
self._apis.pop(entry.registry_key, None)
full_name_entries = self._by_full_name.get(entry.full_name)
if full_name_entries is not None:
self._by_full_name[entry.full_name] = [
candidate for candidate in full_name_entries if candidate is not entry
]
if not self._by_full_name[entry.full_name]:
self._by_full_name.pop(entry.full_name, None)
plugin_entries = self._by_plugin.get(entry.plugin_id)
if plugin_entries is not None:
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
if not self._by_plugin[entry.plugin_id]:
self._by_plugin.pop(entry.plugin_id, None)
name_entries = self._by_name.get(entry.name)
if name_entries is not None:
self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
if not self._by_name[entry.name]:
self._by_name.pop(entry.name, None)
def remove_apis_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的全部 API。"""
entries = list(self._by_plugin.get(plugin_id, []))
for entry in entries:
self._remove_entry(entry)
return len(entries)
def get_api_by_full_name(
self,
full_name: str,
*,
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按完整名查询单个 API。"""
normalized_full_name, normalized_version = self._split_reference(full_name, version)
if not normalized_full_name:
return None
if normalized_version:
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
if entry is None:
return None
if enabled_only and not self.check_api_enabled(entry, session_id):
return None
return entry
candidates = list(self._by_full_name.get(normalized_full_name, []))
filtered_entries = [
entry
for entry in candidates
if not enabled_only or self.check_api_enabled(entry, session_id)
]
if len(filtered_entries) != 1:
return None
return filtered_entries[0]
def get_api(
self,
plugin_id: str,
name: str,
*,
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按插件 ID、短名与版本查询单个 API。"""
return self.get_api_by_full_name(
f"{plugin_id}.{name}",
version=version,
enabled_only=enabled_only,
session_id=session_id,
)
def get_apis(
self,
*,
plugin_id: Optional[str] = None,
name: str = "",
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[APIEntry]:
"""查询 API 列表。"""
normalized_name = str(name or "").strip()
normalized_version = self._normalize_query_version(version)
if plugin_id:
candidates = list(self._by_plugin.get(plugin_id, []))
elif normalized_name:
candidates = list(self._by_name.get(normalized_name, []))
else:
candidates = list(self._apis.values())
filtered_entries: List[APIEntry] = []
for entry in candidates:
if plugin_id and entry.plugin_id != plugin_id:
continue
if normalized_name and entry.name != normalized_name:
continue
if normalized_version and entry.version != normalized_version:
continue
if enabled_only and not self.check_api_enabled(entry, session_id):
continue
filtered_entries.append(entry)
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
return filtered_entries
def toggle_api_status(
self,
full_name: str,
enabled: bool,
*,
version: str = "",
session_id: Optional[str] = None,
) -> bool:
"""设置指定 API 的启用状态。"""
entry = self.get_api_by_full_name(
full_name,
version=version,
enabled_only=False,
session_id=session_id,
)
if entry is None:
return False
if session_id:
if enabled:
entry.disabled_session.discard(session_id)
else:
entry.disabled_session.add(session_id)
else:
entry.enabled = enabled
if enabled:
entry.offline_reason = ""
entry.metadata.pop("offline_reason", None)
return True

View File

@@ -0,0 +1,67 @@
"""授权管理器
负责管理插件的能力授权以及校验
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
@dataclass
class CapabilityPermissionToken:
"""能力令牌"""
plugin_id: str
capabilities: Set[str] = field(default_factory=set)
class AuthorizationManager:
"""授权管理器
管理所有插件的能力令牌,提供授权校验。
"""
def __init__(self) -> None:
self._permission_tokens: Dict[str, CapabilityPermissionToken] = {}
def register_plugin(self, plugin_id: str, capabilities: List[str]) -> CapabilityPermissionToken:
"""为插件签发能力令牌"""
token = CapabilityPermissionToken(plugin_id=plugin_id, capabilities=set(capabilities))
self._permission_tokens[plugin_id] = token
return token
def revoke_permission_token(self, plugin_id: str):
"""移除插件的能力令牌。"""
self._permission_tokens.pop(plugin_id, None)
def clear(self) -> None:
"""清空所有能力令牌。"""
self._permission_tokens.clear()
def check_capability(self, plugin_id: str, capability: str) -> Tuple[bool, str]:
# sourcery skip: assign-if-exp, reintroduce-else, swap-if-else-branches, use-named-expression
"""检查插件是否有权调用某项能力
Returns:
return (bool, str): (是否有此能力, 原因)
"""
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
return True, ""
token = self._permission_tokens.get(plugin_id)
if not token:
return False, f"插件 {plugin_id} 未注册能力令牌"
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
return True, ""
def get_token(self, plugin_id: str) -> Optional[CapabilityPermissionToken]:
"""获取插件的能力令牌"""
return self._permission_tokens.get(plugin_id)
def list_plugins(self) -> List[str]:
"""列出所有已注册的插件"""
return list(self._permission_tokens.keys())

View File

@@ -4,21 +4,19 @@ Host 端实现的能力服务,处理来自插件的 cap.* 请求。
每个能力方法被注册到 RPC Server接收 Runner 转发的请求并执行实际操作。
"""
from typing import Any, Awaitable, Callable, Dict, List
from typing import Any, Callable, Dict, List, Coroutine, TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_runtime.host.policy_engine import PolicyEngine
from src.plugin_runtime.protocol.envelope import (
CapabilityRequestPayload,
CapabilityResponsePayload,
Envelope,
)
from src.plugin_runtime.protocol.envelope import CapabilityRequestPayload, CapabilityResponsePayload, Envelope
from src.plugin_runtime.protocol.errors import ErrorCode, RPCError
if TYPE_CHECKING:
from src.plugin_runtime.host.authorization import AuthorizationManager
logger = get_logger("plugin_runtime.host.capability_service")
# 能力实现函数类型: (plugin_id, capability, args) -> result
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Awaitable[Any]]
CapabilityImpl = Callable[[str, str, Dict[str, Any]], Coroutine[Any, Any, Any]]
class CapabilityService:
@@ -31,8 +29,13 @@ class CapabilityService:
4. 执行实际操作并返回结果
"""
def __init__(self, policy_engine: PolicyEngine) -> None:
self._policy = policy_engine
def __init__(self, authorization: "AuthorizationManager") -> None:
"""初始化能力服务。
Args:
authorization: 能力授权管理器。
"""
self._authorization = authorization
# capability_name -> implementation
self._implementations: Dict[str, CapabilityImpl] = {}
@@ -56,46 +59,32 @@ class CapabilityService:
try:
req = CapabilityRequestPayload.model_validate(envelope.payload)
except Exception as e:
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
f"能力调用 payload 格式错误: {e}",
)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, f"能力调用 payload 非法: {exc}")
capability = req.capability
args = req.args
# 1. 权限校验
allowed, reason = self._policy.check_capability(plugin_id, capability, envelope.generation)
allowed, reason = self._authorization.check_capability(plugin_id, capability)
if not allowed:
error_code = (
ErrorCode.E_GENERATION_MISMATCH if "generation 不匹配" in reason else ErrorCode.E_CAPABILITY_DENIED
)
return envelope.make_error_response(
error_code.value,
reason,
)
return envelope.make_error_response(ErrorCode.E_CAPABILITY_DENIED.value, reason)
# 2. 查找实现
impl = self._implementations.get(capability)
if impl is None:
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的能力: {capability}",
)
return envelope.make_error_response(ErrorCode.E_METHOD_NOT_ALLOWED.value, f"未注册的能力: {capability}")
# 3. 执行
try:
result = await impl(plugin_id, capability, req.args)
result = await impl(plugin_id, capability, args)
resp_payload = CapabilityResponsePayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except RPCError as e:
return envelope.make_error_response(e.code.value, e.message, e.details)
except Exception as e:
logger.error(f"能力 {capability} 执行异常: {e}", exc_info=True)
return envelope.make_error_response(
ErrorCode.E_CAPABILITY_FAILED.value,
str(e),
)
return envelope.make_error_response(ErrorCode.E_CAPABILITY_FAILED.value, str(e))
def list_capabilities(self) -> List[str]:
"""列出所有已注册的能力"""

View File

@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_step
- 按类型注册组件action / command / tool / event_handler / workflow_handler / message_gateway
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -9,8 +9,10 @@
- 注册统计
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
import contextlib
import re
from src.common.logger import get_logger
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
class RegisteredComponent:
"""已注册的组件条目"""
class ComponentTypes(str, Enum):
ACTION = "ACTION"
COMMAND = "COMMAND"
TOOL = "TOOL"
EVENT_HANDLER = "EVENT_HANDLER"
HOOK_HANDLER = "HOOK_HANDLER"
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
class StatusDict(TypedDict):
total: int
action: int
command: int
tool: int
event_handler: int
hook_handler: int
message_gateway: int
plugins: int
class ComponentEntry:
"""组件条目"""
__slots__ = (
"name",
@@ -28,31 +50,120 @@ class RegisteredComponent:
"plugin_id",
"metadata",
"enabled",
"_compiled_pattern",
"compiled_pattern",
"disabled_session",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> None:
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
self.plugin_id = plugin_id
self.metadata = metadata
self.enabled = metadata.get("enabled", True)
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.component_type: ComponentTypes = ComponentTypes(component_type)
self.plugin_id: str = plugin_id
self.metadata: Dict[str, Any] = metadata
self.enabled: bool = metadata.get("enabled", True)
self.disabled_session: Set[str] = set()
# 预编译命令正则(仅 command 类型)
self._compiled_pattern: Optional[re.Pattern] = None
if component_type == "command":
if pattern := metadata.get("command_pattern", ""):
try:
self._compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ActionEntry(ComponentEntry):
"""Action 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
class CommandEntry(ComponentEntry):
"""Command 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
self.aliases: List[str] = metadata.get("aliases", [])
self.compiled_pattern: Optional[re.Pattern] = None
if pattern := metadata.get("command_pattern", ""):
try:
self.compiled_pattern = re.compile(pattern)
except (re.error, TypeError) as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ToolEntry(ComponentEntry):
"""Tool 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.description: str = metadata.get("description", "")
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
super().__init__(name, component_type, plugin_id, metadata)
class EventHandlerEntry(ComponentEntry):
"""EventHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
self.intercept_message: bool = metadata.get("intercept_message", False)
super().__init__(name, component_type, plugin_id, metadata)
class HookHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "")
self.priority: int = metadata.get("priority", 0)
self.blocking: bool = metadata.get("blocking", False)
super().__init__(name, component_type, plugin_id, metadata)
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
self.platform: str = str(metadata.get("platform", "") or "").strip()
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
self.scope: str = str(metadata.get("scope", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata)
@staticmethod
def _normalize_route_type(raw_value: Any) -> str:
"""规范化消息网关路由类型。
Args:
raw_value: 原始路由类型值。
Returns:
str: 规范化后的路由类型。
Raises:
ValueError: 当路由类型不受支持时抛出。
"""
normalized_value = str(raw_value or "").strip().lower()
route_type_aliases = {
"send": "send",
"receive": "receive",
"recv": "receive",
"recive": "receive",
"duplex": "duplex",
}
route_type = route_type_aliases.get(normalized_value)
if route_type is None:
raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
return route_type
@property
def supports_send(self) -> bool:
"""返回当前网关是否支持出站。"""
return self.route_type in {"send", "duplex"}
@property
def supports_receive(self) -> bool:
"""返回当前网关是否支持入站。"""
return self.route_type in {"receive", "duplex"}
class ComponentRegistry:
@@ -64,19 +175,32 @@ class ComponentRegistry:
def __init__(self) -> None:
# 全量索引
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
# 按类型索引
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
"action": {},
"command": {},
"tool": {},
"event_handler": {},
"workflow_step": {},
}
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
comp_type: {} for comp_type in ComponentTypes
} # component_type -> (full_name -> comp)
# 按插件索引
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
@staticmethod
def _normalize_component_type(component_type: str) -> ComponentTypes:
"""规范化组件类型输入。
Args:
component_type: 原始组件类型字符串。
Returns:
ComponentTypes: 规范化后的组件类型枚举。
Raises:
ValueError: 当组件类型不受支持时抛出。
"""
normalized_value = str(component_type or "").strip().upper()
return ComponentTypes(normalized_value)
def clear(self) -> None:
"""清空全部组件注册状态。"""
@@ -85,47 +209,64 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
# ──── 注册 / 注销 ─────────────────────────────────────────
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件
Args:
name: 组件名称不含插件id前缀
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
plugin_id: 插件id
metadata: 组件元数据
Returns:
success (bool): 是否成功注册(失败原因通常是组件类型无效)
"""
try:
normalized_type = self._normalize_component_type(component_type)
if normalized_type == ComponentTypes.ACTION:
comp = ActionEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.COMMAND:
comp = CommandEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.TOOL:
comp = ToolEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.EVENT_HANDLER:
comp = EventHandlerEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.HOOK_HANDLER:
comp = HookHandlerEntry(name, normalized_type.value, plugin_id, metadata)
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
comp = MessageGatewayEntry(name, normalized_type.value, plugin_id, metadata)
else:
raise ValueError(f"组件类型 {component_type} 不存在")
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
return False
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> bool:
"""注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
try:
with contextlib.suppress(ValueError):
old_list.remove(old_comp)
except ValueError:
pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp
if component_type not in self._by_type:
self._by_type[component_type] = {}
self._by_type[component_type][comp.full_name] = comp
self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
def register_plugin_components(
self,
plugin_id: str,
components: List[Dict[str, Any]],
) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。"""
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。
Args:
plugin_id (str): 插件id
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
Returns:
count (int): 成功注册的组件数量
"""
count = 0
for comp_data in components:
ok = self.register_component(
@@ -139,7 +280,13 @@ class ComponentRegistry:
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。"""
"""移除某个插件的所有组件,返回移除数量。
Args:
plugin_id (str): 插件id
Returns:
count (int): 移除的组件数量
"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
@@ -147,106 +294,280 @@ class ComponentRegistry:
type_dict.pop(comp.full_name, None)
return len(comps)
# ──── 启用 / 禁用 ─────────────────────────────────────────
# ====== 启用 / 禁用 ======
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
if session_id and session_id in component.disabled_session:
return False
return component.enabled
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
"""启用或禁用指定组件。"""
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""启用或禁用指定组件。
Args:
full_name (str): 组件全名
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
success (bool): 是否成功设置(失败原因通常是组件不存在)
"""
comp = self._components.get(full_name)
if comp is None:
return False
comp.enabled = enabled
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return True
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
"""批量启用或禁用某插件的所有组件。"""
def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""设置指定组件的启用状态。
Args:
full_name: 组件全名。
enabled: 目标启用状态。
session_id: 可选的会话 ID仅对该会话生效。
Returns:
bool: 是否设置成功。
"""
return self.toggle_component_status(full_name, enabled, session_id=session_id)
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
"""批量启用或禁用某插件的所有组件。
Args:
plugin_id (str): 插件id
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
comp.enabled = enabled
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return len(comps)
# ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
"""按全名查询。
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
"""按全名查询。"""
Args:
full_name (str): 组件全名
Returns:
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
"""
return self._components.get(full_name)
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按类型查询组件
Args:
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
try:
comp_type = self._normalize_component_type(component_type)
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
return list(type_dict.values())
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps)
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按插件查询组件。
def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
Args:
plugin_id (str): 插件ID
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
def find_command_by_text(
self, text: str, session_id: Optional[str] = None
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict别名匹配时为空 dict。
Args:
text (str): 待匹配文本
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
"""
for comp in self._by_type.get("command", {}).values():
if not comp.enabled:
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
if not self.check_component_enabled(comp, session_id):
continue
if comp._compiled_pattern:
m = comp._compiled_pattern.search(text)
if m:
if not isinstance(comp, CommandEntry):
continue
if comp.compiled_pattern:
if m := comp.compiled_pattern.search(text):
return comp, m.groupdict()
# 别名匹配
aliases = comp.metadata.get("aliases", [])
for alias in aliases:
for alias in comp.aliases:
if text.startswith(alias):
return comp, {}
return None
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
if enabled_only and not comp.enabled:
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[EventHandlerEntry]:
"""查询指定事件类型的事件处理器组件。
Args:
event_type (str): 事件类型
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
"""
handlers: List[EventHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if comp.metadata.get("event_type") == event_type:
if not isinstance(comp, EventHandlerEntry):
continue
if comp.event_type == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():
if enabled_only and not comp.enabled:
def get_hook_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[HookHandlerEntry]:
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
Args:
stage: hook 名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
"""
handlers: List[HookHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if comp.metadata.get("stage") == stage:
steps.append(comp)
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps
if not isinstance(comp, HookHandlerEntry):
continue
if comp.stage == stage:
handlers.append(comp)
handlers.sort(key=lambda c: c.priority, reverse=True)
return handlers
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result: List[Dict[str, Any]] = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: Dict[str, Any] = {
"name": comp.full_name,
"description": comp.metadata.get("description", ""),
}
# 从结构化参数或原始参数构建 parameters
params = comp.metadata.get("parameters", [])
params_raw = comp.metadata.get("parameters_raw", {})
if params:
tool_def["parameters"] = params
elif params_raw:
tool_def["parameters"] = params_raw
result.append(tool_def)
return result
def get_message_gateway(
self,
plugin_id: str,
name: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[MessageGatewayEntry]:
"""按插件和组件名获取单个消息网关。
# ──── 统计 ─────────────────────────────────────────────────
Args:
plugin_id: 插件 ID。
name: 网关组件名称。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID。
def get_stats(self) -> Dict[str, int]:
"""获取注册统计。"""
stats: Dict[str, int] = {"total": len(self._components)}
Returns:
Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
"""
component = self._components.get(f"{plugin_id}.{name}")
if not isinstance(component, MessageGatewayEntry):
return None
if enabled_only and not self.check_component_enabled(component, session_id):
return None
return component
def get_message_gateways(
self,
*,
plugin_id: Optional[str] = None,
platform: str = "",
route_type: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[MessageGatewayEntry]:
"""查询消息网关组件列表。
Args:
plugin_id: 可选的插件 ID 过滤条件。
platform: 可选的平台过滤条件。
route_type: 可选的路由类型过滤条件。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID。
Returns:
List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
"""
normalized_platform = str(platform or "").strip()
normalized_route_type = str(route_type or "").strip().lower()
gateways: List[MessageGatewayEntry] = []
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
if not isinstance(comp, MessageGatewayEntry):
continue
if plugin_id and comp.plugin_id != plugin_id:
continue
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if normalized_platform and comp.platform != normalized_platform:
continue
if normalized_route_type and comp.route_type != normalized_route_type:
continue
gateways.append(comp)
return gateways
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
Args:
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
"""
tools: List[ToolEntry] = []
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if isinstance(comp, ToolEntry):
tools.append(comp)
return tools
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict)
stats[comp_type.value.lower()] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats

View File

@@ -4,40 +4,40 @@
1. 按事件类型查询已注册的 event_handler通过 ComponentRegistry
2. 按 weight 排序,依次通过 RPC 调用 Runner 中的处理器
3. 支持阻塞intercept_message和非阻塞分发
4. 事件结果历史记录
4. 事件结果历史记录(有上限)
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
import asyncio
from src.common.logger import get_logger
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
from .message_utils import PluginMessageUtils, MessageDict
if TYPE_CHECKING:
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, EventHandlerEntry
from src.chat.message_receive.message import SessionMessage
logger = get_logger("plugin_runtime.host.event_dispatcher")
# invoke_fn 类型: async (plugin_id, component_name, args) -> response_payload dict
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
# 每个事件类型的最大历史记录数量,防止内存无限增长
_MAX_HISTORY_LENGTH = 100
@dataclass
class EventResult:
"""单个 EventHandler 的执行结果"""
__slots__ = ("handler_name", "success", "continue_processing", "modified_message", "custom_result")
def __init__(
self,
handler_name: str,
success: bool = True,
continue_processing: bool = True,
modified_message: Optional[Dict[str, Any]] = None,
custom_result: Any = None,
):
self.handler_name = handler_name
self.success = success
self.continue_processing = continue_processing
self.modified_message = modified_message
self.custom_result = custom_result
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
modified_message: Optional[MessageDict] = field(default=None)
custom_result: Any = field(default=None)
class EventDispatcher:
@@ -48,17 +48,20 @@ class EventDispatcher:
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry: ComponentRegistry = registry
def __init__(self, component_registry: "ComponentRegistry") -> None:
self._component_registry: "ComponentRegistry" = component_registry
self._result_history: Dict[str, List[EventResult]] = {}
self._history_enabled: Set[str] = set()
# 保持 fire-and-forget task 的强引用,防止被 GC 回收
self._background_tasks: Set[asyncio.Task] = set()
def enable_history(self, event_type: str) -> None:
self._history_enabled.add(event_type)
self._result_history.setdefault(event_type, [])
def disable_history(self, event_type: str) -> None:
self._history_enabled.discard(event_type)
self._result_history.pop(event_type, None)
def get_history(self, event_type: str) -> List[EventResult]:
return self._result_history.get(event_type, [])
@@ -66,47 +69,58 @@ class EventDispatcher:
if event_type in self._result_history:
self._result_history[event_type] = []
async def stop(self):
"""停止 EventDispatcher取消所有未完成的后台任务"""
for task in self._background_tasks:
task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
async def dispatch_event(
self,
event_type: str,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
supervisor: "PluginRunnerSupervisor",
message: Optional["SessionMessage"] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""分发事件到所有对应 handler。
) -> Tuple[bool, Optional["SessionMessage"]]:
"""分发事件到所有对应 handler 的便捷方法
内置了通过 PluginSupervisor.invoke_plugin 调用 plugin.emit_event 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
Args:
event_type: 事件类型字符串
invoke_fn: 异步回调,签名 (plugin_id, component_name, args) -> response_payload dict
supervisor: PluginSupervisor 实例,用于调用 invoke_plugin
message: MaiMessages 序列化后的 dict可选
extra_args: 额外参数
Returns:
(should_continue, modified_message_dict)
(should_continue, modified_message_dict) (bool, SessionMessage | None): (是否继续后续执行, 可选的修改后的消息)
"""
handlers = self._registry.get_event_handlers(event_type)
if not handlers:
handler_entries = self._component_registry.get_event_handlers(event_type)
if not handler_entries:
return True, None
should_continue = True
modified_message: Optional[Dict[str, Any]] = None
intercept_handlers: List[RegisteredComponent] = []
async_handlers: List[RegisteredComponent] = []
modified_message: Optional[MessageDict] = (
PluginMessageUtils._session_message_to_dict(message) if message else None
)
intercept_handlers: List["EventHandlerEntry"] = []
non_blocking_handlers: List["EventHandlerEntry"] = []
for handler in handlers:
if handler.metadata.get("intercept_message", False):
intercept_handlers.append(handler)
for entry in handler_entries:
if entry.intercept_message:
intercept_handlers.append(entry)
else:
async_handlers.append(handler)
non_blocking_handlers.append(entry)
for handler in intercept_handlers:
for entry in intercept_handlers:
args = {
"event_type": event_type,
"message": modified_message or message,
"message": modified_message,
**(extra_args or {}),
}
result = await self._invoke_handler(invoke_fn, handler, args, event_type)
result = await self._invoke_handler(supervisor, entry, args, event_type)
if result and not result.continue_processing:
should_continue = False
break
@@ -114,47 +128,57 @@ class EventDispatcher:
modified_message = result.modified_message
if should_continue:
final_message = modified_message or message
for handler in async_handlers:
async_message = final_message.copy() if isinstance(final_message, dict) else final_message
final_message = modified_message
for entry in non_blocking_handlers:
async_message = final_message.copy() if final_message else final_message
args = {
"event_type": event_type,
"message": async_message,
**(extra_args or {}),
}
# 非阻塞:保持实例级强引用,防止 task 被 GC 回收
task = asyncio.create_task(self._invoke_handler(invoke_fn, handler, args, event_type))
task = asyncio.create_task(self._invoke_handler(supervisor, entry, args, event_type))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return should_continue, modified_message
try:
modified_message_obj = (
PluginMessageUtils._build_session_message_from_dict(modified_message) if modified_message else None # type: ignore
)
except Exception as e:
logger.error(f"构建修改后的 SessionMessage 失败: {e}")
modified_message_obj = None
return should_continue, modified_message_obj
async def _invoke_handler(
self,
invoke_fn: InvokeFn,
handler: RegisteredComponent,
supervisor: "PluginRunnerSupervisor",
handler_entry: "EventHandlerEntry",
args: Dict[str, Any],
event_type: str,
) -> Optional[EventResult]:
"""调用单个 handler 并收集结果。"""
try:
resp = await invoke_fn(handler.plugin_id, handler.name, args)
resp_envelope = await supervisor.invoke_plugin(
"plugin.emit_event", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = EventResult(
handler_name=handler.full_name,
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_message=resp.get("modified_message"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"EventHandler {handler.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(
handler_name=handler.full_name,
success=False,
continue_processing=True,
)
logger.error(f"EventHandler {handler_entry.full_name} 执行失败: {e}", exc_info=True)
result = EventResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
if event_type in self._history_enabled:
self._result_history.setdefault(event_type, []).append(result)
history_list = self._result_history.setdefault(event_type, [])
history_list.append(result)
# 自动清理超出限制的旧记录,防止内存无限增长
if len(history_list) > _MAX_HISTORY_LENGTH:
# 保留最新的 _MAX_HISTORY_LENGTH 条记录
self._result_history[event_type] = history_list[-_MAX_HISTORY_LENGTH:]
return result

View File

@@ -0,0 +1,166 @@
"""
Hook Dispatch 系统
插件可以注册自己的Hook当特定函数被调用时Hook Dispatch系统会将调用转发给插件的Hook处理函数。
每个Hook的参数随Hook点位确定因此参数是易变的。插件开发者需要根据Hook点位的定义来编写Hook处理函数。
在参数/返回值匹配的情况下允许修改参数/返回值。
HookDispatcher 负责:
1. 按 stage 查询已注册的 hook_handler通过 ComponentRegistry
2. 按 priority 排序,区分 blocking 和非 blocking 模式
3. blocking 模式:依次同步调用,支持修改参数/提前终止
4. 非 blocking 模式:异步调用,不阻塞主流程
5. 支持通过 global_config.plugin_runtime.hook_blocking_timeout_sec 设置超时上限
"""
import asyncio
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
from src.common.logger import get_logger
from src.config.config import global_config
if TYPE_CHECKING:
from .supervisor import PluginRunnerSupervisor
from .component_registry import ComponentRegistry, HookHandlerEntry
logger = get_logger("plugin_runtime.host.hook_dispatcher")
@dataclass
class HookResult:
"""单个 HookHandler 的执行结果"""
handler_name: str
success: bool = field(default=True)
continue_processing: bool = field(default=True)
modified_kwargs: Optional[Dict[str, Any]] = field(default=None)
custom_result: Any = field(default=None)
class HookDispatcher:
"""Host-side Hook 分发器
由业务层调用 hook_dispatch()
内部通过 ComponentRegistry 查询 handler
再通过提供的 invoke_fn 回调 RPC 到 Runner 执行。
"""
def __init__(self, component_registry: "ComponentRegistry") -> None:
"""初始化 HookDispatcher
Args:
component_registry: ComponentRegistry 实例,用于查询已注册的 hook_handler
"""
self._component_registry: "ComponentRegistry" = component_registry
self._background_tasks: Set[asyncio.Task] = set()
async def stop(self) -> None:
"""停止 HookDispatcher取消所有未完成的后台任务"""
for task in self._background_tasks:
task.cancel()
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
async def hook_dispatch(
self,
stage: str,
supervisor: "PluginRunnerSupervisor",
**kwargs: Any,
) -> Dict[str, Any]:
"""分发 hook 到所有对应 handler 的便捷方法。
内置了通过 PluginRunnerSupervisor.invoke_plugin 调用 plugin 的逻辑,
无需调用方手动构造 invoke_fn 闭包。
Args:
stage: hook 名称
supervisor: PluginRunnerSupervisor 实例,用于调用 invoke_plugin
**kwargs: 关键字参数,会展开传递给 handler
Returns:
modified_kwargs (Dict[str, Any]): 经过所有 handler 修改后的关键字参数
"""
handler_entries = self._component_registry.get_hook_handlers(stage)
if not handler_entries:
return kwargs
current_kwargs = kwargs.copy()
blocking_handlers: List["HookHandlerEntry"] = []
non_blocking_handlers: List["HookHandlerEntry"] = []
# 分离 blocking 和非 blocking handler
for entry in handler_entries:
if entry.blocking:
blocking_handlers.append(entry)
else:
non_blocking_handlers.append(entry)
# 处理 blocking handlers同步调用支持修改参数/提前终止)
timeout = global_config.plugin_runtime.hook_blocking_timeout_sec or 30.0
for entry in blocking_handlers:
hook_args = {"stage": stage, **current_kwargs}
try:
# 应用超时控制
result = await asyncio.wait_for(
self._invoke_handler(supervisor, entry, hook_args),
timeout=timeout,
)
except asyncio.TimeoutError:
logger.error(f"Blocking HookHandler {entry.full_name} 执行超时 (>{timeout}秒),跳过")
result = HookResult(handler_name=entry.full_name, success=False, continue_processing=True)
if result:
if result.modified_kwargs is not None:
current_kwargs = result.modified_kwargs
if not result.continue_processing:
logger.info(f"HookHandler {entry.full_name} 终止了后续处理")
break
# 处理 non-blocking handlers异步调用不阻塞主流程
for entry in non_blocking_handlers:
async_kwargs = current_kwargs.copy()
hook_args = {"stage": stage, **async_kwargs}
task = asyncio.create_task(
asyncio.wait_for(self._invoke_handler(supervisor, entry, hook_args), timeout=timeout)
)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return current_kwargs
async def _invoke_handler(
self,
supervisor: "PluginRunnerSupervisor",
handler_entry: "HookHandlerEntry",
args: Dict[str, Any],
) -> Optional[HookResult]:
"""调用单个 handler 并收集结果。
Args:
supervisor: PluginRunnerSupervisor 实例
handler_entry: HookHandlerEntry 实例
args: 传递给 handler 的参数字典
stage: hook 名称
Returns:
Optional[HookResult]: 执行结果,如果执行失败则返回 None
"""
try:
resp_envelope = await supervisor.invoke_plugin(
"plugin.invoke_hook", handler_entry.plugin_id, handler_entry.name, args
)
resp = resp_envelope.payload
result = HookResult(
handler_name=handler_entry.full_name,
success=resp.get("success", True),
continue_processing=resp.get("continue_processing", True),
modified_kwargs=resp.get("modified_kwargs"),
custom_result=resp.get("custom_result"),
)
except Exception as e:
logger.error(f"HookHandler {handler_entry.full_name} 执行失败:{e}", exc_info=True)
result = HookResult(handler_name=handler_entry.full_name, success=False, continue_processing=True)
return result

View File

@@ -0,0 +1,45 @@
import logging as stdlib_logging
from src.plugin_runtime.protocol.errors import ErrorCode
from src.plugin_runtime.protocol.envelope import Envelope, LogBatchPayload
class RunnerLogBridge:
"""将 Runner 进程上报的批量日志重放到主进程的 Logger 中。
Runner 通过 ``runner.log_batch`` IPC 事件批量到达。
每条 LogEntry 被重建为一个真实的 :class:`logging.LogRecord` 并直接
调用 ``logging.getLogger(entry.logger_name).handle(record)``
从而接入主进程已配置好的 structlog Handler 链。
"""
async def handle_log_batch(self, envelope: Envelope) -> Envelope:
"""IPC 事件处理器:解析批量日志并重放到主进程 Logger。
Args:
envelope: 方法名为 ``runner.log_batch`` 的 IPC 事件信封。
Returns:
空响应信封(事件模式下将被忽略)。
"""
try:
batch = LogBatchPayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
for entry in batch.entries:
# 重建一个与原始日志尽量相符的 LogRecord
record = stdlib_logging.LogRecord(
name=entry.logger_name,
level=entry.level,
pathname="<runner>",
lineno=0,
msg=entry.message,
args=(),
exc_info=None,
)
record.created = entry.timestamp_ms / 1000.0
record.msecs = entry.timestamp_ms % 1000
if entry.exception_text:
record.exc_text = entry.exception_text
stdlib_logging.getLogger(entry.logger_name).handle(record)
return envelope.make_response(payload={"accepted": True, "count": len(batch.entries)})

View File

@@ -0,0 +1,112 @@
"""Host 侧消息网关包装器。"""
from typing import TYPE_CHECKING, Any, Dict
from src.common.logger import get_logger
from src.platform_io import get_platform_io_manager
from .message_utils import PluginMessageUtils
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from .component_registry import ComponentRegistry
from .supervisor import PluginRunnerSupervisor
logger = get_logger("plugin_runtime.host.message_gateway")
class MessageGateway:
"""Host 侧消息网关包装器。"""
def __init__(self, component_registry: "ComponentRegistry") -> None:
"""初始化消息网关。
Args:
component_registry: 组件注册表。
"""
self._component_registry = component_registry
def build_session_message(self, external_message: Dict[str, Any]) -> "SessionMessage":
"""将标准消息字典转换为 ``SessionMessage``。
Args:
external_message: 外部消息的字典格式数据。
Returns:
SessionMessage: 转换后的内部消息对象。
Raises:
ValueError: 消息字典不合法时抛出。
"""
return PluginMessageUtils._build_session_message_from_dict(external_message)
def build_message_dict(self, internal_message: "SessionMessage") -> Dict[str, Any]:
"""将 ``SessionMessage`` 转换为标准消息字典。
Args:
internal_message: 内部消息对象。
Returns:
Dict[str, Any]: 供消息网关插件消费的标准消息字典。
"""
return dict(PluginMessageUtils._session_message_to_dict(internal_message))
async def receive_external_message(self, external_message: Dict[str, Any]) -> None:
"""接收外部消息并送入主消息链。
Args:
external_message: 外部消息的字典格式数据。
"""
try:
session_message = self.build_session_message(external_message)
except Exception as e:
logger.error(f"转换外部消息失败: {e}")
return
from src.chat.message_receive.bot import chat_bot
await chat_bot.receive_message(session_message)
async def send_message_to_external(
self,
internal_message: "SessionMessage",
supervisor: "PluginRunnerSupervisor",
*,
enabled_only: bool = True,
save_to_db: bool = True,
) -> bool:
"""将内部消息通过 Platform IO 发送到外部平台。
Args:
internal_message: 系统内部的 ``SessionMessage`` 对象。
supervisor: 当前持有该消息网关的 Supervisor。
enabled_only: 兼容旧签名的保留参数,当前未使用。
save_to_db: 发送成功后是否写入数据库。
Returns:
bool: 是否发送成功。
"""
del enabled_only
del supervisor
platform_io_manager = get_platform_io_manager()
if not platform_io_manager.is_started:
logger.warning("Platform IO 尚未启动,无法通过适配器链路发送消息")
return False
route_key = platform_io_manager.build_route_key_from_message(internal_message)
delivery_batch = await platform_io_manager.send_message(internal_message, route_key)
if not delivery_batch.has_success:
logger.warning("通过消息网关链路发送消息失败: 未命中任何成功回执")
return False
first_successful_receipt = delivery_batch.sent_receipts[0]
internal_message.message_id = first_successful_receipt.external_message_id or internal_message.message_id
if save_to_db:
try:
from src.common.utils.utils_message import MessageUtils
MessageUtils.store_message_to_db(internal_message)
except Exception as e:
logger.error(f"保存消息到数据库失败: {e}")
return True

View File

@@ -0,0 +1,487 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, TypedDict
import base64
import hashlib
from src.common.logger import get_logger
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import (
AtComponent,
DictComponent,
EmojiComponent,
ForwardComponent,
ForwardNodeComponent,
ImageComponent,
MessageSequence,
ReplyComponent,
StandardMessageComponents,
TextComponent,
VoiceComponent,
)
logger = get_logger("plugin_runtime.host.message_utils")
class UserInfoDict(TypedDict, total=False):
user_id: str
user_nickname: str
user_cardname: Optional[str]
class GroupInfoDict(TypedDict, total=False):
group_id: str
group_name: str
class MessageInfoDict(TypedDict, total=False):
user_info: UserInfoDict
group_info: Optional[GroupInfoDict]
additional_config: Dict[str, Any]
class MessageDict(TypedDict, total=False):
message_id: str
timestamp: str
platform: str
message_info: MessageInfoDict
raw_message: List[Dict[str, Any]]
is_mentioned: bool
is_at: bool
is_emoji: bool
is_picture: bool
is_command: bool
is_notify: bool
session_id: str
reply_to: Optional[str]
processed_plain_text: Optional[str]
display_message: Optional[str]
class PluginMessageUtils:
@staticmethod
def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
"""将消息组件序列转换为插件运行时使用的字典结构。
Args:
message_sequence: 待转换的消息组件序列。
Returns:
List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
"""
return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
@staticmethod
def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
"""将单个消息组件转换为插件运行时字典结构。
Args:
component: 待转换的消息组件。
Returns:
Dict[str, Any]: 序列化后的消息组件字典。
"""
if isinstance(component, TextComponent):
return {"type": "text", "data": component.text}
if isinstance(component, ImageComponent):
serialized = {
"type": "image",
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
if isinstance(component, EmojiComponent):
serialized = {
"type": "emoji",
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
if isinstance(component, VoiceComponent):
serialized = {
"type": "voice",
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
if isinstance(component, AtComponent):
return {
"type": "at",
"data": {
"target_user_id": component.target_user_id,
"target_user_nickname": component.target_user_nickname,
"target_user_cardname": component.target_user_cardname,
},
}
if isinstance(component, ReplyComponent):
return {
"type": "reply",
"data": {
"target_message_id": component.target_message_id,
"target_message_content": component.target_message_content,
"target_message_sender_id": component.target_message_sender_id,
"target_message_sender_nickname": component.target_message_sender_nickname,
"target_message_sender_cardname": component.target_message_sender_cardname,
},
}
if isinstance(component, ForwardNodeComponent):
return {
"type": "forward",
"data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
}
return {"type": "dict", "data": component.data}
@staticmethod
def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
"""将单个转发节点组件转换为字典结构。
Args:
component: 待转换的转发节点组件。
Returns:
Dict[str, Any]: 序列化后的转发节点字典。
"""
return {
"user_id": component.user_id,
"user_nickname": component.user_nickname,
"user_cardname": component.user_cardname,
"message_id": component.message_id,
"content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
}
@staticmethod
def _message_sequence_from_dict(raw_message_data: List[Dict[str, Any]]) -> MessageSequence:
"""从插件运行时字典结构恢复消息组件序列。
Args:
raw_message_data: 插件运行时消息段字典列表。
Returns:
MessageSequence: 恢复后的消息组件序列。
"""
components = [PluginMessageUtils._component_from_dict(item) for item in raw_message_data]
return MessageSequence(components=components)
@staticmethod
def _component_from_dict(item: Dict[str, Any]) -> StandardMessageComponents:
"""从插件运行时字典结构恢复单个消息组件。
Args:
item: 单个消息组件的字典表示。
Returns:
StandardMessageComponents: 恢复后的内部消息组件对象。
"""
item_type = str(item.get("type") or "").strip()
if item_type == "text":
return TextComponent(text=str(item.get("data") or ""))
if item_type == "image":
return PluginMessageUtils._build_binary_component(ImageComponent, item)
if item_type == "emoji":
return PluginMessageUtils._build_binary_component(EmojiComponent, item)
if item_type == "voice":
return PluginMessageUtils._build_binary_component(VoiceComponent, item)
if item_type == "at":
item_data = item.get("data", {})
if not isinstance(item_data, dict):
item_data = {}
return AtComponent(
target_user_id=str(item_data.get("target_user_id") or ""),
target_user_nickname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_nickname")),
target_user_cardname=PluginMessageUtils._normalize_optional_string(item_data.get("target_user_cardname")),
)
if item_type == "reply":
reply_data = item.get("data")
if isinstance(reply_data, dict):
return ReplyComponent(
target_message_id=str(reply_data.get("target_message_id") or ""),
target_message_content=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_content")
),
target_message_sender_id=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_sender_id")
),
target_message_sender_nickname=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_sender_nickname")
),
target_message_sender_cardname=PluginMessageUtils._normalize_optional_string(
reply_data.get("target_message_sender_cardname")
),
)
return ReplyComponent(target_message_id=str(reply_data or ""))
if item_type == "forward":
forward_nodes: List[ForwardComponent] = []
raw_forward_nodes = item.get("data", [])
if isinstance(raw_forward_nodes, list):
for node in raw_forward_nodes:
if not isinstance(node, dict):
continue
raw_content = node.get("content", [])
node_components: List[StandardMessageComponents] = []
if isinstance(raw_content, list):
node_components = [
PluginMessageUtils._component_from_dict(content)
for content in raw_content
if isinstance(content, dict)
]
if not node_components:
node_components = [TextComponent(text="[empty forward node]")]
forward_nodes.append(
ForwardComponent(
user_nickname=str(node.get("user_nickname") or "未知用户"),
user_id=PluginMessageUtils._normalize_optional_string(node.get("user_id")),
user_cardname=PluginMessageUtils._normalize_optional_string(node.get("user_cardname")),
message_id=str(node.get("message_id") or ""),
content=node_components,
)
)
if not forward_nodes:
return DictComponent(data={"type": "forward", "data": item.get("data", [])})
return ForwardNodeComponent(forward_components=forward_nodes)
component_data = item.get("data")
if isinstance(component_data, dict):
return DictComponent(data=component_data)
return DictComponent(data=item)
@staticmethod
def _build_binary_component(component_cls: Any, item: Dict[str, Any]) -> StandardMessageComponents:
"""从字典构造带二进制负载的消息组件。
Args:
component_cls: 目标组件类型。
item: 消息组件字典。
Returns:
StandardMessageComponents: 构造后的组件对象。
"""
content = str(item.get("data") or "")
binary_hash = str(item.get("hash") or "")
raw_binary_base64 = item.get("binary_data_base64")
binary_data = b""
if isinstance(raw_binary_base64, str) and raw_binary_base64:
try:
binary_data = base64.b64decode(raw_binary_base64)
except Exception:
binary_data = b""
if not binary_hash and binary_data:
binary_hash = hashlib.sha256(binary_data).hexdigest()
return component_cls(binary_hash=binary_hash, content=content, binary_data=binary_data)
@staticmethod
def _normalize_optional_string(value: Any) -> Optional[str]:
"""将任意值规范化为可选字符串。
Args:
value: 待规范化的值。
Returns:
Optional[str]: 规范化后的字符串;若值为空则返回 ``None``。
"""
if value is None:
return None
normalized_value = str(value)
return normalized_value if normalized_value else None
@staticmethod
def _message_info_to_dict(message_info: MessageInfo) -> MessageInfoDict:
"""
将 MessageInfo 对象转换为字典格式
Args:
message_info: MessageInfo 对象
Returns:
字典格式的消息信息
"""
user_info_dict = UserInfoDict(
user_id=message_info.user_info.user_id,
user_nickname=message_info.user_info.user_nickname,
user_cardname=message_info.user_info.user_cardname,
)
group_info_dict: Optional[GroupInfoDict] = None
if message_info.group_info:
group_info_dict = GroupInfoDict(
group_id=message_info.group_info.group_id,
group_name=message_info.group_info.group_name,
)
return MessageInfoDict(
user_info=user_info_dict,
group_info=group_info_dict,
additional_config=message_info.additional_config,
)
@staticmethod
def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
"""
将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
Args:
session_message: SessionMessage 对象
Returns:
字典格式的消息
"""
# 转换基本信息
message_dict = MessageDict(
message_id=session_message.message_id,
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
platform=session_message.platform,
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
is_mentioned=session_message.is_mentioned,
is_at=session_message.is_at,
is_emoji=session_message.is_emoji,
is_picture=session_message.is_picture,
is_command=session_message.is_command,
is_notify=session_message.is_notify,
session_id=session_message.session_id,
)
# 添加可选字段
if session_message.reply_to is not None:
message_dict["reply_to"] = session_message.reply_to
if session_message.processed_plain_text is not None:
message_dict["processed_plain_text"] = session_message.processed_plain_text
if session_message.display_message is not None:
message_dict["display_message"] = session_message.display_message
return message_dict
@staticmethod
def _build_message_info_from_dict(message_info_dict: Dict[str, Any]) -> MessageInfo:
"""
从字典构建 MessageInfo 对象
Args:
message_info_dict: 包含消息信息的字典
Returns:
MessageInfo 对象
"""
# 构建用户信息
user_info_dict = message_info_dict.get("user_info")
if not user_info_dict or not isinstance(user_info_dict, dict):
raise ValueError("消息字典中 'user_info' 字段无效")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
if not isinstance(user_id, str) or not isinstance(user_nickname, str) or not user_id or not user_nickname:
raise ValueError("消息字典中 'user_info' 字段缺少有效的 'user_id''user_nickname'")
user_cardname = str(user_cardname) if user_cardname is not None else None
user_info = UserInfo(user_id=user_id, user_nickname=user_nickname, user_cardname=user_cardname)
# 构建群信息
if group_info_dict := message_info_dict.get("group_info"):
group_id = group_info_dict.get("group_id")
group_name = group_info_dict.get("group_name")
if not isinstance(group_id, str) or not isinstance(group_name, str) or not group_id or not group_name:
raise ValueError("消息字典中 'group_info' 字段缺少有效的 'group_id''group_name'")
group_info = GroupInfo(group_id=group_id, group_name=group_name)
else:
group_info = None
# 获取额外配置
additional_config: Dict[str, Any] = message_info_dict.get("additional_config", {})
return MessageInfo(user_info=user_info, group_info=group_info, additional_config=additional_config)
@staticmethod
def _build_session_message_from_dict(message_dict: Dict[str, Any]) -> SessionMessage:
"""
从字典构建 SessionMessage 对象(递归处理消息组件)
Args:
message_dict: 包含消息完整信息的字典
Returns:
SessionMessage 对象
"""
# 提取基本信息
message_id = message_dict["message_id"]
timestamp_str: str = message_dict.get("timestamp", "")
platform = message_dict["platform"]
if not isinstance(message_id, str) or not message_id:
raise ValueError("消息字典中缺少有效的 'message_id' 字段")
if not isinstance(platform, str) or not platform:
raise ValueError("消息字典中缺少有效的 'platform' 字段")
# 解析时间戳
try:
timestamp_float = float(timestamp_str)
timestamp = datetime.fromtimestamp(timestamp_float)
except (ValueError, TypeError):
timestamp = datetime.now() # 如果解析失败,使用当前时间
# 创建 SessionMessage 实例
session_message = SessionMessage(message_id=message_id, timestamp=timestamp, platform=platform)
# 构建消息信息
session_message.message_info = PluginMessageUtils._build_message_info_from_dict(message_dict["message_info"])
# 构建原始消息组件序列(复用 MessageSequence.from_dict 方法)
raw_message_data = message_dict["raw_message"]
if isinstance(raw_message_data, list):
session_message.raw_message = PluginMessageUtils._message_sequence_from_dict(raw_message_data)
else:
raise ValueError("消息字典中 'raw_message' 字段必须是一个列表")
# 设置其他可选属性
session_message.is_mentioned = message_dict.get("is_mentioned", False)
if not isinstance(session_message.is_mentioned, bool):
session_message.is_mentioned = False
session_message.is_at = message_dict.get("is_at", False)
if not isinstance(session_message.is_at, bool):
session_message.is_at = False
session_message.is_emoji = message_dict.get("is_emoji", False)
if not isinstance(session_message.is_emoji, bool):
session_message.is_emoji = False
session_message.is_picture = message_dict.get("is_picture", False)
if not isinstance(session_message.is_picture, bool):
session_message.is_picture = False
session_message.is_command = message_dict.get("is_command", False)
if not isinstance(session_message.is_command, bool):
session_message.is_command = False
session_message.is_notify = message_dict.get("is_notify", False)
if not isinstance(session_message.is_notify, bool):
session_message.is_notify = False
session_message.session_id = message_dict.get("session_id", "")
if not isinstance(session_message.session_id, str):
session_message.session_id = ""
session_message.reply_to = message_dict.get("reply_to")
if session_message.reply_to is not None and not isinstance(session_message.reply_to, str):
session_message.reply_to = None
session_message.processed_plain_text = message_dict.get("processed_plain_text")
if session_message.processed_plain_text is not None and not isinstance(
session_message.processed_plain_text, str
):
session_message.processed_plain_text = None
session_message.display_message = message_dict.get("display_message")
if session_message.display_message is not None and not isinstance(session_message.display_message, str):
session_message.display_message = None
return session_message

View File

@@ -1,97 +0,0 @@
"""策略引擎
负责能力授权校验。
每个插件在 manifest 中声明能力需求Host 启动时签发能力令牌。
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
@dataclass
class CapabilityToken:
"""能力令牌"""
plugin_id: str
generation: int
capabilities: Set[str] = field(default_factory=set)
class PolicyEngine:
"""策略引擎
管理所有插件的能力令牌,提供授权校验。
"""
def __init__(self) -> None:
self._tokens: Dict[str, Dict[int, CapabilityToken]] = {}
def register_plugin(
self,
plugin_id: str,
generation: int,
capabilities: List[str],
) -> CapabilityToken:
"""为插件签发能力令牌"""
token = CapabilityToken(
plugin_id=plugin_id,
generation=generation,
capabilities=set(capabilities),
)
self._tokens.setdefault(plugin_id, {})[generation] = token
return token
def revoke_plugin(self, plugin_id: str, generation: Optional[int] = None) -> None:
"""撤销插件的能力令牌。"""
if generation is None:
self._tokens.pop(plugin_id, None)
return
generations = self._tokens.get(plugin_id)
if generations is None:
return
generations.pop(generation, None)
if not generations:
self._tokens.pop(plugin_id, None)
def clear(self) -> None:
"""清空所有能力令牌。"""
self._tokens.clear()
def check_capability(self, plugin_id: str, capability: str, generation: Optional[int] = None) -> Tuple[bool, str]:
"""检查插件是否有权调用某项能力
Returns:
(allowed, reason)
"""
generations = self._tokens.get(plugin_id)
if not generations:
return False, f"插件 {plugin_id} 未注册能力令牌"
if generation is None:
token = generations[max(generations)]
else:
token = generations.get(generation)
if token is None:
active_generation = max(generations)
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {active_generation}"
if capability not in token.capabilities:
return False, f"插件 {plugin_id} 未获授权能力: {capability}"
if generation is not None and token.generation != generation:
return False, f"插件 {plugin_id} generation 不匹配: {generation} != {token.generation}"
return True, ""
def get_token(self, plugin_id: str) -> Optional[CapabilityToken]:
"""获取插件的能力令牌"""
generations = self._tokens.get(plugin_id)
if not generations:
return None
return generations[max(generations)]
def list_plugins(self) -> List[str]:
"""列出所有已注册的插件"""
return list(self._tokens.keys())

View File

@@ -7,7 +7,7 @@
4. 请求-响应关联与超时管理
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Coroutine
import asyncio
import contextlib
@@ -32,7 +32,7 @@ from src.plugin_runtime.transport.base import Connection, TransportServer
logger = get_logger("plugin_runtime.host.rpc_server")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
MethodHandler = Callable[[Envelope], Coroutine[Any, Any, Envelope]]
class RPCServer:
@@ -55,108 +55,39 @@ class RPCServer:
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None # 当前活跃的 Runner 连接
self._runner_id: Optional[str] = None
self._runner_generation: int = 0
self._staged_connection: Optional[Connection] = None
self._staged_runner_id: Optional[str] = None
self._staged_runner_generation: int = 0
self._staging_takeover: bool = False
# 方法处理器注册表
self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> (Future, target_generation)
self._pending_requests: Dict[int, Tuple[asyncio.Future, int]] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
# 发送队列(背压控制)
self._send_queue: Optional[asyncio.Queue[Tuple[Connection, bytes, asyncio.Future[None]]]] = None
self._send_worker_task: Optional[asyncio.Task] = None
self._send_worker_task: Optional[asyncio.Task[None]] = None
# 运行状态
self._running: bool = False
self._tasks: List[asyncio.Task] = []
self._tasks: List[asyncio.Task[None]] = []
self._last_handshake_rejection_reason: str = ""
self._connection_lock: asyncio.Lock = asyncio.Lock()
@property
def session_token(self) -> str:
return self._session_token
def reset_session_token(self) -> str:
"""重新生成会话令牌(热重载时调用,防止旧 Runner 重连)"""
self._session_token = secrets.token_hex(32)
return self._session_token
def restore_session_token(self, token: str) -> None:
"""恢复指定的会话令牌(热重载回滚时调用)"""
self._session_token = token
@property
def runner_generation(self) -> int:
return self._runner_generation
@property
def staged_generation(self) -> int:
return self._staged_runner_generation
@property
def is_connected(self) -> bool:
return self._connection is not None and not self._connection.is_closed
def has_generation(self, generation: int) -> bool:
return generation == self._runner_generation or (
self._staged_connection is not None
and not self._staged_connection.is_closed
and generation == self._staged_runner_generation
)
@property
def last_handshake_rejection_reason(self) -> str:
"""返回最近一次握手被拒绝的原因。"""
return self._last_handshake_rejection_reason
def begin_staged_takeover(self) -> None:
"""允许新 Runner 以 staged 方式接入,待 Supervisor 验证后再切换为活跃连接"""
self._staging_takeover = True
async def commit_staged_takeover(self) -> None:
"""提交 staged Runner原活跃连接在提交后被关闭。"""
if self._staged_connection is None or self._staged_connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "没有可提交的新 Runner 连接")
old_connection = self._connection
old_generation = self._runner_generation
self._connection = self._staged_connection
self._runner_id = self._staged_runner_id
self._runner_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
if stale_count := self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=old_generation,
):
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
if old_connection and old_connection is not self._connection and not old_connection.is_closed:
await old_connection.close()
async def rollback_staged_takeover(self) -> None:
"""放弃 staged Runner保留当前活跃连接。"""
staged_connection = self._staged_connection
staged_generation = self._staged_runner_generation
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._staging_takeover = False
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"新 Runner 预热失败,已回滚",
generation=staged_generation,
)
if staged_connection and not staged_connection.is_closed:
await staged_connection.close()
def clear_handshake_state(self) -> None:
"""清空最近一次握手拒绝状态"""
self._last_handshake_rejection_reason = ""
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册 RPC 方法处理器"""
@@ -165,6 +96,7 @@ class RPCServer:
async def start(self) -> None:
"""启动 RPC 服务器"""
self._running = True
self.clear_handshake_state()
self._send_queue = asyncio.Queue(maxsize=self._send_queue_size)
self._send_worker_task = asyncio.create_task(self._send_loop())
await self._transport.start(self._handle_connection)
@@ -173,14 +105,9 @@ class RPCServer:
async def stop(self) -> None:
"""停止 RPC 服务器"""
self._running = False
# 取消所有 pending 请求
for future, _generation in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
self._pending_requests.clear()
self._fail_queued_sends(ErrorCode.E_TIMEOUT, "服务器关闭")
self.clear_handshake_state()
self._fail_pending_requests(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
self._fail_queued_sends(ErrorCode.E_SHUTTING_DOWN, "服务器正在关闭")
if self._send_worker_task:
self._send_worker_task.cancel()
@@ -198,10 +125,6 @@ class RPCServer:
await self._connection.close()
self._connection = None
if self._staged_connection:
await self._staged_connection.close()
self._staged_connection = None
await self._transport.stop()
logger.info("RPC Server 已停止")
@@ -211,7 +134,6 @@ class RPCServer:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
target_generation: Optional[int] = None,
) -> Envelope:
"""向 Runner 发送 RPC 请求并等待响应
@@ -227,18 +149,14 @@ class RPCServer:
Raises:
RPCError: 调用失败
"""
generation = target_generation or self._runner_generation
conn = self._get_connection_for_generation(generation)
if conn is None or conn.is_closed:
if not self._connection or self._connection.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
request_id = self._id_gen.next()
request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -246,12 +164,12 @@ class RPCServer:
# 注册 pending future
loop = asyncio.get_running_loop()
future: asyncio.Future[Envelope] = loop.create_future()
self._pending_requests[request_id] = (future, generation)
self._pending_requests[request_id] = future
try:
# 发送请求
data = self._codec.encode_envelope(envelope)
await self._enqueue_send(conn, data)
await self._enqueue_send(self._connection, data)
# 等待响应
timeout_sec = timeout_ms / 1000.0
@@ -265,150 +183,136 @@ class RPCServer:
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
async def send_event(self, method: str, plugin_id: str = "", payload: Optional[Dict[str, Any]] = None) -> None:
"""向 Runner 发送单向事件(不等待响应)"""
conn = self._connection
if conn is None or conn.is_closed:
return
# ============ 内部方法 ============
# ========= 发送循环 =========
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
raise RuntimeError("没有消息队列")
request_id = self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.EVENT,
method=method,
plugin_id=plugin_id,
generation=self._runner_generation,
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await self._enqueue_send(conn, data)
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
# ─── 内部方法 ──────────────────────────────────────────────
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = RPCError.from_exception(e, {ConnectionError: ErrorCode.E_PLUGIN_CRASHED})
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
# ====== 发送循环方法 ======
async def _handle_connection(self, conn: Connection) -> None:
"""处理新的 Runner 连接"""
logger.info("收到 Runner 连接")
previous_connection = self._connection
previous_generation = self._runner_generation
# 第一条消息必须是 runner.hello 握手
try:
role = await self._handle_handshake(conn)
if role is None:
await conn.close()
return
async with self._connection_lock:
self.clear_handshake_state()
success = await self._handle_handshake(conn)
if not success:
await conn.close()
return
logger.info("Runner staged 握手成功")
self._connection = conn
except Exception as e:
logger.error(f"握手失败: {e}")
await conn.close()
return
if role == "staged":
expected_generation = self._staged_runner_generation
logger.info(
f"Runner staged 握手成功: runner_id={self._staged_runner_id}, generation={self._staged_runner_generation}"
)
else:
self._connection = conn
expected_generation = self._runner_generation
logger.info(f"Runner 握手成功: runner_id={self._runner_id}, generation={self._runner_generation}")
if previous_connection and previous_connection is not conn and not previous_connection.is_closed:
logger.info("检测到新 Runner 已接管连接,关闭旧连接")
if stale_count := self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已被新 generation 接管",
generation=previous_generation,
):
logger.info(f"已清理 {stale_count} 个旧 Runner 的 pending 请求")
await previous_connection.close()
# 启动消息接收循环
try:
await self._recv_loop(conn, expected_generation=expected_generation)
await self._recv_loop(conn)
except Exception as e:
logger.error(f"连接异常断开: {e}")
finally:
if self._connection is conn:
self._connection = None
self._runner_id = None
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Runner 连接已断开",
generation=expected_generation,
)
elif self._staged_connection is conn:
self._staged_connection = None
self._staged_runner_id = None
self._staged_runner_generation = 0
self._fail_pending_requests(
ErrorCode.E_PLUGIN_CRASHED,
"Staged Runner 连接已断开",
generation=expected_generation,
)
should_fail_pending_requests = False
async with self._connection_lock:
if self._connection is conn:
self._connection = None
should_fail_pending_requests = True
if should_fail_pending_requests:
self._fail_pending_requests(ErrorCode.E_PLUGIN_CRASHED, "Runner 连接已断开")
async def _handle_handshake(self, conn: Connection) -> Optional[str]:
async def _handle_handshake(self, conn: Connection) -> bool:
"""处理 runner.hello 握手"""
# 接收握手请求
data = await asyncio.wait_for(conn.recv_frame(), timeout=10.0)
envelope = self._codec.decode_envelope(data)
if envelope.method != "runner.hello":
logger.error(f"期望 runner.hello收到 {envelope.method}")
self._last_handshake_rejection_reason = "首条消息必须为 runner.hello"
error_resp = envelope.make_error_response(
ErrorCode.E_PROTOCOL_MISMATCH.value,
"首条消息必须为 runner.hello",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
return None
return False
# 解析握手 payload
hello = HelloPayload.model_validate(envelope.payload)
# 校验会话令牌
if hello.session_token != self._session_token:
logger.error("会话令牌不匹配")
resp_payload = HelloResponsePayload(
accepted=False,
reason="会话令牌无效",
)
self._last_handshake_rejection_reason = "会话令牌无效"
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return None
return False
# 若已有活跃连接,直接拒绝新的握手,避免后来的连接抢占当前通道。
if self.is_connected:
logger.warning("拒绝新的 Runner 连接:已有活跃连接")
self._last_handshake_rejection_reason = "已有活跃 Runner 连接,拒绝新的握手"
resp_payload = HelloResponsePayload(accepted=False, reason=self._last_handshake_rejection_reason)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return False
# 校验 SDK 版本
if not self._check_sdk_version(hello.sdk_version):
logger.error(f"SDK 版本不兼容: {hello.sdk_version}")
self._last_handshake_rejection_reason = (
f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]"
)
resp_payload = HelloResponsePayload(
accepted=False,
reason=f"SDK 版本 {hello.sdk_version} 不在支持范围 [{MIN_SDK_VERSION}, {MAX_SDK_VERSION}]",
reason=self._last_handshake_rejection_reason,
)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return None
return False
# 握手成功
role = "active"
assigned_generation = self._runner_generation + 1
if self._staging_takeover and self.is_connected:
role = "staged"
self._staged_connection = conn
self._staged_runner_id = hello.runner_id
self._staged_runner_generation = assigned_generation
else:
self._runner_id = hello.runner_id
self._runner_generation = assigned_generation
resp_payload = HelloResponsePayload(
accepted=True,
host_version=PROTOCOL_VERSION,
assigned_generation=assigned_generation,
)
# 发送响应
self.clear_handshake_state()
resp_payload = HelloResponsePayload(accepted=True, host_version=PROTOCOL_VERSION)
resp = envelope.make_response(payload=resp_payload.model_dump())
await conn.send_frame(self._codec.encode_envelope(resp))
return True
return role
def _check_sdk_version(self, sdk_version: str) -> bool:
"""检查 SDK 版本是否在支持范围内"""
try:
sdk_parts = _parse_version_tuple(sdk_version)
min_parts = _parse_version_tuple(MIN_SDK_VERSION)
max_parts = _parse_version_tuple(MAX_SDK_VERSION)
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
async def _recv_loop(self, conn: Connection, expected_generation: int) -> None:
# ========= 接收循环 =========
async def _recv_loop(self, conn: Connection) -> None:
"""消息接收主循环"""
while self._running and not conn.is_closed:
try:
@@ -430,109 +334,40 @@ class RPCServer:
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
if envelope.generation != expected_generation:
error_resp = envelope.make_error_response(
ErrorCode.E_GENERATION_MISMATCH.value,
f"过期 generation: {envelope.generation} != {expected_generation}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
continue
# 异步处理请求Runner 发来的能力调用)
task = asyncio.create_task(self._handle_request(envelope, conn))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
elif envelope.is_event():
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 事件 {envelope.method}: {envelope.generation} != {expected_generation}"
)
continue
task = asyncio.create_task(self._handle_event(envelope))
elif envelope.is_broadcast():
task = asyncio.create_task(self._handle_broadcast(envelope))
self._tasks.append(task)
task.add_done_callback(lambda t: self._tasks.remove(t) if t in self._tasks else None)
else:
logger.warning(f"未知的消息类型: {envelope.message_type}")
continue
# ====== 接收循环内部方法 ======
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Runner 的响应"""
pending = self._pending_requests.get(envelope.request_id)
if pending is None:
pending_future = self._pending_requests.pop(envelope.request_id, None)
if pending_future is None:
return
future, expected_generation = pending
if envelope.generation != expected_generation:
logger.warning(
f"忽略过期 generation 响应 {envelope.method}: {envelope.generation} != {expected_generation}"
)
return
self._pending_requests.pop(envelope.request_id, None)
if not future.done():
if not pending_future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
pending_future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACKPRESSURE, "发送队列已满") from None
await send_future
async def _send_loop(self) -> None:
"""后台发送循环:串行消费发送队列,统一执行连接写入。"""
if self._send_queue is None:
return
while True:
try:
conn, data, send_future = await self._send_queue.get()
except asyncio.CancelledError:
break
try:
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
await conn.send_frame(data)
if not send_future.done():
send_future.set_result(None)
except asyncio.CancelledError:
if not send_future.done():
send_future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "服务器关闭"))
raise
except Exception as e:
send_error = e if isinstance(e, RPCError) else self._normalize_send_exception(e)
if not send_future.done():
send_future.set_exception(send_error)
finally:
self._send_queue.task_done()
@staticmethod
def _normalize_send_exception(error: Exception) -> RPCError:
if isinstance(error, ConnectionError):
return RPCError(ErrorCode.E_PLUGIN_CRASHED, str(error))
return RPCError(ErrorCode.E_UNKNOWN, str(error))
pending_future.set_result(envelope)
async def _handle_request(self, envelope: Envelope, conn: Connection) -> None:
"""处理来自 Runner 的请求(通常是能力调用 cap.*"""
handler = self._method_handlers.get(envelope.method)
if handler is None:
error_resp = envelope.make_error_response(
target_method = envelope.method
handler = self._method_handlers.get(target_method)
if not handler:
error_response = envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"未注册的方法: {envelope.method}",
)
await conn.send_frame(self._codec.encode_envelope(error_resp))
await conn.send_frame(self._codec.encode_envelope(error_response))
return
try:
@@ -546,59 +381,25 @@ class RPCServer:
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
await conn.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Runner 的事件"""
async def _handle_broadcast(self, envelope: Envelope) -> None:
if handler := self._method_handlers.get(envelope.method):
try:
result = await handler(envelope)
# 检查 handler 返回的信封是否包含错误信息
if result is not None and isinstance(result, Envelope) and result.error:
if result.error:
logger.warning(f"事件 {envelope.method} handler 返回错误: {result.error.get('message', '')}")
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
@staticmethod
def _check_sdk_version(sdk_version: str) -> bool:
"""检查 SDK 版本是否在支持范围内"""
try:
sdk_parts = RPCServer._parse_version_tuple(sdk_version)
min_parts = RPCServer._parse_version_tuple(MIN_SDK_VERSION)
max_parts = RPCServer._parse_version_tuple(MAX_SDK_VERSION)
return min_parts <= sdk_parts <= max_parts
except (ValueError, AttributeError):
return False
@staticmethod
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
base_version = base_version.split("+", 1)[0]
parts = [part for part in base_version.split(".") if part != ""]
while len(parts) < 3:
parts.append("0")
return (int(parts[0]), int(parts[1]), int(parts[2]))
def _get_connection_for_generation(self, generation: int) -> Optional[Connection]:
if generation == self._runner_generation:
return self._connection
if generation == self._staged_runner_generation:
return self._staged_connection
return None
def _fail_pending_requests(
self,
error_code: ErrorCode,
message: str,
generation: Optional[int] = None,
) -> int:
stale_count = 0
for request_id, (future, request_generation) in list(self._pending_requests.items()):
if generation is not None and request_generation != generation:
continue
def _fail_pending_requests(self, error_code: ErrorCode, message: str) -> int:
"""失败所有等待中的请求(如连接断开时)"""
aborted_request_count = 0
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(error_code, message))
stale_count += 1
self._pending_requests.pop(request_id, None)
return stale_count
aborted_request_count += 1
self._pending_requests.clear()
return aborted_request_count
def _fail_queued_sends(self, error_code: ErrorCode, message: str) -> int:
if self._send_queue is None:
@@ -617,3 +418,31 @@ class RPCServer:
self._send_queue.task_done()
return failed_count
async def _enqueue_send(self, conn: Connection, data: bytes) -> None:
"""通过发送队列串行发送消息,提供真实背压。"""
if conn.is_closed:
raise RPCError(ErrorCode.E_PLUGIN_CRASHED, "Runner 未连接")
if self._send_queue is None:
await conn.send_frame(data)
return
loop = asyncio.get_running_loop()
send_future: asyncio.Future[None] = loop.create_future()
try:
self._send_queue.put_nowait((conn, data, send_future))
except asyncio.QueueFull:
raise RPCError(ErrorCode.E_BACK_PRESSURE, "发送队列已满") from None
await send_future
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version or "", flags=re.IGNORECASE)[0]
base_version = base_version.split("+", 1)[0]
parts = [part for part in base_version.split(".") if part != ""]
while len(parts) < 3:
parts.append("0")
return (int(parts[0]), int(parts[1]), int(parts[2]))

File diff suppressed because it is too large Load Diff

View File

@@ -1,422 +0,0 @@
"""Host-side WorkflowExecutor
6 阶段线性流转INGRESS → PRE_PROCESS → PLAN → TOOL_EXECUTE → POST_PROCESS → EGRESS
每个阶段执行顺序:
1. Host-side pre-filter: 根据 hook filter 条件过滤不相关的 hook
2. 按 priority 降序排列
3. 串行执行 blocking hook可修改 message返回 HookResult
4. 并发执行 non-blocking hook只读
5. 检查是否有 SKIP_STAGE 或 ABORT
6. PLAN 阶段内置 Command 匹配路由
支持:
- HookResult: CONTINUE / SKIP_STAGE / ABORT
- ErrorPolicy: ABORT / SKIP / LOG (per-hook)
- stage_outputs: 阶段间带命名空间的数据传递
- modification_log: 消息修改审计
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import asyncio
import time
import uuid
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_runtime.host.component_registry import ComponentRegistry, RegisteredComponent
logger = get_logger("plugin_runtime.host.workflow_executor")
# 阶段顺序
STAGE_SEQUENCE: List[str] = [
"ingress",
"pre_process",
"plan",
"tool_execute",
"post_process",
"egress",
]
# HookResult 常量(与 SDK HookResult enum 值对应)
HOOK_CONTINUE = "continue"
HOOK_SKIP_STAGE = "skip_stage"
HOOK_ABORT = "abort"
# blocking hook 全局最大超时(秒):即使 hook 声明 timeout_ms=0 也不会无限等待
# 从配置文件读取,允许用户调整
def _get_blocking_timeout() -> float:
return global_config.plugin_runtime.workflow_blocking_timeout_sec
class ModificationRecord:
"""消息修改记录"""
__slots__ = ("stage", "hook_name", "timestamp", "fields_changed")
def __init__(self, stage: str, hook_name: str, fields_changed: List[str]) -> None:
self.stage = stage
self.hook_name = hook_name
self.timestamp = time.perf_counter()
self.fields_changed = fields_changed
class WorkflowContext:
"""Workflow 执行上下文"""
def __init__(self, trace_id: Optional[str] = None, stream_id: Optional[str] = None) -> None:
self.trace_id = trace_id or uuid.uuid4().hex
self.stream_id = stream_id
self.timings: Dict[str, float] = {}
self.errors: List[str] = []
# 阶段间数据传递(按 stage 命名空间隔离)
self.stage_outputs: Dict[str, Dict[str, Any]] = {}
# 消息修改审计日志
self.modification_log: List[ModificationRecord] = []
# PLAN 阶段命令匹配结果
self.matched_command: Optional[str] = None
def set_stage_output(self, stage: str, key: str, value: Any) -> None:
self.stage_outputs.setdefault(stage, {})[key] = value
def get_stage_output(self, stage: str, key: str, default: Any = None) -> Any:
return self.stage_outputs.get(stage, {}).get(key, default)
class WorkflowResult:
"""Workflow 执行结果"""
def __init__(
self,
status: str = "completed", # completed / aborted / failed
return_message: str = "",
stopped_at: str = "",
diagnostics: Optional[Dict[str, Any]] = None,
) -> None:
self.status = status
self.return_message = return_message
self.stopped_at = stopped_at
self.diagnostics = diagnostics or {}
# invoke_fn 签名
InvokeFn = Callable[[str, str, Dict[str, Any]], Awaitable[Dict[str, Any]]]
class WorkflowExecutor:
"""Host-side Workflow 执行器
实现 stage-based pipeline + per-stage hook chain with priority + early return。
"""
def __init__(self, registry: ComponentRegistry) -> None:
self._registry = registry
self._background_tasks: Set[asyncio.Task] = set()
async def execute(
self,
invoke_fn: InvokeFn,
message: Optional[Dict[str, Any]] = None,
stream_id: Optional[str] = None,
context: Optional[WorkflowContext] = None,
command_invoke_fn: Optional[InvokeFn] = None,
) -> Tuple[WorkflowResult, Optional[Dict[str, Any]], WorkflowContext]:
"""执行 workflow pipeline。
Args:
invoke_fn: 用于 workflow_step 的回调
command_invoke_fn: 用于 command 的回调(走 plugin.invoke_command
未传则复用 invoke_fn
Returns:
(result, final_message, context)
"""
ctx = context or WorkflowContext(stream_id=stream_id)
current_message = dict(message) if message else None
for stage in STAGE_SEQUENCE:
stage_start = time.perf_counter()
try:
# PLAN 阶段: 先做 Command 路由
if stage == "plan" and current_message:
cmd_result = await self._route_command(command_invoke_fn or invoke_fn, current_message, ctx)
if cmd_result is not None:
# 命令匹配成功,跳过 PLAN 阶段的 hook直接存结果进 stage_outputs
ctx.set_stage_output("plan", "command_result", cmd_result)
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 获取该阶段所有 hook已按 priority 降序排列)
all_steps = self._registry.get_workflow_steps(stage)
if not all_steps:
ctx.timings[stage] = time.perf_counter() - stage_start
continue
# 1. Pre-filter
filtered_steps = self._pre_filter(all_steps, current_message)
# 2. 分离 blocking 和 non-blocking
blocking_steps = [s for s in filtered_steps if s.metadata.get("blocking", True)]
nonblocking_steps = [s for s in filtered_steps if not s.metadata.get("blocking", True)]
# 3. 串行执行 blocking hook
skip_stage = False
for step in blocking_steps:
hook_result, modified, step_error = await self._invoke_step(
invoke_fn, step, stage, ctx, current_message
)
if step_error:
error_policy = step.metadata.get("error_policy", "abort")
ctx.errors.append(f"{step.full_name}: {step_error}")
if error_policy == "abort":
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="failed",
return_message=step_error,
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
elif error_policy == "skip":
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(skip): {step_error}")
continue
else: # log
logger.warning(f"[{ctx.trace_id}] hook {step.full_name} 异常(log): {step_error}")
continue
# 更新消息(仅 blocking hook 有权修改)
if modified:
changed_fields = (
_diff_keys(current_message, modified) if current_message else list(modified.keys())
)
ctx.modification_log.append(ModificationRecord(stage, step.full_name, changed_fields))
current_message = modified
if hook_result == HOOK_ABORT:
ctx.timings[stage] = time.perf_counter() - stage_start
return (
WorkflowResult(
status="aborted",
return_message=f"aborted by {step.full_name}",
stopped_at=stage,
diagnostics={"step": step.full_name, "trace_id": ctx.trace_id},
),
current_message,
ctx,
)
if hook_result == HOOK_SKIP_STAGE:
skip_stage = True
break
# 4. 并发执行 non-blocking hook只读忽略返回值中的 modified_message
if nonblocking_steps and not skip_stage:
for step in nonblocking_steps:
self._track_background_task(
asyncio.create_task(
self._invoke_step_fire_and_forget(invoke_fn, step, stage, ctx, current_message)
)
)
ctx.timings[stage] = time.perf_counter() - stage_start
except Exception as e:
ctx.timings[stage] = time.perf_counter() - stage_start
ctx.errors.append(f"{stage}: {e}")
logger.error(f"[{ctx.trace_id}] 阶段 {stage} 未捕获异常: {e}", exc_info=True)
return (
WorkflowResult(
status="failed",
return_message=str(e),
stopped_at=stage,
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
return (
WorkflowResult(
status="completed",
return_message="workflow completed",
diagnostics={"trace_id": ctx.trace_id},
),
current_message,
ctx,
)
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持 non-blocking workflow task 的强引用,直到任务结束。"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
# ─── 内部方法 ──────────────────────────────────────────────
def _pre_filter(
self,
steps: List[RegisteredComponent],
message: Optional[Dict[str, Any]],
) -> List[RegisteredComponent]:
"""根据 hook 声明的 filter 条件预过滤,避免无意义的 IPC 调用。"""
if not message:
return steps
result = []
for step in steps:
filter_cond = step.metadata.get("filter", {})
if not filter_cond:
result.append(step)
continue
if self._match_filter(filter_cond, message):
result.append(step)
return result
@staticmethod
def _match_filter(filter_cond: Dict[str, Any], message: Dict[str, Any]) -> bool:
"""简单 key-value 匹配过滤。
filter 中的每个 key 必须在 message 中存在且值相等,
全部匹配才通过。
"""
for key, expected in filter_cond.items():
actual = message.get(key)
if (isinstance(expected, list) and actual not in expected) or (
not isinstance(expected, list) and actual != expected
):
return False
return True
async def _invoke_step(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: Optional[Dict[str, Any]],
) -> Tuple[str, Optional[Dict[str, Any]], Optional[str]]:
"""调用单个 blocking hook。
Returns:
(hook_result, modified_message, error_string_or_None)
"""
timeout_ms = step.metadata.get("timeout_ms", 0)
# 使用 hook 声明的超时,但不超过全局安全阀
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
step_key = f"{stage}:{step.full_name}"
step_start = time.perf_counter()
try:
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
resp = await asyncio.wait_for(coro, timeout=timeout_sec)
ctx.timings[step_key] = time.perf_counter() - step_start
hook_result = resp.get("hook_result", HOOK_CONTINUE)
modified_message = resp.get("modified_message")
# 存 stage output如果 hook 提供了)
stage_out = resp.get("stage_output")
if isinstance(stage_out, dict):
for k, v in stage_out.items():
ctx.set_stage_output(stage, k, v)
return hook_result, modified_message, None
except asyncio.TimeoutError:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, f"timeout after {timeout_ms}ms"
except Exception as e:
ctx.timings[step_key] = time.perf_counter() - step_start
return HOOK_CONTINUE, None, str(e)
async def _invoke_step_fire_and_forget(
self,
invoke_fn: InvokeFn,
step: RegisteredComponent,
stage: str,
ctx: WorkflowContext,
message: Optional[Dict[str, Any]],
) -> None:
"""Non-blocking hook 调用,只读,忽略结果。"""
timeout_ms = step.metadata.get("timeout_ms", 0)
# 使用 hook 声明的超时,但无声明时回退到全局安全阀,防止 task 泄漏
timeout_sec = timeout_ms / 1000 if timeout_ms > 0 else _get_blocking_timeout()
try:
coro = invoke_fn(
step.plugin_id,
step.name,
{
"stage": stage,
"trace_id": ctx.trace_id,
"message": message,
"stage_outputs": ctx.stage_outputs,
},
)
await asyncio.wait_for(coro, timeout=timeout_sec)
except asyncio.TimeoutError:
logger.warning(f"[{ctx.trace_id}] non-blocking hook {step.full_name} 超时 ({timeout_sec}s)")
except Exception as e:
logger.debug(f"[{ctx.trace_id}] non-blocking hook {step.full_name}: {e}")
async def _route_command(
self,
invoke_fn: InvokeFn,
message: Dict[str, Any],
ctx: WorkflowContext,
) -> Optional[Dict[str, Any]]:
"""PLAN 阶段内置 Command 路由。
在 registry 中查找匹配的 command 组件,
匹配到则直接路由到对应 command handler返回执行结果。
不匹配则返回 None让 PLAN 阶段的 hook 继续执行。
"""
plain_text = message.get("plain_text", "")
if not plain_text:
return None
match_result = self._registry.find_command_by_text(plain_text)
if match_result is None:
return None
matched, matched_groups = match_result
ctx.matched_command = matched.full_name
logger.info(f"[{ctx.trace_id}] 命令匹配: {matched.full_name}")
try:
return await invoke_fn(
matched.plugin_id,
matched.name,
{
"text": plain_text,
"message": message,
"trace_id": ctx.trace_id,
"matched_groups": matched_groups,
},
)
except Exception as e:
logger.error(f"[{ctx.trace_id}] 命令 {matched.full_name} 执行失败: {e}", exc_info=True)
ctx.errors.append(f"command:{matched.full_name}: {e}")
return None
def _diff_keys(old: Dict[str, Any], new: Dict[str, Any]) -> List[str]:
"""返回 new 中与 old 不同的 key 列表。"""
return [k for k, v in new.items() if k not in old or old[k] != v]

View File

@@ -8,23 +8,27 @@
"""
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence, Set, Tuple
import asyncio
import json
import tomlkit
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import config_manager
from src.config.file_watcher import FileChange, FileWatcher
from src.platform_io import DeliveryBatch, InboundMessageEnvelope, get_platform_io_manager
from src.plugin_runtime.capabilities import (
RuntimeComponentCapabilityMixin,
RuntimeCoreCapabilityMixin,
RuntimeDataCapabilityMixin,
)
from src.plugin_runtime.capabilities.registry import register_capability_impls
from src.plugin_runtime.host.message_utils import MessageDict, PluginMessageUtils
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from src.plugin_runtime.host.supervisor import PluginSupervisor
logger = get_logger("plugin_runtime.integration")
@@ -55,6 +59,7 @@ class PluginRuntimeManager(
"""
def __init__(self) -> None:
"""初始化插件运行时管理器。"""
from src.plugin_runtime.host.supervisor import PluginSupervisor
self._builtin_supervisor: Optional[PluginSupervisor] = None
@@ -63,6 +68,26 @@ class PluginRuntimeManager(
self._plugin_file_watcher: Optional[FileWatcher] = None
self._plugin_source_watcher_subscription_id: Optional[str] = None
self._plugin_config_watcher_subscriptions: Dict[str, Tuple[Path, str]] = {}
self._plugin_path_cache: Dict[str, Path] = {}
self._manifest_validator: ManifestValidator = ManifestValidator()
self._config_reload_callback: Callable[[Sequence[str]], Awaitable[None]] = self._handle_main_config_reload
self._config_reload_callback_registered: bool = False
async def _dispatch_platform_inbound(self, envelope: InboundMessageEnvelope) -> None:
"""接收 Platform IO 审核后的入站消息并送入主消息链。
Args:
envelope: Platform IO 产出的入站封装。
"""
session_message = envelope.session_message
if session_message is None and envelope.payload is not None:
session_message = PluginMessageUtils._build_session_message_from_dict(dict(envelope.payload))
if session_message is None:
raise ValueError("Platform IO 入站封装缺少可用的 SessionMessage 或 payload")
from src.chat.message_receive.bot import chat_bot
await chat_bot.receive_message(session_message)
# ─── 插件目录 ─────────────────────────────────────────────
@@ -78,6 +103,42 @@ class PluginRuntimeManager(
candidate = Path("plugins").resolve()
return [candidate] if candidate.is_dir() else []
@classmethod
def _discover_plugin_dependency_map(cls, plugin_dirs: Iterable[Path]) -> Dict[str, List[str]]:
"""扫描指定插件目录集合,返回 ``plugin_id -> dependencies`` 映射。"""
validator = ManifestValidator()
return validator.build_plugin_dependency_map(plugin_dirs)
@classmethod
def _build_group_start_order(
cls,
builtin_dirs: Sequence[Path],
third_party_dirs: Sequence[Path],
) -> List[str]:
"""根据跨 Supervisor 依赖关系决定 Runner 启动顺序。"""
builtin_dependencies = cls._discover_plugin_dependency_map(builtin_dirs)
third_party_dependencies = cls._discover_plugin_dependency_map(third_party_dirs)
builtin_plugin_ids = set(builtin_dependencies)
third_party_plugin_ids = set(third_party_dependencies)
builtin_needs_third_party = any(
dependency in third_party_plugin_ids
for dependencies in builtin_dependencies.values()
for dependency in dependencies
)
third_party_needs_builtin = any(
dependency in builtin_plugin_ids
for dependencies in third_party_dependencies.values()
for dependency in dependencies
)
if builtin_needs_third_party and third_party_needs_builtin:
raise RuntimeError("检测到跨 Supervisor 循环依赖,当前无法安全启动独立 Runner")
if builtin_needs_third_party:
return ["third_party", "builtin"]
return ["builtin", "third_party"]
# ─── 生命周期 ─────────────────────────────────────────────
async def start(self) -> None:
@@ -86,7 +147,7 @@ class PluginRuntimeManager(
logger.warning("PluginRuntimeManager 已在运行中,跳过重复启动")
return
_cfg = global_config.plugin_runtime
_cfg = config_manager.get_global_config().plugin_runtime
if not _cfg.enabled:
logger.info("插件运行时已在配置中禁用,跳过启动")
return
@@ -108,6 +169,8 @@ class PluginRuntimeManager(
logger.info("未找到任何插件目录,跳过插件运行时启动")
return
platform_io_manager = get_platform_io_manager()
# 从配置读取自定义 IPC socket 路径(留空则自动生成)
socket_path_base = _cfg.ipc_socket_path or None
@@ -132,19 +195,46 @@ class PluginRuntimeManager(
started_supervisors: List[PluginSupervisor] = []
try:
if self._builtin_supervisor:
await self._builtin_supervisor.start()
started_supervisors.append(self._builtin_supervisor)
if self._third_party_supervisor:
await self._third_party_supervisor.start()
started_supervisors.append(self._third_party_supervisor)
platform_io_manager.set_inbound_dispatcher(self._dispatch_platform_inbound)
await platform_io_manager.ensure_send_pipeline_ready()
supervisor_groups: Dict[str, Optional[PluginSupervisor]] = {
"builtin": self._builtin_supervisor,
"third_party": self._third_party_supervisor,
}
start_order = self._build_group_start_order(builtin_dirs, third_party_dirs)
for group_name in start_order:
supervisor = supervisor_groups.get(group_name)
if supervisor is None:
continue
external_plugin_versions = {
plugin_id: plugin_version
for started_supervisor in started_supervisors
for plugin_id, plugin_version in started_supervisor.get_loaded_plugin_versions().items()
}
supervisor.set_external_available_plugins(external_plugin_versions)
await supervisor.start()
started_supervisors.append(supervisor)
await self._start_plugin_file_watcher()
config_manager.register_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = True
self._started = True
logger.info(f"插件运行时已启动 — 内置: {builtin_dirs or ''}, 第三方: {third_party_dirs or ''}")
except Exception as e:
logger.error(f"插件运行时启动失败: {e}", exc_info=True)
await self._stop_plugin_file_watcher()
if self._config_reload_callback_registered:
config_manager.unregister_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = False
await asyncio.gather(*(sv.stop() for sv in started_supervisors), return_exceptions=True)
platform_io_manager.clear_inbound_dispatcher()
try:
await platform_io_manager.stop()
except Exception as platform_io_exc:
logger.warning(f"Platform IO 停止失败: {platform_io_exc}")
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
@@ -154,7 +244,11 @@ class PluginRuntimeManager(
if not self._started:
return
platform_io_manager = get_platform_io_manager()
await self._stop_plugin_file_watcher()
if self._config_reload_callback_registered:
config_manager.unregister_reload_callback(self._config_reload_callback)
self._config_reload_callback_registered = False
coroutines: List[Coroutine[Any, Any, None]] = []
if self._builtin_supervisor:
@@ -162,18 +256,32 @@ class PluginRuntimeManager(
if self._third_party_supervisor:
coroutines.append(self._third_party_supervisor.stop())
stop_errors: List[str] = []
try:
await asyncio.gather(*coroutines, return_exceptions=True)
logger.info("插件运行时已停止")
except Exception as e:
logger.error(f"插件运行时停止失败: {e}", exc_info=True)
results = await asyncio.gather(*coroutines, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
stop_errors.append(str(result))
platform_io_manager.clear_inbound_dispatcher()
try:
await platform_io_manager.stop()
except Exception as exc:
stop_errors.append(f"Platform IO: {exc}")
if stop_errors:
logger.error(f"插件运行时停止过程中存在错误: {'; '.join(stop_errors)}")
else:
logger.info("插件运行时已停止")
finally:
self._started = False
self._builtin_supervisor = None
self._third_party_supervisor = None
self._plugin_path_cache.clear()
@property
def is_running(self) -> bool:
"""返回插件运行时是否处于启动状态。"""
return self._started
@property
@@ -181,11 +289,176 @@ class PluginRuntimeManager(
"""获取所有活跃的 Supervisor"""
return [s for s in (self._builtin_supervisor, self._third_party_supervisor) if s is not None]
def _build_registered_dependency_map(self) -> Dict[str, Set[str]]:
"""根据当前已注册插件构建全局依赖图。"""
dependency_map: Dict[str, Set[str]] = {}
for supervisor in self.supervisors:
for plugin_id, registration in getattr(supervisor, "_registered_plugins", {}).items():
dependency_map[plugin_id] = {
str(dependency or "").strip()
for dependency in getattr(registration, "dependencies", [])
if str(dependency or "").strip()
}
return dependency_map
@staticmethod
def _collect_reverse_dependents(
plugin_ids: Set[str],
dependency_map: Dict[str, Set[str]],
) -> Set[str]:
"""根据依赖图收集反向依赖闭包。"""
impacted_plugins: Set[str] = set(plugin_ids)
changed = True
while changed:
changed = False
for registered_plugin_id, dependencies in dependency_map.items():
if registered_plugin_id in impacted_plugins:
continue
if dependencies & impacted_plugins:
impacted_plugins.add(registered_plugin_id)
changed = True
return impacted_plugins
def _build_registered_supervisor_map(self) -> Dict[str, "PluginSupervisor"]:
"""构建当前已注册插件到所属 Supervisor 的映射。"""
return {
plugin_id: supervisor
for supervisor in self.supervisors
for plugin_id in supervisor.get_loaded_plugin_ids()
}
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
"""收集某个 Supervisor 可用的外部插件版本映射。"""
external_plugin_versions: Dict[str, str] = {}
for supervisor in self.supervisors:
if supervisor is target_supervisor:
continue
external_plugin_versions.update(supervisor.get_loaded_plugin_versions())
return external_plugin_versions
def _find_supervisor_by_plugin_directory(self, plugin_id: str) -> Optional["PluginSupervisor"]:
"""根据插件目录推断应负责该插件重载的 Supervisor。"""
for supervisor in self.supervisors:
if self._get_plugin_path_for_supervisor(supervisor, plugin_id) is not None:
return supervisor
return None
def _warn_skipped_cross_supervisor_reload(
self,
requested_loaded_plugin_ids: Set[str],
dependency_map: Dict[str, Set[str]],
supervisor_by_plugin: Dict[str, "PluginSupervisor"],
) -> None:
"""记录因跨 Supervisor 边界而未参与联动重载的插件。"""
if not requested_loaded_plugin_ids:
return
handled_plugin_ids: Set[str] = set()
for supervisor in self.supervisors:
local_requested_plugin_ids = {
plugin_id
for plugin_id in requested_loaded_plugin_ids
if supervisor_by_plugin.get(plugin_id) is supervisor
}
if not local_requested_plugin_ids:
continue
local_plugin_ids = set(supervisor.get_loaded_plugin_ids())
local_dependency_map = {
plugin_id: {
dependency
for dependency in dependency_map.get(plugin_id, set())
if dependency in local_plugin_ids
}
for plugin_id in local_plugin_ids
}
handled_plugin_ids.update(
self._collect_reverse_dependents(local_requested_plugin_ids, local_dependency_map)
)
impacted_plugin_ids = self._collect_reverse_dependents(requested_loaded_plugin_ids, dependency_map)
skipped_plugin_ids = sorted(impacted_plugin_ids - handled_plugin_ids)
if not skipped_plugin_ids:
return
logger.warning(
f"插件 {', '.join(sorted(requested_loaded_plugin_ids))} 存在跨 Supervisor 依赖方未联动重载: "
f"{', '.join(skipped_plugin_ids)}。当前仅在单个 Supervisor 内执行联动重载;"
"跨 Supervisor API 调用仍然可用。如需联动重载,请将相关插件放在同一个 Supervisor 内。"
)
async def reload_plugins_globally(self, plugin_ids: Sequence[str], reason: str = "manual") -> bool:
"""按 Supervisor 分组执行精确重载。
仅在单个 Supervisor 内执行依赖联动;跨 Supervisor 依赖方仅记录告警,
不再自动参与本次热重载。
"""
normalized_plugin_ids = [
normalized_plugin_id
for plugin_id in plugin_ids
if (normalized_plugin_id := str(plugin_id or "").strip())
]
if not normalized_plugin_ids:
return True
dependency_map = self._build_registered_dependency_map()
supervisor_by_plugin = self._build_registered_supervisor_map()
supervisor_roots: Dict["PluginSupervisor", List[str]] = {}
requested_loaded_plugin_ids: Set[str] = set()
missing_plugin_ids: List[str] = []
for plugin_id in normalized_plugin_ids:
supervisor = supervisor_by_plugin.get(plugin_id)
if supervisor is not None:
requested_loaded_plugin_ids.add(plugin_id)
else:
supervisor = self._find_supervisor_by_plugin_directory(plugin_id)
if supervisor is None:
missing_plugin_ids.append(plugin_id)
continue
if plugin_id not in supervisor_roots.setdefault(supervisor, []):
supervisor_roots[supervisor].append(plugin_id)
if missing_plugin_ids:
logger.warning(f"以下插件未找到可重载的 Supervisor已跳过: {', '.join(sorted(missing_plugin_ids))}")
self._warn_skipped_cross_supervisor_reload(
requested_loaded_plugin_ids=requested_loaded_plugin_ids,
dependency_map=dependency_map,
supervisor_by_plugin=supervisor_by_plugin,
)
success = True
for supervisor, root_plugin_ids in supervisor_roots.items():
if not root_plugin_ids:
continue
reloaded = await supervisor.reload_plugins(
plugin_ids=root_plugin_ids,
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
success = success and reloaded
return success and not missing_plugin_ids
async def notify_plugin_config_updated(
self,
plugin_id: str,
config_data: Optional[Dict[str, Any]] = None,
config_version: str = "",
config_scope: str = "self",
) -> bool:
"""向拥有该插件的 Supervisor 推送配置更新事件。
@@ -193,6 +466,7 @@ class PluginRuntimeManager(
plugin_id: 插件 ID
config_data: 可选的配置数据(如果为 None 则由 Supervisor 从磁盘加载)
config_version: 可选的配置版本字符串,供 Supervisor 进行版本控制
config_scope: 配置变更范围。
"""
if not self._started:
return False
@@ -209,23 +483,78 @@ class PluginRuntimeManager(
config_payload = (
config_data
if config_data is not None
else self._load_plugin_config_for_supervisor(plugin_id, plugin_dirs=sv._plugin_dirs)
else self._load_plugin_config_for_supervisor(sv, plugin_id)
)
await sv.notify_plugin_config_updated(
return await sv.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_payload,
config_version=config_version,
config_scope=config_scope,
)
return True
@staticmethod
def _normalize_config_reload_scopes(changed_scopes: Sequence[str]) -> tuple[str, ...]:
"""规范化配置热重载范围列表。
Args:
changed_scopes: 原始配置热重载范围列表。
Returns:
tuple[str, ...]: 去重后的有效配置范围元组。
"""
normalized_scopes: list[str] = []
for scope in changed_scopes:
normalized_scope = str(scope or "").strip().lower()
if normalized_scope not in {"bot", "model"}:
continue
if normalized_scope not in normalized_scopes:
normalized_scopes.append(normalized_scope)
return tuple(normalized_scopes)
async def _broadcast_config_reload(self, scope: str, config_data: Dict[str, Any]) -> None:
"""向订阅指定范围的插件广播配置热重载。
Args:
scope: 配置变更范围,仅支持 ``bot`` 或 ``model``。
config_data: 最新配置数据。
"""
for supervisor in self.supervisors:
for plugin_id in supervisor.get_config_reload_subscribers(scope):
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=config_data,
config_version="",
config_scope=scope,
)
if not delivered:
logger.warning(f"向插件 {plugin_id} 广播 {scope} 配置热重载失败")
async def _handle_main_config_reload(self, changed_scopes: Sequence[str]) -> None:
"""处理 bot/model 主配置热重载广播。
Args:
changed_scopes: 本次热重载命中的配置范围列表。
"""
if not self._started:
return
normalized_scopes = self._normalize_config_reload_scopes(changed_scopes)
if "bot" in normalized_scopes:
await self._broadcast_config_reload("bot", config_manager.get_global_config().model_dump(mode="json"))
if "model" in normalized_scopes:
await self._broadcast_config_reload("model", config_manager.get_model_config().model_dump(mode="json"))
# ─── 事件桥接 ──────────────────────────────────────────────
async def bridge_event(
self,
event_type_value: str,
message_dict: Optional[Dict[str, Any]] = None,
message_dict: Optional[MessageDict] = None,
extra_args: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]]]:
) -> Tuple[bool, Optional[MessageDict]]:
"""将事件分发到所有 Supervisor
Returns:
@@ -235,17 +564,23 @@ class PluginRuntimeManager(
return True, None
new_event_type: str = _EVENT_TYPE_MAP.get(event_type_value, event_type_value)
modified: Optional[Dict[str, Any]] = None
modified: Optional[MessageDict] = None
current_message: Optional["SessionMessage"] = (
PluginMessageUtils._build_session_message_from_dict(dict(message_dict))
if message_dict is not None
else None
)
for sv in self.supervisors:
try:
cont, mod = await sv.dispatch_event(
event_type=new_event_type,
message=modified or message_dict,
message=current_message,
extra_args=extra_args,
)
if mod is not None:
modified = mod
current_message = mod
modified = PluginMessageUtils._session_message_to_dict(mod)
if not cont:
return False, modified
except Exception as e:
@@ -295,6 +630,37 @@ class PluginRuntimeManager(
timeout_ms=timeout_ms,
)
async def try_send_message_via_platform_io(
self,
message: "SessionMessage",
) -> Optional[DeliveryBatch]:
"""尝试通过 Platform IO 中间层发送消息。
Args:
message: 待发送的内部会话消息。
Returns:
Optional[DeliveryBatch]: 若当前消息命中了至少一条发送路由,则返回
实际发送结果;若没有可用路由或 Platform IO 尚未启动,则返回 ``None``。
"""
if not self._started:
return None
platform_io_manager = get_platform_io_manager()
if not platform_io_manager.is_started:
return None
try:
route_key = platform_io_manager.build_route_key_from_message(message)
except Exception as exc:
logger.warning(f"根据消息构造 Platform IO 路由键失败: {exc}")
return None
if not platform_io_manager.resolve_drivers(route_key):
return None
return await platform_io_manager.send_message(message, route_key)
def _get_supervisors_for_plugin(self, plugin_id: str) -> List["PluginSupervisor"]:
"""返回当前持有指定插件的所有 Supervisor。
@@ -314,30 +680,38 @@ class PluginRuntimeManager(
raise RuntimeError(f"插件 {plugin_id} 同时存在于多个 Supervisor 中,无法安全路由")
return matches[0] if matches else None
@staticmethod
def _find_duplicate_plugin_ids(plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
async def load_plugin_globally(self, plugin_id: str, reason: str = "manual") -> bool:
"""加载或重载单个插件,并为其补齐跨 Supervisor 外部依赖。"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return False
try:
registered_supervisor = self._get_supervisor_for_plugin(normalized_plugin_id)
except RuntimeError:
return False
if registered_supervisor is not None:
return await self.reload_plugins_globally([normalized_plugin_id], reason=reason)
supervisor = self._find_supervisor_by_plugin_directory(normalized_plugin_id)
if supervisor is None:
return False
return await supervisor.reload_plugins(
plugin_ids=[normalized_plugin_id],
reason=reason,
external_available_plugins=self._build_external_available_plugins_for_supervisor(supervisor),
)
@classmethod
def _find_duplicate_plugin_ids(cls, plugin_dirs: List[Path]) -> Dict[str, List[Path]]:
"""扫描插件目录,找出被多个目录重复声明的插件 ID。"""
plugin_locations: Dict[str, List[Path]] = {}
for base_dir in plugin_dirs:
if not base_dir.is_dir():
continue
for entry in base_dir.iterdir():
if not entry.is_dir():
continue
manifest_path = entry / "_manifest.json"
plugin_path = entry / "plugin.py"
if not manifest_path.exists() or not plugin_path.exists():
continue
plugin_id = entry.name
try:
with open(manifest_path, "r", encoding="utf-8") as manifest_file:
manifest = json.load(manifest_file)
plugin_id = str(manifest.get("name", entry.name)).strip() or entry.name
except Exception:
continue
plugin_locations.setdefault(plugin_id, []).append(entry)
validator = ManifestValidator()
for plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs):
plugin_locations.setdefault(manifest.id, []).append(plugin_path)
return {
plugin_id: sorted(dict.fromkeys(paths), key=lambda p: str(p))
@@ -370,6 +744,7 @@ class PluginRuntimeManager(
async def _stop_plugin_file_watcher(self) -> None:
"""停止插件文件监视器,并清理所有已注册订阅。"""
if self._plugin_file_watcher is None:
self._plugin_path_cache.clear()
return
for _plugin_id, (_config_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
self._plugin_file_watcher.unsubscribe(subscription_id)
@@ -379,12 +754,79 @@ class PluginRuntimeManager(
self._plugin_source_watcher_subscription_id = None
await self._plugin_file_watcher.stop()
self._plugin_file_watcher = None
self._plugin_path_cache.clear()
def _iter_plugin_dirs(self) -> Iterable[Path]:
"""迭代所有 Supervisor 当前管理的插件根目录。"""
for supervisor in self.supervisors:
yield from getattr(supervisor, "_plugin_dirs", [])
@staticmethod
def _iter_candidate_plugin_paths(plugin_dirs: Iterable[Path]) -> Iterable[Path]:
"""迭代所有可能的插件目录路径。
Args:
plugin_dirs: 一个或多个插件根目录。
Yields:
Path: 单个插件目录路径。
"""
for plugin_dir in plugin_dirs:
plugin_root = Path(plugin_dir).resolve()
if not plugin_root.is_dir():
continue
for entry in plugin_root.iterdir():
if entry.is_dir():
yield entry.resolve()
def _read_plugin_id_from_plugin_path(self, plugin_path: Path) -> Optional[str]:
"""从单个插件目录中读取 manifest 声明的插件 ID。
Args:
plugin_path: 单个插件目录路径。
Returns:
Optional[str]: 解析成功时返回插件 ID否则返回 ``None``。
"""
return self._manifest_validator.read_plugin_id_from_plugin_path(plugin_path)
def _iter_discovered_plugin_paths(self, plugin_dirs: Iterable[Path]) -> Iterable[Tuple[str, Path]]:
"""迭代目录中可解析到的插件 ID 与实际目录路径。
Args:
plugin_dirs: 一个或多个插件根目录。
Yields:
Tuple[str, Path]: ``(plugin_id, plugin_path)`` 二元组。
"""
for plugin_path in self._iter_candidate_plugin_paths(plugin_dirs):
if plugin_id := self._read_plugin_id_from_plugin_path(plugin_path):
yield plugin_id, plugin_path
def _get_plugin_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""为指定 Supervisor 定位某个插件的实际目录。
Args:
supervisor: 目标 Supervisor。
plugin_id: 插件 ID。
Returns:
Optional[Path]: 插件目录路径;未找到时返回 ``None``。
"""
cached_path = self._plugin_path_cache.get(plugin_id)
if cached_path is not None:
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
if self._plugin_dir_matches(cached_path, Path(plugin_dir)):
return cached_path
for candidate_plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
if candidate_plugin_id != plugin_id:
continue
self._plugin_path_cache[plugin_id] = plugin_path
return plugin_path
return None
def _refresh_plugin_config_watch_subscriptions(self) -> None:
"""按当前已注册插件集合刷新 config.toml 的单插件订阅。
@@ -394,7 +836,11 @@ class PluginRuntimeManager(
if self._plugin_file_watcher is None:
return
desired_config_paths = dict(self._iter_registered_plugin_config_paths())
desired_plugin_paths = dict(self._iter_registered_plugin_paths())
self._plugin_path_cache = desired_plugin_paths.copy()
desired_config_paths = {
plugin_id: plugin_path / "config.toml" for plugin_id, plugin_path in desired_plugin_paths.items()
}
for plugin_id, (_old_path, subscription_id) in list(self._plugin_config_watcher_subscriptions.items()):
if desired_config_paths.get(plugin_id) == self._plugin_config_watcher_subscriptions[plugin_id][0]:
@@ -418,28 +864,35 @@ class PluginRuntimeManager(
"""为指定插件生成配置文件变更回调。"""
async def _callback(changes: Sequence[FileChange]) -> None:
"""将 watcher 事件转发到指定插件的配置处理逻辑。
Args:
changes: 当前批次收集到的文件变更列表。
"""
await self._handle_plugin_config_changes(plugin_id, changes)
return _callback
def _iter_registered_plugin_config_paths(self) -> Iterable[Tuple[str, Path]]:
"""迭代当前所有已注册插件的 config.toml 路径。"""
def _iter_registered_plugin_paths(self) -> Iterable[Tuple[str, Path]]:
"""迭代当前所有已注册插件的实际目录路径。"""
for supervisor in self.supervisors:
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
if config_path := self._get_plugin_config_path_for_supervisor(supervisor, plugin_id):
yield plugin_id, config_path
if plugin_path := self._get_plugin_path_for_supervisor(supervisor, plugin_id):
yield plugin_id, plugin_path
def _get_plugin_config_path_for_supervisor(self, supervisor: Any, plugin_id: str) -> Optional[Path]:
"""从指定 Supervisor 的插件目录中定位某个插件的 config.toml。"""
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
plugin_dir = Path(plugin_dir)
plugin_path = plugin_dir.resolve() / plugin_id
if plugin_path.is_dir():
return plugin_path / "config.toml"
return None
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
return None if plugin_path is None else plugin_path / "config.toml"
async def _handle_plugin_config_changes(self, plugin_id: str, changes: Sequence[FileChange]) -> None:
"""处理单个插件配置文件变化,并仅向目标插件推送配置更新。"""
"""处理单个插件配置文件变化,并定向派发自配置更新。
Args:
plugin_id: 发生配置变更的插件 ID。
changes: 当前批次收集到的配置文件变更列表。
"""
if not self._started or not changes:
return
@@ -453,18 +906,24 @@ class PluginRuntimeManager(
return
try:
await supervisor.notify_plugin_config_updated(
config_payload = self._load_plugin_config_for_supervisor(supervisor, plugin_id)
delivered = await supervisor.notify_plugin_config_updated(
plugin_id=plugin_id,
config_data=self._load_plugin_config_for_supervisor(plugin_id, getattr(supervisor, "_plugin_dirs", [])),
config_data=config_payload,
config_version="",
config_scope="self",
)
if not delivered:
logger.warning(f"插件 {plugin_id} 配置文件变更后通知失败")
except Exception as exc:
logger.warning(f"插件 {plugin_id} 配置热更新通知失败: {exc}")
logger.warning(f"插件 {plugin_id} 配置文件变更处理失败: {exc}")
async def _handle_plugin_source_changes(self, changes: Sequence[FileChange]) -> None:
"""处理插件源码相关变化。
这里仅负责源码、清单等会影响插件装载状态的文件;配置文件的变化会由
单独的 per-plugin watcher 处理,避免把单插件配置更新放大成全量 reload。
单独的 per-plugin watcher 处理,并定向派发给目标插件的
``on_config_update()``,避免放大成不必要的跨插件 reload。
"""
if not self._started or not changes:
return
@@ -477,7 +936,7 @@ class PluginRuntimeManager(
logger.error(f"检测到重复插件 ID跳过本次插件热重载: {details}")
return
reload_supervisors: List[Any] = []
changed_plugin_ids: List[str] = []
changed_paths = [change.path.resolve() for change in changes]
for supervisor in self.supervisors:
@@ -485,13 +944,12 @@ class PluginRuntimeManager(
plugin_id = self._match_plugin_id_for_supervisor(supervisor, path)
if plugin_id is None:
continue
if (path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py") and supervisor not in reload_supervisors:
reload_supervisors.append(supervisor)
if path.name in {"plugin.py", "_manifest.json"} or path.suffix == ".py":
if plugin_id not in changed_plugin_ids:
changed_plugin_ids.append(plugin_id)
for supervisor in reload_supervisors:
await supervisor.reload_plugins(reason="file_watcher")
if reload_supervisors:
if changed_plugin_ids:
await self.reload_plugins_globally(changed_plugin_ids, reason="file_watcher")
self._refresh_plugin_config_watch_subscriptions()
@staticmethod
@@ -502,36 +960,47 @@ class PluginRuntimeManager(
def _match_plugin_id_for_supervisor(self, supervisor: Any, path: Path) -> Optional[str]:
"""根据变更路径为指定 Supervisor 推断受影响的插件 ID。"""
for plugin_id, _reg in getattr(supervisor, "_registered_plugins", {}).items():
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
plugin_dir = Path(plugin_dir)
candidate_dir = plugin_dir.resolve() / plugin_id
if path == candidate_dir or path.is_relative_to(candidate_dir):
return plugin_id
resolved_path = path.resolve()
for plugin_id in getattr(supervisor, "_registered_plugins", {}).keys():
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
if plugin_path is not None and (resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path)):
return plugin_id
for plugin_id, plugin_path in self._plugin_path_cache.items():
if not any(self._plugin_dir_matches(plugin_path, Path(plugin_dir)) for plugin_dir in getattr(supervisor, "_plugin_dirs", [])):
continue
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
return plugin_id
for plugin_id, plugin_path in self._iter_discovered_plugin_paths(getattr(supervisor, "_plugin_dirs", [])):
if resolved_path == plugin_path or resolved_path.is_relative_to(plugin_path):
self._plugin_path_cache[plugin_id] = plugin_path
return plugin_id
for plugin_dir in getattr(supervisor, "_plugin_dirs", []):
plugin_dir = Path(plugin_dir)
plugin_root = plugin_dir.resolve()
if self._plugin_dir_matches(path, plugin_dir) and (relative_parts := path.relative_to(plugin_root).parts):
return relative_parts[0]
return None
@staticmethod
def _load_plugin_config_for_supervisor(plugin_id: str, plugin_dirs: Iterable[Path]) -> Dict[str, Any]:
def _load_plugin_config_for_supervisor(self, supervisor: Any, plugin_id: str) -> Dict[str, Any]:
"""从给定插件目录集合中读取目标插件的配置内容。"""
for plugin_dir in plugin_dirs:
plugin_path = plugin_dir.resolve() / plugin_id
if plugin_path.is_dir():
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {}
with open(config_path, "r", encoding="utf-8") as handle:
return tomlkit.load(handle).unwrap()
return {}
plugin_path = self._get_plugin_path_for_supervisor(supervisor, plugin_id)
if plugin_path is None:
return {}
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {}
with open(config_path, "r", encoding="utf-8") as handle:
return tomlkit.load(handle).unwrap()
# ─── 能力实现注册 ──────────────────────────────────────────
def _register_capability_impls(self, supervisor: "PluginSupervisor") -> None:
"""向指定 Supervisor 注册主程序能力实现。
Args:
supervisor: 需要注册能力实现的目标 Supervisor。
"""
register_capability_impls(self, supervisor)

View File

@@ -7,52 +7,52 @@
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
import logging as stdlib_logging
import time
from pydantic import BaseModel, Field
# ─── 协议常量 ──────────────────────────────────────────────────────
PROTOCOL_VERSION = "1.0"
# ====== 协议常量 ======
PROTOCOL_VERSION = "1.0.0"
# 支持的 SDK 版本范围Host 在握手时校验)
MIN_SDK_VERSION = "1.0.0"
MAX_SDK_VERSION = "1.99.99"
# ─── 消息类型 ──────────────────────────────────────────────────────
MAX_SDK_VERSION = "2.99.99"
# ====== 消息类型 ======
class MessageType(str, Enum):
"""RPC 消息类型"""
REQUEST = "request"
RESPONSE = "response"
EVENT = "event"
BROADCAST = "broadcast"
# ─── 请求 ID 生成器 ───────────────────────────────────────────────
class ConfigReloadScope(str, Enum):
"""配置热重载范围。"""
SELF = "self"
BOT = "bot"
MODEL = "model"
# ====== 请求 ID 生成器 ======
class RequestIdGenerator:
"""单调递增 int64 请求 ID 生成器(线程安全由调用方保证或使用 asyncio"""
"""单调递增 int64 请求 ID 生成器"""
def __init__(self, start: int = 1) -> None:
self._counter = start
def next(self) -> int:
async def next(self) -> int:
current = self._counter
self._counter += 1
return current
# ─── Envelope 模型 ─────────────────────────────────────────────────
# ====== Envelope 模型 ======
class Envelope(BaseModel):
"""RPC 统一信封
"""RPC 统一消息封装
所有 Host <-> Runner 消息均封装为此格式。
序列化流程Envelope -> .model_dump() -> MsgPack encode
@@ -60,15 +60,23 @@ class Envelope(BaseModel):
"""
protocol_version: str = Field(default=PROTOCOL_VERSION, description="协议版本")
"""协议版本"""
request_id: int = Field(description="单调递增请求 ID")
"""单调递增请求 ID"""
message_type: MessageType = Field(description="消息类型")
"""消息类型"""
method: str = Field(default="", description="RPC 方法名")
"""RPC 方法名"""
plugin_id: str = Field(default="", description="目标插件 ID")
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳(ms)")
timeout_ms: int = Field(default=30000, description="相对超时(ms)")
generation: int = Field(default=0, description="Runner generation 编号")
"""目标插件 ID"""
timestamp_ms: int = Field(default_factory=lambda: int(time.time() * 1000), description="发送时间戳 (ms)")
"""发送时间戳 (ms)"""
timeout_ms: int = Field(default=30000, description="相对超时 (ms)")
"""相对超时 (ms)"""
payload: Dict[str, Any] = Field(default_factory=dict, description="业务数据")
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息(仅 response)")
"""业务数据"""
error: Optional[Dict[str, Any]] = Field(default=None, description="错误信息 (仅 response)")
"""错误信息 (仅 response)"""
def is_request(self) -> bool:
return self.message_type == MessageType.REQUEST
@@ -76,8 +84,8 @@ class Envelope(BaseModel):
def is_response(self) -> bool:
return self.message_type == MessageType.RESPONSE
def is_event(self) -> bool:
return self.message_type == MessageType.EVENT
def is_broadcast(self) -> bool:
return self.message_type == MessageType.BROADCAST
def make_response(
self, payload: Optional[Dict[str, Any]] = None, error: Optional[Dict[str, Any]] = None
@@ -89,7 +97,6 @@ class Envelope(BaseModel):
message_type=MessageType.RESPONSE,
method=self.method,
plugin_id=self.plugin_id,
generation=self.generation,
payload=payload or {},
error=error,
)
@@ -105,153 +112,302 @@ class Envelope(BaseModel):
)
# ─── 握手消息 ──────────────────────────────────────────────────────
# ====== 握手请求与响应 ======
class HelloPayload(BaseModel):
"""runner.hello 握手请求 payload"""
runner_id: str = Field(description="Runner 进程唯一标识")
"""Runner 进程唯一标识"""
sdk_version: str = Field(description="SDK 版本号")
"""SDK 版本号"""
session_token: str = Field(description="一次性会话令牌")
"""一次性会话令牌"""
class HelloResponsePayload(BaseModel):
"""runner.hello 握手响应 payload"""
accepted: bool = Field(description="是否接受连接")
"""是否接受连接"""
host_version: str = Field(default="", description="Host 版本号")
assigned_generation: int = Field(default=0, description="分配的 generation 编")
reason: str = Field(default="", description="拒绝原因(若 accepted=False)")
# ─── 组件注册消息 ──────────────────────────────────────────────────
"""Host 版本"""
reason: str = Field(default="", description="拒绝原因 (若 accepted=False)")
"""拒绝原因 (若 `accepted`=`False`)"""
# ====== 组件注册消息 ======
class ComponentDeclaration(BaseModel):
"""单个组件声明"""
name: str = Field(description="组件名称")
component_type: str = Field(description="组件类型: action/command/tool/event_handler")
"""组件名称"""
component_type: str = Field(
description="组件类型action/command/tool/event_handler/hook_handler/message_gateway"
)
"""组件类型:`action`/`command`/`tool`/`event_handler`/`hook_handler`/`message_gateway`"""
plugin_id: str = Field(description="所属插件 ID")
"""所属插件 ID"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="组件元数据")
"""组件元数据"""
class RegisterComponentsPayload(BaseModel):
"""plugin.register_components 请求 payload"""
class RegisterPluginPayload(BaseModel):
"""插件组件注册请求载荷。
该模型同时用于 ``plugin.register_components`` 与兼容旧命名的
``plugin.register_plugin`` 请求。
"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
"""插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
"""组件列表"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
"""插件级依赖插件 ID 列表"""
config_reload_subscriptions: List[str] = Field(default_factory=list, description="订阅的全局配置热重载范围")
"""订阅的全局配置热重载范围"""
class BootstrapPluginPayload(BaseModel):
"""plugin.bootstrap 请求 payload"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
plugin_version: str = Field(default="1.0.0", description="插件版本")
"""插件版本"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
# ─── 调用消息 ──────────────────────────────────────────────────────
# ====== 插件调用请求和响应 ======
class InvokePayload(BaseModel):
"""plugin.invoke_* 请求 payload"""
"""plugin.invoke.* 请求 payload"""
component_name: str = Field(description="要调用的组件名称")
"""要调用的组件名称"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
"""调用参数"""
class InvokeResultPayload(BaseModel):
"""plugin.invoke_* 响应 payload"""
"""plugin.invoke.* 响应 payload"""
success: bool = Field(description="是否成功")
"""是否成功"""
result: Any = Field(default=None, description="返回值")
"""返回值"""
# ─── 能力调用消息 ──────────────────────────────────────────────────
# ====== 能力调用消息 ======
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)"""
capability: str = Field(description="能力名称,如 send.text, db.query")
"""能力名称,如 send.text, db.query"""
args: Dict[str, Any] = Field(default_factory=dict, description="调用参数")
"""调用参数"""
class CapabilityResponsePayload(BaseModel):
"""cap.* 响应 payload"""
success: bool = Field(description="是否成功")
"""是否成功"""
result: Any = Field(default=None, description="返回值")
"""返回值"""
# ─── 健康检查 ──────────────────────────────────────────────────────
# ====== 健康检查 ======
class HealthPayload(BaseModel):
"""plugin.health 响应 payload"""
healthy: bool = Field(description="是否健康")
"""是否健康"""
loaded_plugins: List[str] = Field(default_factory=list, description="已加载的插件列表")
uptime_ms: int = Field(default=0, description="运行时长(ms)")
"""已加载的插件列表"""
uptime_ms: int = Field(default=0, description="运行时长 (ms)")
"""运行时长 (ms)"""
class RunnerReadyPayload(BaseModel):
"""runner.ready 请求 payload"""
loaded_plugins: List[str] = Field(default_factory=list, description="已完成初始化的插件列表")
"""已完成初始化的插件列表"""
failed_plugins: List[str] = Field(default_factory=list, description="初始化失败的插件列表")
"""初始化失败的插件列表"""
# ─── 配置更新 ──────────────────────────────────────────────────────
# Host 侧现已支持配置更新推送:
# - 总配置热重载完成后PluginRuntimeManager 会向已加载插件推送配置更新事件。
# - 插件目录下的 config.toml 变化由现有 FileWatcher 监听并转发为 plugin.config_updated。
# ====== 配置更新 ======
class ConfigUpdatedPayload(BaseModel):
"""plugin.config_updated 事件 payload"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
config_scope: ConfigReloadScope = Field(description="配置变更范围")
"""配置变更范围"""
config_version: str = Field(description="新配置版本")
"""新配置版本"""
config_data: Dict[str, Any] = Field(default_factory=dict, description="配置内容")
"""配置内容"""
# ─── 关停 ──────────────────────────────────────────────────────────
# ====== 关停 ======
class ShutdownPayload(BaseModel):
"""plugin.shutdown / plugin.prepare_shutdown payload"""
reason: str = Field(default="normal", description="关停原因")
drain_timeout_ms: int = Field(default=5000, description="排空超时(ms)")
"""关停原因"""
drain_timeout_ms: int = Field(default=5000, description="排空超时 (ms)")
"""排空超时 (ms)"""
# ─── 日志传输 ──────────────────────────────────────────────────────
class UnregisterPluginPayload(BaseModel):
"""插件注销请求载荷。"""
plugin_id: str = Field(description="插件 ID")
"""插件 ID"""
reason: str = Field(default="manual", description="注销原因")
"""注销原因"""
class ReloadPluginPayload(BaseModel):
"""插件重载请求载荷。"""
plugin_id: str = Field(description="目标插件 ID")
"""目标插件 ID"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
external_available_plugins: Dict[str, str] = Field(
default_factory=dict,
description="可视为已满足的外部依赖插件版本映射",
)
"""可视为已满足的外部依赖插件版本映射"""
class ReloadPluginsPayload(BaseModel):
"""批量插件重载请求载荷。"""
plugin_ids: List[str] = Field(default_factory=list, description="目标插件 ID 列表")
"""目标插件 ID 列表"""
reason: str = Field(default="manual", description="重载原因")
"""重载原因"""
external_available_plugins: Dict[str, str] = Field(
default_factory=dict,
description="可视为已满足的外部依赖插件版本映射",
)
"""可视为已满足的外部依赖插件版本映射"""
class ReloadPluginResultPayload(BaseModel):
"""插件重载结果载荷。"""
success: bool = Field(description="是否重载成功")
"""是否重载成功"""
requested_plugin_id: str = Field(description="请求重载的插件 ID")
"""请求重载的插件 ID"""
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""
class ReloadPluginsResultPayload(BaseModel):
"""批量插件重载结果载荷。"""
success: bool = Field(description="是否重载成功")
"""是否重载成功"""
requested_plugin_ids: List[str] = Field(default_factory=list, description="请求重载的插件 ID 列表")
"""请求重载的插件 ID 列表"""
reloaded_plugins: List[str] = Field(default_factory=list, description="成功完成重载的插件列表")
"""成功完成重载的插件列表"""
unloaded_plugins: List[str] = Field(default_factory=list, description="本次已卸载的插件列表")
"""本次已卸载的插件列表"""
failed_plugins: Dict[str, str] = Field(default_factory=dict, description="重载失败的插件及原因")
"""重载失败的插件及原因"""
class MessageGatewayStateUpdatePayload(BaseModel):
"""消息网关运行时状态更新载荷。"""
gateway_name: str = Field(description="消息网关组件名称")
"""消息网关组件名称"""
ready: bool = Field(description="当前链路是否已经就绪")
"""当前链路是否已经就绪"""
platform: str = Field(default="", description="当前链路负责的平台名称")
"""当前链路负责的平台名称"""
account_id: str = Field(default="", description="当前链路对应的账号 ID 或 self_id")
"""当前链路对应的账号 ID 或 self_id"""
scope: str = Field(default="", description="当前链路对应的可选路由作用域")
"""当前链路对应的可选路由作用域"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的运行时状态元数据")
"""可选的运行时状态元数据"""
class MessageGatewayStateUpdateResultPayload(BaseModel):
"""消息网关运行时状态更新结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次状态更新")
"""Host 是否接受了本次状态更新"""
ready: bool = Field(description="Host 记录的当前就绪状态")
"""Host 记录的当前就绪状态"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="当前生效的路由键")
"""当前生效的路由键"""
class RouteMessagePayload(BaseModel):
"""消息网关向 Host 路由外部消息的请求载荷。"""
gateway_name: str = Field(description="接收消息的网关组件名称")
"""接收消息的网关组件名称"""
message: Dict[str, Any] = Field(description="符合 MessageDict 结构的标准消息字典")
"""符合 MessageDict 结构的标准消息字典"""
route_metadata: Dict[str, Any] = Field(default_factory=dict, description="可选的路由辅助元数据")
"""可选的路由辅助元数据"""
external_message_id: str = Field(default="", description="可选的外部平台消息 ID")
"""可选的外部平台消息 ID"""
dedupe_key: str = Field(default="", description="可选的显式去重键")
"""可选的显式去重键"""
class ReceiveExternalMessageResultPayload(BaseModel):
"""外部消息注入结果载荷。"""
accepted: bool = Field(description="Host 是否接受了本次消息注入")
"""Host 是否接受了本次消息注入"""
route_key: Dict[str, Any] = Field(default_factory=dict, description="本次消息使用的归一路由键")
"""本次消息使用的归一路由键"""
RegisterPluginPayload.model_rebuild()
# ====== 日志传输 ======
class LogEntry(BaseModel):
"""单条日志记录Runner → Host 传输格式)"""
timestamp_ms: int = Field(
description="日志时间戳Unix epoch 毫秒",
)
level: int = Field(
description=("stdlib logging 整数级别: 10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"),
)
logger_name: str = Field(
description="Logger 名称,如 plugin.my_plugin.submodule",
)
message: str = Field(
description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)",
)
timestamp_ms: int = Field(description="日志时间戳Unix epoch 毫秒")
"""日志时间戳Unix epoch 毫秒"""
level: int = Field(description="stdlib logging 整数级别10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL")
"""stdlib logging 整数级别10=DEBUG, 20=INFO, 30=WARNING, 40=ERROR, 50=CRITICAL"""
logger_name: str = Field(description="Logger 名称,如 plugin.my_plugin.submodule")
"""Logger 名称,如 plugin.my_plugin.submodule"""
message: str = Field(description="经 Formatter 格式化后的完整日志消息(含 exc_info 文本)")
"""经 Formatter 格式化后的完整日志消息(含 exc_info 文本)"""
exception_text: str = Field(
default="",
description="原始异常摘要exc_text供结构化消费已嵌入 message 中",
)
"""原始异常摘要exc_text供结构化消费已嵌入 message 中"""
log_color_in_hex: Optional[str] = Field(default=None, description="日志颜色的十六进制字符串(如 #RRGGBB")
@property
def levelname(self) -> str:
@@ -262,6 +418,5 @@ class LogEntry(BaseModel):
class LogBatchPayload(BaseModel):
"""runner.log_batch 事件 payloadRunner 端向 Host 批量推送日志记录"""
entries: List[LogEntry] = Field(
description="本批次日志记录列表,按时间升序排列",
)
entries: List[LogEntry] = Field(description="本批次日志记录列表,按时间升序排列")
"""本批次日志记录列表,按时间升序排列"""

View File

@@ -18,17 +18,17 @@ class ErrorCode(str, Enum):
E_TIMEOUT = "E_TIMEOUT"
E_BAD_PAYLOAD = "E_BAD_PAYLOAD"
E_PROTOCOL_MISMATCH = "E_PROTOCOL_MISMATCH"
E_SHUTTING_DOWN = "E_SHUTTING_DOWN"
# 权限与策略
E_UNAUTHORIZED = "E_UNAUTHORIZED"
E_METHOD_NOT_ALLOWED = "E_METHOD_NOT_ALLOWED"
E_BACKPRESSURE = "E_BACKPRESSURE"
E_BACK_PRESSURE = "E_BACK_PRESSURE"
E_HOST_OVERLOADED = "E_HOST_OVERLOADED"
# 插件生命周期
E_PLUGIN_CRASHED = "E_PLUGIN_CRASHED"
E_PLUGIN_NOT_FOUND = "E_PLUGIN_NOT_FOUND"
E_GENERATION_MISMATCH = "E_GENERATION_MISMATCH"
E_RELOAD_IN_PROGRESS = "E_RELOAD_IN_PROGRESS"
# 能力调用
@@ -65,3 +65,13 @@ class RPCError(Exception):
message=data.get("message", ""),
details=data.get("details", {}),
)
@classmethod
def from_exception(cls, exception: Exception, code_mapping: Optional[Dict[type[Exception], ErrorCode]] = None):
if isinstance(exception, cls):
return exception
if code_mapping:
for exception_type, code in code_mapping.items():
if isinstance(exception, exception_type):
return cls(code=code, message=str(exception))
return cls(ErrorCode.E_UNKNOWN, str(exception))

View File

@@ -66,6 +66,12 @@ class RunnerIPCLogHandler(logging.Handler):
ALLOWED_LOGGER_PREFIXES: tuple[str, ...] = ("plugin.", "plugin_runtime.", "_maibot_plugin_")
def __init__(self) -> None:
"""初始化 Runner 端日志转发处理器。
创建有界日志缓冲区,并准备与 RPC 客户端绑定的后台刷新任务。
此时不会启动任何异步任务;真正开始转发要等到 :meth:`start`
被调用后才会发生。
"""
super().__init__()
# deque(maxlen=N): append/popleft 在 CPython GIL 保护下线程安全
self._buffer: collections.deque[LogEntry] = collections.deque(maxlen=self.QUEUE_MAX)

File diff suppressed because it is too large Load Diff

View File

@@ -13,16 +13,16 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Tuple
import contextlib
import importlib
import importlib.util
import json
import os
import re
import sys
from src.common.logger import get_logger
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
from src.plugin_runtime.runner.manifest_validator import ManifestValidator, PluginManifest
logger = get_logger("plugin_runtime.runner.plugin_loader")
PluginCandidate = Tuple[Path, Dict[str, Any], Path]
PluginCandidate = Tuple[Path, PluginManifest, Path]
class PluginMeta:
@@ -32,28 +32,28 @@ class PluginMeta:
self,
plugin_id: str,
plugin_dir: str,
module_name: str,
plugin_instance: Any,
manifest: Dict[str, Any],
manifest: PluginManifest,
) -> None:
"""初始化插件元数据。
Args:
plugin_id: 插件 ID。
plugin_dir: 插件目录绝对路径。
module_name: 插件入口模块名。
plugin_instance: 插件实例对象。
manifest: 解析后的强类型 Manifest。
"""
self.plugin_id = plugin_id
self.plugin_dir = plugin_dir
self.module_name = module_name
self.instance = plugin_instance
self.manifest = manifest
self.version = manifest.get("version", "1.0.0")
self.capabilities_required = manifest.get("capabilities", [])
self.dependencies: List[str] = self._extract_dependencies(manifest)
@staticmethod
def _extract_dependencies(manifest: Dict[str, Any]) -> List[str]:
raw = manifest.get("dependencies", [])
result: List[str] = []
for dep in raw:
if isinstance(dep, str):
result.append(dep.strip())
elif isinstance(dep, dict):
if name := str(dep.get("name", "")).strip():
result.append(name)
return result
self.version = manifest.version
self.capabilities_required = list(manifest.capabilities)
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
self.component_handlers: Dict[str, str] = {}
class PluginLoader:
@@ -66,30 +66,52 @@ class PluginLoader:
"""
def __init__(self, host_version: str = "") -> None:
"""初始化插件加载器。
Args:
host_version: Host 版本号,用于 manifest 兼容性校验。
"""
self._loaded_plugins: Dict[str, PluginMeta] = {}
self._failed_plugins: Dict[str, str] = {}
self._manifest_validator = ManifestValidator(host_version=host_version)
self._compat_hook_installed = False
def discover_and_load(self, plugin_dirs: List[str]) -> List[PluginMeta]:
"""扫描多个目录并加载所有插件(含依赖排序和 manifest 校验)
def discover_and_load(
self,
plugin_dirs: List[str],
extra_available: Optional[Dict[str, str]] = None,
) -> List[PluginMeta]:
"""扫描多个目录并加载所有插件。
Args:
plugin_dirs: 插件目录列表
plugin_dirs: 插件目录列表
extra_available: 额外视为已满足的外部依赖插件版本映射。
Returns:
成功加载的插件元数据列表按依赖顺序
List[PluginMeta]: 成功加载的插件元数据列表按依赖顺序排列。
"""
candidates, duplicate_candidates = self._discover_candidates(plugin_dirs)
self._record_duplicate_candidates(duplicate_candidates)
# 第二阶段:依赖解析(拓扑排序)
load_order, failed_deps = self._resolve_dependencies(candidates)
load_order, failed_deps = self._resolve_dependencies(candidates, extra_available=extra_available)
self._record_failed_dependencies(failed_deps)
# 第三阶段:按依赖顺序加载
return self._load_plugins_in_order(load_order, candidates)
def discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
"""扫描插件目录并返回候选插件。
Args:
plugin_dirs: 需要扫描的插件根目录列表。
Returns:
Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
候选插件映射和重复插件 ID 冲突映射。
"""
return self._discover_candidates(plugin_dirs)
def _discover_candidates(self, plugin_dirs: List[str]) -> Tuple[Dict[str, PluginCandidate], Dict[str, List[Path]]]:
"""扫描插件目录并收集候选插件。"""
candidates: Dict[str, PluginCandidate] = {}
@@ -123,26 +145,17 @@ class PluginLoader:
def _discover_single_candidate(self, plugin_dir: Path) -> Optional[Tuple[str, PluginCandidate]]:
"""发现并校验单个插件目录。"""
manifest_path = plugin_dir / "_manifest.json"
plugin_path = plugin_dir / "plugin.py"
if not manifest_path.exists() or not plugin_path.exists():
if not plugin_path.exists():
return None
try:
with manifest_path.open("r", encoding="utf-8") as manifest_file:
manifest: Dict[str, Any] = json.load(manifest_file)
except Exception as e:
self._failed_plugins[plugin_dir.name] = f"manifest 解析失败: {e}"
logger.error(f"插件 {plugin_dir.name} manifest 解析失败: {e}")
return None
if not self._manifest_validator.validate(manifest):
manifest = self._manifest_validator.load_from_plugin_path(plugin_dir)
if manifest is None:
errors = "; ".join(self._manifest_validator.errors)
self._failed_plugins[plugin_dir.name] = f"manifest 校验失败: {errors}"
return None
plugin_id = str(manifest.get("name", plugin_dir.name)).strip() or plugin_dir.name
plugin_id = manifest.id
return plugin_id, (plugin_dir, manifest, plugin_path)
def _record_duplicate_candidates(self, duplicate_candidates: Dict[str, List[Path]]) -> None:
@@ -170,7 +183,6 @@ class PluginLoader:
plugin_dir, manifest, plugin_path = candidates[plugin_id]
try:
if meta := self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path):
self._loaded_plugins[meta.plugin_id] = meta
results.append(meta)
except Exception as e:
self._failed_plugins[plugin_id] = str(e)
@@ -182,45 +194,193 @@ class PluginLoader:
"""获取已加载的插件"""
return self._loaded_plugins.get(plugin_id)
def set_loaded_plugin(self, meta: PluginMeta) -> None:
"""登记一个已经完成初始化的插件。
Args:
meta: 待登记的插件元数据。
"""
self._loaded_plugins[meta.plugin_id] = meta
def remove_loaded_plugin(self, plugin_id: str) -> Optional[PluginMeta]:
"""移除一个已加载插件的元数据。
Args:
plugin_id: 待移除的插件 ID。
Returns:
Optional[PluginMeta]: 被移除的插件元数据;不存在时返回 ``None``。
"""
return self._loaded_plugins.pop(plugin_id, None)
def purge_plugin_modules(self, plugin_id: str, plugin_dir: str) -> List[str]:
"""清理指定插件目录下的模块缓存。
Args:
plugin_id: 插件 ID。
plugin_dir: 插件目录绝对路径。
Returns:
List[str]: 已从 ``sys.modules`` 中移除的模块名列表。
"""
removed_modules: List[str] = []
plugin_path = Path(plugin_dir).resolve()
synthetic_module_name = self._build_safe_module_name(plugin_id)
for module_name, module in list(sys.modules.items()):
if module_name == synthetic_module_name:
removed_modules.append(module_name)
sys.modules.pop(module_name, None)
continue
module_file = getattr(module, "__file__", None)
if module_file is None:
continue
try:
module_path = Path(module_file).resolve()
except Exception:
continue
if module_path.is_relative_to(plugin_path):
removed_modules.append(module_name)
sys.modules.pop(module_name, None)
importlib.invalidate_caches()
return removed_modules
@staticmethod
def _build_safe_module_name(plugin_id: str) -> str:
"""将插件 ID 转换为可用于动态导入的安全模块名。
Args:
plugin_id: 原始插件 ID。
Returns:
str: 仅包含字母、数字和下划线的合成模块名。
"""
normalized_plugin_id = re.sub(r"[^0-9A-Za-z_]", "_", str(plugin_id or "").strip())
if normalized_plugin_id and normalized_plugin_id[0].isdigit():
normalized_plugin_id = f"_{normalized_plugin_id}"
return f"_maibot_plugin_{normalized_plugin_id or 'plugin'}"
def list_plugins(self) -> List[str]:
"""列出所有已加载的插件 ID"""
return list(self._loaded_plugins.keys())
@property
def failed_plugins(self) -> Dict[str, str]:
"""返回当前记录的失败插件原因映射。"""
return dict(self._failed_plugins)
@property
def manifest_validator(self) -> ManifestValidator:
"""返回当前加载器持有的 Manifest 校验器。
Returns:
ManifestValidator: 当前使用的 Manifest 校验器实例。
"""
return self._manifest_validator
# ──── 依赖解析 ────────────────────────────────────────────
def resolve_dependencies(
self,
candidates: Dict[str, PluginCandidate],
extra_available: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
"""解析候选插件的依赖顺序。
Args:
candidates: 待加载的候选插件集合。
extra_available: 视为已满足的外部依赖插件版本映射。
Returns:
Tuple[List[str], Dict[str, str]]: 可加载顺序和失败原因映射。
"""
return self._resolve_dependencies(candidates, extra_available=extra_available)
def load_candidate(self, plugin_id: str, candidate: PluginCandidate) -> Optional[PluginMeta]:
"""加载单个候选插件模块。
Args:
plugin_id: 插件 ID。
candidate: 候选插件三元组。
Returns:
Optional[PluginMeta]: 加载成功的插件元数据;失败时返回 ``None``。
"""
plugin_dir, manifest, plugin_path = candidate
return self._load_single_plugin(plugin_id, plugin_dir, manifest, plugin_path)
def _resolve_dependencies(
self,
candidates: Dict[str, PluginCandidate],
extra_available: Optional[Dict[str, str]] = None,
) -> Tuple[List[str], Dict[str, str]]:
"""拓扑排序解析加载顺序,返回 (有序列表, 失败项 {id: reason})。"""
available = set(candidates.keys())
satisfied_dependencies = {
str(plugin_id or "").strip(): str(plugin_version or "").strip()
for plugin_id, plugin_version in (extra_available or {}).items()
if str(plugin_id or "").strip() and str(plugin_version or "").strip()
}
dep_graph: Dict[str, Set[str]] = {}
failed: Dict[str, str] = {}
for pid, (_, manifest, _) in candidates.items():
raw_deps = manifest.get("dependencies", [])
resolved: Set[str] = set()
missing: List[str] = []
for dep in raw_deps:
dep_name = dep if isinstance(dep, str) else str(dep.get("name", ""))
dep_name = dep_name.strip()
if not dep_name or dep_name == pid:
missing_or_incompatible: List[str] = []
for dependency in manifest.plugin_dependencies:
dependency_id = dependency.id
if dependency_id in available:
dependency_manifest = candidates[dependency_id][1]
if not self._manifest_validator.is_plugin_dependency_satisfied(
dependency,
dependency_manifest.version,
):
missing_or_incompatible.append(
f"{dependency_id} (需要 {dependency.version_spec},当前 {dependency_manifest.version})"
)
continue
resolved.add(dependency_id)
continue
if dep_name in available:
resolved.add(dep_name)
else:
missing.append(dep_name)
if missing:
failed[pid] = f"缺少依赖: {', '.join(missing)}"
external_dependency_version = satisfied_dependencies.get(dependency_id)
if external_dependency_version is None:
missing_or_incompatible.append(f"{dependency_id} (未找到依赖插件)")
continue
if not self._manifest_validator.is_plugin_dependency_satisfied(
dependency,
external_dependency_version,
):
missing_or_incompatible.append(
f"{dependency_id} (需要 {dependency.version_spec},当前 {external_dependency_version})"
)
if missing_or_incompatible:
failed[pid] = f"依赖未满足: {', '.join(missing_or_incompatible)}"
dep_graph[pid] = resolved
# 移除失败项
for pid in failed:
dep_graph.pop(pid, None)
# 迭代传播“依赖自身加载失败”到上游依赖方,避免误报为循环依赖
changed = True
while changed:
changed = False
failed_plugin_ids = set(failed)
for pid, dependencies in list(dep_graph.items()):
if pid in failed:
dep_graph.pop(pid, None)
continue
failed_dependencies = sorted(dependency for dependency in dependencies if dependency in failed_plugin_ids)
if not failed_dependencies:
continue
failed[pid] = f"依赖未满足: {', '.join(f'{dependency} (依赖插件加载失败)' for dependency in failed_dependencies)}"
dep_graph.pop(pid, None)
changed = True
# Kahn 拓扑排序
indegree = {pid: len(deps) for pid, deps in dep_graph.items()}
@@ -253,7 +413,7 @@ class PluginLoader:
self,
plugin_id: str,
plugin_dir: Path,
manifest: Dict[str, Any],
manifest: PluginManifest,
plugin_path: Path,
) -> Optional[PluginMeta]:
"""加载单个插件"""
@@ -261,8 +421,12 @@ class PluginLoader:
self._ensure_compat_hook()
# 动态导入插件模块
module_name = f"_maibot_plugin_{plugin_id}"
spec = importlib.util.spec_from_file_location(module_name, str(plugin_path))
module_name = self._build_safe_module_name(plugin_id)
spec = importlib.util.spec_from_file_location(
module_name,
str(plugin_path),
submodule_search_locations=[str(plugin_dir)],
)
if spec is None or spec.loader is None:
logger.error(f"无法创建模块 spec: {plugin_path}")
return None
@@ -271,37 +435,73 @@ class PluginLoader:
sys.modules[module_name] = module
plugin_parent_dir = plugin_dir.parent
with self._temporary_sys_path_entry(plugin_parent_dir):
spec.loader.exec_module(module)
try:
with self._temporary_sys_path_entry(plugin_parent_dir):
spec.loader.exec_module(module)
# 优先使用新版 create_plugin 工厂函数
create_plugin = getattr(module, "create_plugin", None)
if create_plugin is not None:
instance = create_plugin()
logger.info(f"插件 {plugin_id} v{manifest.get('version', '?')} 加载成功")
return PluginMeta(
plugin_id=plugin_id,
plugin_dir=str(plugin_dir),
plugin_instance=instance,
manifest=manifest,
)
# 优先使用新版 create_plugin 工厂函数
create_plugin = getattr(module, "create_plugin", None)
if create_plugin is not None:
instance = create_plugin()
self._validate_sdk_plugin_contract(plugin_id, instance)
logger.info(f"插件 {plugin_id} v{manifest.version} 加载成功")
return PluginMeta(
plugin_id=plugin_id,
plugin_dir=str(plugin_dir),
module_name=module_name,
plugin_instance=instance,
manifest=manifest,
)
# 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
instance = self._try_load_legacy_plugin(module, plugin_id)
if instance is not None:
logger.info(
f"插件 {plugin_id} v{manifest.get('version', '?')} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk"
)
return PluginMeta(
plugin_id=plugin_id,
plugin_dir=str(plugin_dir),
plugin_instance=instance,
manifest=manifest,
)
# 回退:检测旧版 @register_plugin 标记的 BasePlugin 子类
instance = self._try_load_legacy_plugin(module, plugin_id)
if instance is not None:
logger.info(
f"插件 {plugin_id} v{manifest.version} 通过旧版兼容层加载成功(请尽快迁移到 maibot_sdk"
)
return PluginMeta(
plugin_id=plugin_id,
plugin_dir=str(plugin_dir),
module_name=module_name,
plugin_instance=instance,
manifest=manifest,
)
except Exception:
sys.modules.pop(module_name, None)
raise
logger.error(f"插件 {plugin_id} 缺少 create_plugin 工厂函数且未检测到旧版 BasePlugin")
return None
@staticmethod
def _validate_sdk_plugin_contract(plugin_id: str, instance: Any) -> None:
"""校验 SDK 插件的基础契约。
Args:
plugin_id: 当前插件 ID。
instance: ``create_plugin()`` 返回的插件实例。
Raises:
TypeError: 当插件未覆盖必需生命周期方法或订阅声明不合法时抛出。
"""
try:
from maibot_sdk.plugin import MaiBotPlugin
except ImportError:
return
if not isinstance(instance, MaiBotPlugin):
return
if type(instance).on_load is MaiBotPlugin.on_load:
raise TypeError(f"插件 {plugin_id} 必须实现 on_load()")
if type(instance).on_unload is MaiBotPlugin.on_unload:
raise TypeError(f"插件 {plugin_id} 必须实现 on_unload()")
if type(instance).on_config_update is MaiBotPlugin.on_config_update:
raise TypeError(f"插件 {plugin_id} 必须实现 on_config_update()")
instance.get_config_reload_subscriptions()
@staticmethod
@contextlib.contextmanager
def _temporary_sys_path_entry(path: Path) -> Iterator[None]:

View File

@@ -1,14 +1,6 @@
"""Runner 端 RPC Client
"""Runner 端 RPC 客户端。"""
负责:
1. 连接 Host RPC Server
2. 发送握手runner.hello
3. 发送组件注册请求
4. 接收并分发 Host 的调用请求
5. 发送能力调用请求到 Host
"""
from typing import Any, Awaitable, Callable, Dict, Optional, cast
from typing import Any, Awaitable, Callable, Dict, Optional, Set, cast
import asyncio
import contextlib
@@ -29,12 +21,15 @@ from src.plugin_runtime.transport.factory import create_transport_client
logger = get_logger("plugin_runtime.runner.rpc_client")
# RPC 方法处理器类型
MethodHandler = Callable[[Envelope], Awaitable[Envelope]]
def _get_sdk_version() -> str:
"""从 maibot_sdk 包元数据中读取实际版本号,失败时回退到 1.0.0。"""
"""读取 SDK 版本号。
Returns:
str: 已安装的 SDK 版本;读取失败时回退到 ``1.0.0``。
"""
try:
from importlib.metadata import version
@@ -47,73 +42,78 @@ SDK_VERSION = _get_sdk_version()
class RPCClient:
"""Runner 端 RPC 客户端
管理与 Host 的 IPC 连接,支持双向 RPC 调用。
"""
"""Runner 端 RPC 客户端"""
def __init__(
self,
host_address: str,
session_token: str,
codec: Optional[Codec] = None,
):
self._host_address = host_address
self._session_token = session_token
self._codec = codec or MsgPackCodec()
) -> None:
"""初始化 RPC 客户端。
Args:
host_address: Host 的 IPC 地址。
session_token: 握手用会话令牌。
codec: 可选的编解码器实现。
"""
self._host_address: str = host_address
self._session_token: str = session_token
self._codec: Codec = codec or MsgPackCodec()
self._id_gen = RequestIdGenerator()
self._connection: Optional[Connection] = None
self._runner_id = str(uuid.uuid4())
self._generation: int = 0
# 方法处理器注册表Host 发来的调用)
self._runner_id: str = str(uuid.uuid4())
self._method_handlers: Dict[str, MethodHandler] = {}
# 等待响应的 pending 请求: request_id -> Future
self._pending_requests: Dict[int, asyncio.Future] = {}
# 运行状态
self._running = False
self._recv_task: Optional[asyncio.Task] = None
self._background_tasks: set[asyncio.Task] = set()
@property
def generation(self) -> int:
return self._generation
self._pending_requests: Dict[int, asyncio.Future[Envelope]] = {}
self._running: bool = False
self._recv_task: Optional[asyncio.Task[None]] = None
self._background_tasks: Set[asyncio.Task[Any]] = set()
@property
def is_connected(self) -> bool:
"""返回当前连接是否可用。"""
return self._connection is not None and not self._connection.is_closed
def register_method(self, method: str, handler: MethodHandler) -> None:
"""注册方法处理器(处理 Host 发来的请求)"""
"""注册 Host -> Runner 的 RPC 处理器。
Args:
method: RPC 方法名。
handler: 方法处理函数。
"""
self._method_handlers[method] = handler
def _require_connection(self) -> Connection:
"""返回当前可用连接;若连接不可用则抛出 RPCError。"""
"""返回当前可用连接
Returns:
Connection: 当前连接对象。
Raises:
RPCError: 当前未连接到 Host。
"""
connection = self._connection
if connection is None or connection.is_closed:
raise RPCError(ErrorCode.E_UNKNOWN, "未连接到 Host")
return cast(Connection, connection)
async def connect_and_handshake(self) -> bool:
"""连接 Host 并完成握手
"""连接 Host 并完成握手
Returns:
是否握手成功
bool: 是否握手成功
"""
client = create_transport_client(self._host_address)
self._connection = await client.connect()
connection = self._require_connection()
# 发送 runner.hello
hello = HelloPayload(
runner_id=self._runner_id,
sdk_version=SDK_VERSION,
session_token=self._session_token,
)
request_id = self._id_gen.next()
request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
@@ -121,33 +121,27 @@ class RPCClient:
payload=hello.model_dump(),
)
data = self._codec.encode_envelope(envelope)
await connection.send_frame(data)
await connection.send_frame(self._codec.encode_envelope(envelope))
# 接收握手响应
resp_data = await asyncio.wait_for(connection.recv_frame(), timeout=10.0)
resp = self._codec.decode_envelope(resp_data)
response = self._codec.decode_envelope(resp_data)
resp_payload = HelloResponsePayload.model_validate(response.payload)
resp_payload = HelloResponsePayload.model_validate(resp.payload)
if not resp_payload.accepted:
logger.error(f"握手被拒绝: {resp_payload.reason}")
await self._connection.close()
self._connection = None
await self.disconnect()
return False
self._generation = resp_payload.assigned_generation
logger.info(f"握手成功: generation={self._generation}, host_version={resp_payload.host_version}")
# 启动消息接收循环
logger.info(f"握手成功: host_version={resp_payload.host_version}")
self._running = True
self._recv_task = asyncio.create_task(self._recv_loop())
self._recv_task = asyncio.create_task(self._recv_loop(), name="RPCClient.recv")
return True
async def disconnect(self) -> None:
"""断开连接"""
"""断开与 Host 的连接并清理状态。"""
self._running = False
if self._recv_task:
if self._recv_task is not None:
self._recv_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._recv_task
@@ -160,13 +154,12 @@ class RPCClient:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
# 取消所有 pending 请求
for future in self._pending_requests.values():
if not future.done():
future.set_exception(RPCError(ErrorCode.E_TIMEOUT, "连接关闭"))
self._pending_requests.clear()
if self._connection:
if self._connection is not None:
await self._connection.close()
self._connection = None
@@ -177,16 +170,27 @@ class RPCClient:
payload: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
"""向 Host 发送 RPC 请求并等待响应"""
connection = self._require_connection()
"""向 Host 发送 RPC 请求并等待响应
request_id = self._id_gen.next()
Args:
method: RPC 方法名。
plugin_id: 目标插件 ID。
payload: 请求载荷。
timeout_ms: 超时时间,单位毫秒。
Returns:
Envelope: Host 返回的响应信封。
Raises:
RPCError: 发送失败、超时或连接异常。
"""
connection = self._require_connection()
request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.REQUEST,
method=method,
plugin_id=plugin_id,
generation=self._generation,
timeout_ms=timeout_ms,
payload=payload or {},
)
@@ -196,21 +200,16 @@ class RPCClient:
self._pending_requests[request_id] = future
try:
data = self._codec.encode_envelope(envelope)
await connection.send_frame(data)
timeout_sec = timeout_ms / 1000.0
return await asyncio.wait_for(future, timeout=timeout_sec)
await connection.send_frame(self._codec.encode_envelope(envelope))
return await asyncio.wait_for(future, timeout=timeout_ms / 1000.0)
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise RPCError(ErrorCode.E_TIMEOUT, f"请求 {method} 超时 ({timeout_ms}ms)") from None
except Exception as e:
except Exception as exc:
self._pending_requests.pop(request_id, None)
if isinstance(e, RPCError):
if isinstance(exc, RPCError):
raise
raise RPCError(ErrorCode.E_UNKNOWN, str(e)) from e
# ─── 内部方法 ──────────────────────────────────────────────
raise RPCError(ErrorCode.E_UNKNOWN, str(exc)) from exc
async def send_event(
self,
@@ -218,33 +217,30 @@ class RPCClient:
plugin_id: str = "",
payload: Optional[Dict[str, Any]] = None,
) -> None:
"""向 Host 发送单向事件fire-and-forget不等待响应
"""向 Host 发送单向广播消息
Args:
method: RPC 方法名,如 "runner.log_batch"
plugin_id: 目标插件 ID(可为空,表示 Runner 级消息)
payload: 事件数据
method: RPC 方法名。
plugin_id: 目标插件 ID。
payload: 广播载荷
"""
if not self.is_connected:
return
connection = self._require_connection()
request_id = self._id_gen.next()
request_id = await self._id_gen.next()
envelope = Envelope(
request_id=request_id,
message_type=MessageType.EVENT,
message_type=MessageType.BROADCAST,
method=method,
plugin_id=plugin_id,
generation=self._generation,
payload=payload or {},
)
data = self._codec.encode_envelope(envelope)
await connection.send_frame(data)
await connection.send_frame(self._codec.encode_envelope(envelope))
async def _recv_loop(self) -> None:
"""消息接收主循环"""
while self._running and self._connection and not self._connection.is_closed:
"""持续接收 Host 发来的消息并分发。"""
while self._running and self._connection is not None and not self._connection.is_closed:
try:
data = await self._connection.recv_frame()
except (asyncio.IncompleteReadError, ConnectionError):
@@ -252,39 +248,47 @@ class RPCClient:
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"接收帧失败: {e}")
except Exception as exc:
logger.error(f"接收帧失败: {exc}")
break
try:
envelope = self._codec.decode_envelope(data)
except Exception as e:
logger.error(f"解码消息失败: {e}")
except Exception as exc:
logger.error(f"解码消息失败: {exc}")
continue
if envelope.is_response():
self._handle_response(envelope)
elif envelope.is_request():
self._track_background_task(asyncio.create_task(self._handle_request(envelope)))
elif envelope.is_event():
self._track_background_task(asyncio.create_task(self._handle_event(envelope)))
elif envelope.is_broadcast():
self._track_background_task(asyncio.create_task(self._handle_broadcast(envelope)))
def _handle_response(self, envelope: Envelope) -> None:
"""处理来自 Host 的响应"""
"""处理 Host 返回的响应
Args:
envelope: 响应信封。
"""
future = self._pending_requests.pop(envelope.request_id, None)
if future and not future.done():
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
if future is None or future.done():
return
if envelope.error:
future.set_exception(RPCError.from_dict(envelope.error))
else:
future.set_result(envelope)
async def _handle_request(self, envelope: Envelope) -> None:
"""处理来自 Host 的请求(调用插件组件)"""
"""处理 Host 发来的请求
Args:
envelope: 请求信封。
"""
connection = self._connection
if connection is None or connection.is_closed:
logger.warning(f"处理请求 {envelope.method} 时连接已关闭,跳过响应")
return
connection = cast(Connection, connection)
handler = self._method_handlers.get(envelope.method)
if handler is None:
@@ -298,23 +302,34 @@ class RPCClient:
try:
response = await handler(envelope)
await connection.send_frame(self._codec.encode_envelope(response))
except RPCError as e:
error_resp = envelope.make_error_response(e.code.value, e.message, e.details)
except RPCError as exc:
error_resp = envelope.make_error_response(exc.code.value, exc.message, exc.details)
await connection.send_frame(self._codec.encode_envelope(error_resp))
except Exception as e:
logger.error(f"处理请求 {envelope.method} 异常: {e}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(e))
except Exception as exc:
logger.error(f"处理请求 {envelope.method} 异常: {exc}", exc_info=True)
error_resp = envelope.make_error_response(ErrorCode.E_UNKNOWN.value, str(exc))
await connection.send_frame(self._codec.encode_envelope(error_resp))
async def _handle_event(self, envelope: Envelope) -> None:
"""处理来自 Host 事件"""
if handler := self._method_handlers.get(envelope.method):
try:
await handler(envelope)
except Exception as e:
logger.error(f"处理事件 {envelope.method} 异常: {e}", exc_info=True)
async def _handle_broadcast(self, envelope: Envelope) -> None:
"""处理 Host 发来的广播事件
def _track_background_task(self, task: asyncio.Task) -> None:
"""保持后台任务强引用,直到其完成或被取消。"""
Args:
envelope: 广播信封。
"""
handler = self._method_handlers.get(envelope.method)
if handler is None:
return
try:
await handler(envelope)
except Exception as exc:
logger.error(f"处理广播 {envelope.method} 异常: {exc}", exc_info=True)
def _track_background_task(self, task: asyncio.Task[Any]) -> None:
"""持有后台任务强引用直到其结束。
Args:
task: 后台任务。
"""
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,9 @@
"""Windows Named Pipe 传输实现。
适用于 Windows 平台,使用 asyncio ProactorEventLoop 的 named pipe 支持。
注意Named Pipe 是 Windows 特有的 IPC 机制,
在 Linux/macOS 平台上不可用。Unix-like 平台请使用 UDS 传输。
"""
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, cast
@@ -18,10 +21,12 @@ _DEFAULT_PIPE_PREFIX = "maibot-plugin"
class _NamedPipeServerHandle(Protocol):
"""Named Pipe 服务端句柄的协议定义。"""
def close(self) -> None: ...
class _NamedPipeEventLoop(Protocol):
"""ProactorEventLoop 的协议定义,提供 named pipe 相关方法。"""
async def start_serving_pipe(
self,
protocol_factory: Callable[[], asyncio.BaseProtocol],
@@ -40,6 +45,15 @@ class _NamedPipeEventLoop(Protocol):
def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
"""规范化 Named Pipe 地址。
Args:
pipe_name: 管道名称。如果以 '\\\\.\\pipe\\' 开头则直接使用,
否则会自动添加前缀。如果为 None 则生成随机名称。
Returns:
规范化的管道地址(格式:\\\\.\\pipe\\name
"""
if pipe_name and pipe_name.startswith(_PIPE_PREFIX):
return pipe_name
@@ -55,12 +69,21 @@ def _normalize_pipe_address(pipe_name: Optional[str] = None) -> str:
class NamedPipeConnection(Connection):
"""基于 Windows Named Pipe 的连接。"""
"""基于 Windows Named Pipe 的连接。
封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
pass
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__(reader, writer)
class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
"""Named Pipe 服务端协议实现。
处理客户端连接的生命周期,包括连接建立、数据处理和连接关闭。
"""
def __init__(self, handler: ConnectionHandler, loop: asyncio.AbstractEventLoop) -> None:
self._reader: asyncio.StreamReader = asyncio.StreamReader()
super().__init__(self._reader)
@@ -69,39 +92,58 @@ class _NamedPipeServerProtocol(asyncio.StreamReaderProtocol):
self._handler_task: Optional[asyncio.Task[None]] = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""连接建立时的回调。"""
super().connection_made(transport)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), self, self._reader, self._loop)
connection = NamedPipeConnection(self._reader, writer)
self._handler_task = self._loop.create_task(self._run_handler(connection))
# 使用 asyncio.create_task 确保任务正确调度
self._handler_task = asyncio.create_task(self._run_handler(connection))
self._handler_task.add_done_callback(self._on_handler_done)
async def _run_handler(self, connection: NamedPipeConnection) -> None:
"""运行连接处理器。"""
try:
await self._handler(connection)
finally:
await connection.close()
def _on_handler_done(self, task: asyncio.Task[None]) -> None:
"""连接处理器完成时的回调。"""
if task.cancelled():
return
if exc := task.exception():
self._loop.call_exception_handler(
{
"message": "Named pipe 连接处理失败",
"exception": exc,
"protocol": self,
}
)
try:
self._loop.call_exception_handler(
{
"message": "Named pipe 连接处理失败",
"exception": exc,
"protocol": self,
}
)
except Exception:
# 如果 loop 已经关闭,忽略异常
pass
class NamedPipeTransportServer(TransportServer):
"""Windows Named Pipe 传输服务端。"""
"""Windows Named Pipe 传输服务端。
使用 ProactorEventLoop 的 start_serving_pipe 方法监听客户端连接。
"""
def __init__(self, pipe_name: Optional[str] = None) -> None:
self._address: str = _normalize_pipe_address(pipe_name)
self._servers: List[_NamedPipeServerHandle] = []
async def start(self, handler: ConnectionHandler) -> None:
"""启动 Named Pipe 服务端。
Args:
handler: 新连接到来时的回调函数
Raises:
RuntimeError: 当在非 Windows 平台或事件循环不支持时
"""
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
@@ -116,32 +158,49 @@ class NamedPipeTransportServer(TransportServer):
)
async def stop(self) -> None:
"""停止 Named Pipe 服务端并清理资源。"""
for server in self._servers:
server.close()
# 等待所有服务器句柄完全关闭
await asyncio.gather(
*[asyncio.sleep(0.1) for _ in self._servers],
return_exceptions=True
)
self._servers.clear()
await asyncio.sleep(0)
def get_address(self) -> str:
return self._address
class NamedPipeTransportClient(TransportClient):
"""Windows Named Pipe 传输客户端。"""
"""Windows Named Pipe 传输客户端。
用于主动连接到 Named Pipe 服务端。
"""
def __init__(self, address: str) -> None:
self._address: str = _normalize_pipe_address(address)
async def connect(self) -> Connection:
"""建立到 Named Pipe 服务端的连接。
Returns:
NamedPipeConnection: 连接对象
Raises:
NotImplementedError: 当在非 Windows 平台或事件循环不支持时
"""
if sys.platform != "win32":
raise RuntimeError("Named pipe 仅支持 Windows")
raise NotImplementedError("Named pipe 仅支持 Windows")
loop = asyncio.get_running_loop()
if not hasattr(loop, "create_pipe_connection"):
raise RuntimeError("当前事件循环不支持 Windows named pipe")
raise NotImplementedError("当前事件循环不支持 Windows named pipe")
pipe_loop = cast(_NamedPipeEventLoop, loop)
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport, _protocol = await pipe_loop.create_pipe_connection(lambda: protocol, self._address)
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), protocol, reader, loop)
# 使用返回的 protocol 创建 StreamWriter
writer = asyncio.StreamWriter(cast(asyncio.WriteTransport, transport), _protocol, reader, loop)
return NamedPipeConnection(reader, writer)

View File

@@ -1,6 +1,9 @@
"""Unix Domain Socket 传输实现
适用于 Linux / macOS 平台。
注意UDS (Unix Domain Socket) 是 Unix-like 系统特有的 IPC 机制,
在 Windows 平台上不可用。Windows 平台请使用 Named Pipe 传输。
"""
from pathlib import Path
@@ -8,20 +11,30 @@ from typing import Optional
import asyncio
import os
import sys
import tempfile
from .base import Connection, ConnectionHandler, TransportClient, TransportServer
class UDSConnection(Connection):
"""基于 UDS 的连接"""
"""基于 UDS 的连接
封装了底层 StreamReader/StreamWriter提供分帧读写能力。
"""
pass # 直接复用 Connection 基类的分帧读写
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
super().__init__(reader, writer)
# Unix domain socket 路径的系统限制sun_path 字段长度)
# Linux: 108 字节, macOS: 104 字节
_UDS_PATH_MAX = 104
# Linux: 108 字节macOS: 104 字节,其他 Unix: 通常 104 字节
if sys.platform == "linux":
_UDS_PATH_MAX = 108
elif sys.platform == "darwin": # macOS
_UDS_PATH_MAX = 104
else:
_UDS_PATH_MAX = 104 # 保守默认值
class UDSTransportServer(TransportServer):
@@ -44,6 +57,18 @@ class UDSTransportServer(TransportServer):
self._server: Optional[asyncio.AbstractServer] = None
async def start(self, handler: ConnectionHandler) -> None:
"""启动 UDS 服务端
Args:
handler: 新连接到来时的回调函数
Raises:
RuntimeError: 当在非 Unix 平台(如 Windows上调用时
"""
# 平台检查UDS 仅在 Unix-like 系统上可用
if sys.platform == "win32":
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
# 清理残留 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
@@ -58,10 +83,16 @@ class UDSTransportServer(TransportServer):
finally:
await conn.close()
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
try:
self._server = await asyncio.start_unix_server(_on_connect, path=str(self._socket_path))
# 设置文件权限为仅当前用户可访问
self._socket_path.chmod(0o600)
# 设置文件权限为仅当前用户可访问
self._socket_path.chmod(0o600)
except Exception:
# 启动失败时清理可能创建的目录和 socket 文件
if self._socket_path.exists():
self._socket_path.unlink()
raise
async def stop(self) -> None:
if self._server:
@@ -77,11 +108,26 @@ class UDSTransportServer(TransportServer):
class UDSTransportClient(TransportClient):
"""UDS 传输客户端"""
"""UDS 传输客户端
用于主动连接到 UDS 服务端。
"""
def __init__(self, socket_path: Path) -> None:
self._socket_path: Path = socket_path
async def connect(self) -> Connection:
"""建立到 UDS 服务端的连接
Returns:
UDSConnection: 连接对象
Raises:
RuntimeError: 当在非 Unix 平台(如 Windows上调用时
"""
# 平台检查UDS 仅在 Unix-like 系统上可用
if sys.platform == "win32":
raise RuntimeError("UDS 不支持 Windows 平台,请使用 Named Pipe")
reader, writer = await asyncio.open_unix_connection(str(self._socket_path))
return UDSConnection(reader, writer)