Files
mai-bot/src/plugin_runtime/host/component_registry.py
2026-03-13 11:45:26 +08:00

253 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_step
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
- 多维度查询(按名称、类型、插件)
- 注册统计
"""
from typing import Any, Dict, List, Optional
import re
from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
class RegisteredComponent:
"""已注册的组件条目"""
__slots__ = (
"name",
"full_name",
"component_type",
"plugin_id",
"metadata",
"enabled",
"_compiled_pattern",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> None:
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
self.plugin_id = plugin_id
self.metadata = metadata
self.enabled = metadata.get("enabled", True)
# 预编译命令正则(仅 command 类型)
self._compiled_pattern: Optional[re.Pattern] = None
if component_type == "command":
if pattern := metadata.get("command_pattern", ""):
try:
self._compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ComponentRegistry:
"""Host-side 组件注册表
由 Supervisor 在收到 plugin.register_components 时调用。
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self) -> None:
# 全量索引
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
# 按类型索引
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
"action": {},
"command": {},
"tool": {},
"event_handler": {},
"workflow_step": {},
}
# 按插件索引
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
def clear(self) -> None:
"""清空全部组件注册状态。"""
self._components.clear()
for type_dict in self._by_type.values():
type_dict.clear()
self._by_plugin.clear()
# ──── 注册 / 注销 ─────────────────────────────────────────
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> bool:
"""注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
try:
old_list.remove(old_comp)
except ValueError:
pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp
if component_type not in self._by_type:
self._by_type[component_type] = {}
self._by_type[component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
def register_plugin_components(
self,
plugin_id: str,
components: List[Dict[str, Any]],
) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。"""
count = 0
for comp_data in components:
ok = self.register_component(
name=comp_data.get("name", ""),
component_type=comp_data.get("component_type", ""),
plugin_id=plugin_id,
metadata=comp_data.get("metadata", {}),
)
if ok:
count += 1
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
if type_dict := self._by_type.get(comp.component_type):
type_dict.pop(comp.full_name, None)
return len(comps)
# ──── 启用 / 禁用 ─────────────────────────────────────────
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
"""启用或禁用指定组件。"""
comp = self._components.get(full_name)
if comp is None:
return False
comp.enabled = enabled
return True
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
"""批量启用或禁用某插件的所有组件。"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
comp.enabled = enabled
return len(comps)
# ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
"""按全名查询。"""
return self._components.get(full_name)
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return list(type_dict.values())
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps)
def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict别名匹配时为空 dict。
"""
for comp in self._by_type.get("command", {}).values():
if not comp.enabled:
continue
if comp._compiled_pattern:
m = comp._compiled_pattern.search(text)
if m:
return comp, m.groupdict()
# 别名匹配
aliases = comp.metadata.get("aliases", [])
for alias in aliases:
if text.startswith(alias):
return comp, {}
return None
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
if enabled_only and not comp.enabled:
continue
if comp.metadata.get("event_type") == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
return handlers
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():
if enabled_only and not comp.enabled:
continue
if comp.metadata.get("stage") == stage:
steps.append(comp)
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result: List[Dict[str, Any]] = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: Dict[str, Any] = {
"name": comp.full_name,
"description": comp.metadata.get("description", ""),
}
# 从结构化参数或原始参数构建 parameters
params = comp.metadata.get("parameters", [])
params_raw = comp.metadata.get("parameters_raw", {})
if params:
tool_def["parameters"] = params
elif params_raw:
tool_def["parameters"] = params_raw
result.append(tool_def)
return result
# ──── 统计 ─────────────────────────────────────────────────
def get_stats(self) -> Dict[str, int]:
"""获取注册统计。"""
stats: Dict[str, int] = {"total": len(self._components)}
for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats