From 9dea6b0e6fdeae1be119eaeb0a1449ac64d10900 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Mon, 23 Mar 2026 17:18:05 +0800 Subject: [PATCH] feat: implement dedicated API registry and enhance API handling capabilities - Added APIEntry and APIRegistry classes for managing plugin APIs. - Updated PluginRunnerSupervisor to include API registry and methods for invoking APIs. - Enhanced PluginRuntimeManager to support API registration and invocation. - Created tests for API registration, invocation, and visibility between plugins. - Refactored component handling to distinguish between runtime components and APIs. --- pytests/test_plugin_runtime.py | 9 +- pytests/test_plugin_runtime_api.py | 294 +++++++++++++++ src/plugin_runtime/capabilities/components.py | 344 +++++++++++++++++- src/plugin_runtime/capabilities/registry.py | 4 + src/plugin_runtime/host/api_registry.py | 290 +++++++++++++++ src/plugin_runtime/host/supervisor.py | 81 ++++- src/plugin_runtime/runner/runner_main.py | 1 + 7 files changed, 1012 insertions(+), 11 deletions(-) create mode 100644 pytests/test_plugin_runtime_api.py create mode 100644 src/plugin_runtime/host/api_registry.py diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index 20cceb82..9dfc34d8 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -2152,8 +2152,11 @@ class TestIntegration: self.supervisors = [FakeSupervisor("plugin_a"), FakeSupervisor("plugin_b")] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = FakeSupervisor("plugin_a") + manager._third_party_supervisor = FakeSupervisor("plugin_b") - result = await integration_module.PluginRuntimeManager._cap_component_enable( + result = await manager._cap_component_enable( "plugin_a", "component.enable", {"name": "shared", "component_type": "tool", "scope": "global", "stream_id": ""}, @@ -2182,8 +2185,10 @@ class TestIntegration: self.supervisors = [FakeSupervisor()] monkeypatch.setattr(integration_module, "get_plugin_runtime_manager", lambda: FakeManager()) + manager = integration_module.PluginRuntimeManager() + manager._builtin_supervisor = FakeSupervisor() - result = await integration_module.PluginRuntimeManager._cap_component_disable( + result = await manager._cap_component_disable( "plugin_a", "component.disable", {"name": "plugin_a.handler", "component_type": "tool", "scope": "stream", "stream_id": "s1"}, diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py new file mode 100644 index 00000000..fca7736a --- /dev/null +++ b/pytests/test_plugin_runtime_api.py @@ -0,0 +1,294 @@ +"""插件 API 注册与调用测试。""" + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from src.plugin_runtime.integration import PluginRuntimeManager +from src.plugin_runtime.host.supervisor import PluginSupervisor +from src.plugin_runtime.protocol.envelope import ( + ComponentDeclaration, + Envelope, + MessageType, + RegisterPluginPayload, + UnregisterPluginPayload, +) + + +def _build_manager(*supervisors: PluginSupervisor) -> PluginRuntimeManager: + """构造一个最小可用的插件运行时管理器。 + + Args: + *supervisors: 需要挂载的监督器列表。 + + Returns: + PluginRuntimeManager: 已注入监督器的运行时管理器。 + """ + + manager = PluginRuntimeManager() + if supervisors: + manager._builtin_supervisor = supervisors[0] + if len(supervisors) > 1: + manager._third_party_supervisor = supervisors[1] + return manager + + +async def _register_plugin( + supervisor: PluginSupervisor, + plugin_id: str, + components: List[Dict[str, Any]], +) -> Envelope: + """通过 Supervisor 注册测试插件。 + + Args: + supervisor: 目标监督器。 + plugin_id: 测试插件 ID。 + components: 组件声明列表。 + + Returns: + Envelope: 注册响应信封。 + """ + + payload = RegisterPluginPayload( + plugin_id=plugin_id, + plugin_version="1.0.0", + components=[ + ComponentDeclaration( + name=str(component.get("name", "") or ""), + component_type=str(component.get("component_type", "") or ""), + plugin_id=plugin_id, + metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, + ) + for component in components + ], + ) + return await supervisor._handle_register_plugin( + Envelope( + request_id=1, + message_type=MessageType.REQUEST, + method="plugin.register_components", + plugin_id=plugin_id, + payload=payload.model_dump(), + ) + ) + + +async def _unregister_plugin(supervisor: PluginSupervisor, plugin_id: str) -> Envelope: + """通过 Supervisor 注销测试插件。 + + Args: + supervisor: 目标监督器。 + plugin_id: 测试插件 ID。 + + Returns: + Envelope: 注销响应信封。 + """ + + payload = UnregisterPluginPayload(plugin_id=plugin_id, reason="test") + return await supervisor._handle_unregister_plugin( + Envelope( + request_id=2, + message_type=MessageType.REQUEST, + method="plugin.unregister", + plugin_id=plugin_id, + payload=payload.model_dump(), + ) + ) + + +@pytest.mark.asyncio +async def test_register_plugin_syncs_dedicated_api_registry() -> None: + """插件注册时应将 API 同步到独立注册表,而不是通用组件表。""" + + supervisor = PluginSupervisor(plugin_dirs=[]) + response = await _register_plugin( + supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML", + "version": "1", + "public": True, + }, + } + ], + ) + + assert response.payload["accepted"] is True + assert response.payload["registered_components"] == 0 + assert response.payload["registered_apis"] == 1 + assert supervisor.api_registry.get_api("provider", "render_html") is not None + assert supervisor.component_registry.get_component("provider.render_html") is None + + unregister_response = await _unregister_plugin(supervisor, "provider") + assert unregister_response.payload["removed_apis"] == 1 + assert supervisor.api_registry.get_api("provider", "render_html") is None + + +@pytest.mark.asyncio +async def test_api_call_allows_public_api_between_plugins(monkeypatch: pytest.MonkeyPatch) -> None: + """公开 API 应允许其他插件通过 Host 转发调用。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML", + "version": "1", + "public": True, + }, + } + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟 API RPC 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) + + monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "version": "1", + "args": {"html": "
Hello
"}, + }, + ) + + assert result == {"success": True, "result": {"image": "ok"}} + assert captured["plugin_id"] == "provider" + assert captured["component_name"] == "render_html" + assert captured["args"] == {"html": "
Hello
"} + + +@pytest.mark.asyncio +async def test_api_call_rejects_private_api_between_plugins() -> None: + """未公开的 API 默认不允许跨插件调用。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "secret_api", + "component_type": "API", + "metadata": { + "description": "私有 API", + "version": "1", + "public": False, + }, + } + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.secret_api", + "args": {}, + }, + ) + + assert result["success"] is False + assert "未公开" in str(result["error"]) + + +@pytest.mark.asyncio +async def test_api_list_and_component_toggle_use_dedicated_registry() -> None: + """API 列表与组件启停应直接作用于独立 API 注册表。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "public_api", + "component_type": "API", + "metadata": {"version": "1", "public": True}, + }, + { + "name": "private_api", + "component_type": "API", + "metadata": {"version": "1", "public": False}, + }, + ], + ) + await _register_plugin( + consumer_supervisor, + "consumer", + [ + { + "name": "self_private_api", + "component_type": "API", + "metadata": {"version": "1", "public": False}, + } + ], + ) + + manager = _build_manager(provider_supervisor, consumer_supervisor) + list_result = await manager._cap_api_list("consumer", "api.list", {}) + + assert list_result["success"] is True + api_names = {(item["plugin_id"], item["name"]) for item in list_result["apis"]} + assert ("provider", "public_api") in api_names + assert ("provider", "private_api") not in api_names + assert ("consumer", "self_private_api") in api_names + + disable_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.public_api", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert disable_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is None + + enable_result = await manager._cap_component_enable( + "consumer", + "component.enable", + { + "name": "provider.public_api", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert enable_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 4223525f..2eede108 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -6,7 +6,8 @@ from src.common.logger import get_logger logger = get_logger("plugin_runtime.integration") if TYPE_CHECKING: - from src.plugin_runtime.host.component_registry import RegisteredComponent + from src.plugin_runtime.host.api_registry import APIEntry + from src.plugin_runtime.host.component_registry import ComponentEntry from src.plugin_runtime.host.supervisor import PluginSupervisor @@ -18,7 +19,7 @@ class _RuntimeComponentManagerProtocol(Protocol): def _resolve_component_toggle_target( self, name: str, component_type: str - ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ... + ) -> tuple[Optional["ComponentEntry"], Optional[str]]: ... def _find_duplicate_plugin_ids(self, plugin_dirs: List[Path]) -> Dict[str, List[Path]]: ... @@ -26,6 +27,203 @@ class _RuntimeComponentManagerProtocol(Protocol): class RuntimeComponentCapabilityMixin: + @staticmethod + def _normalize_component_type(component_type: str) -> str: + """规范化组件类型名称。 + + Args: + component_type: 原始组件类型。 + + Returns: + str: 统一转为大写后的组件类型名。 + """ + + return str(component_type or "").strip().upper() + + @classmethod + def _is_api_component_type(cls, component_type: str) -> bool: + """判断组件类型是否为 API。 + + Args: + component_type: 原始组件类型。 + + Returns: + bool: 是否为 API 组件类型。 + """ + + return cls._normalize_component_type(component_type) == "API" + + @staticmethod + def _serialize_api_entry(entry: "APIEntry") -> Dict[str, Any]: + """将 API 组件条目序列化为能力返回值。 + + Args: + entry: API 组件条目。 + + Returns: + Dict[str, Any]: 适合通过能力层返回给插件的 API 元信息。 + """ + + return { + "name": entry.name, + "full_name": entry.full_name, + "plugin_id": entry.plugin_id, + "description": entry.description, + "version": entry.version, + "public": entry.public, + "enabled": entry.enabled, + "metadata": dict(entry.metadata), + } + + @classmethod + def _serialize_api_component_entry(cls, entry: "APIEntry") -> Dict[str, Any]: + """将 API 条目序列化为通用组件视图。 + + Args: + entry: API 组件条目。 + + Returns: + Dict[str, Any]: 适合 ``component.get_all_plugins`` 返回的组件结构。 + """ + + serialized_entry = cls._serialize_api_entry(entry) + return { + "name": serialized_entry["name"], + "full_name": serialized_entry["full_name"], + "type": "API", + "enabled": serialized_entry["enabled"], + "metadata": serialized_entry["metadata"], + } + + @staticmethod + def _is_api_visible_to_plugin(entry: "APIEntry", caller_plugin_id: str) -> bool: + """判断某个 API 是否对调用方可见。 + + Args: + entry: 目标 API 组件条目。 + caller_plugin_id: 调用方插件 ID。 + + Returns: + bool: 是否允许当前插件可见并调用。 + """ + + return entry.plugin_id == caller_plugin_id or entry.public + + def _resolve_api_target( + self: _RuntimeComponentManagerProtocol, + caller_plugin_id: str, + api_name: str, + version: str = "", + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: + """解析 API 名称到唯一可调用的目标组件。 + + Args: + caller_plugin_id: 调用方插件 ID。 + api_name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + version: 可选的 API 版本。 + + Returns: + tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: + 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 + """ + + normalized_api_name = str(api_name or "").strip() + normalized_version = str(version or "").strip() + if not normalized_api_name: + return None, None, "缺少必要参数 api_name" + + if "." in normalized_api_name: + target_plugin_id, target_api_name = normalized_api_name.split(".", 1) + try: + supervisor = self._get_supervisor_for_plugin(target_plugin_id) + except RuntimeError as exc: + return None, None, str(exc) + + if supervisor is None: + return None, None, f"未找到 API 提供方插件: {target_plugin_id}" + + entry = supervisor.api_registry.get_api( + plugin_id=target_plugin_id, + name=target_api_name, + enabled_only=True, + ) + if entry is None: + return None, None, f"未找到 API: {normalized_api_name}" + if normalized_version and entry.version != normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + if not self._is_api_visible_to_plugin(entry, caller_plugin_id): + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + return supervisor, entry, None + + visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + hidden_match_exists = False + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True): + if normalized_version and entry.version != normalized_version: + continue + if self._is_api_visible_to_plugin(entry, caller_plugin_id): + visible_matches.append((supervisor, entry)) + else: + hidden_match_exists = True + + if len(visible_matches) == 1: + return visible_matches[0][0], visible_matches[0][1], None + if len(visible_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name" + if hidden_match_exists: + return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" + if normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + return None, None, f"未找到 API: {normalized_api_name}" + + def _resolve_api_toggle_target( + self: _RuntimeComponentManagerProtocol, + name: str, + ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: + """解析需要启用或禁用的 API 组件。 + + Args: + name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + + Returns: + tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: + 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 + """ + + normalized_name = str(name or "").strip() + if not normalized_name: + return None, None, "缺少必要参数 name" + + if "." in normalized_name: + plugin_id, api_name = normalized_name.split(".", 1) + try: + supervisor = self._get_supervisor_for_plugin(plugin_id) + except RuntimeError as exc: + return None, None, str(exc) + + if supervisor is None: + return None, None, f"未找到 API 提供方插件: {plugin_id}" + + entry = supervisor.api_registry.get_api( + plugin_id=plugin_id, + name=api_name, + enabled_only=False, + ) + if entry is None: + return None, None, f"未找到 API: {normalized_name}" + return supervisor, entry, None + + matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False): + matches.append((supervisor, entry)) + + if len(matches) == 1: + return matches[0][0], matches[0][1], None + if len(matches) > 1: + return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name" + return None, None, f"未找到 API: {normalized_name}" + async def _cap_component_get_all_plugins( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] ) -> Any: @@ -46,6 +244,10 @@ class RuntimeComponentCapabilityMixin: } for component in comps ] + components_list.extend( + self._serialize_api_component_entry(entry) + for entry in sv.api_registry.get_apis(plugin_id=pid, enabled_only=False) + ) result[pid] = { "name": pid, "version": reg.plugin_version, @@ -96,24 +298,28 @@ class RuntimeComponentCapabilityMixin: def _resolve_component_toggle_target( self: _RuntimeComponentManagerProtocol, name: str, component_type: str - ) -> tuple[Optional["RegisteredComponent"], Optional[str]]: - short_name_matches: List["RegisteredComponent"] = [] + ) -> tuple[Optional["ComponentEntry"], Optional[str]]: + normalized_component_type = self._normalize_component_type(component_type) + short_name_matches: List["ComponentEntry"] = [] for sv in self.supervisors: comp = sv.component_registry.get_component(name) - if comp is not None and comp.component_type == component_type: + if comp is not None and comp.component_type == normalized_component_type: return comp, None short_name_matches.extend( candidate - for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False) + for candidate in sv.component_registry.get_components_by_type( + normalized_component_type, + enabled_only=False, + ) if candidate.name == name ) if len(short_name_matches) == 1: return short_name_matches[0], None if len(short_name_matches) > 1: - return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name" - return None, f"未找到组件: {name} ({component_type})" + return None, f"组件名不唯一: {name} ({normalized_component_type}),请使用完整名 plugin_id.component_name" + return None, f"未找到组件: {name} ({normalized_component_type})" async def _cap_component_enable( self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any] @@ -127,6 +333,13 @@ class RuntimeComponentCapabilityMixin: if scope != "global" or stream_id: return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} + if self._is_api_component_type(component_type): + supervisor, api_entry, error = self._resolve_api_toggle_target(name) + if supervisor is None or api_entry is None: + return {"success": False, "error": error or f"未找到 API: {name}"} + supervisor.api_registry.toggle_api_status(api_entry.full_name, True) + return {"success": True} + comp, error = self._resolve_component_toggle_target(name, component_type) if comp is None: return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} @@ -146,6 +359,13 @@ class RuntimeComponentCapabilityMixin: if scope != "global" or stream_id: return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} + if self._is_api_component_type(component_type): + supervisor, api_entry, error = self._resolve_api_toggle_target(name) + if supervisor is None or api_entry is None: + return {"success": False, "error": error or f"未找到 API: {name}"} + supervisor.api_registry.toggle_api_status(api_entry.full_name, False) + return {"success": True} + comp, error = self._resolve_component_toggle_target(name, component_type) if comp is None: return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"} @@ -239,3 +459,111 @@ class RuntimeComponentCapabilityMixin: logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}") return {"success": False, "error": str(e)} return {"success": False, "error": f"未找到插件: {plugin_name}"} + + async def _cap_api_call( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """调用其他插件公开的 API。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 调用结果。 + """ + + del capability + api_name = str(args.get("api_name", "") or "").strip() + version = str(args.get("version", "") or "").strip() + api_args = args.get("args", {}) + if not isinstance(api_args, dict): + return {"success": False, "error": "参数 args 必须为字典"} + + supervisor, entry, error = self._resolve_api_target(plugin_id, api_name, version) + if supervisor is None or entry is None: + return {"success": False, "error": error or "API 解析失败"} + + try: + response = await supervisor.invoke_api( + plugin_id=entry.plugin_id, + component_name=entry.name, + args=api_args, + timeout_ms=30000, + ) + except Exception as exc: + logger.error(f"[cap.api.call] 调用 API {entry.full_name} 失败: {exc}", exc_info=True) + return {"success": False, "error": str(exc)} + + if response.error: + return {"success": False, "error": response.error.get("message", "API 调用失败")} + + payload = response.payload if isinstance(response.payload, dict) else {} + if not bool(payload.get("success", False)): + result = payload.get("result") + return {"success": False, "error": "" if result is None else str(result)} + return {"success": True, "result": payload.get("result")} + + async def _cap_api_get( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """获取当前插件可见的单个 API 元信息。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 元信息或 ``None``。 + """ + + del capability + api_name = str(args.get("api_name", "") or "").strip() + version = str(args.get("version", "") or "").strip() + if not api_name: + return {"success": False, "error": "缺少必要参数 api_name"} + + supervisor, entry, _error = self._resolve_api_target(plugin_id, api_name, version) + if supervisor is None or entry is None: + return {"success": True, "api": None} + return {"success": True, "api": self._serialize_api_entry(entry)} + + async def _cap_api_list( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """列出当前插件可见的 API 列表。 + + Args: + plugin_id: 当前调用方插件 ID。 + capability: 能力名称。 + args: 能力参数。 + + Returns: + Any: API 元信息列表。 + """ + + del capability + target_plugin_id = str(args.get("plugin_id", "") or "").strip() + apis: List[Dict[str, Any]] = [] + for supervisor in self.supervisors: + for entry in supervisor.api_registry.get_apis( + plugin_id=target_plugin_id or None, + enabled_only=True, + ): + if not self._is_api_visible_to_plugin(entry, plugin_id): + continue + apis.append(self._serialize_api_entry(entry)) + + apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) + return {"success": True, "apis": apis} diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index 96b190b4..31693833 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -74,6 +74,10 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi _register("tool.get_definitions", manager._cap_tool_get_definitions) + _register("api.call", manager._cap_api_call) + _register("api.get", manager._cap_api_get) + _register("api.list", manager._cap_api_list) + _register("component.get_all_plugins", manager._cap_component_get_all_plugins) _register("component.get_plugin_info", manager._cap_component_get_plugin_info) _register("component.list_loaded_plugins", manager._cap_component_list_loaded_plugins) diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py new file mode 100644 index 00000000..84578ca5 --- /dev/null +++ b/src/plugin_runtime/host/api_registry.py @@ -0,0 +1,290 @@ +"""Host 侧插件 API 动态注册表。""" + +from typing import Any, Dict, List, Optional, Set + +from src.common.logger import get_logger + +logger = get_logger("plugin_runtime.host.api_registry") + + +class APIEntry: + """API 组件条目。""" + + __slots__ = ( + "description", + "disabled_session", + "enabled", + "full_name", + "metadata", + "name", + "plugin_id", + "public", + "version", + ) + + def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None: + """初始化 API 组件条目。 + + Args: + name: API 名称。 + plugin_id: 所属插件 ID。 + metadata: API 元数据。 + """ + + self.name: str = name + self.full_name: str = f"{plugin_id}.{name}" + self.plugin_id: str = plugin_id + self.description: str = str(metadata.get("description", "") or "") + self.version: str = str(metadata.get("version", "1") or "1").strip() or "1" + self.public: bool = bool(metadata.get("public", False)) + self.metadata: Dict[str, Any] = dict(metadata) + self.enabled: bool = bool(metadata.get("enabled", True)) + self.disabled_session: Set[str] = set() + + +class APIRegistry: + """Host 侧插件 API 动态注册表。 + + 该注册表不直接面向 Runner,而是复用插件组件注册/卸载事件, + 维护面向 API 调用场景的专用索引。 + """ + + def __init__(self) -> None: + """初始化 API 注册表。""" + + self._apis: Dict[str, APIEntry] = {} + self._by_plugin: Dict[str, List[APIEntry]] = {} + self._by_name: Dict[str, List[APIEntry]] = {} + + def clear(self) -> None: + """清空全部 API 注册状态。""" + + self._apis.clear() + self._by_plugin.clear() + self._by_name.clear() + + @staticmethod + def _is_api_component(component_type: Any) -> bool: + """判断组件声明是否属于 API。 + + Args: + component_type: 原始组件类型值。 + + Returns: + bool: 是否为 API 组件。 + """ + + return str(component_type or "").strip().upper() == "API" + + @staticmethod + def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool: + """判断 API 条目当前是否处于启用状态。 + + Args: + entry: 待检查的 API 条目。 + session_id: 可选的会话 ID。 + + Returns: + bool: 当前是否可用。 + """ + + if session_id and session_id in entry.disabled_session: + return False + return entry.enabled + + def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: + """注册单个 API 条目。 + + Args: + name: API 名称。 + plugin_id: 所属插件 ID。 + metadata: API 元数据。 + + Returns: + bool: 是否成功注册。 + """ + + normalized_name = str(name or "").strip() + if not normalized_name: + logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略") + return False + + entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata) + if entry.full_name in self._apis: + logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目") + self._remove_entry(self._apis[entry.full_name]) + + self._apis[entry.full_name] = entry + self._by_plugin.setdefault(plugin_id, []).append(entry) + self._by_name.setdefault(entry.name, []).append(entry) + return True + + def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: + """批量注册某个插件声明的全部 API。 + + Args: + plugin_id: 插件 ID。 + components: 插件组件声明列表。 + + Returns: + int: 成功注册的 API 数量。 + """ + + count = 0 + for component in components: + if not self._is_api_component(component.get("component_type")): + continue + if self.register_api( + name=str(component.get("name", "") or ""), + plugin_id=plugin_id, + metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}, + ): + count += 1 + return count + + def _remove_entry(self, entry: APIEntry) -> None: + """从全部索引中移除单个 API 条目。 + + Args: + entry: 待移除的 API 条目。 + """ + + self._apis.pop(entry.full_name, None) + plugin_entries = self._by_plugin.get(entry.plugin_id) + if plugin_entries is not None: + self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry] + if not self._by_plugin[entry.plugin_id]: + self._by_plugin.pop(entry.plugin_id, None) + + name_entries = self._by_name.get(entry.name) + if name_entries is not None: + self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry] + if not self._by_name[entry.name]: + self._by_name.pop(entry.name, None) + + def remove_apis_by_plugin(self, plugin_id: str) -> int: + """移除某个插件的全部 API。 + + Args: + plugin_id: 目标插件 ID。 + + Returns: + int: 被移除的 API 数量。 + """ + + entries = list(self._by_plugin.get(plugin_id, [])) + for entry in entries: + self._remove_entry(entry) + return len(entries) + + def get_api_by_full_name( + self, + full_name: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[APIEntry]: + """按完整名查询单个 API。 + + Args: + full_name: API 完整名,格式为 ``plugin_id.api_name``。 + enabled_only: 是否仅返回启用状态的 API。 + session_id: 可选的会话 ID。 + + Returns: + Optional[APIEntry]: 命中时返回 API 条目。 + """ + + entry = self._apis.get(full_name) + if entry is None: + return None + if enabled_only and not self.check_api_enabled(entry, session_id): + return None + return entry + + def get_api( + self, + plugin_id: str, + name: str, + *, + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> Optional[APIEntry]: + """按插件 ID 和短名查询单个 API。 + + Args: + plugin_id: 提供方插件 ID。 + name: API 短名。 + enabled_only: 是否仅返回启用状态的 API。 + session_id: 可选的会话 ID。 + + Returns: + Optional[APIEntry]: 命中时返回 API 条目。 + """ + + return self.get_api_by_full_name( + f"{plugin_id}.{name}", + enabled_only=enabled_only, + session_id=session_id, + ) + + def get_apis( + self, + *, + plugin_id: Optional[str] = None, + name: str = "", + enabled_only: bool = True, + session_id: Optional[str] = None, + ) -> List[APIEntry]: + """查询 API 列表。 + + Args: + plugin_id: 可选的插件 ID 过滤条件。 + name: 可选的 API 名称过滤条件。 + enabled_only: 是否仅返回启用状态的 API。 + session_id: 可选的会话 ID。 + + Returns: + List[APIEntry]: 符合条件的 API 条目列表。 + """ + + normalized_name = str(name or "").strip() + if plugin_id: + candidates = list(self._by_plugin.get(plugin_id, [])) + elif normalized_name: + candidates = list(self._by_name.get(normalized_name, [])) + else: + candidates = list(self._apis.values()) + + filtered_entries: List[APIEntry] = [] + for entry in candidates: + if normalized_name and entry.name != normalized_name: + continue + if enabled_only and not self.check_api_enabled(entry, session_id): + continue + filtered_entries.append(entry) + return filtered_entries + + def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: + """设置指定 API 的启用状态。 + + Args: + full_name: API 完整名。 + enabled: 目标启用状态。 + session_id: 可选的会话 ID,仅对该会话生效。 + + Returns: + bool: 是否设置成功。 + """ + + entry = self._apis.get(full_name) + if entry is None: + return False + if session_id: + if enabled: + entry.disabled_session.discard(session_id) + else: + entry.disabled_session.add(session_id) + else: + entry.enabled = enabled + return True diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 4a9885f8..1add64c6 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -34,6 +34,7 @@ from src.plugin_runtime.protocol.errors import ErrorCode, RPCError from src.plugin_runtime.transport.factory import create_transport_server from .authorization import AuthorizationManager +from .api_registry import APIRegistry from .capability_service import CapabilityService from .component_registry import ComponentRegistry from .event_dispatcher import EventDispatcher @@ -93,6 +94,7 @@ class PluginRunnerSupervisor: self._transport = create_transport_server(socket_path=socket_path) self._authorization = AuthorizationManager() self._capability_service = CapabilityService(self._authorization) + self._api_registry = APIRegistry() self._component_registry = ComponentRegistry() self._event_dispatcher = EventDispatcher(self._component_registry) self._hook_dispatcher = HookDispatcher(self._component_registry) @@ -124,6 +126,11 @@ class PluginRunnerSupervisor: """返回能力服务。""" return self._capability_service + @property + def api_registry(self) -> APIRegistry: + """返回 API 专用注册表。""" + return self._api_registry + @property def component_registry(self) -> ComponentRegistry: """返回组件注册表。""" @@ -310,6 +317,33 @@ class PluginRunnerSupervisor: timeout_ms=timeout_ms, ) + async def invoke_api( + self, + plugin_id: str, + component_name: str, + args: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用插件声明的 API 方法。 + + Args: + plugin_id: 目标插件 ID。 + component_name: API 组件名称。 + args: 传递给 API 方法的关键字参数。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 + """ + + return await self.invoke_plugin( + method="plugin.invoke_api", + plugin_id=plugin_id, + component_name=component_name, + args=args, + timeout_ms=timeout_ms, + ) + async def reload_plugin(self, plugin_id: str, reason: str = "manual") -> bool: """按插件 ID 触发精确重载。 @@ -507,13 +541,17 @@ class PluginRunnerSupervisor: except Exception as exc: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + component_declarations = [component.model_dump() for component in payload.components] + runtime_components, api_components = self._split_component_declarations(component_declarations) self._component_registry.remove_components_by_plugin(payload.plugin_id) + self._api_registry.remove_apis_by_plugin(payload.plugin_id) await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) registered_count = self._component_registry.register_plugin_components( payload.plugin_id, - [component.model_dump() for component in payload.components], + runtime_components, ) + registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components) self._registered_plugins[payload.plugin_id] = payload self._message_gateway_states[payload.plugin_id] = {} @@ -522,6 +560,7 @@ class PluginRunnerSupervisor: "accepted": True, "plugin_id": payload.plugin_id, "registered_components": registered_count, + "registered_apis": registered_api_count, "message_gateways": len( self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False) ), @@ -543,6 +582,7 @@ class PluginRunnerSupervisor: return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) + removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) @@ -554,10 +594,48 @@ class PluginRunnerSupervisor: "plugin_id": payload.plugin_id, "reason": payload.reason, "removed_components": removed_components, + "removed_apis": removed_apis, "removed_registration": removed_registration, } ) + @staticmethod + def _is_api_component(component: Dict[str, Any]) -> bool: + """判断组件声明是否属于 API。 + + Args: + component: 原始组件声明字典。 + + Returns: + bool: 是否为 API 组件。 + """ + + return str(component.get("component_type", "") or "").strip().upper() == "API" + + def _split_component_declarations( + self, + components: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """拆分通用组件声明和 API 声明。 + + Args: + components: Runner 上报的原始组件声明列表。 + + Returns: + Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + 第一个列表为需要进入通用组件表的声明, + 第二个列表为需要进入 API 专用表的声明。 + """ + + runtime_components: List[Dict[str, Any]] = [] + api_components: List[Dict[str, Any]] = [] + for component in components: + if self._is_api_component(component): + api_components.append(component) + else: + runtime_components.append(component) + return runtime_components, api_components + @staticmethod def _build_message_gateway_driver_id(plugin_id: str, gateway_name: str) -> str: """构造消息网关驱动 ID。 @@ -1172,6 +1250,7 @@ class PluginRunnerSupervisor: def _clear_runner_state(self) -> None: """清理当前 Runner 对应的 Host 侧注册状态。""" self._authorization.clear() + self._api_registry.clear() self._component_registry.clear() self._registered_plugins.clear() self._message_gateway_states.clear() diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 3a50e2f7..4bee714c 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -291,6 +291,7 @@ class PluginRunner: """注册 Host -> Runner 的方法处理器。""" self._rpc_client.register_method("plugin.invoke_command", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_action", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)