refactor: component_registry更易理解

This commit is contained in:
UnCLAS-Prommer
2026-03-18 02:08:13 +08:00
committed by DrSmoothl
parent ca6fd96d4c
commit 14a0c21cbf
2 changed files with 347 additions and 125 deletions

View File

@@ -1,7 +1,7 @@
"""Host-side ComponentRegistry
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / workflow_step
- 按类型注册组件action / command / tool / event_handler / workflow_handler / message_gateway
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
@@ -9,8 +9,10 @@
- 注册统计
"""
from typing import Any, Dict, List, Optional
from enum import Enum
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
import contextlib
import re
from src.common.logger import get_logger
@@ -18,8 +20,28 @@ from src.common.logger import get_logger
logger = get_logger("plugin_runtime.host.component_registry")
class RegisteredComponent:
"""已注册的组件条目"""
class ComponentTypes(str, Enum):
ACTION = "ACTION"
COMMAND = "COMMAND"
TOOL = "TOOL"
EVENT_HANDLER = "EVENT_HANDLER"
WORKFLOW_HANDLER = "WORKFLOW_HANDLER"
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
class StatusDict(TypedDict):
total: int
ACTION: int
COMMAND: int
TOOL: int
EVENT_HANDLER: int
WORKFLOW_HANDLER: int
MESSAGE_GATEWAY: int
plugins: int
class ComponentEntry:
"""组件条目"""
__slots__ = (
"name",
@@ -28,31 +50,74 @@ class RegisteredComponent:
"plugin_id",
"metadata",
"enabled",
"_compiled_pattern",
"compiled_pattern",
"disabled_session",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> None:
self.name = name
self.full_name = f"{plugin_id}.{name}"
self.component_type = component_type
self.plugin_id = plugin_id
self.metadata = metadata
self.enabled = metadata.get("enabled", True)
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.component_type: ComponentTypes = ComponentTypes(component_type)
self.plugin_id: str = plugin_id
self.metadata: Dict[str, Any] = metadata
self.enabled: bool = metadata.get("enabled", True)
self.disabled_session: Set[str] = set()
# 预编译命令正则(仅 command 类型)
self._compiled_pattern: Optional[re.Pattern] = None
if component_type == "command":
if pattern := metadata.get("command_pattern", ""):
try:
self._compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ActionEntry(ComponentEntry):
"""Action 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
class CommandEntry(ComponentEntry):
"""Command 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.compiled_pattern: Optional[re.Pattern] = None
self.aliases: List[str] = metadata.get("aliases", [])
if pattern := metadata.get("command_pattern", ""):
try:
self.compiled_pattern = re.compile(pattern)
except re.error as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
super().__init__(name, component_type, plugin_id, metadata)
class ToolEntry(ComponentEntry):
"""Tool 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.description: str = metadata.get("description", "")
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
super().__init__(name, component_type, plugin_id, metadata)
class EventHandlerEntry(ComponentEntry):
"""EventHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
super().__init__(name, component_type, plugin_id, metadata)
class WorkflowHandlerEntry(ComponentEntry):
"""WorkflowHandler 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
self.stage: str = metadata.get("stage", "")
self.priority: int = metadata.get("priority", 0)
super().__init__(name, component_type, plugin_id, metadata)
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
super().__init__(name, component_type, plugin_id, metadata)
class ComponentRegistry:
@@ -64,19 +129,15 @@ class ComponentRegistry:
def __init__(self) -> None:
# 全量索引
self._components: Dict[str, RegisteredComponent] = {} # full_name -> comp
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
# 按类型索引
self._by_type: Dict[str, Dict[str, RegisteredComponent]] = {
"action": {},
"command": {},
"tool": {},
"event_handler": {},
"workflow_step": {},
}
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
comp_type: {} for comp_type in ComponentTypes
} # component_type -> (full_name -> comp)
# 按插件索引
self._by_plugin: Dict[str, List[RegisteredComponent]] = {}
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
def clear(self) -> None:
"""清空全部组件注册状态。"""
@@ -85,47 +146,63 @@ class ComponentRegistry:
type_dict.clear()
self._by_plugin.clear()
# ──── 注册 / 注销 ─────────────────────────────────────────
# ====== 注册 / 注销 ======
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
"""注册单个组件
Args:
name: 组件名称不含插件id前缀
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
plugin_id: 插件id
metadata: 组件元数据
Returns:
success (bool): 是否成功注册(失败原因通常是组件类型无效)
"""
try:
if component_type == ComponentTypes.ACTION:
comp = ActionEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.COMMAND:
comp = CommandEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.TOOL:
comp = ToolEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.EVENT_HANDLER:
comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.WORKFLOW_HANDLER:
comp = WorkflowHandlerEntry(name, component_type, plugin_id, metadata)
elif component_type == ComponentTypes.MESSAGE_GATEWAY:
comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
else:
raise ValueError(f"组件类型 {component_type} 不存在")
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
return False
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
) -> bool:
"""注册单个组件。"""
comp = RegisteredComponent(name, component_type, plugin_id, metadata)
if comp.full_name in self._components:
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
old_comp = self._components[comp.full_name]
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
old_list = self._by_plugin.get(old_comp.plugin_id)
if old_list is not None:
try:
with contextlib.suppress(ValueError):
old_list.remove(old_comp)
except ValueError:
pass
# 从旧类型索引中移除,防止类型变更时幽灵残留
if old_type_dict := self._by_type.get(old_comp.component_type):
old_type_dict.pop(comp.full_name, None)
self._components[comp.full_name] = comp
if component_type not in self._by_type:
self._by_type[component_type] = {}
self._by_type[component_type][comp.full_name] = comp
self._by_type[comp.component_type][comp.full_name] = comp
self._by_plugin.setdefault(plugin_id, []).append(comp)
return True
def register_plugin_components(
self,
plugin_id: str,
components: List[Dict[str, Any]],
) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。"""
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量注册一个插件的所有组件,返回成功注册数。
Args:
plugin_id (str): 插件id
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
Returns:
count (int): 成功注册的组件数量
"""
count = 0
for comp_data in components:
ok = self.register_component(
@@ -139,7 +216,13 @@ class ComponentRegistry:
return count
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。"""
"""移除某个插件的所有组件,返回移除数量。
Args:
plugin_id (str): 插件id
Returns:
count (int): 移除的组件数量
"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
@@ -147,106 +230,200 @@ class ComponentRegistry:
type_dict.pop(comp.full_name, None)
return len(comps)
# ──── 启用 / 禁用 ─────────────────────────────────────────
# ====== 启用 / 禁用 ======
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
if session_id and session_id in component.disabled_session:
return False
return component.enabled
def set_component_enabled(self, full_name: str, enabled: bool) -> bool:
"""启用或禁用指定组件。"""
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""启用或禁用指定组件。
Args:
full_name (str): 组件全名
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
success (bool): 是否成功设置(失败原因通常是组件不存在)
"""
comp = self._components.get(full_name)
if comp is None:
return False
comp.enabled = enabled
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return True
def set_plugin_enabled(self, plugin_id: str, enabled: bool) -> int:
"""批量启用或禁用某插件的所有组件。"""
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
"""批量启用或禁用某插件的所有组件。
Args:
plugin_id (str): 插件id
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
comp.enabled = enabled
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return len(comps)
# ──── 查询方法 ─────────────────────────────────────────────
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
"""按全名查询。
def get_component(self, full_name: str) -> Optional[RegisteredComponent]:
"""按全名查询。"""
Args:
full_name (str): 组件全名
Returns:
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
"""
return self._components.get(full_name)
def get_components_by_type(self, component_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按类型查询。"""
type_dict = self._by_type.get(component_type, {})
def get_components_by_type(
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按类型查询组件
Args:
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
try:
comp_type = ComponentTypes(component_type)
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [c for c in type_dict.values() if c.enabled]
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
return list(type_dict.values())
def get_components_by_plugin(self, plugin_id: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""按插件查询。"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if c.enabled] if enabled_only else list(comps)
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按插件查询组件。
def find_command_by_text(self, text: str) -> Optional[tuple[RegisteredComponent, Dict[str, Any]]]:
Args:
plugin_id (str): 插件ID
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
def find_command_by_text(
self, text: str, session_id: Optional[str] = None
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict别名匹配时为空 dict。
Args:
text (str): 待匹配文本
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
"""
for comp in self._by_type.get("command", {}).values():
if not comp.enabled:
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
if not self.check_component_enabled(comp, session_id):
continue
if comp._compiled_pattern:
m = comp._compiled_pattern.search(text)
if m:
if not isinstance(comp, CommandEntry):
continue
if comp.compiled_pattern:
if m := comp.compiled_pattern.search(text):
return comp, m.groupdict()
# 别名匹配
aliases = comp.metadata.get("aliases", [])
for alias in aliases:
for alias in comp.aliases:
if text.startswith(alias):
return comp, {}
return None
def get_event_handlers(self, event_type: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定事件类型的所有 event_handler按 weight 降序排列。"""
handlers = []
for comp in self._by_type.get("event_handler", {}).values():
if enabled_only and not comp.enabled:
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[EventHandlerEntry]:
"""查询指定事件类型的事件处理器组件。
Args:
event_type (str): 事件类型
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
"""
handlers: List[EventHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if comp.metadata.get("event_type") == event_type:
if not isinstance(comp, EventHandlerEntry):
continue
if comp.event_type == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.metadata.get("weight", 0), reverse=True)
handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
def get_workflow_steps(self, stage: str, *, enabled_only: bool = True) -> List[RegisteredComponent]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。"""
steps = []
for comp in self._by_type.get("workflow_step", {}).values():
if enabled_only and not comp.enabled:
def get_workflow_handlers(
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[WorkflowHandlerEntry]:
"""获取特定 workflow 阶段的所有步骤,按 priority 降序。
Args:
stage: workflow 阶段名称
enabled_only: 是否仅返回启用的组件
session_id: 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[WorkflowHandlerEntry]): 符合条件的 WorkflowHandler 组件列表,按 priority 降序排序
"""
handlers: List[WorkflowHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.WORKFLOW_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if comp.metadata.get("stage") == stage:
steps.append(comp)
steps.sort(key=lambda c: c.metadata.get("priority", 0), reverse=True)
return steps
if not isinstance(comp, WorkflowHandlerEntry):
continue
if comp.stage == stage:
handlers.append(comp)
handlers.sort(key=lambda c: c.priority, reverse=True)
return handlers
def get_tools_for_llm(self, *, enabled_only: bool = True) -> List[Dict[str, Any]]:
"""获取可供 LLM 使用的工具列表openai function-calling 格式预览)。"""
result: List[Dict[str, Any]] = []
for comp in self.get_components_by_type("tool", enabled_only=enabled_only):
tool_def: Dict[str, Any] = {
"name": comp.full_name,
"description": comp.metadata.get("description", ""),
}
# 从结构化参数或原始参数构建 parameters
params = comp.metadata.get("parameters", [])
params_raw = comp.metadata.get("parameters_raw", {})
if params:
tool_def["parameters"] = params
elif params_raw:
tool_def["parameters"] = params_raw
result.append(tool_def)
return result
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
Args:
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
"""
tools: List[ToolEntry] = []
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if isinstance(comp, ToolEntry):
tools.append(comp)
return tools
# ──── 统计 ─────────────────────────────────────────────────
def get_stats(self) -> Dict[str, int]:
"""获取注册统计。"""
stats: Dict[str, int] = {"total": len(self._components)}
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
stats: StatusDict = {"total": len(self._components)} # type: ignore
for comp_type, type_dict in self._by_type.items():
stats[comp_type] = len(type_dict)
stats[comp_type.value] = len(type_dict)
stats["plugins"] = len(self._by_plugin)
return stats