feat: Enhance API and Outbound Tracking Functionality

- Add test for fallback to bot account in platform IO route metadata when context message is absent.
- Improve PlatformIOManager to avoid duplicate driver entries and streamline fallback driver handling.
- Refactor OutboundTracker to support tracking by both internal message ID and driver ID, enhancing the uniqueness of pending records.
- Introduce dynamic API capabilities in RuntimeComponent, allowing plugins to replace their dynamic API lists.
- Update APIRegistry to manage dynamic APIs more effectively, including registration and toggling of API statuses.
- Implement authorization checks for dynamic API capabilities to ensure proper permissions.
- Restrict direct calls to certain host RPC methods from plugins for enhanced security.
- Refactor send_service to ensure fallback to current platform account when no context message is available.
This commit is contained in:
DrSmoothl
2026-03-23 21:01:55 +08:00
parent d13767ee21
commit 7a304ba549
11 changed files with 771 additions and 200 deletions

View File

@@ -82,7 +82,61 @@ async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
) )
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq")) explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"] assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
finally:
await manager.stop()
@pytest.mark.asyncio
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
manager = PlatformIOManager()
legacy_calls: list[dict[str, Any]] = []
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
"""记录 legacy driver 调用。"""
legacy_calls.append({"message": message, "show_log": show_log})
return True
monkeypatch.setattr(
uni_message_sender,
"send_prepared_message_to_platform",
_fake_send_prepared_message_to_platform,
)
try:
await manager.ensure_send_pipeline_ready()
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
await manager.add_driver(plugin_driver)
manager.bind_send_route(
RouteBinding(
route_key=RouteKey(platform="qq"),
driver_id=plugin_driver.driver_id,
driver_kind=plugin_driver.descriptor.kind,
)
)
message = type("FakeMessage", (), {"message_id": "message-1"})()
batch = await manager.send_message(
message=message,
route_key=RouteKey(platform="qq"),
metadata={"show_log": False},
)
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
"legacy.send.qq",
"plugin.qq.sender",
]
assert batch.failed_receipts == []
assert len(legacy_calls) == 1
assert legacy_calls[0]["message"] is message
assert legacy_calls[0]["show_log"] is False
finally: finally:
await manager.stop() await manager.stop()

View File

@@ -292,3 +292,233 @@ async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
) )
assert enable_result["success"] is True assert enable_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
@pytest.mark.asyncio
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
provider_supervisor = PluginSupervisor(plugin_dirs=[])
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(
provider_supervisor,
"provider",
[
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML v1",
"version": "1",
"public": True,
"handler_name": "handle_render_html_v1",
},
},
{
"name": "render_html",
"component_type": "API",
"metadata": {
"description": "渲染 HTML v2",
"version": "2",
"public": True,
"handler_name": "handle_render_html_v2",
},
},
],
)
await _register_plugin(consumer_supervisor, "consumer", [])
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟多版本 API 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
manager = _build_manager(provider_supervisor, consumer_supervisor)
ambiguous_result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"args": {"html": "<div>Hello</div>"},
},
)
assert ambiguous_result["success"] is False
assert "多个版本" in str(ambiguous_result["error"])
disable_ambiguous_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.render_html",
"component_type": "API",
"scope": "global",
"stream_id": "",
},
)
assert disable_ambiguous_result["success"] is False
assert "多个版本" in str(disable_ambiguous_result["error"])
disable_v1_result = await manager._cap_component_disable(
"consumer",
"component.disable",
{
"name": "provider.render_html",
"component_type": "API",
"scope": "global",
"stream_id": "",
"version": "1",
},
)
assert disable_v1_result["success"] is True
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
result = await manager._cap_api_call(
"consumer",
"api.call",
{
"api_name": "provider.render_html",
"version": "2",
"args": {"html": "<div>Hello</div>"},
},
)
assert result == {"success": True, "result": {"image": "ok"}}
assert captured["plugin_id"] == "provider"
assert captured["component_name"] == "handle_render_html_v2"
assert captured["args"] == {"html": "<div>Hello</div>"}
@pytest.mark.asyncio
async def test_api_replace_dynamic_can_offline_removed_entries(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
supervisor = PluginSupervisor(plugin_dirs=[])
await _register_plugin(supervisor, "provider", [])
manager = _build_manager(supervisor)
captured: Dict[str, Any] = {}
async def fake_invoke_api(
plugin_id: str,
component_name: str,
args: Dict[str, Any] | None = None,
timeout_ms: int = 30000,
) -> Any:
"""模拟动态 API 调用。"""
captured["plugin_id"] = plugin_id
captured["component_name"] = component_name
captured["args"] = args or {}
captured["timeout_ms"] = timeout_ms
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
replace_result = await manager._cap_api_replace_dynamic(
"provider",
"api.replace_dynamic",
{
"apis": [
{
"name": "mcp.search",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_search",
},
},
{
"name": "mcp.read",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_read",
},
},
],
"offline_reason": "MCP 服务器已关闭",
},
)
assert replace_result["success"] is True
assert replace_result["count"] == 2
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
("mcp.read", "1"),
("mcp.search", "1"),
}
call_result = await manager._cap_api_call(
"provider",
"api.call",
{
"api_name": "provider.mcp.search",
"version": "1",
"args": {"query": "hello"},
},
)
assert call_result == {"success": True, "result": {"ok": True}}
assert captured["component_name"] == "dynamic_search"
assert captured["args"]["query"] == "hello"
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
assert captured["args"]["__maibot_api_version__"] == "1"
second_replace_result = await manager._cap_api_replace_dynamic(
"provider",
"api.replace_dynamic",
{
"apis": [
{
"name": "mcp.read",
"type": "API",
"metadata": {
"version": "1",
"public": True,
"handler_name": "dynamic_read",
},
}
],
"offline_reason": "MCP 服务器已关闭",
},
)
assert second_replace_result["success"] is True
assert second_replace_result["count"] == 1
assert second_replace_result["offlined"] == 1
offlined_call_result = await manager._cap_api_call(
"provider",
"api.call",
{
"api_name": "provider.mcp.search",
"version": "1",
"args": {},
},
)
assert offlined_call_result["success"] is False
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
("mcp.read", "1"),
}

