feat: Enhance API and Outbound Tracking Functionality

- Add test for fallback to bot account in platform IO route metadata when context message is absent.
- Improve PlatformIOManager to avoid duplicate driver entries and streamline fallback driver handling.
- Refactor OutboundTracker to support tracking by both internal message ID and driver ID, enhancing the uniqueness of pending records.
- Introduce dynamic API capabilities in RuntimeComponent, allowing plugins to replace their dynamic API lists.
- Update APIRegistry to manage dynamic APIs more effectively, including registration and toggling of API statuses.
- Implement authorization checks for dynamic API capabilities to ensure proper permissions.
- Restrict direct calls to certain host RPC methods from plugins for enhanced security.
- Refactor send_service to ensure fallback to current platform account when no context message is available.
This commit is contained in:
DrSmoothl
2026-03-23 21:01:55 +08:00
parent d13767ee21
commit 7a304ba549
11 changed files with 771 additions and 200 deletions

View File

@@ -72,6 +72,8 @@ class RuntimeComponentCapabilityMixin:
"version": entry.version,
"public": entry.public,
"enabled": entry.enabled,
"dynamic": entry.dynamic,
"offline_reason": entry.offline_reason,
"metadata": dict(entry.metadata),
}
@@ -109,6 +111,32 @@ class RuntimeComponentCapabilityMixin:
return entry.plugin_id == caller_plugin_id or entry.public
@staticmethod
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
"""规范化 API 名称与版本参数。
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
"""
normalized_api_name = str(api_name or "").strip()
normalized_version = str(version or "").strip()
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
candidate_name = candidate_name.strip()
candidate_version = candidate_version.strip()
if candidate_name and candidate_version:
normalized_api_name = candidate_name
normalized_version = candidate_version
return normalized_api_name, normalized_version
@staticmethod
def _build_api_unavailable_error(entry: "APIEntry") -> str:
"""构造 API 当前不可用时的错误信息。"""
if entry.offline_reason:
return entry.offline_reason
return f"API {entry.registry_key} 当前不可用"
def _resolve_api_target(
self: _RuntimeComponentManagerProtocol,
caller_plugin_id: str,
@@ -127,8 +155,7 @@ class RuntimeComponentCapabilityMixin:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
normalized_api_name = str(api_name or "").strip()
normalized_version = str(version or "").strip()
normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
if not normalized_api_name:
return None, None, "缺少必要参数 api_name"
@@ -142,34 +169,61 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
entry = supervisor.api_registry.get_api(
entries = supervisor.api_registry.get_apis(
plugin_id=target_plugin_id,
name=target_api_name,
enabled_only=True,
version=normalized_version,
enabled_only=False,
)
if entry is None:
return None, None, f"未找到 API: {normalized_api_name}"
if normalized_version and entry.version != normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
if not self._is_api_visible_to_plugin(entry, caller_plugin_id):
visible_enabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
]
visible_disabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
]
if len(visible_enabled_entries) == 1:
return supervisor, visible_enabled_entries[0], None
if len(visible_enabled_entries) > 1:
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
if visible_disabled_entries:
if len(visible_disabled_entries) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
return supervisor, entry, None
if normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
return None, None, f"未找到 API: {normalized_api_name}"
visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
hidden_match_exists = False
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True):
if normalized_version and entry.version != normalized_version:
continue
for entry in supervisor.api_registry.get_apis(
name=normalized_api_name,
version=normalized_version,
enabled_only=False,
):
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
visible_matches.append((supervisor, entry))
if entry.enabled:
visible_enabled_matches.append((supervisor, entry))
else:
visible_disabled_matches.append((supervisor, entry))
else:
hidden_match_exists = True
if len(visible_matches) == 1:
return visible_matches[0][0], visible_matches[0][1], None
if len(visible_matches) > 1:
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name"
if len(visible_enabled_matches) == 1:
return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
if len(visible_enabled_matches) > 1:
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
if visible_disabled_matches:
if len(visible_disabled_matches) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
if hidden_match_exists:
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version:
@@ -179,18 +233,20 @@ class RuntimeComponentCapabilityMixin:
def _resolve_api_toggle_target(
self: _RuntimeComponentManagerProtocol,
name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析需要启用或禁用的 API 组件。
Args:
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
version: 可选的 API 版本。
Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
"""
normalized_name = str(name or "").strip()
normalized_name, normalized_version = self._normalize_api_reference(name, version)
if not normalized_name:
return None, None, "缺少必要参数 name"
@@ -204,24 +260,31 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None:
return None, None, f"未找到 API 提供方插件: {plugin_id}"
entry = supervisor.api_registry.get_api(
entries = supervisor.api_registry.get_apis(
plugin_id=plugin_id,
name=api_name,
version=normalized_version,
enabled_only=False,
)
if entry is None:
if not entries:
return None, None, f"未找到 API: {normalized_name}"
return supervisor, entry, None
if len(entries) > 1:
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
return supervisor, entries[0], None
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False):
for entry in supervisor.api_registry.get_apis(
name=normalized_name,
version=normalized_version,
enabled_only=False,
):
matches.append((supervisor, entry))
if len(matches) == 1:
return matches[0][0], matches[0][1], None
if len(matches) > 1:
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name"
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
return None, None, f"未找到 API: {normalized_name}"
async def _cap_component_get_all_plugins(
@@ -326,6 +389,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -334,10 +398,10 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name)
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.full_name, True)
supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
@@ -352,6 +416,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
@@ -360,10 +425,10 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name)
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.full_name, False)
supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type)
@@ -488,11 +553,17 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None or entry is None:
return {"success": False, "error": error or "API 解析失败"}
invoke_args = dict(api_args)
if entry.dynamic:
invoke_args.setdefault("__maibot_api_name__", entry.name)
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
invoke_args.setdefault("__maibot_api_version__", entry.version)
try:
response = await supervisor.invoke_api(
plugin_id=entry.plugin_id,
component_name=entry.name,
args=api_args,
component_name=entry.handler_name,
args=invoke_args,
timeout_ms=30000,
)
except Exception as exc:
@@ -555,10 +626,16 @@ class RuntimeComponentCapabilityMixin:
del capability
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
api_name, version = self._normalize_api_reference(
str(args.get("api_name", args.get("name", "")) or ""),
str(args.get("version", "") or ""),
)
apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None,
name=api_name,
version=version,
enabled_only=True,
):
if not self._is_api_visible_to_plugin(entry, plugin_id):
@@ -567,3 +644,75 @@ class RuntimeComponentCapabilityMixin:
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis}
async def _cap_api_replace_dynamic(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""替换插件自行维护的动态 API 列表。"""
del capability
raw_apis = args.get("apis", [])
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
if not isinstance(raw_apis, list):
return {"success": False, "error": "参数 apis 必须为列表"}
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if supervisor is None:
return {"success": False, "error": f"未找到插件: {plugin_id}"}
normalized_components: List[Dict[str, Any]] = []
seen_registry_keys: set[str] = set()
for index, raw_api in enumerate(raw_apis):
if not isinstance(raw_api, dict):
return {"success": False, "error": f"apis[{index}] 必须为字典"}
api_name = str(raw_api.get("name", "") or "").strip()
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
if not api_name:
return {"success": False, "error": f"apis[{index}] 缺少 name"}
if not self._is_api_component_type(component_type):
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
normalized_metadata = dict(metadata)
normalized_metadata["dynamic"] = True
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
if registry_key in seen_registry_keys:
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
seen_registry_keys.add(registry_key)
existing_entry = supervisor.api_registry.get_api(
plugin_id,
api_name,
version=version,
enabled_only=False,
)
if existing_entry is not None and not existing_entry.dynamic:
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
normalized_components.append(
{
"name": api_name,
"component_type": "API",
"metadata": normalized_metadata,
}
)
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
plugin_id,
normalized_components,
offline_reason=offline_reason,
)
return {
"success": True,
"count": registered_count,
"offlined": offlined_count,
}

