Files
mai-bot/src/plugin_runtime/host/api_registry.py
DrSmoothl 9dea6b0e6f feat: implement dedicated API registry and enhance API handling capabilities
- Added APIEntry and APIRegistry classes for managing plugin APIs.
- Updated PluginRunnerSupervisor to include API registry and methods for invoking APIs.
- Enhanced PluginRuntimeManager to support API registration and invocation.
- Created tests for API registration, invocation, and visibility between plugins.
- Refactored component handling to distinguish between runtime components and APIs.
2026-03-24 12:14:41 +08:00

291 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Host 侧插件 API 动态注册表。"""
from typing import Any, Dict, List, Optional, Set
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.api_registry")
class APIEntry:
"""API 组件条目。"""
__slots__ = (
"description",
"disabled_session",
"enabled",
"full_name",
"metadata",
"name",
"plugin_id",
"public",
"version",
)
def __init__(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
"""初始化 API 组件条目。
Args:
name: API 名称。
plugin_id: 所属插件 ID。
metadata: API 元数据。
"""
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.plugin_id: str = plugin_id
self.description: str = str(metadata.get("description", "") or "")
self.version: str = str(metadata.get("version", "1") or "1").strip() or "1"
self.public: bool = bool(metadata.get("public", False))
self.metadata: Dict[str, Any] = dict(metadata)
self.enabled: bool = bool(metadata.get("enabled", True))
self.disabled_session: Set[str] = set()
class APIRegistry:
"""Host 侧插件 API 动态注册表。
该注册表不直接面向 Runner而是复用插件组件注册/卸载事件,
维护面向 API 调用场景的专用索引。
"""
def __init__(self) -> None:
"""初始化 API 注册表。"""
self._apis: Dict[str, APIEntry] = {}
self._by_plugin: Dict[str, List[APIEntry]] = {}
self._by_name: Dict[str, List[APIEntry]] = {}
def clear(self) -> None:
"""清空全部 API 注册状态。"""
self._apis.clear()
self._by_plugin.clear()
self._by_name.clear()
@staticmethod
def _is_api_component(component_type: Any) -> bool:
"""判断组件声明是否属于 API。
Args:
component_type: 原始组件类型值。
Returns:
bool: 是否为 API 组件。
"""
return str(component_type or "").strip().upper() == "API"
@staticmethod
def check_api_enabled(entry: APIEntry, session_id: Optional[str] = None) -> bool:
"""判断 API 条目当前是否处于启用状态。
Args:
entry: 待检查的 API 条目。
session_id: 可选的会话 ID。
Returns:
bool: 当前是否可用。
"""
if session_id and session_id in entry.disabled_session:
return False
return entry.enabled
def register_api(self, name: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个 API 条目。
Args:
name: API 名称。
plugin_id: 所属插件 ID。
metadata: API 元数据。
Returns:
bool: 是否成功注册。
"""
normalized_name = str(name or "").strip()
if not normalized_name:
logger.warning(f"插件 {plugin_id} 存在空 API 名称声明,已忽略")
return False
entry = APIEntry(name=normalized_name, plugin_id=plugin_id, metadata=metadata)
if entry.full_name in self._apis:
logger.warning(f"API {entry.full_name} 已存在,覆盖旧条目")
self._remove_entry(self._apis[entry.full_name])
self._apis[entry.full_name] = entry
self._by_plugin.setdefault(plugin_id, []).append(entry)
self._by_name.setdefault(entry.name, []).append(entry)
return True
def register_plugin_apis(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册某个插件声明的全部 API。
Args:
plugin_id: 插件 ID。
components: 插件组件声明列表。
Returns:
int: 成功注册的 API 数量。
"""
count = 0
for component in components:
if not self._is_api_component(component.get("component_type")):
continue
if self.register_api(
name=str(component.get("name", "") or ""),
plugin_id=plugin_id,
metadata=component.get("metadata", {}) if isinstance(component.get("metadata"), dict) else {},
):
count += 1
return count
def _remove_entry(self, entry: APIEntry) -> None:
"""从全部索引中移除单个 API 条目。
Args:
entry: 待移除的 API 条目。
"""
self._apis.pop(entry.full_name, None)
plugin_entries = self._by_plugin.get(entry.plugin_id)
if plugin_entries is not None:
self._by_plugin[entry.plugin_id] = [candidate for candidate in plugin_entries if candidate is not entry]
if not self._by_plugin[entry.plugin_id]:
self._by_plugin.pop(entry.plugin_id, None)
name_entries = self._by_name.get(entry.name)
if name_entries is not None:
self._by_name[entry.name] = [candidate for candidate in name_entries if candidate is not entry]
if not self._by_name[entry.name]:
self._by_name.pop(entry.name, None)
def remove_apis_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的全部 API。
Args:
plugin_id: 目标插件 ID。
Returns:
int: 被移除的 API 数量。
"""
entries = list(self._by_plugin.get(plugin_id, []))
for entry in entries:
self._remove_entry(entry)
return len(entries)
def get_api_by_full_name(
self,
full_name: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按完整名查询单个 API。
Args:
full_name: API 完整名,格式为 ``plugin_id.api_name``。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
Optional[APIEntry]: 命中时返回 API 条目。
"""
entry = self._apis.get(full_name)
if entry is None:
return None
if enabled_only and not self.check_api_enabled(entry, session_id):
return None
return entry
def get_api(
self,
plugin_id: str,
name: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[APIEntry]:
"""按插件 ID 和短名查询单个 API。
Args:
plugin_id: 提供方插件 ID。
name: API 短名。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
Optional[APIEntry]: 命中时返回 API 条目。
"""
return self.get_api_by_full_name(
f"{plugin_id}.{name}",
enabled_only=enabled_only,
session_id=session_id,
)
def get_apis(
self,
*,
plugin_id: Optional[str] = None,
name: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[APIEntry]:
"""查询 API 列表。
Args:
plugin_id: 可选的插件 ID 过滤条件。
name: 可选的 API 名称过滤条件。
enabled_only: 是否仅返回启用状态的 API。
session_id: 可选的会话 ID。
Returns:
List[APIEntry]: 符合条件的 API 条目列表。
"""
normalized_name = str(name or "").strip()
if plugin_id:
candidates = list(self._by_plugin.get(plugin_id, []))
elif normalized_name:
candidates = list(self._by_name.get(normalized_name, []))
else:
candidates = list(self._apis.values())
filtered_entries: List[APIEntry] = []
for entry in candidates:
if normalized_name and entry.name != normalized_name:
continue
if enabled_only and not self.check_api_enabled(entry, session_id):
continue
filtered_entries.append(entry)
return filtered_entries
def toggle_api_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""设置指定 API 的启用状态。
Args:
full_name: API 完整名。
enabled: 目标启用状态。
session_id: 可选的会话 ID仅对该会话生效。
Returns:
bool: 是否设置成功。
"""
entry = self._apis.get(full_name)
if entry is None:
return False
if session_id:
if enabled:
entry.disabled_session.discard(session_id)
else:
entry.disabled_session.add(session_id)
else:
entry.enabled = enabled
return True