View File

@@ -73,6 +73,19 @@ def _build_target_stream() -> BotChatSession:
) )
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
assert metadata["platform_io_account_id"] == "bot-qq"
assert metadata["platform_io_target_user_id"] == "target-user"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None: async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
"""send service 应将发送职责统一交给 Platform IO。""" """send service 应将发送职责统一交给 Platform IO。"""

View File

@@ -394,23 +394,22 @@ class PlatformIOManager:
""" """
drivers: List[PlatformIODriver] = [] drivers: List[PlatformIODriver] = []
seen_driver_ids: set[str] = set()
for binding in self._send_route_table.resolve_bindings(route_key): for binding in self._send_route_table.resolve_bindings(route_key):
driver = self._driver_registry.get(binding.driver_id) driver = self._driver_registry.get(binding.driver_id)
if driver is not None: if driver is not None and driver.driver_id not in seen_driver_ids:
drivers.append(driver) drivers.append(driver)
if drivers: seen_driver_ids.add(driver.driver_id)
return drivers
fallback_driver = self._legacy_send_drivers.get(route_key.platform) fallback_driver = self._legacy_send_drivers.get(route_key.platform)
if fallback_driver is None: if fallback_driver is not None:
return [] descriptor = fallback_driver.descriptor
account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
drivers.append(fallback_driver)
descriptor = fallback_driver.descriptor return drivers
if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id):
return []
if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope):
return []
return [fallback_driver]
@staticmethod @staticmethod
def build_route_key_from_message(message: "SessionMessage") -> RouteKey: def build_route_key_from_message(message: "SessionMessage") -> RouteKey:

View File

@@ -92,11 +92,24 @@ class OutboundTracker:
raise ValueError("ttl_seconds 必须大于 0") raise ValueError("ttl_seconds 必须大于 0")
self._ttl_seconds = ttl_seconds self._ttl_seconds = ttl_seconds
self._pending: Dict[str, PendingOutboundRecord] = {} self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
self._pending_expire_heap: List[Tuple[float, str]] = [] self._pending_expire_heap: List[Tuple[float, str, str]] = []
self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {} self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
self._receipt_expire_heap: List[Tuple[float, str]] = [] self._receipt_expire_heap: List[Tuple[float, str]] = []
@staticmethod
def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
"""构造单条出站跟踪记录的唯一键。
Args:
internal_message_id: 内部消息 ID。
driver_id: 负责当前投递的驱动 ID。
Returns:
Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
"""
return internal_message_id, driver_id
def begin_tracking( def begin_tracking(
self, self,
internal_message_id: str, internal_message_id: str,
@@ -116,13 +129,15 @@ class OutboundTracker:
PendingOutboundRecord: 新创建的待完成记录。 PendingOutboundRecord: 新创建的待完成记录。
Raises: Raises:
ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。 ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
未完成记录时抛出。
""" """
now = time.monotonic() now = time.monotonic()
self._cleanup_expired(now) self._cleanup_expired(now)
pending_key = self._build_pending_key(internal_message_id, driver_id)
if internal_message_id in self._pending: if pending_key in self._pending:
raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录") raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id}已存在未完成的出站跟踪记录")
expires_at = now + self._ttl_seconds expires_at = now + self._ttl_seconds
record = PendingOutboundRecord( record = PendingOutboundRecord(
@@ -133,8 +148,8 @@ class OutboundTracker:
expires_at=expires_at, expires_at=expires_at,
metadata=metadata or {}, metadata=metadata or {},
) )
self._pending[internal_message_id] = record self._pending[pending_key] = record
heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id)) heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
return record return record
def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]: def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
@@ -149,7 +164,19 @@ class OutboundTracker:
now = time.monotonic() now = time.monotonic()
self._cleanup_expired(now) self._cleanup_expired(now)
pending_record = self._pending.pop(receipt.internal_message_id, None) pending_record: Optional[PendingOutboundRecord] = None
if receipt.driver_id:
pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
pending_record = self._pending.pop(pending_key, None)
else:
matched_records = [
key
for key, record in self._pending.items()
if record.internal_message_id == receipt.internal_message_id
]
if len(matched_records) == 1:
pending_record = self._pending.pop(matched_records[0], None)
if receipt.external_message_id: if receipt.external_message_id:
expires_at = now + self._ttl_seconds expires_at = now + self._ttl_seconds
self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt( self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
@@ -160,17 +187,33 @@ class OutboundTracker:
heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id)) heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
return pending_record return pending_record
def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]: def get_pending(
self,
internal_message_id: str,
driver_id: Optional[str] = None,
) -> Optional[PendingOutboundRecord]:
"""根据内部消息 ID 查询待完成记录。 """根据内部消息 ID 查询待完成记录。
Args: Args:
internal_message_id: 要查询的内部消息 ID。 internal_message_id: 要查询的内部消息 ID。
driver_id: 可选的驱动 ID提供后仅返回该驱动上的待完成记录。
Returns: Returns:
Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。 Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
""" """
self._cleanup_expired(time.monotonic()) self._cleanup_expired(time.monotonic())
return self._pending.get(internal_message_id)
if driver_id:
return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
matched_records = [
record
for record in self._pending.values()
if record.internal_message_id == internal_message_id
]
if len(matched_records) == 1:
return matched_records[0]
return None
def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]: def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
"""根据外部平台消息 ID 查询已完成回执。 """根据外部平台消息 ID 查询已完成回执。
@@ -213,13 +256,14 @@ class OutboundTracker:
``expires_at`` 对比,跳过这类旧节点。 ``expires_at`` 对比,跳过这类旧节点。
""" """
while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now: while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap) expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
current_record = self._pending.get(internal_message_id) pending_key = self._build_pending_key(internal_message_id, driver_id)
current_record = self._pending.get(pending_key)
if current_record is None: if current_record is None:
continue continue
if current_record.expires_at != expires_at: if current_record.expires_at != expires_at:
continue continue
self._pending.pop(internal_message_id, None) self._pending.pop(pending_key, None)
def _cleanup_expired_receipts(self, now: float) -> None: def _cleanup_expired_receipts(self, now: float) -> None:
"""清理已经过期的回执索引。 """清理已经过期的回执索引。