View File

@@ -77,6 +77,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("api.call", manager._cap_api_call)
_register("api.get", manager._cap_api_get)
_register("api.list", manager._cap_api_list)
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)

View File

@@ -1,45 +1,60 @@
"""Host 侧插件 API 动态注册表。"""
from typing import Any, Dict, List, Optional, Set
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.api_registry")
@dataclass(slots=True)
class APIEntry:
"""API 组件条目。"""
__slots__ = (
"description",
"disabled_session",
"enabled",
"full_name",
"metadata",
"name",
"plugin_id",
"public",
"version",
)
name: str
plugin_id: str
description: str = ""
version: str = "1"
public: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
enabled: bool = True
handler_name: str = ""
dynamic: bool = False
offline_reason: str = ""
disabled_session: Set[str] = field(default_factory=set)
full_name: str = field(init=False)
registry_key: str = field(init=False)
def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
"""初始化 API 组件条目。
def __post_init__(self) -> None:
"""规范化 API 条目字段。"""
Args:
name: API 名称。
plugin_id: 所属插件 ID。
metadata: API 元数据。
"""
self.name = str(self.name or "").strip()
self.plugin_id = str(self.plugin_id or "").strip()
self.description = str(self.description or "").strip()
self.version = str(self.version or "1").strip() or "1"
self.handler_name = str(self.handler_name or self.name).strip() or self.name
self.offline_reason = str(self.offline_reason or "").strip()
self.full_name = f"{self.plugin_id}.{self.name}"
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.plugin_id: str = plugin_id
self.description: str = str(metadata.get("description", "") or "")
self.version: str = str(metadata.get("version", "1") or "1").strip() or "1"
self.public: bool = bool(metadata.get("public", False))
self.metadata: Dict[str, Any] = dict(metadata)
self.enabled: bool = bool(metadata.get("enabled", True))
self.disabled_session: Set[str] = set()
@classmethod
def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
"""根据 Runner 上报的元数据构造 API 条目。"""
safe_metadata = dict(metadata)
return cls(
name=name,
plugin_id=plugin_id,
description=str(safe_metadata.get("description", "") or ""),
version=str(safe_metadata.get("version", "1") or "1"),
public=bool(safe_metadata.get("public", False)),
metadata=safe_metadata,
enabled=bool(safe_metadata.get("enabled", True)),
handler_name=str(safe_metadata.get("handler_name", name) or name),
dynamic=bool(safe_metadata.get("dynamic", False)),
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
)
class APIRegistry:
@@ -53,6 +68,7 @@ class APIRegistry:
"""初始化 API 注册表。"""
self._apis: Dict[str, APIEntry] = {}
self._by_full_name: Dict[str, List[APIEntry]] = {}
self._by_plugin: Dict[str, List[APIEntry]] = {}
self._by_name: Dict[str, List[APIEntry]] = {}
@@ -60,75 +76,75 @@ class APIRegistry:
"""清空全部 API 注册状态。"""
self._apis.clear()
self._by_full_name.clear()
self._by_plugin.clear()
self._by_name.clear()
@staticmethod
def _is_api_component(component_type: Any) -> bool:
"""判断组件声明是否属于 API。
Args:
component_type: 原始组件类型值。
Returns:
bool: 是否为 API 组件。
"""
"""判断组件声明是否属于 API。"""
return str(component_type or "").strip().upper() == "API"
@staticmethod
def _normalize_query_version(version: Any) -> str:
"""规范化查询使用的版本字符串。"""
return str(version or "").strip()
@classmethod
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
"""解析可能带 ``@version`` 后缀的 API 引用。"""
normalized_reference = str(reference or "").strip()
normalized_version = cls._normalize_query_version(version)
if normalized_reference and not normalized_version and "@" in normalized_reference:
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
candidate_reference = candidate_reference.strip()
candidate_version = candidate_version.strip()
if candidate_reference and candidate_version:
normalized_reference = candidate_reference
normalized_version = candidate_version
return normalized_reference, normalized_version
@staticmethod
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
"""构造 API 注册表唯一键。"""
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
normalized_version = str(version or "1").strip() or "1"
return f"{normalized_full_name}@{normalized_version}"
@staticmethod
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
"""判断 API 条目当前是否处于启用状态。
Args:
entry: 待检查的 API 条目。
session_id: 可选的会话 ID。
Returns:
bool: 当前是否可用。
"""
"""判断 API 条目当前是否处于启用状态。"""
if session_id and session_id in entry.disabled_session:
return False
return entry.enabled
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个 API 条目。
Args:
name: API 名称。
plugin_id: 所属插件 ID。
metadata: API 元数据。
Returns:
bool: 是否成功注册。
"""
"""注册单个 API 条目。"""
normalized_name = str(name or "").strip()
if not normalized_name:
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
return False
entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
if entry.full_name in self._apis:
logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目")
self._remove_entry(self._apis[entry.full_name])
entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
existing_entry = self._apis.get(entry.registry_key)
if existing_entry is not None:
logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
self._remove_entry(existing_entry)
self._apis[entry.full_name] = entry
self._apis[entry.registry_key] = entry
self._by_full_name.setdefault(entry.full_name, []).append(entry)
self._by_plugin.setdefault(plugin_id, []).append(entry)
self._by_name.setdefault(entry.name, []).append(entry)
return True
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册某个插件声明的全部 API。
Args:
plugin_id: 插件 ID。
components: 插件组件声明列表。
Returns:
int: 成功注册的 API 数量。
"""
"""批量注册某个插件声明的全部 API。"""
count = 0
for component in components:
@@ -142,14 +158,60 @@ class APIRegistry:
count += 1
return count
def replace_plugin_dynamic_apis(
self,
plugin_id: str,
components: List[Dict[str, Any]],
*,
offline_reason: str = "动态 API 已下线",
) -> Tuple[int, int]:
"""替换指定插件当前声明的动态 API 集合。"""
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
desired_registry_keys: Set[str] = set()
registered_count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
dynamic_metadata = dict(metadata)
dynamic_metadata["dynamic"] = True
dynamic_metadata.pop("offline_reason", None)
entry = APIEntry.from_metadata(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=dynamic_metadata,
)
desired_registry_keys.add(entry.registry_key)
if self.register_api(entry.name, plugin_id, dynamic_metadata):
registered_count += 1
offlined_count = 0
for entry in list(self._by_plugin.get(plugin_id, [])):
if not entry.dynamic or entry.registry_key in desired_registry_keys:
continue
entry.enabled = False
entry.offline_reason = normalized_offline_reason
entry.metadata["offline_reason"] = normalized_offline_reason
offlined_count += 1
return registered_count, offlined_count
def _remove_entry(self, entry: APIEntry) -> None:
"""从全部索引中移除单个 API 条目。
"""从全部索引中移除单个 API 条目。"""
Args:
entry: 待移除的 API 条目。
"""
self._apis.pop(entry.registry_key, None)
full_name_entries = self._by_full_name.get(entry.full_name)
if full_name_entries is not None:
self._by_full_name[entry.full_name] = [
candidate for candidate in full_name_entries if candidate is not entry
]
if not self._by_full_name[entry.full_name]:
self._by_full_name.pop(entry.full_name, None)
self._apis.pop(entry.full_name, None)
plugin_entries = self._by_plugin.get(entry.plugin_id)
if plugin_entries is not None:
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
@@ -163,14 +225,7 @@ class APIRegistry:
self._by_name.pop(entry.name, None)
def remove_apis_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的全部 API。
Args:
plugin_id: 目标插件 ID。
Returns:
int: 被移除的 API 数量。
"""
"""移除某个插件的全部 API。"""
entries = list(self._by_plugin.get(plugin_id, []))
for entry in entries:
@@ -181,49 +236,48 @@ class APIRegistry:
self,
full_name: str,
*,
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按完整名查询单个 API。
"""按完整名查询单个 API。"""
Args:
full_name: API 完整名,格式为 ``plugin_id.api_name``。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
Optional[APIEntry]: 命中时返回 API 条目。
"""
entry = self._apis.get(full_name)
if entry is None:
normalized_full_name, normalized_version = self._split_reference(full_name, version)
if not normalized_full_name:
return None
if enabled_only and not self.check_api_enabled(entry, session_id):
if normalized_version:
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
if entry is None:
return None
if enabled_only and not self.check_api_enabled(entry, session_id):
return None
return entry
candidates = list(self._by_full_name.get(normalized_full_name, []))
filtered_entries = [
entry
for entry in candidates
if not enabled_only or self.check_api_enabled(entry, session_id)
]
if len(filtered_entries) != 1:
return None
return entry
return filtered_entries[0]
def get_api(
self,
plugin_id: str,
name: str,
*,
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按插件 ID 和短名查询单个 API。
Args:
plugin_id: 提供方插件 ID。
name: API 短名。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
Optional[APIEntry]: 命中时返回 API 条目。
"""
"""按插件 ID、短名与版本查询单个 API。"""
return self.get_api_by_full_name(
f"{plugin_id}.{name}",
version=version,
enabled_only=enabled_only,
session_id=session_id,
)
@@ -233,22 +287,15 @@ class APIRegistry:
*,
plugin_id: Optional[str] = None,
name: str = "",
version: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[APIEntry]:
"""查询 API 列表。
Args:
plugin_id: 可选的插件 ID 过滤条件。
name: 可选的 API 名称过滤条件。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
List[APIEntry]: 符合条件的 API 条目列表。
"""
"""查询 API 列表。"""
normalized_name = str(name or "").strip()
normalized_version = self._normalize_query_version(version)
if plugin_id:
candidates = list(self._by_plugin.get(plugin_id, []))
elif normalized_name:
@@ -258,26 +305,35 @@ class APIRegistry:
filtered_entries: List[APIEntry] = []
for entry in candidates:
if plugin_id and entry.plugin_id != plugin_id:
continue
if normalized_name and entry.name != normalized_name:
continue
if normalized_version and entry.version != normalized_version:
continue
if enabled_only and not self.check_api_enabled(entry, session_id):
continue
filtered_entries.append(entry)
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
return filtered_entries
def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""设置指定 API 的启用状态。
def toggle_api_status(
self,
full_name: str,
enabled: bool,
*,
version: str = "",
session_id: Optional[str] = None,
) -> bool:
"""设置指定 API 的启用状态。"""
Args:
full_name: API 完整名。
enabled: 目标启用状态。
session_id: 可选的会话 ID仅对该会话生效。
Returns:
bool: 是否设置成功。
"""
entry = self._apis.get(full_name)
entry = self.get_api_by_full_name(
full_name,
version=version,
enabled_only=False,
session_id=session_id,
)
if entry is None:
return False
if session_id:
@@ -287,4 +343,7 @@ class APIRegistry:
entry.disabled_session.add(session_id)
else:
entry.enabled = enabled
if enabled:
entry.offline_reason = ""
entry.metadata.pop("offline_reason", None)
return True

View File

@@ -7,6 +7,8 @@
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
@dataclass
class CapabilityPermissionToken:
@@ -46,6 +48,9 @@ class AuthorizationManager:
Returns:
return (bool, str): (是否有此能力, 原因)
"""
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
return True, ""
token = self._permission_tokens.get(plugin_id)
if not token:
return False, f"插件 {plugin_id} 未注册能力令牌"

View File

@@ -45,6 +45,14 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main")
_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset(
{
"cap.call",
"host.route_message",
"host.update_message_gateway_state",
}
)
class _ContextAwarePlugin(Protocol):
"""支持注入运行时上下文的插件协议。
@@ -247,8 +255,14 @@ class PluginRunner:
logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份"
)
normalized_method = str(method or "").strip()
if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
raise PermissionError(
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
f"{normalized_method or '<empty>'}"
)
resp = await rpc_client.send_request(
method=method,
method=normalized_method,
plugin_id=bound_plugin_id,
payload=payload or {},
)