Files
mai-bot/src/plugin_runtime/capabilities/components.py

234 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Protocol
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.integration")
if TYPE_CHECKING:
from src.plugin_runtime.host.component_registry import RegisteredComponent
from src.plugin_runtime.host.supervisor import PluginSupervisor
class _RuntimeComponentManagerProtocol(Protocol):
@property
def supervisors(self) -> List["PluginSupervisor"]: ...
def _get_supervisor_for_plugin(self, plugin_id: str) -> Optional["PluginSupervisor"]: ...
def _resolve_component_toggle_target(
self, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]: ...
def _find_duplicate_plugin_ids(self, plugin_dirs: List[str]) -> Dict[str, List[str]]: ...
def _iter_plugin_dirs(self) -> Iterable[str]: ...
class RuntimeComponentCapabilityMixin:
async def _cap_component_get_all_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
result: Dict[str, Any] = {}
for sv in self.supervisors:
for pid, reg in sv._registered_plugins.items():
if pid in result:
logger.error(f"检测到重复插件 ID {pid}component.get_all_plugins 结果已拒绝聚合")
return {"success": False, "error": f"检测到重复插件 ID: {pid}"}
comps = sv.component_registry.get_components_by_plugin(pid, enabled_only=False)
components_list = [
{
"name": component.name,
"full_name": component.full_name,
"type": component.component_type,
"enabled": component.enabled,
"metadata": component.metadata,
}
for component in comps
]
result[pid] = {
"name": pid,
"version": reg.plugin_version,
"description": "",
"author": "",
"enabled": True,
"components": components_list,
}
return {"success": True, "plugins": result}
async def _cap_component_get_plugin_info(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", plugin_id)
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is not None and (reg := sv._registered_plugins.get(plugin_name)) is not None:
return {
"success": True,
"plugin": {
"name": plugin_name,
"version": reg.plugin_version,
"description": "",
"author": "",
"enabled": True,
},
}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_list_loaded_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugins: List[str] = []
for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins}
async def _cap_component_list_registered_plugins(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugins: List[str] = []
for sv in self.supervisors:
plugins.extend(sv._registered_plugins.keys())
return {"success": True, "plugins": plugins}
def _resolve_component_toggle_target(
self: _RuntimeComponentManagerProtocol, name: str, component_type: str
) -> tuple[Optional["RegisteredComponent"], Optional[str]]:
short_name_matches: List["RegisteredComponent"] = []
for sv in self.supervisors:
comp = sv.component_registry.get_component(name)
if comp is not None and comp.component_type == component_type:
return comp, None
short_name_matches.extend(
candidate
for candidate in sv.component_registry.get_components_by_type(component_type, enabled_only=False)
if candidate.name == name
)
if len(short_name_matches) == 1:
return short_name_matches[0], None
if len(short_name_matches) > 1:
return None, f"组件名不唯一: {name} ({component_type}),请使用完整名 plugin_id.component_name"
return None, f"未找到组件: {name} ({component_type})"
async def _cap_component_enable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
return {"success": False, "error": "缺少必要参数 name 或 component_type"}
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件启用,不支持 scope/stream_id 定位"}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
comp.enabled = True
return {"success": True}
async def _cap_component_disable(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
name: str = args.get("name", "")
component_type: str = args.get("component_type", "")
scope: str = args.get("scope", "global")
stream_id: str = args.get("stream_id", "")
if not name or not component_type:
return {"success": False, "error": "缺少必要参数 name 或 component_type"}
if scope != "global" or stream_id:
return {"success": False, "error": "当前仅支持全局组件禁用,不支持 scope/stream_id 定位"}
comp, error = self._resolve_component_toggle_target(name, component_type)
if comp is None:
return {"success": False, "error": error or f"未找到组件: {name} ({component_type})"}
comp.enabled = False
return {"success": True}
async def _cap_component_load_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", "")
if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"}
import os
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())):
details = "; ".join(
f"{conflict_plugin_id}: {', '.join(paths)}"
for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items())
)
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
registered_supervisor = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if registered_supervisor is not None:
try:
reloaded = await registered_supervisor.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
for sv in self.supervisors:
for pdir in sv._plugin_dirs:
if os.path.isdir(os.path.join(pdir, plugin_name)):
try:
reloaded = await sv.reload_plugins(reason=f"load {plugin_name}")
if reloaded:
return {"success": True, "count": 1}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.load_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}
async def _cap_component_unload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
return {"success": False, "error": "新运行时不支持单独卸载插件,请使用 reload"}
async def _cap_component_reload_plugin(
self: _RuntimeComponentManagerProtocol, plugin_id: str, capability: str, args: Dict[str, Any]
) -> Any:
plugin_name: str = args.get("plugin_name", "")
if not plugin_name:
return {"success": False, "error": "缺少必要参数 plugin_name"}
if duplicate_plugin_ids := self._find_duplicate_plugin_ids(list(self._iter_plugin_dirs())):
details = "; ".join(
f"{conflict_plugin_id}: {', '.join(paths)}"
for conflict_plugin_id, paths in sorted(duplicate_plugin_ids.items())
)
return {"success": False, "error": f"检测到重复插件 ID拒绝热重载: {details}"}
try:
sv = self._get_supervisor_for_plugin(plugin_name)
except RuntimeError as exc:
return {"success": False, "error": str(exc)}
if sv is not None:
try:
reloaded = await sv.reload_plugins(reason=f"reload {plugin_name}")
if reloaded:
return {"success": True}
return {"success": False, "error": f"插件 {plugin_name} 热重载失败,已回滚"}
except Exception as e:
logger.error(f"[cap.component.reload_plugin] 热重载失败: {e}")
return {"success": False, "error": str(e)}
return {"success": False, "error": f"未找到插件: {plugin_name}"}