View File

@@ -72,6 +72,8 @@ class RuntimeComponentCapabilityMixin:
"version": entry.version, "version": entry.version,
"public": entry.public, "public": entry.public,
"enabled": entry.enabled, "enabled": entry.enabled,
"dynamic": entry.dynamic,
"offline_reason": entry.offline_reason,
"metadata": dict(entry.metadata), "metadata": dict(entry.metadata),
} }
@@ -109,6 +111,32 @@ class RuntimeComponentCapabilityMixin:
return entry.plugin_id == caller_plugin_id or entry.public return entry.plugin_id == caller_plugin_id or entry.public
@staticmethod
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
"""规范化 API 名称与版本参数。
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
"""
normalized_api_name = str(api_name or "").strip()
normalized_version = str(version or "").strip()
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
candidate_name = candidate_name.strip()
candidate_version = candidate_version.strip()
if candidate_name and candidate_version:
normalized_api_name = candidate_name
normalized_version = candidate_version
return normalized_api_name, normalized_version
@staticmethod
def _build_api_unavailable_error(entry: "APIEntry") -> str:
"""构造 API 当前不可用时的错误信息。"""
if entry.offline_reason:
return entry.offline_reason
return f"API {entry.registry_key} 当前不可用"
def _resolve_api_target( def _resolve_api_target(
self: _RuntimeComponentManagerProtocol, self: _RuntimeComponentManagerProtocol,
caller_plugin_id: str, caller_plugin_id: str,
@@ -127,8 +155,7 @@ class RuntimeComponentCapabilityMixin:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
""" """
normalized_api_name = str(api_name or "").strip() normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
normalized_version = str(version or "").strip()
if not normalized_api_name: if not normalized_api_name:
return None, None, "缺少必要参数 api_name" return None, None, "缺少必要参数 api_name"
@@ -142,34 +169,61 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None: if supervisor is None:
return None, None, f"未找到 API 提供方插件: {target_plugin_id}" return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
entry = supervisor.api_registry.get_api( entries = supervisor.api_registry.get_apis(
plugin_id=target_plugin_id, plugin_id=target_plugin_id,
name=target_api_name, name=target_api_name,
enabled_only=True, version=normalized_version,
enabled_only=False,
) )
if entry is None: visible_enabled_entries = [
return None, None, f"未找到 API: {normalized_api_name}" entry
if normalized_version and entry.version != normalized_version: for entry in entries
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}" if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
if not self._is_api_visible_to_plugin(entry, caller_plugin_id): ]
visible_disabled_entries = [
entry
for entry in entries
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
]
if len(visible_enabled_entries) == 1:
return supervisor, visible_enabled_entries[0], None
if len(visible_enabled_entries) > 1:
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
if visible_disabled_entries:
if len(visible_disabled_entries) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
return supervisor, entry, None if normalized_version:
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
return None, None, f"未找到 API: {normalized_api_name}"
visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
hidden_match_exists = False hidden_match_exists = False
for supervisor in self.supervisors: for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True): for entry in supervisor.api_registry.get_apis(
if normalized_version and entry.version != normalized_version: name=normalized_api_name,
continue version=normalized_version,
enabled_only=False,
):
if self._is_api_visible_to_plugin(entry, caller_plugin_id): if self._is_api_visible_to_plugin(entry, caller_plugin_id):
visible_matches.append((supervisor, entry)) if entry.enabled:
visible_enabled_matches.append((supervisor, entry))
else:
visible_disabled_matches.append((supervisor, entry))
else: else:
hidden_match_exists = True hidden_match_exists = True
if len(visible_matches) == 1: if len(visible_enabled_matches) == 1:
return visible_matches[0][0], visible_matches[0][1], None return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
if len(visible_matches) > 1: if len(visible_enabled_matches) > 1:
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name" return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
if visible_disabled_matches:
if len(visible_disabled_matches) == 1:
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
if hidden_match_exists: if hidden_match_exists:
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用" return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
if normalized_version: if normalized_version:
@@ -179,18 +233,20 @@ class RuntimeComponentCapabilityMixin:
def _resolve_api_toggle_target( def _resolve_api_toggle_target(
self: _RuntimeComponentManagerProtocol, self: _RuntimeComponentManagerProtocol,
name: str, name: str,
version: str = "",
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]: ) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
"""解析需要启用或禁用的 API 组件。 """解析需要启用或禁用的 API 组件。
Args: Args:
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。 name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
version: 可选的 API 版本。
Returns: Returns:
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]: tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。 解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
""" """
normalized_name = str(name or "").strip() normalized_name, normalized_version = self._normalize_api_reference(name, version)
if not normalized_name: if not normalized_name:
return None, None, "缺少必要参数 name" return None, None, "缺少必要参数 name"
@@ -204,24 +260,31 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None: if supervisor is None:
return None, None, f"未找到 API 提供方插件: {plugin_id}" return None, None, f"未找到 API 提供方插件: {plugin_id}"
entry = supervisor.api_registry.get_api( entries = supervisor.api_registry.get_apis(
plugin_id=plugin_id, plugin_id=plugin_id,
name=api_name, name=api_name,
version=normalized_version,
enabled_only=False, enabled_only=False,
) )
if entry is None: if not entries:
return None, None, f"未找到 API: {normalized_name}" return None, None, f"未找到 API: {normalized_name}"
return supervisor, entry, None if len(entries) > 1:
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
return supervisor, entries[0], None
matches: List[tuple["PluginSupervisor", "APIEntry"]] = [] matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
for supervisor in self.supervisors: for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False): for entry in supervisor.api_registry.get_apis(
name=normalized_name,
version=normalized_version,
enabled_only=False,
):
matches.append((supervisor, entry)) matches.append((supervisor, entry))
if len(matches) == 1: if len(matches) == 1:
return matches[0][0], matches[0][1], None return matches[0][0], matches[0][1], None
if len(matches) > 1: if len(matches) > 1:
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name" return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
return None, None, f"未找到 API: {normalized_name}" return None, None, f"未找到 API: {normalized_name}"
async def _cap_component_get_all_plugins( async def _cap_component_get_all_plugins(
@@ -326,6 +389,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any: ) -> Any:
name: str = args.get("name", "") name: str = args.get("name", "")
component_type: str = args.get("component_type", "") component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global") scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "") stream_id: str = args.get("stream_id", "")
if not name or not component_type: if not name or not component_type:
@@ -334,10 +398,10 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"} return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type): if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name) supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None: if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"} return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.full_name, True) supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
return {"success": True} return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type) comp, error = self._resolve_component_toggle_target(name, component_type)
@@ -352,6 +416,7 @@ class RuntimeComponentCapabilityMixin:
) -> Any: ) -> Any:
name: str = args.get("name", "") name: str = args.get("name", "")
component_type: str = args.get("component_type", "") component_type: str = args.get("component_type", "")
version: str = args.get("version", "")
scope: str = args.get("scope", "global") scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "") stream_id: str = args.get("stream_id", "")
if not name or not component_type: if not name or not component_type:
@@ -360,10 +425,10 @@ class RuntimeComponentCapabilityMixin:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"} return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
if self._is_api_component_type(component_type): if self._is_api_component_type(component_type):
supervisor, api_entry, error = self._resolve_api_toggle_target(name) supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
if supervisor is None or api_entry is None: if supervisor is None or api_entry is None:
return {"success": False, "error": error or f"未找到 API: {name}"} return {"success": False, "error": error or f"未找到 API: {name}"}
supervisor.api_registry.toggle_api_status(api_entry.full_name, False) supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
return {"success": True} return {"success": True}
comp, error = self._resolve_component_toggle_target(name, component_type) comp, error = self._resolve_component_toggle_target(name, component_type)
@@ -488,11 +553,17 @@ class RuntimeComponentCapabilityMixin:
if supervisor is None or entry is None: if supervisor is None or entry is None:
return {"success": False, "error": error or "API 解析失败"} return {"success": False, "error": error or "API 解析失败"}
invoke_args = dict(api_args)
if entry.dynamic:
invoke_args.setdefault("__maibot_api_name__", entry.name)
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
invoke_args.setdefault("__maibot_api_version__", entry.version)
try: try:
response = await supervisor.invoke_api( response = await supervisor.invoke_api(
plugin_id=entry.plugin_id, plugin_id=entry.plugin_id,
component_name=entry.name, component_name=entry.handler_name,
args=api_args, args=invoke_args,
timeout_ms=30000, timeout_ms=30000,
) )
except Exception as exc: except Exception as exc:
@@ -555,10 +626,16 @@ class RuntimeComponentCapabilityMixin:
del capability del capability
target_plugin_id = str(args.get("plugin_id", "") or "").strip() target_plugin_id = str(args.get("plugin_id", "") or "").strip()
api_name, version = self._normalize_api_reference(
str(args.get("api_name", args.get("name", "")) or ""),
str(args.get("version", "") or ""),
)
apis: List[Dict[str, Any]] = [] apis: List[Dict[str, Any]] = []
for supervisor in self.supervisors: for supervisor in self.supervisors:
for entry in supervisor.api_registry.get_apis( for entry in supervisor.api_registry.get_apis(
plugin_id=target_plugin_id or None, plugin_id=target_plugin_id or None,
name=api_name,
version=version,
enabled_only=True, enabled_only=True,
): ):
if not self._is_api_visible_to_plugin(entry, plugin_id): if not self._is_api_visible_to_plugin(entry, plugin_id):
@@ -567,3 +644,75 @@ class RuntimeComponentCapabilityMixin:
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"]))) apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
return {"success": True, "apis": apis} return {"success": True, "apis": apis}
async def _cap_api_replace_dynamic(
self: _RuntimeComponentManagerProtocol,
plugin_id: str,
capability: str,
args: Dict[str, Any],
) -> Any:
"""替换插件自行维护的动态 API 列表。"""
del capability
raw_apis = args.get("apis", [])
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
if not isinstance(raw_apis, list):
return {"success": False, "error": "参数 apis 必须为列表"}
try:
supervisor = self._get_supervisor_for_plugin(plugin_id)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if supervisor is None:
return {"success": False, "error": f"未找到插件: {plugin_id}"}
normalized_components: List[Dict[str, Any]] = []
seen_registry_keys: set[str] = set()
for index, raw_api in enumerate(raw_apis):
if not isinstance(raw_api, dict):
return {"success": False, "error": f"apis[{index}] 必须为字典"}
api_name = str(raw_api.get("name", "") or "").strip()
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
if not api_name:
return {"success": False, "error": f"apis[{index}] 缺少 name"}
if not self._is_api_component_type(component_type):
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
normalized_metadata = dict(metadata)
normalized_metadata["dynamic"] = True
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
if registry_key in seen_registry_keys:
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
seen_registry_keys.add(registry_key)
existing_entry = supervisor.api_registry.get_api(
plugin_id,
api_name,
version=version,
enabled_only=False,
)
if existing_entry is not None and not existing_entry.dynamic:
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
normalized_components.append(
{
"name": api_name,
"component_type": "API",
"metadata": normalized_metadata,
}
)
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
plugin_id,
normalized_components,
offline_reason=offline_reason,
)
return {
"success": True,
"count": registered_count,
"offlined": offlined_count,
}

