diff --git a/pytests/test_platform_io_legacy_driver.py b/pytests/test_platform_io_legacy_driver.py index 2e94c1fc..76f14d8f 100644 --- a/pytests/test_platform_io_legacy_driver.py +++ b/pytests/test_platform_io_legacy_driver.py @@ -82,7 +82,61 @@ async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route( ) explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) - assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"] + assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"] + finally: + await manager.stop() + + +@pytest.mark.asyncio +async def test_platform_io_broadcasts_to_plugin_and_legacy_driver( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同一路由命中插件驱动与 legacy driver 时,应同时广播发送。""" + + manager = PlatformIOManager() + legacy_calls: list[dict[str, Any]] = [] + monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"}) + + async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool: + """记录 legacy driver 调用。""" + + legacy_calls.append({"message": message, "show_log": show_log}) + return True + + monkeypatch.setattr( + uni_message_sender, + "send_prepared_message_to_platform", + _fake_send_prepared_message_to_platform, + ) + + try: + await manager.ensure_send_pipeline_ready() + + plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq") + await manager.add_driver(plugin_driver) + manager.bind_send_route( + RouteBinding( + route_key=RouteKey(platform="qq"), + driver_id=plugin_driver.driver_id, + driver_kind=plugin_driver.descriptor.kind, + ) + ) + + message = type("FakeMessage", (), {"message_id": "message-1"})() + batch = await manager.send_message( + message=message, + route_key=RouteKey(platform="qq"), + metadata={"show_log": False}, + ) + + assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [ + "legacy.send.qq", + "plugin.qq.sender", + ] + assert batch.failed_receipts == [] + assert len(legacy_calls) == 1 + assert legacy_calls[0]["message"] is message + assert legacy_calls[0]["show_log"] is False finally: await manager.stop() diff --git a/pytests/test_plugin_runtime_api.py b/pytests/test_plugin_runtime_api.py index fca7736a..58a8e6ba 100644 --- a/pytests/test_plugin_runtime_api.py +++ b/pytests/test_plugin_runtime_api.py @@ -292,3 +292,233 @@ async def test_api_list_and_component_toggle_use_dedicated_registry() -> None: ) assert enable_result["success"] is True assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None + + +@pytest.mark.asyncio +async def test_api_registry_supports_multiple_versions_with_distinct_handlers( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """同名 API 不同版本应可并存,并按版本路由到不同处理器。""" + + provider_supervisor = PluginSupervisor(plugin_dirs=[]) + consumer_supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin( + provider_supervisor, + "provider", + [ + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML v1", + "version": "1", + "public": True, + "handler_name": "handle_render_html_v1", + }, + }, + { + "name": "render_html", + "component_type": "API", + "metadata": { + "description": "渲染 HTML v2", + "version": "2", + "public": True, + "handler_name": "handle_render_html_v2", + }, + }, + ], + ) + await _register_plugin(consumer_supervisor, "consumer", []) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟多版本 API 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}}) + + monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api) + manager = _build_manager(provider_supervisor, consumer_supervisor) + + ambiguous_result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "args": {"html": "
Hello
"}, + }, + ) + assert ambiguous_result["success"] is False + assert "多个版本" in str(ambiguous_result["error"]) + + disable_ambiguous_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.render_html", + "component_type": "API", + "scope": "global", + "stream_id": "", + }, + ) + assert disable_ambiguous_result["success"] is False + assert "多个版本" in str(disable_ambiguous_result["error"]) + + disable_v1_result = await manager._cap_component_disable( + "consumer", + "component.disable", + { + "name": "provider.render_html", + "component_type": "API", + "scope": "global", + "stream_id": "", + "version": "1", + }, + ) + assert disable_v1_result["success"] is True + assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None + assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None + + result = await manager._cap_api_call( + "consumer", + "api.call", + { + "api_name": "provider.render_html", + "version": "2", + "args": {"html": "
Hello
"}, + }, + ) + + assert result == {"success": True, "result": {"image": "ok"}} + assert captured["plugin_id"] == "provider" + assert captured["component_name"] == "handle_render_html_v2" + assert captured["args"] == {"html": "
Hello
"} + + +@pytest.mark.asyncio +async def test_api_replace_dynamic_can_offline_removed_entries( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """动态 API 替换后,被移除的 API 应返回明确下线错误。""" + + supervisor = PluginSupervisor(plugin_dirs=[]) + await _register_plugin(supervisor, "provider", []) + manager = _build_manager(supervisor) + + captured: Dict[str, Any] = {} + + async def fake_invoke_api( + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None = None, + timeout_ms: int = 30000, + ) -> Any: + """模拟动态 API 调用。""" + + captured["plugin_id"] = plugin_id + captured["component_name"] = component_name + captured["args"] = args or {} + captured["timeout_ms"] = timeout_ms + return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}}) + + monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api) + + replace_result = await manager._cap_api_replace_dynamic( + "provider", + "api.replace_dynamic", + { + "apis": [ + { + "name": "mcp.search", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_search", + }, + }, + { + "name": "mcp.read", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_read", + }, + }, + ], + "offline_reason": "MCP 服务器已关闭", + }, + ) + + assert replace_result["success"] is True + assert replace_result["count"] == 2 + list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) + assert {(item["name"], item["version"]) for item in list_result["apis"]} == { + ("mcp.read", "1"), + ("mcp.search", "1"), + } + + call_result = await manager._cap_api_call( + "provider", + "api.call", + { + "api_name": "provider.mcp.search", + "version": "1", + "args": {"query": "hello"}, + }, + ) + assert call_result == {"success": True, "result": {"ok": True}} + assert captured["component_name"] == "dynamic_search" + assert captured["args"]["query"] == "hello" + assert captured["args"]["__maibot_api_name__"] == "mcp.search" + assert captured["args"]["__maibot_api_version__"] == "1" + + second_replace_result = await manager._cap_api_replace_dynamic( + "provider", + "api.replace_dynamic", + { + "apis": [ + { + "name": "mcp.read", + "type": "API", + "metadata": { + "version": "1", + "public": True, + "handler_name": "dynamic_read", + }, + } + ], + "offline_reason": "MCP 服务器已关闭", + }, + ) + + assert second_replace_result["success"] is True + assert second_replace_result["count"] == 1 + assert second_replace_result["offlined"] == 1 + + offlined_call_result = await manager._cap_api_call( + "provider", + "api.call", + { + "api_name": "provider.mcp.search", + "version": "1", + "args": {}, + }, + ) + assert offlined_call_result["success"] is False + assert "MCP 服务器已关闭" in str(offlined_call_result["error"]) + + list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"}) + assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == { + ("mcp.read", "1"), + } diff --git a/pytests/test_send_service.py b/pytests/test_send_service.py index 4ddd4fa1..16aad080 100644 --- a/pytests/test_send_service.py +++ b/pytests/test_send_service.py @@ -73,6 +73,19 @@ def _build_target_stream() -> BotChatSession: ) +def test_inherit_platform_io_route_metadata_falls_back_to_bot_account( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """没有上下文消息时,也应回填当前平台账号用于账号级路由命中。""" + + monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "") + + metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream()) + + assert metadata["platform_io_account_id"] == "bot-qq" + assert metadata["platform_io_target_user_id"] == "target-user" + + @pytest.mark.asyncio async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: """send service 应将发送职责统一交给 Platform IO。""" diff --git a/src/platform_io/manager.py b/src/platform_io/manager.py index be03e35d..ab1b11e5 100644 --- a/src/platform_io/manager.py +++ b/src/platform_io/manager.py @@ -394,23 +394,22 @@ class PlatformIOManager: """ drivers: List[PlatformIODriver] = [] + seen_driver_ids: set[str] = set() for binding in self._send_route_table.resolve_bindings(route_key): driver = self._driver_registry.get(binding.driver_id) - if driver is not None: + if driver is not None and driver.driver_id not in seen_driver_ids: drivers.append(driver) - if drivers: - return drivers + seen_driver_ids.add(driver.driver_id) fallback_driver = self._legacy_send_drivers.get(route_key.platform) - if fallback_driver is None: - return [] + if fallback_driver is not None: + descriptor = fallback_driver.descriptor + account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id) + scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope) + if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids: + drivers.append(fallback_driver) - descriptor = fallback_driver.descriptor - if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id): - return [] - if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope): - return [] - return [fallback_driver] + return drivers @staticmethod def build_route_key_from_message(message: "SessionMessage") -> RouteKey: diff --git a/src/platform_io/outbound_tracker.py b/src/platform_io/outbound_tracker.py index 438aa566..3725691f 100644 --- a/src/platform_io/outbound_tracker.py +++ b/src/platform_io/outbound_tracker.py @@ -92,11 +92,24 @@ class OutboundTracker: raise ValueError("ttl_seconds 必须大于 0") self._ttl_seconds = ttl_seconds - self._pending: Dict[str, PendingOutboundRecord] = {} - self._pending_expire_heap: List[Tuple[float, str]] = [] + self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {} + self._pending_expire_heap: List[Tuple[float, str, str]] = [] self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {} self._receipt_expire_heap: List[Tuple[float, str]] = [] + @staticmethod + def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]: + """构造单条出站跟踪记录的唯一键。 + + Args: + internal_message_id: 内部消息 ID。 + driver_id: 负责当前投递的驱动 ID。 + + Returns: + Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。 + """ + return internal_message_id, driver_id + def begin_tracking( self, internal_message_id: str, @@ -116,13 +129,15 @@ class OutboundTracker: PendingOutboundRecord: 新创建的待完成记录。 Raises: - ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。 + ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在 + 未完成记录时抛出。 """ now = time.monotonic() self._cleanup_expired(now) + pending_key = self._build_pending_key(internal_message_id, driver_id) - if internal_message_id in self._pending: - raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录") + if pending_key in self._pending: + raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录") expires_at = now + self._ttl_seconds record = PendingOutboundRecord( @@ -133,8 +148,8 @@ class OutboundTracker: expires_at=expires_at, metadata=metadata or {}, ) - self._pending[internal_message_id] = record - heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id)) + self._pending[pending_key] = record + heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id)) return record def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]: @@ -149,7 +164,19 @@ class OutboundTracker: now = time.monotonic() self._cleanup_expired(now) - pending_record = self._pending.pop(receipt.internal_message_id, None) + pending_record: Optional[PendingOutboundRecord] = None + if receipt.driver_id: + pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id) + pending_record = self._pending.pop(pending_key, None) + else: + matched_records = [ + key + for key, record in self._pending.items() + if record.internal_message_id == receipt.internal_message_id + ] + if len(matched_records) == 1: + pending_record = self._pending.pop(matched_records[0], None) + if receipt.external_message_id: expires_at = now + self._ttl_seconds self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt( @@ -160,17 +187,33 @@ class OutboundTracker: heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id)) return pending_record - def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]: + def get_pending( + self, + internal_message_id: str, + driver_id: Optional[str] = None, + ) -> Optional[PendingOutboundRecord]: """根据内部消息 ID 查询待完成记录。 Args: internal_message_id: 要查询的内部消息 ID。 + driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。 Returns: Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。 """ self._cleanup_expired(time.monotonic()) - return self._pending.get(internal_message_id) + + if driver_id: + return self._pending.get(self._build_pending_key(internal_message_id, driver_id)) + + matched_records = [ + record + for record in self._pending.values() + if record.internal_message_id == internal_message_id + ] + if len(matched_records) == 1: + return matched_records[0] + return None def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]: """根据外部平台消息 ID 查询已完成回执。 @@ -213,13 +256,14 @@ class OutboundTracker: ``expires_at`` 对比,跳过这类旧节点。 """ while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now: - expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap) - current_record = self._pending.get(internal_message_id) + expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap) + pending_key = self._build_pending_key(internal_message_id, driver_id) + current_record = self._pending.get(pending_key) if current_record is None: continue if current_record.expires_at != expires_at: continue - self._pending.pop(internal_message_id, None) + self._pending.pop(pending_key, None) def _cleanup_expired_receipts(self, now: float) -> None: """清理已经过期的回执索引。 diff --git a/src/plugin_runtime/capabilities/components.py b/src/plugin_runtime/capabilities/components.py index 2eede108..67033fdd 100644 --- a/src/plugin_runtime/capabilities/components.py +++ b/src/plugin_runtime/capabilities/components.py @@ -72,6 +72,8 @@ class RuntimeComponentCapabilityMixin: "version": entry.version, "public": entry.public, "enabled": entry.enabled, + "dynamic": entry.dynamic, + "offline_reason": entry.offline_reason, "metadata": dict(entry.metadata), } @@ -109,6 +111,32 @@ class RuntimeComponentCapabilityMixin: return entry.plugin_id == caller_plugin_id or entry.public + @staticmethod + def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]: + """规范化 API 名称与版本参数。 + + 支持在 ``api_name`` 中直接携带 ``@version`` 后缀。 + """ + + normalized_api_name = str(api_name or "").strip() + normalized_version = str(version or "").strip() + if normalized_api_name and not normalized_version and "@" in normalized_api_name: + candidate_name, candidate_version = normalized_api_name.rsplit("@", 1) + candidate_name = candidate_name.strip() + candidate_version = candidate_version.strip() + if candidate_name and candidate_version: + normalized_api_name = candidate_name + normalized_version = candidate_version + return normalized_api_name, normalized_version + + @staticmethod + def _build_api_unavailable_error(entry: "APIEntry") -> str: + """构造 API 当前不可用时的错误信息。""" + + if entry.offline_reason: + return entry.offline_reason + return f"API {entry.registry_key} 当前不可用" + def _resolve_api_target( self: _RuntimeComponentManagerProtocol, caller_plugin_id: str, @@ -127,8 +155,7 @@ class RuntimeComponentCapabilityMixin: 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 """ - normalized_api_name = str(api_name or "").strip() - normalized_version = str(version or "").strip() + normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version) if not normalized_api_name: return None, None, "缺少必要参数 api_name" @@ -142,34 +169,61 @@ class RuntimeComponentCapabilityMixin: if supervisor is None: return None, None, f"未找到 API 提供方插件: {target_plugin_id}" - entry = supervisor.api_registry.get_api( + entries = supervisor.api_registry.get_apis( plugin_id=target_plugin_id, name=target_api_name, - enabled_only=True, + version=normalized_version, + enabled_only=False, ) - if entry is None: - return None, None, f"未找到 API: {normalized_api_name}" - if normalized_version and entry.version != normalized_version: - return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" - if not self._is_api_visible_to_plugin(entry, caller_plugin_id): + visible_enabled_entries = [ + entry + for entry in entries + if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled + ] + visible_disabled_entries = [ + entry + for entry in entries + if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled + ] + if len(visible_enabled_entries) == 1: + return supervisor, visible_enabled_entries[0], None + if len(visible_enabled_entries) > 1: + return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version" + if visible_disabled_entries: + if len(visible_disabled_entries) == 1: + return None, None, self._build_api_unavailable_error(visible_disabled_entries[0]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version" + if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries): return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" - return supervisor, entry, None + if normalized_version: + return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" + return None, None, f"未找到 API: {normalized_api_name}" - visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] + visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] hidden_match_exists = False for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True): - if normalized_version and entry.version != normalized_version: - continue + for entry in supervisor.api_registry.get_apis( + name=normalized_api_name, + version=normalized_version, + enabled_only=False, + ): if self._is_api_visible_to_plugin(entry, caller_plugin_id): - visible_matches.append((supervisor, entry)) + if entry.enabled: + visible_enabled_matches.append((supervisor, entry)) + else: + visible_disabled_matches.append((supervisor, entry)) else: hidden_match_exists = True - if len(visible_matches) == 1: - return visible_matches[0][0], visible_matches[0][1], None - if len(visible_matches) > 1: - return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name" + if len(visible_enabled_matches) == 1: + return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None + if len(visible_enabled_matches) > 1: + return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version" + if visible_disabled_matches: + if len(visible_disabled_matches) == 1: + return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1]) + return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version" if hidden_match_exists: return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" if normalized_version: @@ -179,18 +233,20 @@ class RuntimeComponentCapabilityMixin: def _resolve_api_toggle_target( self: _RuntimeComponentManagerProtocol, name: str, + version: str = "", ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: """解析需要启用或禁用的 API 组件。 Args: name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 + version: 可选的 API 版本。 Returns: tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 """ - normalized_name = str(name or "").strip() + normalized_name, normalized_version = self._normalize_api_reference(name, version) if not normalized_name: return None, None, "缺少必要参数 name" @@ -204,24 +260,31 @@ class RuntimeComponentCapabilityMixin: if supervisor is None: return None, None, f"未找到 API 提供方插件: {plugin_id}" - entry = supervisor.api_registry.get_api( + entries = supervisor.api_registry.get_apis( plugin_id=plugin_id, name=api_name, + version=normalized_version, enabled_only=False, ) - if entry is None: + if not entries: return None, None, f"未找到 API: {normalized_name}" - return supervisor, entry, None + if len(entries) > 1: + return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version" + return supervisor, entries[0], None matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] for supervisor in self.supervisors: - for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False): + for entry in supervisor.api_registry.get_apis( + name=normalized_name, + version=normalized_version, + enabled_only=False, + ): matches.append((supervisor, entry)) if len(matches) == 1: return matches[0][0], matches[0][1], None if len(matches) > 1: - return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name" + return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version" return None, None, f"未找到 API: {normalized_name}" async def _cap_component_get_all_plugins( @@ -326,6 +389,7 @@ class RuntimeComponentCapabilityMixin: ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") + version: str = args.get("version", "") scope: str = args.get("scope", "global") stream_id: str = args.get("stream_id", "") if not name or not component_type: @@ -334,10 +398,10 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} if self._is_api_component_type(component_type): - supervisor, api_entry, error = self._resolve_api_toggle_target(name) + supervisor, api_entry, error = self._resolve_api_toggle_target(name, version) if supervisor is None or api_entry is None: return {"success": False, "error": error or f"未找到 API: {name}"} - supervisor.api_registry.toggle_api_status(api_entry.full_name, True) + supervisor.api_registry.toggle_api_status(api_entry.registry_key, True) return {"success": True} comp, error = self._resolve_component_toggle_target(name, component_type) @@ -352,6 +416,7 @@ class RuntimeComponentCapabilityMixin: ) -> Any: name: str = args.get("name", "") component_type: str = args.get("component_type", "") + version: str = args.get("version", "") scope: str = args.get("scope", "global") stream_id: str = args.get("stream_id", "") if not name or not component_type: @@ -360,10 +425,10 @@ class RuntimeComponentCapabilityMixin: return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} if self._is_api_component_type(component_type): - supervisor, api_entry, error = self._resolve_api_toggle_target(name) + supervisor, api_entry, error = self._resolve_api_toggle_target(name, version) if supervisor is None or api_entry is None: return {"success": False, "error": error or f"未找到 API: {name}"} - supervisor.api_registry.toggle_api_status(api_entry.full_name, False) + supervisor.api_registry.toggle_api_status(api_entry.registry_key, False) return {"success": True} comp, error = self._resolve_component_toggle_target(name, component_type) @@ -488,11 +553,17 @@ class RuntimeComponentCapabilityMixin: if supervisor is None or entry is None: return {"success": False, "error": error or "API 解析失败"} + invoke_args = dict(api_args) + if entry.dynamic: + invoke_args.setdefault("__maibot_api_name__", entry.name) + invoke_args.setdefault("__maibot_api_full_name__", entry.full_name) + invoke_args.setdefault("__maibot_api_version__", entry.version) + try: response = await supervisor.invoke_api( plugin_id=entry.plugin_id, - component_name=entry.name, - args=api_args, + component_name=entry.handler_name, + args=invoke_args, timeout_ms=30000, ) except Exception as exc: @@ -555,10 +626,16 @@ class RuntimeComponentCapabilityMixin: del capability target_plugin_id = str(args.get("plugin_id", "") or "").strip() + api_name, version = self._normalize_api_reference( + str(args.get("api_name", args.get("name", "")) or ""), + str(args.get("version", "") or ""), + ) apis: List[Dict[str, Any]] = [] for supervisor in self.supervisors: for entry in supervisor.api_registry.get_apis( plugin_id=target_plugin_id or None, + name=api_name, + version=version, enabled_only=True, ): if not self._is_api_visible_to_plugin(entry, plugin_id): @@ -567,3 +644,75 @@ class RuntimeComponentCapabilityMixin: apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) return {"success": True, "apis": apis} + + async def _cap_api_replace_dynamic( + self: _RuntimeComponentManagerProtocol, + plugin_id: str, + capability: str, + args: Dict[str, Any], + ) -> Any: + """替换插件自行维护的动态 API 列表。""" + + del capability + raw_apis = args.get("apis", []) + offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线" + if not isinstance(raw_apis, list): + return {"success": False, "error": "参数 apis 必须为列表"} + + try: + supervisor = self._get_supervisor_for_plugin(plugin_id) + except RuntimeError as exc: + return {"success": False, "error": str(exc)} + + if supervisor is None: + return {"success": False, "error": f"未找到插件: {plugin_id}"} + + normalized_components: List[Dict[str, Any]] = [] + seen_registry_keys: set[str] = set() + for index, raw_api in enumerate(raw_apis): + if not isinstance(raw_api, dict): + return {"success": False, "error": f"apis[{index}] 必须为字典"} + + api_name = str(raw_api.get("name", "") or "").strip() + component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip() + if not api_name: + return {"success": False, "error": f"apis[{index}] 缺少 name"} + if not self._is_api_component_type(component_type): + return {"success": False, "error": f"apis[{index}] 不是 API 组件"} + + metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {} + normalized_metadata = dict(metadata) + normalized_metadata["dynamic"] = True + version = str(normalized_metadata.get("version", "1") or "1").strip() or "1" + registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version) + if registry_key in seen_registry_keys: + return {"success": False, "error": f"动态 API 重复声明: {registry_key}"} + seen_registry_keys.add(registry_key) + + existing_entry = supervisor.api_registry.get_api( + plugin_id, + api_name, + version=version, + enabled_only=False, + ) + if existing_entry is not None and not existing_entry.dynamic: + return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"} + + normalized_components.append( + { + "name": api_name, + "component_type": "API", + "metadata": normalized_metadata, + } + ) + + registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis( + plugin_id, + normalized_components, + offline_reason=offline_reason, + ) + return { + "success": True, + "count": registered_count, + "offlined": offlined_count, + } diff --git a/src/plugin_runtime/capabilities/registry.py b/src/plugin_runtime/capabilities/registry.py index 31693833..7f87604d 100644 --- a/src/plugin_runtime/capabilities/registry.py +++ b/src/plugin_runtime/capabilities/registry.py @@ -77,6 +77,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi _register("api.call", manager._cap_api_call) _register("api.get", manager._cap_api_get) _register("api.list", manager._cap_api_list) + _register("api.replace_dynamic", manager._cap_api_replace_dynamic) _register("component.get_all_plugins", manager._cap_component_get_all_plugins) _register("component.get_plugin_info", manager._cap_component_get_plugin_info) diff --git a/src/plugin_runtime/host/api_registry.py b/src/plugin_runtime/host/api_registry.py index 84578ca5..1cbc05f6 100644 --- a/src/plugin_runtime/host/api_registry.py +++ b/src/plugin_runtime/host/api_registry.py @@ -1,45 +1,60 @@ """Host 侧插件 API 动态注册表。""" -from typing import Any, Dict, List, Optional, Set +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple from src.common.logger import get_logger logger = get_logger("plugin_runtime.host.api_registry") +@dataclass(slots=True) class APIEntry: """API 组件条目。""" - __slots__ = ( - "description", - "disabled_session", - "enabled", - "full_name", - "metadata", - "name", - "plugin_id", - "public", - "version", - ) + name: str + plugin_id: str + description: str = "" + version: str = "1" + public: bool = False + metadata: Dict[str, Any] = field(default_factory=dict) + enabled: bool = True + handler_name: str = "" + dynamic: bool = False + offline_reason: str = "" + disabled_session: Set[str] = field(default_factory=set) + full_name: str = field(init=False) + registry_key: str = field(init=False) - def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None: - """初始化 API 组件条目。 + def __post_init__(self) -> None: + """规范化 API 条目字段。""" - Args: - name: API 名称。 - plugin_id: 所属插件 ID。 - metadata: API 元数据。 - """ + self.name = str(self.name or "").strip() + self.plugin_id = str(self.plugin_id or "").strip() + self.description = str(self.description or "").strip() + self.version = str(self.version or "1").strip() or "1" + self.handler_name = str(self.handler_name or self.name).strip() or self.name + self.offline_reason = str(self.offline_reason or "").strip() + self.full_name = f"{self.plugin_id}.{self.name}" + self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version) - self.name: str = name - self.full_name: str = f"{plugin_id}.{name}" - self.plugin_id: str = plugin_id - self.description: str = str(metadata.get("description", "") or "") - self.version: str = str(metadata.get("version", "1") or "1").strip() or "1" - self.public: bool = bool(metadata.get("public", False)) - self.metadata: Dict[str, Any] = dict(metadata) - self.enabled: bool = bool(metadata.get("enabled", True)) - self.disabled_session: Set[str] = set() + @classmethod + def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry": + """根据 Runner 上报的元数据构造 API 条目。""" + + safe_metadata = dict(metadata) + return cls( + name=name, + plugin_id=plugin_id, + description=str(safe_metadata.get("description", "") or ""), + version=str(safe_metadata.get("version", "1") or "1"), + public=bool(safe_metadata.get("public", False)), + metadata=safe_metadata, + enabled=bool(safe_metadata.get("enabled", True)), + handler_name=str(safe_metadata.get("handler_name", name) or name), + dynamic=bool(safe_metadata.get("dynamic", False)), + offline_reason=str(safe_metadata.get("offline_reason", "") or ""), + ) class APIRegistry: @@ -53,6 +68,7 @@ class APIRegistry: """初始化 API 注册表。""" self._apis: Dict[str, APIEntry] = {} + self._by_full_name: Dict[str, List[APIEntry]] = {} self._by_plugin: Dict[str, List[APIEntry]] = {} self._by_name: Dict[str, List[APIEntry]] = {} @@ -60,75 +76,75 @@ class APIRegistry: """清空全部 API 注册状态。""" self._apis.clear() + self._by_full_name.clear() self._by_plugin.clear() self._by_name.clear() @staticmethod def _is_api_component(component_type: Any) -> bool: - """判断组件声明是否属于 API。 - - Args: - component_type: 原始组件类型值。 - - Returns: - bool: 是否为 API 组件。 - """ + """判断组件声明是否属于 API。""" return str(component_type or "").strip().upper() == "API" + @staticmethod + def _normalize_query_version(version: Any) -> str: + """规范化查询使用的版本字符串。""" + + return str(version or "").strip() + + @classmethod + def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]: + """解析可能带 ``@version`` 后缀的 API 引用。""" + + normalized_reference = str(reference or "").strip() + normalized_version = cls._normalize_query_version(version) + if normalized_reference and not normalized_version and "@" in normalized_reference: + candidate_reference, candidate_version = normalized_reference.rsplit("@", 1) + candidate_reference = candidate_reference.strip() + candidate_version = candidate_version.strip() + if candidate_reference and candidate_version: + normalized_reference = candidate_reference + normalized_version = candidate_version + return normalized_reference, normalized_version + + @staticmethod + def build_registry_key(plugin_id: str, name: str, version: str) -> str: + """构造 API 注册表唯一键。""" + + normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}" + normalized_version = str(version or "1").strip() or "1" + return f"{normalized_full_name}@{normalized_version}" + @staticmethod def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool: - """判断 API 条目当前是否处于启用状态。 - - Args: - entry: 待检查的 API 条目。 - session_id: 可选的会话 ID。 - - Returns: - bool: 当前是否可用。 - """ + """判断 API 条目当前是否处于启用状态。""" if session_id and session_id in entry.disabled_session: return False return entry.enabled def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: - """注册单个 API 条目。 - - Args: - name: API 名称。 - plugin_id: 所属插件 ID。 - metadata: API 元数据。 - - Returns: - bool: 是否成功注册。 - """ + """注册单个 API 条目。""" normalized_name = str(name or "").strip() if not normalized_name: logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略") return False - entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata) - if entry.full_name in self._apis: - logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目") - self._remove_entry(self._apis[entry.full_name]) + entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata) + existing_entry = self._apis.get(entry.registry_key) + if existing_entry is not None: + logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目") + self._remove_entry(existing_entry) - self._apis[entry.full_name] = entry + self._apis[entry.registry_key] = entry + self._by_full_name.setdefault(entry.full_name, []).append(entry) self._by_plugin.setdefault(plugin_id, []).append(entry) self._by_name.setdefault(entry.name, []).append(entry) return True def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: - """批量注册某个插件声明的全部 API。 - - Args: - plugin_id: 插件 ID。 - components: 插件组件声明列表。 - - Returns: - int: 成功注册的 API 数量。 - """ + """批量注册某个插件声明的全部 API。""" count = 0 for component in components: @@ -142,14 +158,60 @@ class APIRegistry: count += 1 return count + def replace_plugin_dynamic_apis( + self, + plugin_id: str, + components: List[Dict[str, Any]], + *, + offline_reason: str = "动态 API 已下线", + ) -> Tuple[int, int]: + """替换指定插件当前声明的动态 API 集合。""" + + normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线" + desired_registry_keys: Set[str] = set() + registered_count = 0 + + for component in components: + if not self._is_api_component(component.get("component_type")): + continue + metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {} + dynamic_metadata = dict(metadata) + dynamic_metadata["dynamic"] = True + dynamic_metadata.pop("offline_reason", None) + + entry = APIEntry.from_metadata( + name=str(component.get("name", "") or ""), + plugin_id=plugin_id, + metadata=dynamic_metadata, + ) + desired_registry_keys.add(entry.registry_key) + if self.register_api(entry.name, plugin_id, dynamic_metadata): + registered_count += 1 + + offlined_count = 0 + for entry in list(self._by_plugin.get(plugin_id, [])): + if not entry.dynamic or entry.registry_key in desired_registry_keys: + continue + entry.enabled = False + entry.offline_reason = normalized_offline_reason + entry.metadata["offline_reason"] = normalized_offline_reason + offlined_count += 1 + + return registered_count, offlined_count + def _remove_entry(self, entry: APIEntry) -> None: - """从全部索引中移除单个 API 条目。 + """从全部索引中移除单个 API 条目。""" - Args: - entry: 待移除的 API 条目。 - """ + self._apis.pop(entry.registry_key, None) + + full_name_entries = self._by_full_name.get(entry.full_name) + if full_name_entries is not None: + self._by_full_name[entry.full_name] = [ + candidate for candidate in full_name_entries if candidate is not entry + ] + if not self._by_full_name[entry.full_name]: + self._by_full_name.pop(entry.full_name, None) - self._apis.pop(entry.full_name, None) plugin_entries = self._by_plugin.get(entry.plugin_id) if plugin_entries is not None: self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry] @@ -163,14 +225,7 @@ class APIRegistry: self._by_name.pop(entry.name, None) def remove_apis_by_plugin(self, plugin_id: str) -> int: - """移除某个插件的全部 API。 - - Args: - plugin_id: 目标插件 ID。 - - Returns: - int: 被移除的 API 数量。 - """ + """移除某个插件的全部 API。""" entries = list(self._by_plugin.get(plugin_id, [])) for entry in entries: @@ -181,49 +236,48 @@ class APIRegistry: self, full_name: str, *, + version: str = "", enabled_only: bool = True, session_id: Optional[str] = None, ) -> Optional[APIEntry]: - """按完整名查询单个 API。 + """按完整名查询单个 API。""" - Args: - full_name: API 完整名,格式为 ``plugin_id.api_name``。 - enabled_only: 是否仅返回启用状态的 API。 - session_id: 可选的会话 ID。 - - Returns: - Optional[APIEntry]: 命中时返回 API 条目。 - """ - - entry = self._apis.get(full_name) - if entry is None: + normalized_full_name, normalized_version = self._split_reference(full_name, version) + if not normalized_full_name: return None - if enabled_only and not self.check_api_enabled(entry, session_id): + + if normalized_version: + entry = self._apis.get(f"{normalized_full_name}@{normalized_version}") + if entry is None: + return None + if enabled_only and not self.check_api_enabled(entry, session_id): + return None + return entry + + candidates = list(self._by_full_name.get(normalized_full_name, [])) + filtered_entries = [ + entry + for entry in candidates + if not enabled_only or self.check_api_enabled(entry, session_id) + ] + if len(filtered_entries) != 1: return None - return entry + return filtered_entries[0] def get_api( self, plugin_id: str, name: str, *, + version: str = "", enabled_only: bool = True, session_id: Optional[str] = None, ) -> Optional[APIEntry]: - """按插件 ID 和短名查询单个 API。 - - Args: - plugin_id: 提供方插件 ID。 - name: API 短名。 - enabled_only: 是否仅返回启用状态的 API。 - session_id: 可选的会话 ID。 - - Returns: - Optional[APIEntry]: 命中时返回 API 条目。 - """ + """按插件 ID、短名与版本查询单个 API。""" return self.get_api_by_full_name( f"{plugin_id}.{name}", + version=version, enabled_only=enabled_only, session_id=session_id, ) @@ -233,22 +287,15 @@ class APIRegistry: *, plugin_id: Optional[str] = None, name: str = "", + version: str = "", enabled_only: bool = True, session_id: Optional[str] = None, ) -> List[APIEntry]: - """查询 API 列表。 - - Args: - plugin_id: 可选的插件 ID 过滤条件。 - name: 可选的 API 名称过滤条件。 - enabled_only: 是否仅返回启用状态的 API。 - session_id: 可选的会话 ID。 - - Returns: - List[APIEntry]: 符合条件的 API 条目列表。 - """ + """查询 API 列表。""" normalized_name = str(name or "").strip() + normalized_version = self._normalize_query_version(version) + if plugin_id: candidates = list(self._by_plugin.get(plugin_id, [])) elif normalized_name: @@ -258,26 +305,35 @@ class APIRegistry: filtered_entries: List[APIEntry] = [] for entry in candidates: + if plugin_id and entry.plugin_id != plugin_id: + continue if normalized_name and entry.name != normalized_name: continue + if normalized_version and entry.version != normalized_version: + continue if enabled_only and not self.check_api_enabled(entry, session_id): continue filtered_entries.append(entry) + + filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version)) return filtered_entries - def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: - """设置指定 API 的启用状态。 + def toggle_api_status( + self, + full_name: str, + enabled: bool, + *, + version: str = "", + session_id: Optional[str] = None, + ) -> bool: + """设置指定 API 的启用状态。""" - Args: - full_name: API 完整名。 - enabled: 目标启用状态。 - session_id: 可选的会话 ID,仅对该会话生效。 - - Returns: - bool: 是否设置成功。 - """ - - entry = self._apis.get(full_name) + entry = self.get_api_by_full_name( + full_name, + version=version, + enabled_only=False, + session_id=session_id, + ) if entry is None: return False if session_id: @@ -287,4 +343,7 @@ class APIRegistry: entry.disabled_session.add(session_id) else: entry.enabled = enabled + if enabled: + entry.offline_reason = "" + entry.metadata.pop("offline_reason", None) return True diff --git a/src/plugin_runtime/host/authorization.py b/src/plugin_runtime/host/authorization.py index 3fb48c6a..70593768 100644 --- a/src/plugin_runtime/host/authorization.py +++ b/src/plugin_runtime/host/authorization.py @@ -7,6 +7,8 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Set, Tuple +_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"}) + @dataclass class CapabilityPermissionToken: @@ -46,6 +48,9 @@ class AuthorizationManager: Returns: return (bool, str): (是否有此能力, 原因) """ + if capability in _ALWAYS_ALLOWED_CAPABILITIES: + return True, "" + token = self._permission_tokens.get(plugin_id) if not token: return False, f"插件 {plugin_id} 未注册能力令牌" diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index b94b01d1..b38946d6 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -45,6 +45,14 @@ from src.plugin_runtime.runner.rpc_client import RPCClient logger = get_logger("plugin_runtime.runner.main") +_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset( + { + "cap.call", + "host.route_message", + "host.update_message_gateway_state", + } +) + class _ContextAwarePlugin(Protocol): """支持注入运行时上下文的插件协议。 @@ -247,8 +255,14 @@ class PluginRunner: logger.warning( f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份" ) + normalized_method = str(method or "").strip() + if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS: + raise PermissionError( + f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: " + f"{normalized_method or ''}" + ) resp = await rpc_client.send_request( - method=method, + method=normalized_method, plugin_id=bound_plugin_id, payload=payload or {}, ) diff --git a/src/services/send_service.py b/src/services/send_service.py index 54f2a9de..134fb15e 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -57,20 +57,23 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s inherited_metadata: Dict[str, object] = {} context_message = target_stream.context.message if target_stream.context else None - if context_message is None: - return inherited_metadata + if context_message is not None: + additional_config = context_message.message_info.additional_config + if isinstance(additional_config, dict): + for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): + value = additional_config.get(key) + if value is None: + continue + normalized_value = str(value).strip() + if normalized_value: + inherited_metadata[key] = value - additional_config = context_message.message_info.additional_config - if not isinstance(additional_config, dict): - return inherited_metadata - - for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): - value = additional_config.get(key) - if value is None: - continue - normalized_value = str(value).strip() - if normalized_value: - inherited_metadata[key] = value + # 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号, + # 让按 ``platform + account_id`` 绑定的路由仍有机会命中。 + if not RouteKeyFactory.extract_components(inherited_metadata)[0]: + bot_account = get_bot_account(target_stream.platform) + if bot_account: + inherited_metadata["platform_io_account_id"] = bot_account if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()): inherited_metadata["platform_io_target_group_id"] = normalized_group_id