feat: Enhance API and Outbound Tracking Functionality
- Add test for fallback to bot account in platform IO route metadata when context message is absent. - Improve PlatformIOManager to avoid duplicate driver entries and streamline fallback driver handling. - Refactor OutboundTracker to support tracking by both internal message ID and driver ID, enhancing the uniqueness of pending records. - Introduce dynamic API capabilities in RuntimeComponent, allowing plugins to replace their dynamic API lists. - Update APIRegistry to manage dynamic APIs more effectively, including registration and toggling of API statuses. - Implement authorization checks for dynamic API capabilities to ensure proper permissions. - Restrict direct calls to certain host RPC methods from plugins for enhanced security. - Refactor send_service to ensure fallback to current platform account when no context message is available.
This commit is contained in:
@@ -82,7 +82,61 @@ async def test_platform_io_uses_legacy_driver_when_no_explicit_send_route(
|
||||
)
|
||||
|
||||
explicit_drivers = manager.resolve_drivers(RouteKey(platform="qq"))
|
||||
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender"]
|
||||
assert [driver.driver_id for driver in explicit_drivers] == ["plugin.qq.sender", "legacy.send.qq"]
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_io_broadcasts_to_plugin_and_legacy_driver(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同一路由命中插件驱动与 legacy driver 时,应同时广播发送。"""
|
||||
|
||||
manager = PlatformIOManager()
|
||||
legacy_calls: list[dict[str, Any]] = []
|
||||
monkeypatch.setattr(chat_utils, "get_all_bot_accounts", lambda: {"qq": "bot-qq"})
|
||||
|
||||
async def _fake_send_prepared_message_to_platform(message: Any, show_log: bool = True) -> bool:
|
||||
"""记录 legacy driver 调用。"""
|
||||
|
||||
legacy_calls.append({"message": message, "show_log": show_log})
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(
|
||||
uni_message_sender,
|
||||
"send_prepared_message_to_platform",
|
||||
_fake_send_prepared_message_to_platform,
|
||||
)
|
||||
|
||||
try:
|
||||
await manager.ensure_send_pipeline_ready()
|
||||
|
||||
plugin_driver = _PluginDriver(driver_id="plugin.qq.sender", platform="qq")
|
||||
await manager.add_driver(plugin_driver)
|
||||
manager.bind_send_route(
|
||||
RouteBinding(
|
||||
route_key=RouteKey(platform="qq"),
|
||||
driver_id=plugin_driver.driver_id,
|
||||
driver_kind=plugin_driver.descriptor.kind,
|
||||
)
|
||||
)
|
||||
|
||||
message = type("FakeMessage", (), {"message_id": "message-1"})()
|
||||
batch = await manager.send_message(
|
||||
message=message,
|
||||
route_key=RouteKey(platform="qq"),
|
||||
metadata={"show_log": False},
|
||||
)
|
||||
|
||||
assert sorted(receipt.driver_id for receipt in batch.sent_receipts) == [
|
||||
"legacy.send.qq",
|
||||
"plugin.qq.sender",
|
||||
]
|
||||
assert batch.failed_receipts == []
|
||||
assert len(legacy_calls) == 1
|
||||
assert legacy_calls[0]["message"] is message
|
||||
assert legacy_calls[0]["show_log"] is False
|
||||
finally:
|
||||
await manager.stop()
|
||||
|
||||
|
||||
@@ -292,3 +292,233 @@ async def test_api_list_and_component_toggle_use_dedicated_registry() -> None:
|
||||
)
|
||||
assert enable_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "public_api", enabled_only=True) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_registry_supports_multiple_versions_with_distinct_handlers(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""同名 API 不同版本应可并存,并按版本路由到不同处理器。"""
|
||||
|
||||
provider_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
consumer_supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(
|
||||
provider_supervisor,
|
||||
"provider",
|
||||
[
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML v1",
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "handle_render_html_v1",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "render_html",
|
||||
"component_type": "API",
|
||||
"metadata": {
|
||||
"description": "渲染 HTML v2",
|
||||
"version": "2",
|
||||
"public": True,
|
||||
"handler_name": "handle_render_html_v2",
|
||||
},
|
||||
},
|
||||
],
|
||||
)
|
||||
await _register_plugin(consumer_supervisor, "consumer", [])
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟多版本 API 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"image": "ok"}})
|
||||
|
||||
monkeypatch.setattr(provider_supervisor, "invoke_api", fake_invoke_api)
|
||||
manager = _build_manager(provider_supervisor, consumer_supervisor)
|
||||
|
||||
ambiguous_result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
assert ambiguous_result["success"] is False
|
||||
assert "多个版本" in str(ambiguous_result["error"])
|
||||
|
||||
disable_ambiguous_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.render_html",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
},
|
||||
)
|
||||
assert disable_ambiguous_result["success"] is False
|
||||
assert "多个版本" in str(disable_ambiguous_result["error"])
|
||||
|
||||
disable_v1_result = await manager._cap_component_disable(
|
||||
"consumer",
|
||||
"component.disable",
|
||||
{
|
||||
"name": "provider.render_html",
|
||||
"component_type": "API",
|
||||
"scope": "global",
|
||||
"stream_id": "",
|
||||
"version": "1",
|
||||
},
|
||||
)
|
||||
assert disable_v1_result["success"] is True
|
||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="1", enabled_only=True) is None
|
||||
assert provider_supervisor.api_registry.get_api("provider", "render_html", version="2", enabled_only=True) is not None
|
||||
|
||||
result = await manager._cap_api_call(
|
||||
"consumer",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.render_html",
|
||||
"version": "2",
|
||||
"args": {"html": "<div>Hello</div>"},
|
||||
},
|
||||
)
|
||||
|
||||
assert result == {"success": True, "result": {"image": "ok"}}
|
||||
assert captured["plugin_id"] == "provider"
|
||||
assert captured["component_name"] == "handle_render_html_v2"
|
||||
assert captured["args"] == {"html": "<div>Hello</div>"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_replace_dynamic_can_offline_removed_entries(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""动态 API 替换后,被移除的 API 应返回明确下线错误。"""
|
||||
|
||||
supervisor = PluginSupervisor(plugin_dirs=[])
|
||||
await _register_plugin(supervisor, "provider", [])
|
||||
manager = _build_manager(supervisor)
|
||||
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
async def fake_invoke_api(
|
||||
plugin_id: str,
|
||||
component_name: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
timeout_ms: int = 30000,
|
||||
) -> Any:
|
||||
"""模拟动态 API 调用。"""
|
||||
|
||||
captured["plugin_id"] = plugin_id
|
||||
captured["component_name"] = component_name
|
||||
captured["args"] = args or {}
|
||||
captured["timeout_ms"] = timeout_ms
|
||||
return SimpleNamespace(error=None, payload={"success": True, "result": {"ok": True}})
|
||||
|
||||
monkeypatch.setattr(supervisor, "invoke_api", fake_invoke_api)
|
||||
|
||||
replace_result = await manager._cap_api_replace_dynamic(
|
||||
"provider",
|
||||
"api.replace_dynamic",
|
||||
{
|
||||
"apis": [
|
||||
{
|
||||
"name": "mcp.search",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_search",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "mcp.read",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_read",
|
||||
},
|
||||
},
|
||||
],
|
||||
"offline_reason": "MCP 服务器已关闭",
|
||||
},
|
||||
)
|
||||
|
||||
assert replace_result["success"] is True
|
||||
assert replace_result["count"] == 2
|
||||
list_result = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
||||
assert {(item["name"], item["version"]) for item in list_result["apis"]} == {
|
||||
("mcp.read", "1"),
|
||||
("mcp.search", "1"),
|
||||
}
|
||||
|
||||
call_result = await manager._cap_api_call(
|
||||
"provider",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.mcp.search",
|
||||
"version": "1",
|
||||
"args": {"query": "hello"},
|
||||
},
|
||||
)
|
||||
assert call_result == {"success": True, "result": {"ok": True}}
|
||||
assert captured["component_name"] == "dynamic_search"
|
||||
assert captured["args"]["query"] == "hello"
|
||||
assert captured["args"]["__maibot_api_name__"] == "mcp.search"
|
||||
assert captured["args"]["__maibot_api_version__"] == "1"
|
||||
|
||||
second_replace_result = await manager._cap_api_replace_dynamic(
|
||||
"provider",
|
||||
"api.replace_dynamic",
|
||||
{
|
||||
"apis": [
|
||||
{
|
||||
"name": "mcp.read",
|
||||
"type": "API",
|
||||
"metadata": {
|
||||
"version": "1",
|
||||
"public": True,
|
||||
"handler_name": "dynamic_read",
|
||||
},
|
||||
}
|
||||
],
|
||||
"offline_reason": "MCP 服务器已关闭",
|
||||
},
|
||||
)
|
||||
|
||||
assert second_replace_result["success"] is True
|
||||
assert second_replace_result["count"] == 1
|
||||
assert second_replace_result["offlined"] == 1
|
||||
|
||||
offlined_call_result = await manager._cap_api_call(
|
||||
"provider",
|
||||
"api.call",
|
||||
{
|
||||
"api_name": "provider.mcp.search",
|
||||
"version": "1",
|
||||
"args": {},
|
||||
},
|
||||
)
|
||||
assert offlined_call_result["success"] is False
|
||||
assert "MCP 服务器已关闭" in str(offlined_call_result["error"])
|
||||
|
||||
list_after_replace = await manager._cap_api_list("provider", "api.list", {"plugin_id": "provider"})
|
||||
assert {(item["name"], item["version"]) for item in list_after_replace["apis"]} == {
|
||||
("mcp.read", "1"),
|
||||
}
|
||||
|
||||
@@ -73,6 +73,19 @@ def _build_target_stream() -> BotChatSession:
|
||||
)
|
||||
|
||||
|
||||
def test_inherit_platform_io_route_metadata_falls_back_to_bot_account(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""没有上下文消息时,也应回填当前平台账号用于账号级路由命中。"""
|
||||
|
||||
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq" if platform == "qq" else "")
|
||||
|
||||
metadata = send_service._inherit_platform_io_route_metadata(_build_target_stream())
|
||||
|
||||
assert metadata["platform_io_account_id"] == "bot-qq"
|
||||
assert metadata["platform_io_target_user_id"] == "target-user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""send service 应将发送职责统一交给 Platform IO。"""
|
||||
|
||||
@@ -394,23 +394,22 @@ class PlatformIOManager:
|
||||
"""
|
||||
|
||||
drivers: List[PlatformIODriver] = []
|
||||
seen_driver_ids: set[str] = set()
|
||||
for binding in self._send_route_table.resolve_bindings(route_key):
|
||||
driver = self._driver_registry.get(binding.driver_id)
|
||||
if driver is not None:
|
||||
if driver is not None and driver.driver_id not in seen_driver_ids:
|
||||
drivers.append(driver)
|
||||
if drivers:
|
||||
return drivers
|
||||
seen_driver_ids.add(driver.driver_id)
|
||||
|
||||
fallback_driver = self._legacy_send_drivers.get(route_key.platform)
|
||||
if fallback_driver is None:
|
||||
return []
|
||||
if fallback_driver is not None:
|
||||
descriptor = fallback_driver.descriptor
|
||||
account_matches = descriptor.account_id is None or route_key.account_id in (None, descriptor.account_id)
|
||||
scope_matches = descriptor.scope is None or route_key.scope in (None, descriptor.scope)
|
||||
if account_matches and scope_matches and fallback_driver.driver_id not in seen_driver_ids:
|
||||
drivers.append(fallback_driver)
|
||||
|
||||
descriptor = fallback_driver.descriptor
|
||||
if descriptor.account_id is not None and route_key.account_id not in (None, descriptor.account_id):
|
||||
return []
|
||||
if descriptor.scope is not None and route_key.scope not in (None, descriptor.scope):
|
||||
return []
|
||||
return [fallback_driver]
|
||||
return drivers
|
||||
|
||||
@staticmethod
|
||||
def build_route_key_from_message(message: "SessionMessage") -> RouteKey:
|
||||
|
||||
@@ -92,11 +92,24 @@ class OutboundTracker:
|
||||
raise ValueError("ttl_seconds 必须大于 0")
|
||||
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._pending: Dict[str, PendingOutboundRecord] = {}
|
||||
self._pending_expire_heap: List[Tuple[float, str]] = []
|
||||
self._pending: Dict[Tuple[str, str], PendingOutboundRecord] = {}
|
||||
self._pending_expire_heap: List[Tuple[float, str, str]] = []
|
||||
self._receipts_by_external_id: Dict[str, StoredDeliveryReceipt] = {}
|
||||
self._receipt_expire_heap: List[Tuple[float, str]] = []
|
||||
|
||||
@staticmethod
|
||||
def _build_pending_key(internal_message_id: str, driver_id: str) -> Tuple[str, str]:
|
||||
"""构造单条出站跟踪记录的唯一键。
|
||||
|
||||
Args:
|
||||
internal_message_id: 内部消息 ID。
|
||||
driver_id: 负责当前投递的驱动 ID。
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: ``(internal_message_id, driver_id)`` 组合键。
|
||||
"""
|
||||
return internal_message_id, driver_id
|
||||
|
||||
def begin_tracking(
|
||||
self,
|
||||
internal_message_id: str,
|
||||
@@ -116,13 +129,15 @@ class OutboundTracker:
|
||||
PendingOutboundRecord: 新创建的待完成记录。
|
||||
|
||||
Raises:
|
||||
ValueError: 当同一个 ``internal_message_id`` 已经存在未完成记录时抛出。
|
||||
ValueError: 当同一个 ``internal_message_id`` 与 ``driver_id`` 组合已经存在
|
||||
未完成记录时抛出。
|
||||
"""
|
||||
now = time.monotonic()
|
||||
self._cleanup_expired(now)
|
||||
pending_key = self._build_pending_key(internal_message_id, driver_id)
|
||||
|
||||
if internal_message_id in self._pending:
|
||||
raise ValueError(f"消息 {internal_message_id} 已存在未完成的出站跟踪记录")
|
||||
if pending_key in self._pending:
|
||||
raise ValueError(f"消息 {internal_message_id} 在驱动 {driver_id} 上已存在未完成的出站跟踪记录")
|
||||
|
||||
expires_at = now + self._ttl_seconds
|
||||
record = PendingOutboundRecord(
|
||||
@@ -133,8 +148,8 @@ class OutboundTracker:
|
||||
expires_at=expires_at,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._pending[internal_message_id] = record
|
||||
heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id))
|
||||
self._pending[pending_key] = record
|
||||
heapq.heappush(self._pending_expire_heap, (expires_at, internal_message_id, driver_id))
|
||||
return record
|
||||
|
||||
def finish_tracking(self, receipt: DeliveryReceipt) -> Optional[PendingOutboundRecord]:
|
||||
@@ -149,7 +164,19 @@ class OutboundTracker:
|
||||
now = time.monotonic()
|
||||
self._cleanup_expired(now)
|
||||
|
||||
pending_record = self._pending.pop(receipt.internal_message_id, None)
|
||||
pending_record: Optional[PendingOutboundRecord] = None
|
||||
if receipt.driver_id:
|
||||
pending_key = self._build_pending_key(receipt.internal_message_id, receipt.driver_id)
|
||||
pending_record = self._pending.pop(pending_key, None)
|
||||
else:
|
||||
matched_records = [
|
||||
key
|
||||
for key, record in self._pending.items()
|
||||
if record.internal_message_id == receipt.internal_message_id
|
||||
]
|
||||
if len(matched_records) == 1:
|
||||
pending_record = self._pending.pop(matched_records[0], None)
|
||||
|
||||
if receipt.external_message_id:
|
||||
expires_at = now + self._ttl_seconds
|
||||
self._receipts_by_external_id[receipt.external_message_id] = StoredDeliveryReceipt(
|
||||
@@ -160,17 +187,33 @@ class OutboundTracker:
|
||||
heapq.heappush(self._receipt_expire_heap, (expires_at, receipt.external_message_id))
|
||||
return pending_record
|
||||
|
||||
def get_pending(self, internal_message_id: str) -> Optional[PendingOutboundRecord]:
|
||||
def get_pending(
|
||||
self,
|
||||
internal_message_id: str,
|
||||
driver_id: Optional[str] = None,
|
||||
) -> Optional[PendingOutboundRecord]:
|
||||
"""根据内部消息 ID 查询待完成记录。
|
||||
|
||||
Args:
|
||||
internal_message_id: 要查询的内部消息 ID。
|
||||
driver_id: 可选的驱动 ID;提供后仅返回该驱动上的待完成记录。
|
||||
|
||||
Returns:
|
||||
Optional[PendingOutboundRecord]: 若记录仍存在,则返回对应待完成记录。
|
||||
"""
|
||||
self._cleanup_expired(time.monotonic())
|
||||
return self._pending.get(internal_message_id)
|
||||
|
||||
if driver_id:
|
||||
return self._pending.get(self._build_pending_key(internal_message_id, driver_id))
|
||||
|
||||
matched_records = [
|
||||
record
|
||||
for record in self._pending.values()
|
||||
if record.internal_message_id == internal_message_id
|
||||
]
|
||||
if len(matched_records) == 1:
|
||||
return matched_records[0]
|
||||
return None
|
||||
|
||||
def get_receipt_by_external_id(self, external_message_id: str) -> Optional[DeliveryReceipt]:
|
||||
"""根据外部平台消息 ID 查询已完成回执。
|
||||
@@ -213,13 +256,14 @@ class OutboundTracker:
|
||||
``expires_at`` 对比,跳过这类旧节点。
|
||||
"""
|
||||
while self._pending_expire_heap and self._pending_expire_heap[0][0] <= now:
|
||||
expires_at, internal_message_id = heapq.heappop(self._pending_expire_heap)
|
||||
current_record = self._pending.get(internal_message_id)
|
||||
expires_at, internal_message_id, driver_id = heapq.heappop(self._pending_expire_heap)
|
||||
pending_key = self._build_pending_key(internal_message_id, driver_id)
|
||||
current_record = self._pending.get(pending_key)
|
||||
if current_record is None:
|
||||
continue
|
||||
if current_record.expires_at != expires_at:
|
||||
continue
|
||||
self._pending.pop(internal_message_id, None)
|
||||
self._pending.pop(pending_key, None)
|
||||
|
||||
def _cleanup_expired_receipts(self, now: float) -> None:
|
||||
"""清理已经过期的回执索引。
|
||||
|
||||
@@ -72,6 +72,8 @@ class RuntimeComponentCapabilityMixin:
|
||||
"version": entry.version,
|
||||
"public": entry.public,
|
||||
"enabled": entry.enabled,
|
||||
"dynamic": entry.dynamic,
|
||||
"offline_reason": entry.offline_reason,
|
||||
"metadata": dict(entry.metadata),
|
||||
}
|
||||
|
||||
@@ -109,6 +111,32 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
return entry.plugin_id == caller_plugin_id or entry.public
|
||||
|
||||
@staticmethod
|
||||
def _normalize_api_reference(api_name: str, version: str = "") -> tuple[str, str]:
|
||||
"""规范化 API 名称与版本参数。
|
||||
|
||||
支持在 ``api_name`` 中直接携带 ``@version`` 后缀。
|
||||
"""
|
||||
|
||||
normalized_api_name = str(api_name or "").strip()
|
||||
normalized_version = str(version or "").strip()
|
||||
if normalized_api_name and not normalized_version and "@" in normalized_api_name:
|
||||
candidate_name, candidate_version = normalized_api_name.rsplit("@", 1)
|
||||
candidate_name = candidate_name.strip()
|
||||
candidate_version = candidate_version.strip()
|
||||
if candidate_name and candidate_version:
|
||||
normalized_api_name = candidate_name
|
||||
normalized_version = candidate_version
|
||||
return normalized_api_name, normalized_version
|
||||
|
||||
@staticmethod
|
||||
def _build_api_unavailable_error(entry: "APIEntry") -> str:
|
||||
"""构造 API 当前不可用时的错误信息。"""
|
||||
|
||||
if entry.offline_reason:
|
||||
return entry.offline_reason
|
||||
return f"API {entry.registry_key} 当前不可用"
|
||||
|
||||
def _resolve_api_target(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
caller_plugin_id: str,
|
||||
@@ -127,8 +155,7 @@ class RuntimeComponentCapabilityMixin:
|
||||
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
|
||||
"""
|
||||
|
||||
normalized_api_name = str(api_name or "").strip()
|
||||
normalized_version = str(version or "").strip()
|
||||
normalized_api_name, normalized_version = self._normalize_api_reference(api_name, version)
|
||||
if not normalized_api_name:
|
||||
return None, None, "缺少必要参数 api_name"
|
||||
|
||||
@@ -142,34 +169,61 @@ class RuntimeComponentCapabilityMixin:
|
||||
if supervisor is None:
|
||||
return None, None, f"未找到 API 提供方插件: {target_plugin_id}"
|
||||
|
||||
entry = supervisor.api_registry.get_api(
|
||||
entries = supervisor.api_registry.get_apis(
|
||||
plugin_id=target_plugin_id,
|
||||
name=target_api_name,
|
||||
enabled_only=True,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if entry is None:
|
||||
return None, None, f"未找到 API: {normalized_api_name}"
|
||||
if normalized_version and entry.version != normalized_version:
|
||||
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
|
||||
if not self._is_api_visible_to_plugin(entry, caller_plugin_id):
|
||||
visible_enabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and entry.enabled
|
||||
]
|
||||
visible_disabled_entries = [
|
||||
entry
|
||||
for entry in entries
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id) and not entry.enabled
|
||||
]
|
||||
if len(visible_enabled_entries) == 1:
|
||||
return supervisor, visible_enabled_entries[0], None
|
||||
if len(visible_enabled_entries) > 1:
|
||||
return None, None, f"API {normalized_api_name} 存在多个版本,请显式指定 version"
|
||||
if visible_disabled_entries:
|
||||
if len(visible_disabled_entries) == 1:
|
||||
return None, None, self._build_api_unavailable_error(visible_disabled_entries[0])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请显式指定 version"
|
||||
if any(not self._is_api_visible_to_plugin(entry, caller_plugin_id) for entry in entries):
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
return supervisor, entry, None
|
||||
if normalized_version:
|
||||
return None, None, f"未找到版本为 {normalized_version} 的 API: {normalized_api_name}"
|
||||
return None, None, f"未找到 API: {normalized_api_name}"
|
||||
|
||||
visible_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
visible_enabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
visible_disabled_matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
hidden_match_exists = False
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(name=normalized_api_name, enabled_only=True):
|
||||
if normalized_version and entry.version != normalized_version:
|
||||
continue
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
name=normalized_api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
):
|
||||
if self._is_api_visible_to_plugin(entry, caller_plugin_id):
|
||||
visible_matches.append((supervisor, entry))
|
||||
if entry.enabled:
|
||||
visible_enabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
visible_disabled_matches.append((supervisor, entry))
|
||||
else:
|
||||
hidden_match_exists = True
|
||||
|
||||
if len(visible_matches) == 1:
|
||||
return visible_matches[0][0], visible_matches[0][1], None
|
||||
if len(visible_matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name"
|
||||
if len(visible_enabled_matches) == 1:
|
||||
return visible_enabled_matches[0][0], visible_enabled_matches[0][1], None
|
||||
if len(visible_enabled_matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_api_name},请使用 plugin_id.api_name 或显式指定 version"
|
||||
if visible_disabled_matches:
|
||||
if len(visible_disabled_matches) == 1:
|
||||
return None, None, self._build_api_unavailable_error(visible_disabled_matches[0][1])
|
||||
return None, None, f"API {normalized_api_name} 存在多个已下线版本,请使用 plugin_id.api_name@version"
|
||||
if hidden_match_exists:
|
||||
return None, None, f"API {normalized_api_name} 未公开,禁止跨插件调用"
|
||||
if normalized_version:
|
||||
@@ -179,18 +233,20 @@ class RuntimeComponentCapabilityMixin:
|
||||
def _resolve_api_toggle_target(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
name: str,
|
||||
version: str = "",
|
||||
) -> tuple[Optional["PluginSupervisor"], Optional["APIEntry"], Optional[str]]:
|
||||
"""解析需要启用或禁用的 API 组件。
|
||||
|
||||
Args:
|
||||
name: API 名称,支持 ``plugin_id.api_name`` 或唯一短名。
|
||||
version: 可选的 API 版本。
|
||||
|
||||
Returns:
|
||||
tuple[Optional[PluginSupervisor], Optional[APIEntry], Optional[str]]:
|
||||
解析成功时返回 ``(监督器, API 条目, None)``,失败时返回错误信息。
|
||||
"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
normalized_name, normalized_version = self._normalize_api_reference(name, version)
|
||||
if not normalized_name:
|
||||
return None, None, "缺少必要参数 name"
|
||||
|
||||
@@ -204,24 +260,31 @@ class RuntimeComponentCapabilityMixin:
|
||||
if supervisor is None:
|
||||
return None, None, f"未找到 API 提供方插件: {plugin_id}"
|
||||
|
||||
entry = supervisor.api_registry.get_api(
|
||||
entries = supervisor.api_registry.get_apis(
|
||||
plugin_id=plugin_id,
|
||||
name=api_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if entry is None:
|
||||
if not entries:
|
||||
return None, None, f"未找到 API: {normalized_name}"
|
||||
return supervisor, entry, None
|
||||
if len(entries) > 1:
|
||||
return None, None, f"API {normalized_name} 存在多个版本,请显式指定 version"
|
||||
return supervisor, entries[0], None
|
||||
|
||||
matches: List[tuple["PluginSupervisor", "APIEntry"]] = []
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(name=normalized_name, enabled_only=False):
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
name=normalized_name,
|
||||
version=normalized_version,
|
||||
enabled_only=False,
|
||||
):
|
||||
matches.append((supervisor, entry))
|
||||
|
||||
if len(matches) == 1:
|
||||
return matches[0][0], matches[0][1], None
|
||||
if len(matches) > 1:
|
||||
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name"
|
||||
return None, None, f"API 名称不唯一: {normalized_name},请使用 plugin_id.api_name 或显式指定 version"
|
||||
return None, None, f"未找到 API: {normalized_name}"
|
||||
|
||||
async def _cap_component_get_all_plugins(
|
||||
@@ -326,6 +389,7 @@ class RuntimeComponentCapabilityMixin:
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
version: str = args.get("version", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not name or not component_type:
|
||||
@@ -334,10 +398,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
|
||||
|
||||
if self._is_api_component_type(component_type):
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name)
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
|
||||
if supervisor is None or api_entry is None:
|
||||
return {"success": False, "error": error or f"未找到 API: {name}"}
|
||||
supervisor.api_registry.toggle_api_status(api_entry.full_name, True)
|
||||
supervisor.api_registry.toggle_api_status(api_entry.registry_key, True)
|
||||
return {"success": True}
|
||||
|
||||
comp, error = self._resolve_component_toggle_target(name, component_type)
|
||||
@@ -352,6 +416,7 @@ class RuntimeComponentCapabilityMixin:
|
||||
) -> Any:
|
||||
name: str = args.get("name", "")
|
||||
component_type: str = args.get("component_type", "")
|
||||
version: str = args.get("version", "")
|
||||
scope: str = args.get("scope", "global")
|
||||
stream_id: str = args.get("stream_id", "")
|
||||
if not name or not component_type:
|
||||
@@ -360,10 +425,10 @@ class RuntimeComponentCapabilityMixin:
|
||||
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
|
||||
|
||||
if self._is_api_component_type(component_type):
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name)
|
||||
supervisor, api_entry, error = self._resolve_api_toggle_target(name, version)
|
||||
if supervisor is None or api_entry is None:
|
||||
return {"success": False, "error": error or f"未找到 API: {name}"}
|
||||
supervisor.api_registry.toggle_api_status(api_entry.full_name, False)
|
||||
supervisor.api_registry.toggle_api_status(api_entry.registry_key, False)
|
||||
return {"success": True}
|
||||
|
||||
comp, error = self._resolve_component_toggle_target(name, component_type)
|
||||
@@ -488,11 +553,17 @@ class RuntimeComponentCapabilityMixin:
|
||||
if supervisor is None or entry is None:
|
||||
return {"success": False, "error": error or "API 解析失败"}
|
||||
|
||||
invoke_args = dict(api_args)
|
||||
if entry.dynamic:
|
||||
invoke_args.setdefault("__maibot_api_name__", entry.name)
|
||||
invoke_args.setdefault("__maibot_api_full_name__", entry.full_name)
|
||||
invoke_args.setdefault("__maibot_api_version__", entry.version)
|
||||
|
||||
try:
|
||||
response = await supervisor.invoke_api(
|
||||
plugin_id=entry.plugin_id,
|
||||
component_name=entry.name,
|
||||
args=api_args,
|
||||
component_name=entry.handler_name,
|
||||
args=invoke_args,
|
||||
timeout_ms=30000,
|
||||
)
|
||||
except Exception as exc:
|
||||
@@ -555,10 +626,16 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
del capability
|
||||
target_plugin_id = str(args.get("plugin_id", "") or "").strip()
|
||||
api_name, version = self._normalize_api_reference(
|
||||
str(args.get("api_name", args.get("name", "")) or ""),
|
||||
str(args.get("version", "") or ""),
|
||||
)
|
||||
apis: List[Dict[str, Any]] = []
|
||||
for supervisor in self.supervisors:
|
||||
for entry in supervisor.api_registry.get_apis(
|
||||
plugin_id=target_plugin_id or None,
|
||||
name=api_name,
|
||||
version=version,
|
||||
enabled_only=True,
|
||||
):
|
||||
if not self._is_api_visible_to_plugin(entry, plugin_id):
|
||||
@@ -567,3 +644,75 @@ class RuntimeComponentCapabilityMixin:
|
||||
|
||||
apis.sort(key=lambda item: (str(item["plugin_id"]), str(item["name"]), str(item["version"])))
|
||||
return {"success": True, "apis": apis}
|
||||
|
||||
async def _cap_api_replace_dynamic(
|
||||
self: _RuntimeComponentManagerProtocol,
|
||||
plugin_id: str,
|
||||
capability: str,
|
||||
args: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""替换插件自行维护的动态 API 列表。"""
|
||||
|
||||
del capability
|
||||
raw_apis = args.get("apis", [])
|
||||
offline_reason = str(args.get("offline_reason", "") or "").strip() or "动态 API 已下线"
|
||||
if not isinstance(raw_apis, list):
|
||||
return {"success": False, "error": "参数 apis 必须为列表"}
|
||||
|
||||
try:
|
||||
supervisor = self._get_supervisor_for_plugin(plugin_id)
|
||||
except RuntimeError as exc:
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
if supervisor is None:
|
||||
return {"success": False, "error": f"未找到插件: {plugin_id}"}
|
||||
|
||||
normalized_components: List[Dict[str, Any]] = []
|
||||
seen_registry_keys: set[str] = set()
|
||||
for index, raw_api in enumerate(raw_apis):
|
||||
if not isinstance(raw_api, dict):
|
||||
return {"success": False, "error": f"apis[{index}] 必须为字典"}
|
||||
|
||||
api_name = str(raw_api.get("name", "") or "").strip()
|
||||
component_type = str(raw_api.get("component_type", raw_api.get("type", "API")) or "").strip()
|
||||
if not api_name:
|
||||
return {"success": False, "error": f"apis[{index}] 缺少 name"}
|
||||
if not self._is_api_component_type(component_type):
|
||||
return {"success": False, "error": f"apis[{index}] 不是 API 组件"}
|
||||
|
||||
metadata = raw_api.get("metadata", {}) if isinstance(raw_api.get("metadata"), dict) else {}
|
||||
normalized_metadata = dict(metadata)
|
||||
normalized_metadata["dynamic"] = True
|
||||
version = str(normalized_metadata.get("version", "1") or "1").strip() or "1"
|
||||
registry_key = supervisor.api_registry.build_registry_key(plugin_id, api_name, version)
|
||||
if registry_key in seen_registry_keys:
|
||||
return {"success": False, "error": f"动态 API 重复声明: {registry_key}"}
|
||||
seen_registry_keys.add(registry_key)
|
||||
|
||||
existing_entry = supervisor.api_registry.get_api(
|
||||
plugin_id,
|
||||
api_name,
|
||||
version=version,
|
||||
enabled_only=False,
|
||||
)
|
||||
if existing_entry is not None and not existing_entry.dynamic:
|
||||
return {"success": False, "error": f"动态 API 不能覆盖静态 API: {registry_key}"}
|
||||
|
||||
normalized_components.append(
|
||||
{
|
||||
"name": api_name,
|
||||
"component_type": "API",
|
||||
"metadata": normalized_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
registered_count, offlined_count = supervisor.api_registry.replace_plugin_dynamic_apis(
|
||||
plugin_id,
|
||||
normalized_components,
|
||||
offline_reason=offline_reason,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"count": registered_count,
|
||||
"offlined": offlined_count,
|
||||
}
|
||||
|
||||
@@ -77,6 +77,7 @@ def register_capability_impls(manager: "PluginRuntimeManager", supervisor: Plugi
|
||||
_register("api.call", manager._cap_api_call)
|
||||
_register("api.get", manager._cap_api_get)
|
||||
_register("api.list", manager._cap_api_list)
|
||||
_register("api.replace_dynamic", manager._cap_api_replace_dynamic)
|
||||
|
||||
_register("component.get_all_plugins", manager._cap_component_get_all_plugins)
|
||||
_register("component.get_plugin_info", manager._cap_component_get_plugin_info)
|
||||
|
||||
@@ -1,45 +1,60 @@
|
||||
"""Host 侧插件 API 动态注册表。"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_runtime.host.api_registry")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class APIEntry:
|
||||
"""API 组件条目。"""
|
||||
|
||||
__slots__ = (
|
||||
"description",
|
||||
"disabled_session",
|
||||
"enabled",
|
||||
"full_name",
|
||||
"metadata",
|
||||
"name",
|
||||
"plugin_id",
|
||||
"public",
|
||||
"version",
|
||||
)
|
||||
name: str
|
||||
plugin_id: str
|
||||
description: str = ""
|
||||
version: str = "1"
|
||||
public: bool = False
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
enabled: bool = True
|
||||
handler_name: str = ""
|
||||
dynamic: bool = False
|
||||
offline_reason: str = ""
|
||||
disabled_session: Set[str] = field(default_factory=set)
|
||||
full_name: str = field(init=False)
|
||||
registry_key: str = field(init=False)
|
||||
|
||||
def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||||
"""初始化 API 组件条目。
|
||||
def __post_init__(self) -> None:
|
||||
"""规范化 API 条目字段。"""
|
||||
|
||||
Args:
|
||||
name: API 名称。
|
||||
plugin_id: 所属插件 ID。
|
||||
metadata: API 元数据。
|
||||
"""
|
||||
self.name = str(self.name or "").strip()
|
||||
self.plugin_id = str(self.plugin_id or "").strip()
|
||||
self.description = str(self.description or "").strip()
|
||||
self.version = str(self.version or "1").strip() or "1"
|
||||
self.handler_name = str(self.handler_name or self.name).strip() or self.name
|
||||
self.offline_reason = str(self.offline_reason or "").strip()
|
||||
self.full_name = f"{self.plugin_id}.{self.name}"
|
||||
self.registry_key = APIRegistry.build_registry_key(self.plugin_id, self.name, self.version)
|
||||
|
||||
self.name: str = name
|
||||
self.full_name: str = f"{plugin_id}.{name}"
|
||||
self.plugin_id: str = plugin_id
|
||||
self.description: str = str(metadata.get("description", "") or "")
|
||||
self.version: str = str(metadata.get("version", "1") or "1").strip() or "1"
|
||||
self.public: bool = bool(metadata.get("public", False))
|
||||
self.metadata: Dict[str, Any] = dict(metadata)
|
||||
self.enabled: bool = bool(metadata.get("enabled", True))
|
||||
self.disabled_session: Set[str] = set()
|
||||
@classmethod
|
||||
def from_metadata(cls, name: str, plugin_id: str, metadata: Dict[str, Any]) -> "APIEntry":
|
||||
"""根据 Runner 上报的元数据构造 API 条目。"""
|
||||
|
||||
safe_metadata = dict(metadata)
|
||||
return cls(
|
||||
name=name,
|
||||
plugin_id=plugin_id,
|
||||
description=str(safe_metadata.get("description", "") or ""),
|
||||
version=str(safe_metadata.get("version", "1") or "1"),
|
||||
public=bool(safe_metadata.get("public", False)),
|
||||
metadata=safe_metadata,
|
||||
enabled=bool(safe_metadata.get("enabled", True)),
|
||||
handler_name=str(safe_metadata.get("handler_name", name) or name),
|
||||
dynamic=bool(safe_metadata.get("dynamic", False)),
|
||||
offline_reason=str(safe_metadata.get("offline_reason", "") or ""),
|
||||
)
|
||||
|
||||
|
||||
class APIRegistry:
|
||||
@@ -53,6 +68,7 @@ class APIRegistry:
|
||||
"""初始化 API 注册表。"""
|
||||
|
||||
self._apis: Dict[str, APIEntry] = {}
|
||||
self._by_full_name: Dict[str, List[APIEntry]] = {}
|
||||
self._by_plugin: Dict[str, List[APIEntry]] = {}
|
||||
self._by_name: Dict[str, List[APIEntry]] = {}
|
||||
|
||||
@@ -60,75 +76,75 @@ class APIRegistry:
|
||||
"""清空全部 API 注册状态。"""
|
||||
|
||||
self._apis.clear()
|
||||
self._by_full_name.clear()
|
||||
self._by_plugin.clear()
|
||||
self._by_name.clear()
|
||||
|
||||
@staticmethod
|
||||
def _is_api_component(component_type: Any) -> bool:
|
||||
"""判断组件声明是否属于 API。
|
||||
|
||||
Args:
|
||||
component_type: 原始组件类型值。
|
||||
|
||||
Returns:
|
||||
bool: 是否为 API 组件。
|
||||
"""
|
||||
"""判断组件声明是否属于 API。"""
|
||||
|
||||
return str(component_type or "").strip().upper() == "API"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_query_version(version: Any) -> str:
|
||||
"""规范化查询使用的版本字符串。"""
|
||||
|
||||
return str(version or "").strip()
|
||||
|
||||
@classmethod
|
||||
def _split_reference(cls, reference: str, version: Any = "") -> Tuple[str, str]:
|
||||
"""解析可能带 ``@version`` 后缀的 API 引用。"""
|
||||
|
||||
normalized_reference = str(reference or "").strip()
|
||||
normalized_version = cls._normalize_query_version(version)
|
||||
if normalized_reference and not normalized_version and "@" in normalized_reference:
|
||||
candidate_reference, candidate_version = normalized_reference.rsplit("@", 1)
|
||||
candidate_reference = candidate_reference.strip()
|
||||
candidate_version = candidate_version.strip()
|
||||
if candidate_reference and candidate_version:
|
||||
normalized_reference = candidate_reference
|
||||
normalized_version = candidate_version
|
||||
return normalized_reference, normalized_version
|
||||
|
||||
@staticmethod
|
||||
def build_registry_key(plugin_id: str, name: str, version: str) -> str:
|
||||
"""构造 API 注册表唯一键。"""
|
||||
|
||||
normalized_full_name = f"{str(plugin_id or '').strip()}.{str(name or '').strip()}"
|
||||
normalized_version = str(version or "1").strip() or "1"
|
||||
return f"{normalized_full_name}@{normalized_version}"
|
||||
|
||||
@staticmethod
|
||||
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
|
||||
"""判断 API 条目当前是否处于启用状态。
|
||||
|
||||
Args:
|
||||
entry: 待检查的 API 条目。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
Returns:
|
||||
bool: 当前是否可用。
|
||||
"""
|
||||
"""判断 API 条目当前是否处于启用状态。"""
|
||||
|
||||
if session_id and session_id in entry.disabled_session:
|
||||
return False
|
||||
return entry.enabled
|
||||
|
||||
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""注册单个 API 条目。
|
||||
|
||||
Args:
|
||||
name: API 名称。
|
||||
plugin_id: 所属插件 ID。
|
||||
metadata: API 元数据。
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注册。
|
||||
"""
|
||||
"""注册单个 API 条目。"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
if not normalized_name:
|
||||
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
|
||||
return False
|
||||
|
||||
entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
|
||||
if entry.full_name in self._apis:
|
||||
logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目")
|
||||
self._remove_entry(self._apis[entry.full_name])
|
||||
entry = APIEntry.from_metadata(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
|
||||
existing_entry = self._apis.get(entry.registry_key)
|
||||
if existing_entry is not None:
|
||||
logger.warning(f"API {entry.registry_key} 已存在,覆盖旧条目")
|
||||
self._remove_entry(existing_entry)
|
||||
|
||||
self._apis[entry.full_name] = entry
|
||||
self._apis[entry.registry_key] = entry
|
||||
self._by_full_name.setdefault(entry.full_name, []).append(entry)
|
||||
self._by_plugin.setdefault(plugin_id, []).append(entry)
|
||||
self._by_name.setdefault(entry.name, []).append(entry)
|
||||
return True
|
||||
|
||||
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||||
"""批量注册某个插件声明的全部 API。
|
||||
|
||||
Args:
|
||||
plugin_id: 插件 ID。
|
||||
components: 插件组件声明列表。
|
||||
|
||||
Returns:
|
||||
int: 成功注册的 API 数量。
|
||||
"""
|
||||
"""批量注册某个插件声明的全部 API。"""
|
||||
|
||||
count = 0
|
||||
for component in components:
|
||||
@@ -142,14 +158,60 @@ class APIRegistry:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def replace_plugin_dynamic_apis(
|
||||
self,
|
||||
plugin_id: str,
|
||||
components: List[Dict[str, Any]],
|
||||
*,
|
||||
offline_reason: str = "动态 API 已下线",
|
||||
) -> Tuple[int, int]:
|
||||
"""替换指定插件当前声明的动态 API 集合。"""
|
||||
|
||||
normalized_offline_reason = str(offline_reason or "").strip() or "动态 API 已下线"
|
||||
desired_registry_keys: Set[str] = set()
|
||||
registered_count = 0
|
||||
|
||||
for component in components:
|
||||
if not self._is_api_component(component.get("component_type")):
|
||||
continue
|
||||
metadata = component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {}
|
||||
dynamic_metadata = dict(metadata)
|
||||
dynamic_metadata["dynamic"] = True
|
||||
dynamic_metadata.pop("offline_reason", None)
|
||||
|
||||
entry = APIEntry.from_metadata(
|
||||
name=str(component.get("name", "") or ""),
|
||||
plugin_id=plugin_id,
|
||||
metadata=dynamic_metadata,
|
||||
)
|
||||
desired_registry_keys.add(entry.registry_key)
|
||||
if self.register_api(entry.name, plugin_id, dynamic_metadata):
|
||||
registered_count += 1
|
||||
|
||||
offlined_count = 0
|
||||
for entry in list(self._by_plugin.get(plugin_id, [])):
|
||||
if not entry.dynamic or entry.registry_key in desired_registry_keys:
|
||||
continue
|
||||
entry.enabled = False
|
||||
entry.offline_reason = normalized_offline_reason
|
||||
entry.metadata["offline_reason"] = normalized_offline_reason
|
||||
offlined_count += 1
|
||||
|
||||
return registered_count, offlined_count
|
||||
|
||||
def _remove_entry(self, entry: APIEntry) -> None:
|
||||
"""从全部索引中移除单个 API 条目。
|
||||
"""从全部索引中移除单个 API 条目。"""
|
||||
|
||||
Args:
|
||||
entry: 待移除的 API 条目。
|
||||
"""
|
||||
self._apis.pop(entry.registry_key, None)
|
||||
|
||||
full_name_entries = self._by_full_name.get(entry.full_name)
|
||||
if full_name_entries is not None:
|
||||
self._by_full_name[entry.full_name] = [
|
||||
candidate for candidate in full_name_entries if candidate is not entry
|
||||
]
|
||||
if not self._by_full_name[entry.full_name]:
|
||||
self._by_full_name.pop(entry.full_name, None)
|
||||
|
||||
self._apis.pop(entry.full_name, None)
|
||||
plugin_entries = self._by_plugin.get(entry.plugin_id)
|
||||
if plugin_entries is not None:
|
||||
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
|
||||
@@ -163,14 +225,7 @@ class APIRegistry:
|
||||
self._by_name.pop(entry.name, None)
|
||||
|
||||
def remove_apis_by_plugin(self, plugin_id: str) -> int:
|
||||
"""移除某个插件的全部 API。
|
||||
|
||||
Args:
|
||||
plugin_id: 目标插件 ID。
|
||||
|
||||
Returns:
|
||||
int: 被移除的 API 数量。
|
||||
"""
|
||||
"""移除某个插件的全部 API。"""
|
||||
|
||||
entries = list(self._by_plugin.get(plugin_id, []))
|
||||
for entry in entries:
|
||||
@@ -181,49 +236,48 @@ class APIRegistry:
|
||||
self,
|
||||
full_name: str,
|
||||
*,
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[APIEntry]:
|
||||
"""按完整名查询单个 API。
|
||||
"""按完整名查询单个 API。"""
|
||||
|
||||
Args:
|
||||
full_name: API 完整名,格式为 ``plugin_id.api_name``。
|
||||
enabled_only: 是否仅返回启用状态的 API。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[APIEntry]: 命中时返回 API 条目。
|
||||
"""
|
||||
|
||||
entry = self._apis.get(full_name)
|
||||
if entry is None:
|
||||
normalized_full_name, normalized_version = self._split_reference(full_name, version)
|
||||
if not normalized_full_name:
|
||||
return None
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
|
||||
if normalized_version:
|
||||
entry = self._apis.get(f"{normalized_full_name}@{normalized_version}")
|
||||
if entry is None:
|
||||
return None
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
return None
|
||||
return entry
|
||||
|
||||
candidates = list(self._by_full_name.get(normalized_full_name, []))
|
||||
filtered_entries = [
|
||||
entry
|
||||
for entry in candidates
|
||||
if not enabled_only or self.check_api_enabled(entry, session_id)
|
||||
]
|
||||
if len(filtered_entries) != 1:
|
||||
return None
|
||||
return entry
|
||||
return filtered_entries[0]
|
||||
|
||||
def get_api(
|
||||
self,
|
||||
plugin_id: str,
|
||||
name: str,
|
||||
*,
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Optional[APIEntry]:
|
||||
"""按插件 ID 和短名查询单个 API。
|
||||
|
||||
Args:
|
||||
plugin_id: 提供方插件 ID。
|
||||
name: API 短名。
|
||||
enabled_only: 是否仅返回启用状态的 API。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
Returns:
|
||||
Optional[APIEntry]: 命中时返回 API 条目。
|
||||
"""
|
||||
"""按插件 ID、短名与版本查询单个 API。"""
|
||||
|
||||
return self.get_api_by_full_name(
|
||||
f"{plugin_id}.{name}",
|
||||
version=version,
|
||||
enabled_only=enabled_only,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -233,22 +287,15 @@ class APIRegistry:
|
||||
*,
|
||||
plugin_id: Optional[str] = None,
|
||||
name: str = "",
|
||||
version: str = "",
|
||||
enabled_only: bool = True,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[APIEntry]:
|
||||
"""查询 API 列表。
|
||||
|
||||
Args:
|
||||
plugin_id: 可选的插件 ID 过滤条件。
|
||||
name: 可选的 API 名称过滤条件。
|
||||
enabled_only: 是否仅返回启用状态的 API。
|
||||
session_id: 可选的会话 ID。
|
||||
|
||||
Returns:
|
||||
List[APIEntry]: 符合条件的 API 条目列表。
|
||||
"""
|
||||
"""查询 API 列表。"""
|
||||
|
||||
normalized_name = str(name or "").strip()
|
||||
normalized_version = self._normalize_query_version(version)
|
||||
|
||||
if plugin_id:
|
||||
candidates = list(self._by_plugin.get(plugin_id, []))
|
||||
elif normalized_name:
|
||||
@@ -258,26 +305,35 @@ class APIRegistry:
|
||||
|
||||
filtered_entries: List[APIEntry] = []
|
||||
for entry in candidates:
|
||||
if plugin_id and entry.plugin_id != plugin_id:
|
||||
continue
|
||||
if normalized_name and entry.name != normalized_name:
|
||||
continue
|
||||
if normalized_version and entry.version != normalized_version:
|
||||
continue
|
||||
if enabled_only and not self.check_api_enabled(entry, session_id):
|
||||
continue
|
||||
filtered_entries.append(entry)
|
||||
|
||||
filtered_entries.sort(key=lambda entry: (entry.plugin_id, entry.name, entry.version))
|
||||
return filtered_entries
|
||||
|
||||
def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||||
"""设置指定 API 的启用状态。
|
||||
def toggle_api_status(
|
||||
self,
|
||||
full_name: str,
|
||||
enabled: bool,
|
||||
*,
|
||||
version: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""设置指定 API 的启用状态。"""
|
||||
|
||||
Args:
|
||||
full_name: API 完整名。
|
||||
enabled: 目标启用状态。
|
||||
session_id: 可选的会话 ID,仅对该会话生效。
|
||||
|
||||
Returns:
|
||||
bool: 是否设置成功。
|
||||
"""
|
||||
|
||||
entry = self._apis.get(full_name)
|
||||
entry = self.get_api_by_full_name(
|
||||
full_name,
|
||||
version=version,
|
||||
enabled_only=False,
|
||||
session_id=session_id,
|
||||
)
|
||||
if entry is None:
|
||||
return False
|
||||
if session_id:
|
||||
@@ -287,4 +343,7 @@ class APIRegistry:
|
||||
entry.disabled_session.add(session_id)
|
||||
else:
|
||||
entry.enabled = enabled
|
||||
if enabled:
|
||||
entry.offline_reason = ""
|
||||
entry.metadata.pop("offline_reason", None)
|
||||
return True
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
_ALWAYS_ALLOWED_CAPABILITIES = frozenset({"api.replace_dynamic"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapabilityPermissionToken:
|
||||
@@ -46,6 +48,9 @@ class AuthorizationManager:
|
||||
Returns:
|
||||
return (bool, str): (是否有此能力, 原因)
|
||||
"""
|
||||
if capability in _ALWAYS_ALLOWED_CAPABILITIES:
|
||||
return True, ""
|
||||
|
||||
token = self._permission_tokens.get(plugin_id)
|
||||
if not token:
|
||||
return False, f"插件 {plugin_id} 未注册能力令牌"
|
||||
|
||||
@@ -45,6 +45,14 @@ from src.plugin_runtime.runner.rpc_client import RPCClient
|
||||
|
||||
logger = get_logger("plugin_runtime.runner.main")
|
||||
|
||||
_PLUGIN_ALLOWED_RAW_HOST_METHODS = frozenset(
|
||||
{
|
||||
"cap.call",
|
||||
"host.route_message",
|
||||
"host.update_message_gateway_state",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _ContextAwarePlugin(Protocol):
|
||||
"""支持注入运行时上下文的插件协议。
|
||||
@@ -247,8 +255,14 @@ class PluginRunner:
|
||||
logger.warning(
|
||||
f"插件 {bound_plugin_id} 尝试以 {plugin_id} 身份发起 RPC,已强制绑定回自身身份"
|
||||
)
|
||||
normalized_method = str(method or "").strip()
|
||||
if normalized_method not in _PLUGIN_ALLOWED_RAW_HOST_METHODS:
|
||||
raise PermissionError(
|
||||
f"插件 {bound_plugin_id} 不允许直接调用 Host 原始 RPC 方法: "
|
||||
f"{normalized_method or '<empty>'}"
|
||||
)
|
||||
resp = await rpc_client.send_request(
|
||||
method=method,
|
||||
method=normalized_method,
|
||||
plugin_id=bound_plugin_id,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
@@ -57,20 +57,23 @@ def _inherit_platform_io_route_metadata(target_stream: BotChatSession) -> Dict[s
|
||||
inherited_metadata: Dict[str, object] = {}
|
||||
|
||||
context_message = target_stream.context.message if target_stream.context else None
|
||||
if context_message is None:
|
||||
return inherited_metadata
|
||||
if context_message is not None:
|
||||
additional_config = context_message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
|
||||
value = additional_config.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
normalized_value = str(value).strip()
|
||||
if normalized_value:
|
||||
inherited_metadata[key] = value
|
||||
|
||||
additional_config = context_message.message_info.additional_config
|
||||
if not isinstance(additional_config, dict):
|
||||
return inherited_metadata
|
||||
|
||||
for key in (*RouteKeyFactory.ACCOUNT_ID_KEYS, *RouteKeyFactory.SCOPE_KEYS):
|
||||
value = additional_config.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
normalized_value = str(value).strip()
|
||||
if normalized_value:
|
||||
inherited_metadata[key] = value
|
||||
# 当目标会话没有可继承的上下文消息时,至少补齐当前平台账号,
|
||||
# 让按 ``platform + account_id`` 绑定的路由仍有机会命中。
|
||||
if not RouteKeyFactory.extract_components(inherited_metadata)[0]:
|
||||
bot_account = get_bot_account(target_stream.platform)
|
||||
if bot_account:
|
||||
inherited_metadata["platform_io_account_id"] = bot_account
|
||||
|
||||
if target_stream.group_id and (normalized_group_id := str(target_stream.group_id).strip()):
|
||||
inherited_metadata["platform_io_target_group_id"] = normalized_group_id
|
||||
|
||||
Reference in New Issue
Block a user