View File

@@ -77,6 +77,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
_register("api.call", manager._cap_api_call) _register("api.call", manager._cap_api_call)
_register("api.get", manager._cap_api_get) _register("api.get", manager._cap_api_get)
_register("api.list", manager._cap_api_list) _register("api.list", manager._cap_api_list)
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
_register("component.get_all_plugins", manager._cap_component_get_all_plugins) _register("component.get_all_plugins", manager._cap_component_get_all_plugins)
_register("component.get_plugin_info", manager._cap_component_get_plugin_info) _register("component.get_plugin_info", manager._cap_component_get_plugin_info)

View File

@@ -1,45 +1,60 @@
"""Host 侧插件 API 动态注册表。""" """Host 侧插件 API 动态注册表。"""
from typing import Any, Dict, List, Optional, Set from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.api_registry") logger = get_logger("plugin_runtime.host.api_registry")
@dataclass(slots=True)
class APIEntry: class APIEntry:
"""API 组件条目。""" """API 组件条目。"""
__slots__ = ( name: str
"description", plugin_id: str
"disabled_session", description: str = ""
"enabled", version: str = "1"
"full_name", public: bool = False
"metadata", metadata: Dict[str, Any] = field(default_factory=dict)
"name", enabled: bool = True
"plugin_id", handler_name: str = ""
"public", dynamic: bool = False
"version", offline_reason: str = ""
) disabled_session: Set[str] = field(default_factory=set)
full_name: str = field(init=False)
registry_key: str = field(init=False)
def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None: def __post_init__(self) -> None:
"""初始化 API 组件条目。 """规范化 API 条目字段。"""
Args: self.name = str(self.name or "").strip()
name: API 名称。 self.plugin_id = str(self.plugin_id or "").strip()
plugin_id: 所属插件 ID。 self.description = str(self.description or "").strip()
metadata: API 元数据。 self.version = str(self.version or "1").strip() or "1"
""" self.handler_name = str(self.handler_name or self.name).strip() or self.name
self.offline_reason = str(self.offline_reason or "").strip()
self.full_name = f"{self.plugin_id}.{self.name}"
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
self.name: str = name @classmethod
self.full_name: str = f"{plugin_id}.{name}" def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
self.plugin_id: str = plugin_id """根据 Runner 上报的元数据构造 API 条目。"""
self.description: str = str(metadata.get("description", "") or "")
self.version: str = str(metadata.get("version", "1") or "1").strip() or "1" safe_metadata = dict(metadata)
self.public: bool = bool(metadata.get("public", False)) return cls(
self.metadata: Dict[str, Any] = dict(metadata) name=name,
self.enabled: bool = bool(metadata.get("enabled", True)) plugin_id=plugin_id,
self.disabled_session: Set[str] = set() description=str(safe_metadata.get("description", "") or ""),
version=str(safe_metadata.get("version", "1") or "1"),
public=bool(safe_metadata.get("public", False)),
metadata=safe_metadata,
enabled=bool(safe_metadata.get("enabled", True)),
handler_name=str(safe_metadata.get("handler_name", name) or name),
dynamic=bool(safe_metadata.get("dynamic", False)),
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
)
class APIRegistry: class APIRegistry:
@@ -53,6 +68,7 @@ class APIRegistry:
"""初始化 API 注册表。""" """初始化 API 注册表。"""
self._apis: Dict[str, APIEntry] = {} self._apis: Dict[str, APIEntry] = {}
self._by_full_name: Dict[str, List[APIEntry]] = {}
self._by_plugin: Dict[str, List[APIEntry]] = {} self._by_plugin: Dict[str, List[APIEntry]] = {}
self._by_name: Dict[str, List[APIEntry]] = {} self._by_name: Dict[str, List[APIEntry]] = {}
@@ -60,75 +76,75 @@ class APIRegistry:
"""清空全部 API 注册状态。""" """清空全部 API 注册状态。"""
self._apis.clear() self._apis.clear()
self._by_full_name.clear()
self._by_plugin.clear() self._by_plugin.clear()
self._by_name.clear() self._by_name.clear()
@staticmethod @staticmethod
def _is_api_component(component_type: Any) -> bool: def _is_api_component(component_type: Any) -> bool:
"""判断组件声明是否属于 API。 """判断组件声明是否属于 API。"""
Args:
component_type: 原始组件类型值。
Returns:
bool: 是否为 API 组件。
"""
return str(component_type or "").strip().upper() == "API" return str(component_type or "").strip().upper() == "API"
@staticmethod
def _normalize_query_version(version: Any) -> str:
"""规范化查询使用的版本字符串。"""
return str(version or "").strip()
@classmethod
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
"""解析可能带 ``@version`` 后缀的 API 引用。"""
normalized_reference = str(reference or "").strip()
normalized_version = cls._normalize_query_version(version)
if normalized_reference and not normalized_version and "@" in normalized_reference:
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
candidate_reference = candidate_reference.strip()
candidate_version = candidate_version.strip()
if candidate_reference and candidate_version:
normalized_reference = candidate_reference
normalized_version = candidate_version
return normalized_reference, normalized_version
@staticmethod
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
"""构造 API 注册表唯一键。"""
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
normalized_version = str(version or "1").strip() or "1"
return f"{normalized_full_name}@{normalized_version}"
@staticmethod @staticmethod
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool: def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
"""判断 API 条目当前是否处于启用状态。 """判断 API 条目当前是否处于启用状态。"""
Args:
entry: 待检查的 API 条目。
session_id: 可选的会话 ID。
Returns:
bool: 当前是否可用。
"""
if session_id and session_id in entry.disabled_session: if session_id and session_id in entry.disabled_session:
return False return False
return entry.enabled return entry.enabled
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool: def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个 API 条目。 """注册单个 API 条目。"""
Args:
name: API 名称。
plugin_id: 所属插件 ID。
metadata: API 元数据。
Returns:
bool: 是否成功注册。
"""
normalized_name = str(name or "").strip() normalized_name = str(name or "").strip()
if not normalized_name: if not normalized_name:
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略") logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
return False return False
entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata) entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
if entry.full_name in self._apis: existing_entry = self._apis.get(entry.registry_key)
logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目") if existing_entry is not None:
self._remove_entry(self._apis[entry.full_name]) logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
self._remove_entry(existing_entry)
self._apis[entry.full_name] = entry self._apis[entry.registry_key] = entry
self._by_full_name.setdefault(entry.full_name, []).append(entry)
self._by_plugin.setdefault(plugin_id, []).append(entry) self._by_plugin.setdefault(plugin_id, []).append(entry)
self._by_name.setdefault(entry.name, []).append(entry) self._by_name.setdefault(entry.name, []).append(entry)
return True return True
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int: def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册某个插件声明的全部 API。 """批量注册某个插件声明的全部 API。"""
Args:
plugin_id: 插件 ID。
components: 插件组件声明列表。
Returns:
int: 成功注册的 API 数量。
"""
count = 0 count = 0
for component in components: for component in components:
@@ -142,14 +158,60 @@ class APIRegistry:
count += 1 count += 1
return count return count
def replace_plugin_dynamic_apis(
self,
plugin_id: str,
components: List[Dict[str, Any]],
*,
offline_reason: str = "动态 API 已下线",
) -> Tuple[int, int]:
"""替换指定插件当前声明的动态 API 集合。"""
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
desired_registry_keys: Set[str] = set()
registered_count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
dynamic_metadata = dict(metadata)
dynamic_metadata["dynamic"] = True
dynamic_metadata.pop("offline_reason", None)
entry = APIEntry.from_metadata(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=dynamic_metadata,
)
desired_registry_keys.add(entry.registry_key)
if self.register_api(entry.name, plugin_id, dynamic_metadata):
registered_count += 1
offlined_count = 0
for entry in list(self._by_plugin.get(plugin_id, [])):
if not entry.dynamic or entry.registry_key in desired_registry_keys:
continue
entry.enabled = False
entry.offline_reason = normalized_offline_reason
entry.metadata["offline_reason"] = normalized_offline_reason
offlined_count += 1
return registered_count, offlined_count
def _remove_entry(self, entry: APIEntry) -> None: def _remove_entry(self, entry: APIEntry) -> None:
"""从全部索引中移除单个 API 条目。 """从全部索引中移除单个 API 条目。"""
Args: self._apis.pop(entry.registry_key, None)
entry: 待移除的 API 条目。
""" full_name_entries = self._by_full_name.get(entry.full_name)
if full_name_entries is not None:
self._by_full_name[entry.full_name] = [
candidate for candidate in full_name_entries if candidate is not entry
]
if not self._by_full_name[entry.full_name]:
self._by_full_name.pop(entry.full_name, None)
self._apis.pop(entry.full_name, None)
plugin_entries = self._by_plugin.get(entry.plugin_id) plugin_entries = self._by_plugin.get(entry.plugin_id)
if plugin_entries is not None: if plugin_entries is not None:
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry] self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
@@ -163,14 +225,7 @@ class APIRegistry:
self._by_name.pop(entry.name, None) self._by_name.pop(entry.name, None)
def remove_apis_by_plugin(self, plugin_id: str) -> int: def remove_apis_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的全部 API。 """移除某个插件的全部 API。"""
Args:
plugin_id: 目标插件 ID。
Returns:
int: 被移除的 API 数量。
"""
entries = list(self._by_plugin.get(plugin_id, [])) entries = list(self._by_plugin.get(plugin_id, []))
for entry in entries: for entry in entries:
@@ -181,49 +236,48 @@ class APIRegistry:
self, self,
full_name: str, full_name: str,
*, *,
version: str = "",
enabled_only: bool = True, enabled_only: bool = True,
session_id: Optional[str] = None, session_id: Optional[str] = None,
) -> Optional[APIEntry]: ) -> Optional[APIEntry]:
"""按完整名查询单个 API。 """按完整名查询单个 API。"""
Args: normalized_full_name, normalized_version = self._split_reference(full_name, version)
full_name: API 完整名,格式为 ``plugin_id.api_name``。 if not normalized_full_name:
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
Optional[APIEntry]: 命中时返回 API 条目。
"""
entry = self._apis.get(full_name)
if entry is None:
return None return None
if enabled_only and not self.check_api_enabled(entry, session_id):
if normalized_version:
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
if entry is None:
return None
if enabled_only and not self.check_api_enabled(entry, session_id):
return None
return entry
candidates = list(self._by_full_name.get(normalized_full_name, []))
filtered_entries = [
entry
for entry in candidates
if not enabled_only or self.check_api_enabled(entry, session_id)
]
if len(filtered_entries) != 1:
return None return None
return entry return filtered_entries[0]
def get_api( def get_api(
self, self,
plugin_id: str, plugin_id: str,
name: str, name: str,
*, *,
version: str = "",
enabled_only: bool = True, enabled_only: bool = True,
session_id: Optional[str] = None, session_id: Optional[str] = None,
) -> Optional[APIEntry]: ) -> Optional[APIEntry]:
"""按插件 ID 和短名查询单个 API。 """按插件 ID、短名与版本查询单个 API。"""
Args:
plugin_id: 提供方插件 ID。
name: API 短名。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
Optional[APIEntry]: 命中时返回 API 条目。
"""
return self.get_api_by_full_name( return self.get_api_by_full_name(
f"{plugin_id}.{name}", f"{plugin_id}.{name}",
version=version,
enabled_only=enabled_only, enabled_only=enabled_only,
session_id=session_id, session_id=session_id,
) )
@@ -233,22 +287,15 @@ class APIRegistry:
*, *,
plugin_id: Optional[str] = None, plugin_id: Optional[str] = None,
name: str = "", name: str = "",
version: str = "",
enabled_only: bool = True, enabled_only: bool = True,
session_id: Optional[str] = None, session_id: Optional[str] = None,
) -> List[APIEntry]: ) -> List[APIEntry]:
"""查询 API 列表。 """查询 API 列表。"""
Args:
plugin_id: 可选的插件 ID 过滤条件。
name: 可选的 API 名称过滤条件。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
List[APIEntry]: 符合条件的 API 条目列表。
"""
normalized_name = str(name or "").strip() normalized_name = str(name or "").strip()
normalized_version = self._normalize_query_version(version)
if plugin_id: if plugin_id:
candidates = list(self._by_plugin.get(plugin_id, [])) candidates = list(self._by_plugin.get(plugin_id, []))
elif normalized_name: elif normalized_name:
@@ -258,26 +305,35 @@ class APIRegistry:
filtered_entries: List[APIEntry] = [] filtered_entries: List[APIEntry] = []
for entry in candidates: for entry in candidates:
if plugin_id and entry.plugin_id != plugin_id:
continue
if normalized_name and entry.name != normalized_name: if normalized_name and entry.name != normalized_name:
continue continue
if normalized_version and entry.version != normalized_version:
continue
if enabled_only and not self.check_api_enabled(entry, session_id): if enabled_only and not self.check_api_enabled(entry, session_id):
continue continue
filtered_entries.append(entry) filtered_entries.append(entry)
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
return filtered_entries return filtered_entries
def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool: def toggle_api_status(
"""设置指定 API 的启用状态。 self,
full_name: str,
enabled: bool,
*,
version: str = "",
session_id: Optional[str] = None,
) -> bool:
"""设置指定 API 的启用状态。"""
Args: entry = self.get_api_by_full_name(
full_name: API 完整名。 full_name,
enabled: 目标启用状态。 version=version,
session_id: 可选的会话 ID仅对该会话生效。 enabled_only=False,
session_id=session_id,
Returns: )
bool: 是否设置成功。
"""
entry = self._apis.get(full_name)
if entry is None: if entry is None:
return False return False
if session_id: if session_id:
@@ -287,4 +343,7 @@ class APIRegistry:
entry.disabled_session.add(session_id) entry.disabled_session.add(session_id)
else: else:
entry.enabled = enabled entry.enabled = enabled
if enabled:
entry.offline_reason = ""
entry.metadata.pop("offline_reason", None)
return True return True

View File

@@ -7,6 +7,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
@dataclass @dataclass
class CapabilityPermissionToken: class CapabilityPermissionToken:
@@ -46,6 +48,9 @@ class AuthorizationManager:
Returns: Returns:
return (bool, str): (是否有此能力, 原因) return (bool, str): (是否有此能力, 原因)
""" """
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
return True, ""
token = self._permission_tokens.get(plugin_id) token = self._permission_tokens.get(plugin_id)
if not token: if not token:
return False, f"插件 {plugin_id} 未注册能力令牌" return False, f"插件 {plugin_id} 未注册能力令牌"

View File

@@ -45,6 +45,14 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
logger = get_logger("plugin_runtime.runner.main") logger = get_logger("plugin_runtime.runner.main")
_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset(
{
"cap.call",
"host.route_message",
"host.update_message_gateway_state",
}
)
class _ContextAwarePlugin(Protocol): class _ContextAwarePlugin(Protocol):
"""支持注入运行时上下文的插件协议。 """支持注入运行时上下文的插件协议。
@@ -247,8 +255,14 @@ class PluginRunner:
logger.warning( logger.warning(
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份" f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC已强制绑定回自身身份"
) )
normalized_method = str(method or "").strip()
if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
raise PermissionError(
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
f"{normalized_method or '<empty>'}"
)
resp = await rpc_client.send_request( resp = await rpc_client.send_request(
method=method, method=normalized_method,
plugin_id=bound_plugin_id, plugin_id=bound_plugin_id,
payload=payload or {}, payload=payload or {},
) )

View File

@@ -57,20 +57,23 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s
inherited_metadata: Dict[str, object] = {} inherited_metadata: Dict[str, object] = {}
context_message = target_stream.context.message if target_stream.context else None context_message = target_stream.context.message if target_stream.context else None
if context_message is None: if context_message is not None:
return inherited_metadata additional_config = context_message.message_info.additional_config
if isinstance(additional_config, dict):
for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
value = additional_config.get(key)
if value is None:
continue
normalized_value = str(value).strip()
if normalized_value:
inherited_metadata[key] = value
additional_config = context_message.message_info.additional_config # 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号,
if not isinstance(additional_config, dict): # 让按 ``platform + account_id`` 绑定的路由仍有机会命中。
return inherited_metadata if not RouteKeyFactory.extract_components(inherited_metadata)[0]:
bot_account = get_bot_account(target_stream.platform)
for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS): if bot_account:
value = additional_config.get(key) inherited_metadata["platform_io_account_id"] = bot_account
if value is None:
continue
normalized_value = str(value).strip()
if normalized_value:
inherited_metadata[key] = value
if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()): if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()):
inherited_metadata["platform_io_target_group_id"] = normalized_group_id inherited_metadata["platform_io_target_group_id"] = normalized_group_id