Files
mai-bot/src/plugin_runtime/host/component_registry.py

1277 lines
46 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Host 侧组件注册表。
对齐旧系统 component_registry.py 的核心能力:
- 按类型注册组件action / command / tool / event_handler / hook_handler / message_gateway
- 命名空间 (plugin_id.component_name)
- 命令正则匹配
- 组件启用/禁用
- 多维度查询(按名称、类型、插件)
- 注册统计
"""
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, TypedDict
import contextlib
import re
from src.common.logger import get_logger
from src.core.tooling import build_tool_detailed_description
from .hook_spec_registry import HookSpecRegistry
logger = get_logger("plugin_runtime.host.component_registry")
class ComponentRegistrationError(ValueError):
"""组件注册失败异常。"""
def __init__(
self,
message: str,
*,
component_name: str = "",
component_type: str = "",
plugin_id: str = "",
) -> None:
"""初始化组件注册失败异常。
Args:
message: 原始错误信息。
component_name: 组件名称。
component_type: 组件类型。
plugin_id: 插件 ID。
"""
self.component_name = str(component_name or "").strip()
self.component_type = str(component_type or "").strip()
self.plugin_id = str(plugin_id or "").strip()
super().__init__(message)
class ComponentTypes(str, Enum):
ACTION = "ACTION"
COMMAND = "COMMAND"
TOOL = "TOOL"
EVENT_HANDLER = "EVENT_HANDLER"
HOOK_HANDLER = "HOOK_HANDLER"
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
ComponentChatScope = Literal["all", "group", "private"]
def _normalize_chat_scope(raw_value: Any) -> ComponentChatScope:
"""规范化组件聊天类型适用范围。"""
normalized_value = str(raw_value or "all").strip().lower()
if normalized_value == "group":
return "group"
if normalized_value == "private":
return "private"
return "all"
class StatusDict(TypedDict):
total: int
action: int
command: int
tool: int
event_handler: int
hook_handler: int
message_gateway: int
plugins: int
class ComponentEntry:
"""组件条目"""
__slots__ = (
"name",
"full_name",
"component_type",
"plugin_id",
"metadata",
"enabled",
"compiled_pattern",
"disabled_session",
"chat_scope",
"allowed_session",
)
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
self.name: str = name
self.full_name: str = f"{plugin_id}.{name}"
self.component_type: ComponentTypes = ComponentTypes(component_type)
self.plugin_id: str = plugin_id
self.metadata: Dict[str, Any] = metadata
self.enabled: bool = metadata.get("enabled", True)
self.disabled_session: Set[str] = set()
self.chat_scope: ComponentChatScope = _normalize_chat_scope(chat_scope)
self.allowed_session: Set[str] = {
str(session_id).strip()
for session_id in (allowed_session or [])
if str(session_id).strip()
}
class ActionEntry(ComponentEntry):
"""Action 组件条目"""
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session)
class CommandEntry(ComponentEntry):
"""Command 组件条目"""
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session)
self.aliases: List[str] = metadata.get("aliases", [])
self.compiled_pattern: Optional[re.Pattern] = None
if pattern := metadata.get("command_pattern", ""):
try:
self.compiled_pattern = re.compile(pattern)
except (re.error, TypeError) as e:
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
class ToolEntry(ComponentEntry):
"""Tool 组件条目"""
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
self.description: str = str(metadata.get("description", "") or "").strip()
self.brief_description: str = str(
metadata.get("brief_description", self.description) or self.description or f"工具 {name}"
).strip()
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
self.parameters_raw: Dict[str, Any] | List[Dict[str, Any]] = metadata.get("parameters_raw", {})
detailed_description = str(metadata.get("detailed_description", "") or "").strip()
self.detailed_description: str = detailed_description
self.invoke_method: str = str(metadata.get("invoke_method", "plugin.invoke_tool") or "plugin.invoke_tool").strip()
self.legacy_component_type: str = str(metadata.get("legacy_component_type", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session)
if not self.detailed_description:
parameters_schema = self._get_parameters_schema()
self.detailed_description = build_tool_detailed_description(parameters_schema)
def _get_parameters_schema(self) -> Dict[str, Any] | None:
"""获取当前工具条目的对象级参数 Schema。
Returns:
Dict[str, Any] | None: 归一化后的参数 Schema。
"""
if isinstance(self.parameters_raw, dict) and self.parameters_raw:
if self.parameters_raw.get("type") == "object" or "properties" in self.parameters_raw:
return dict(self.parameters_raw)
required_names: List[str] = []
normalized_properties: Dict[str, Any] = {}
for property_name, property_schema in self.parameters_raw.items():
if not isinstance(property_schema, dict):
continue
property_schema_copy = dict(property_schema)
if bool(property_schema_copy.pop("required", False)):
required_names.append(str(property_name))
normalized_properties[str(property_name)] = property_schema_copy
schema: Dict[str, Any] = {
"type": "object",
"properties": normalized_properties,
}
if required_names:
schema["required"] = required_names
return schema
if isinstance(self.parameters, list) and self.parameters:
properties: Dict[str, Any] = {}
required_names: List[str] = []
for parameter in self.parameters:
if not isinstance(parameter, dict):
continue
parameter_name = str(parameter.get("name", "") or "").strip()
if not parameter_name:
continue
if bool(parameter.get("required", False)):
required_names.append(parameter_name)
properties[parameter_name] = {
key: value
for key, value in parameter.items()
if key not in {"name", "required", "param_type"}
}
properties[parameter_name]["type"] = str(
parameter.get("type", parameter.get("param_type", "string")) or "string"
)
schema = {
"type": "object",
"properties": properties,
}
if required_names:
schema["required"] = required_names
return schema
return None
class EventHandlerEntry(ComponentEntry):
"""EventHandler 组件条目"""
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
self.event_type: str = metadata.get("event_type", "")
self.weight: int = metadata.get("weight", 0)
self.intercept_message: bool = metadata.get("intercept_message", False)
super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session)
class HookHandlerEntry(ComponentEntry):
"""HookHandler 组件条目。"""
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
self.hook: str = self._normalize_hook_name(metadata.get("hook", ""))
self.mode: str = self._normalize_mode(metadata.get("mode", "blocking"))
self.order: str = self._normalize_order(metadata.get("order", "normal"))
self.timeout_ms: int = self._normalize_timeout_ms(metadata.get("timeout_ms", 0))
self.error_policy: str = self._normalize_error_policy(metadata.get("error_policy", "skip"))
super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session)
@staticmethod
def _normalize_error_policy(raw_value: Any) -> str:
"""规范化 Hook 异常处理策略。
Args:
raw_value: 原始异常处理策略值。
Returns:
str: 规范化后的异常处理策略。
Raises:
ValueError: 当异常处理策略不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "skip"
if normalized_value not in {"abort", "skip", "log"}:
raise ValueError(f"HookHandler 异常处理策略不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_hook_name(raw_value: Any) -> str:
"""规范化命名 Hook 名称。
Args:
raw_value: 原始 Hook 名称。
Returns:
str: 去空白后的 Hook 名称。
Raises:
ValueError: 当 Hook 名称为空时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
if not (normalized_value := str(normalized_source or "").strip()):
raise ValueError("HookHandler 的 hook 名称不能为空")
return normalized_value
@staticmethod
def _normalize_mode(raw_value: Any) -> str:
"""规范化 Hook 处理模式。
Args:
raw_value: 原始模式值。
Returns:
str: 规范化后的模式。
Raises:
ValueError: 当模式不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "blocking"
if normalized_value not in {"blocking", "observe"}:
raise ValueError(f"HookHandler 模式不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_order(raw_value: Any) -> str:
"""规范化 Hook 顺序槽位。
Args:
raw_value: 原始顺序值。
Returns:
str: 规范化后的顺序槽位。
Raises:
ValueError: 当顺序值不受支持时抛出。
"""
normalized_source = getattr(raw_value, "value", raw_value)
normalized_value = str(normalized_source or "").strip().lower() or "normal"
if normalized_value not in {"early", "normal", "late"}:
raise ValueError(f"HookHandler 顺序槽位不合法: {raw_value}")
return normalized_value
@staticmethod
def _normalize_timeout_ms(raw_value: Any) -> int:
"""规范化 Hook 超时配置。
Args:
raw_value: 原始超时值。
Returns:
int: 规范化后的超时毫秒数。
Raises:
ValueError: 当超时值为负数或无法转换为整数时抛出。
"""
try:
timeout_ms = int(raw_value or 0)
except (TypeError, ValueError) as exc:
raise ValueError(f"HookHandler 超时配置不合法: {raw_value}") from exc
if timeout_ms < 0:
raise ValueError(f"HookHandler 超时配置不能为负数: {raw_value}")
return timeout_ms
@property
def is_blocking(self) -> bool:
"""返回当前 Hook 是否为阻塞模式。"""
return self.mode == "blocking"
@property
def is_observe(self) -> bool:
"""返回当前 Hook 是否为观察模式。"""
return self.mode == "observe"
class MessageGatewayEntry(ComponentEntry):
"""MessageGateway 组件条目"""
def __init__(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> None:
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
self.platform: str = str(metadata.get("platform", "") or "").strip()
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
self.scope: str = str(metadata.get("scope", "") or "").strip()
super().__init__(name, component_type, plugin_id, metadata, chat_scope, allowed_session)
@staticmethod
def _normalize_route_type(raw_value: Any) -> str:
"""规范化消息网关路由类型。
Args:
raw_value: 原始路由类型值。
Returns:
str: 规范化后的路由类型。
Raises:
ValueError: 当路由类型不受支持时抛出。
"""
normalized_value = str(raw_value or "").strip().lower()
route_type_aliases = {
"send": "send",
"receive": "receive",
"recv": "receive",
"recive": "receive",
"duplex": "duplex",
}
route_type = route_type_aliases.get(normalized_value)
if route_type is None:
raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
return route_type
@property
def supports_send(self) -> bool:
"""返回当前网关是否支持出站。"""
return self.route_type in {"send", "duplex"}
@property
def supports_receive(self) -> bool:
"""返回当前网关是否支持入站。"""
return self.route_type in {"receive", "duplex"}
class ComponentRegistry:
"""Host 侧组件注册表。
由 Supervisor 在收到 plugin.register_components 时调用。
供业务层查询可用组件、匹配命令、调度 action/event 等。
"""
def __init__(self, hook_spec_registry: Optional[HookSpecRegistry] = None) -> None:
"""初始化组件注册表。
Args:
hook_spec_registry: 可选的 Hook 规格注册中心;提供后会在注册
HookHandler 时执行规格校验。
"""
# 全量索引
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
# 按类型索引
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
comp_type: {} for comp_type in ComponentTypes
} # component_type -> (full_name -> comp)
# 按插件索引
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
self._hook_spec_registry = hook_spec_registry
@staticmethod
def _convert_action_metadata_to_tool_metadata(
name: str,
metadata: Dict[str, Any],
) -> Dict[str, Any]:
"""将旧 Action 元数据转换为统一 Tool 元数据。
Args:
name: 组件名称。
metadata: Action 原始元数据。
Returns:
Dict[str, Any]: 转换后的 Tool 元数据。
"""
action_parameters = metadata.get("action_parameters")
parameters_schema: Dict[str, Any] | None = None
if isinstance(action_parameters, dict) and action_parameters:
properties: Dict[str, Any] = {}
for parameter_name, parameter_description in action_parameters.items():
normalized_name = str(parameter_name or "").strip()
if not normalized_name:
continue
properties[normalized_name] = {
"type": "string",
"description": str(parameter_description or "").strip() or "兼容旧 Action 参数",
}
if properties:
parameters_schema = {
"type": "object",
"properties": properties,
}
detailed_parts: List[str] = []
if parameters_schema is not None:
parameter_description = build_tool_detailed_description(parameters_schema)
if parameter_description:
detailed_parts.append(parameter_description)
action_require = [
str(item).strip()
for item in (metadata.get("action_require") or [])
if str(item).strip()
]
if action_require:
detailed_parts.append("使用建议:\n" + "\n".join(f"- {item}" for item in action_require))
associated_types = [
str(item).strip()
for item in (metadata.get("associated_types") or [])
if str(item).strip()
]
if associated_types:
detailed_parts.append(f"适用消息类型:{''.join(associated_types)}")
activation_type = str(metadata.get("activation_type", "always") or "always").strip()
activation_keywords = [
str(item).strip()
for item in (metadata.get("activation_keywords") or [])
if str(item).strip()
]
activation_lines = [f"兼容旧 Action 激活方式:{activation_type}"]
if activation_keywords:
activation_lines.append(f"激活关键词:{''.join(activation_keywords)}")
if str(metadata.get("action_prompt", "") or "").strip():
activation_lines.append(f"原始 Action 提示语:{str(metadata['action_prompt']).strip()}")
detailed_parts.append("\n".join(activation_lines))
brief_description = str(metadata.get("brief_description", metadata.get("description", "") or f"工具 {name}")).strip()
return {
**metadata,
"description": brief_description,
"brief_description": brief_description,
"detailed_description": "\n\n".join(part for part in detailed_parts if part).strip(),
"parameters_raw": parameters_schema or {},
"invoke_method": "plugin.invoke_action",
"legacy_action": True,
"legacy_component_type": "ACTION",
}
@staticmethod
def _normalize_component_type(component_type: str) -> ComponentTypes:
"""规范化组件类型输入。
Args:
component_type: 原始组件类型字符串。
Returns:
ComponentTypes: 规范化后的组件类型枚举。
Raises:
ValueError: 当组件类型不受支持时抛出。
"""
normalized_value = str(component_type or "").strip().upper()
return ComponentTypes(normalized_value)
def clear(self) -> None:
"""清空全部组件注册状态。"""
self._components.clear()
for type_dict in self._by_type.values():
type_dict.clear()
self._by_plugin.clear()
@staticmethod
def _is_legacy_action_component(component: ComponentEntry) -> bool:
"""判断组件是否为兼容旧 Action 的 Tool 条目。
Args:
component: 待判断的组件条目。
Returns:
bool: 是否为兼容旧 Action 组件。
"""
if not isinstance(component, ToolEntry):
return False
return str(component.metadata.get("legacy_component_type", "") or "").strip().upper() == "ACTION"
def _validate_hook_handler_entry(self, component: HookHandlerEntry) -> None:
"""校验 HookHandler 是否满足已注册的 Hook 规格。
Args:
component: 待校验的 HookHandler 条目。
Raises:
ComponentRegistrationError: HookHandler 声明不合法时抛出。
"""
if self._hook_spec_registry is None:
return
hook_spec = self._hook_spec_registry.get_hook_spec(component.hook)
if hook_spec is None:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 声明了未注册的 Hook: {component.hook}",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.is_blocking and not hook_spec.allow_blocking:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能注册为 blockingHook {component.hook} 不允许 blocking 处理器",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.is_observe and not hook_spec.allow_observe:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能注册为 observeHook {component.hook} 不允许 observe 处理器",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
if component.error_policy == "abort" and not hook_spec.allow_abort:
raise ComponentRegistrationError(
f"HookHandler {component.full_name} 不能使用 error_policy=abortHook {component.hook} 不允许 abort",
component_name=component.name,
component_type=component.component_type.value,
plugin_id=component.plugin_id,
)
def _build_component_entry(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> ComponentEntry:
"""根据声明构造组件条目。
Args:
name: 组件名称。
component_type: 组件类型。
plugin_id: 插件 ID。
metadata: 组件元数据。
Returns:
ComponentEntry: 已构造并完成校验的组件条目。
Raises:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
try:
normalized_type = self._normalize_component_type(component_type)
normalized_metadata = dict(metadata)
if normalized_type == ComponentTypes.ACTION:
normalized_metadata = self._convert_action_metadata_to_tool_metadata(name, normalized_metadata)
component = ToolEntry(
name,
ComponentTypes.TOOL.value,
plugin_id,
normalized_metadata,
chat_scope,
allowed_session,
)
elif normalized_type == ComponentTypes.COMMAND:
component = CommandEntry(
name,
normalized_type.value,
plugin_id,
normalized_metadata,
chat_scope,
allowed_session,
)
elif normalized_type == ComponentTypes.TOOL:
component = ToolEntry(
name,
normalized_type.value,
plugin_id,
normalized_metadata,
chat_scope,
allowed_session,
)
elif normalized_type == ComponentTypes.EVENT_HANDLER:
component = EventHandlerEntry(
name,
normalized_type.value,
plugin_id,
normalized_metadata,
chat_scope,
allowed_session,
)
elif normalized_type == ComponentTypes.HOOK_HANDLER:
component = HookHandlerEntry(
name,
normalized_type.value,
plugin_id,
normalized_metadata,
chat_scope,
allowed_session,
)
self._validate_hook_handler_entry(component)
elif normalized_type == ComponentTypes.MESSAGE_GATEWAY:
component = MessageGatewayEntry(
name,
normalized_type.value,
plugin_id,
normalized_metadata,
chat_scope,
allowed_session,
)
else:
raise ComponentRegistrationError(
f"组件类型 {component_type} 不存在",
component_name=name,
component_type=component_type,
plugin_id=plugin_id,
)
except ComponentRegistrationError:
raise
except Exception as exc:
raise ComponentRegistrationError(
str(exc),
component_name=name,
component_type=component_type,
plugin_id=plugin_id,
) from exc
return component
def _remove_existing_component_entry(self, component: ComponentEntry) -> None:
"""移除同名旧组件条目。
Args:
component: 即将写入的新组件条目。
"""
if component.full_name not in self._components:
return
logger.warning(f"组件 {component.full_name} 已存在,覆盖")
old_component = self._components[component.full_name]
old_list = self._by_plugin.get(old_component.plugin_id)
if old_list is not None:
with contextlib.suppress(ValueError):
old_list.remove(old_component)
if old_type_dict := self._by_type.get(old_component.component_type):
old_type_dict.pop(component.full_name, None)
def _add_component_entry(self, component: ComponentEntry) -> None:
"""写入单个组件条目到全部索引。
Args:
component: 待写入的组件条目。
"""
self._remove_existing_component_entry(component)
self._components[component.full_name] = component
self._by_type[component.component_type][component.full_name] = component
self._by_plugin.setdefault(component.plugin_id, []).append(component)
# ====== 注册 / 注销 ======
def register_component(
self,
name: str,
component_type: str,
plugin_id: str,
metadata: Dict[str, Any],
chat_scope: str = "all",
allowed_session: Optional[List[str]] = None,
) -> bool:
"""注册单个组件。
Args:
name: 组件名称(不含插件 ID 前缀)。
component_type: 组件类型(如 ``ACTION``、``COMMAND`` 等)。
plugin_id: 插件 ID。
metadata: 组件元数据。
Returns:
bool: 注册成功时恒为 ``True``。
Raises:
ComponentRegistrationError: 组件声明不合法时抛出。
"""
component = self._build_component_entry(
name,
component_type,
plugin_id,
metadata,
chat_scope,
allowed_session,
)
self._add_component_entry(component)
return True
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
"""批量替换一个插件的组件集合。
该方法会先完整校验所有组件声明,只有全部通过后才会替换旧组件,
从而避免插件进入半注册状态。
Args:
plugin_id: 插件 ID。
components: 组件声明字典列表。
Returns:
int: 实际注册的组件数量。
Raises:
ComponentRegistrationError: 任一组件声明不合法时抛出。
"""
prepared_components: List[ComponentEntry] = []
for component_data in components:
raw_metadata = (
dict(component_data.get("metadata", {}))
if isinstance(component_data.get("metadata"), dict)
else {}
)
chat_scope = str(component_data.get("chat_scope", raw_metadata.pop("chat_scope", "all")) or "all")
raw_allowed_session = component_data.get("allowed_session", raw_metadata.pop("allowed_session", []))
allowed_session = (
[str(item).strip() for item in raw_allowed_session if str(item).strip()]
if isinstance(raw_allowed_session, list)
else []
)
prepared_components.append(
self._build_component_entry(
name=str(component_data.get("name", "") or ""),
component_type=str(component_data.get("component_type", "") or ""),
plugin_id=plugin_id,
metadata=raw_metadata,
chat_scope=chat_scope,
allowed_session=allowed_session,
)
)
self.remove_components_by_plugin(plugin_id)
for component in prepared_components:
self._add_component_entry(component)
return len(prepared_components)
def remove_components_by_plugin(self, plugin_id: str) -> int:
"""移除某个插件的所有组件,返回移除数量。
Args:
plugin_id (str): 插件id
Returns:
count (int): 移除的组件数量
"""
comps = self._by_plugin.pop(plugin_id, [])
for comp in comps:
self._components.pop(comp.full_name, None)
if type_dict := self._by_type.get(comp.component_type):
type_dict.pop(comp.full_name, None)
return len(comps)
# ====== 启用 / 禁用 ======
def check_component_enabled(
self,
component: ComponentEntry,
session_id: Optional[str] = None,
is_group_chat: Optional[bool] = None,
group_id: Optional[str] = None,
platform: Optional[str] = None,
):
if session_id and session_id in component.disabled_session:
return False
if is_group_chat is not None:
if component.chat_scope == "group" and is_group_chat is not True:
return False
if component.chat_scope == "private" and is_group_chat is not False:
return False
if component.allowed_session:
allowed_candidates = {str(session_id or "").strip(), str(group_id or "").strip()}
if platform and group_id:
allowed_candidates.add(f"{platform}:{group_id}")
allowed_candidates.discard("")
if component.allowed_session.isdisjoint(allowed_candidates):
return False
return component.enabled
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""启用或禁用指定组件。
Args:
full_name (str): 组件全名
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
success (bool): 是否成功设置(失败原因通常是组件不存在)
"""
comp = self._components.get(full_name)
if comp is None:
return False
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return True
def set_component_enabled(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
"""设置指定组件的启用状态。
Args:
full_name: 组件全名。
enabled: 目标启用状态。
session_id: 可选的会话 ID仅对该会话生效。
Returns:
bool: 是否设置成功。
"""
return self.toggle_component_status(full_name, enabled, session_id=session_id)
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
"""批量启用或禁用某插件的所有组件。
Args:
plugin_id (str): 插件id
enabled (bool): 使能情况
session_id (Optional[str]): 可选的会话ID仅对该会话禁用如果提供
Returns:
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
"""
comps = self._by_plugin.get(plugin_id, [])
for comp in comps:
if session_id:
if enabled:
comp.disabled_session.discard(session_id)
else:
comp.disabled_session.add(session_id)
else:
comp.enabled = enabled
return len(comps)
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
"""按全名查询。
Args:
full_name (str): 组件全名
Returns:
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
"""
return self._components.get(full_name)
def get_components_by_type(
self,
component_type: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
is_group_chat: Optional[bool] = None,
group_id: Optional[str] = None,
platform: Optional[str] = None,
) -> List[ComponentEntry]:
"""按类型查询组件
Args:
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
try:
comp_type = self._normalize_component_type(component_type)
except ValueError:
logger.error(f"组件类型 {component_type} 不存在")
raise
if comp_type == ComponentTypes.ACTION:
action_components = [
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if self._is_legacy_action_component(component)
]
if enabled_only:
return [
component
for component in action_components
if self.check_component_enabled(component, session_id, is_group_chat, group_id, platform)
]
return action_components
type_dict = self._by_type.get(comp_type, {})
if enabled_only:
return [
c
for c in type_dict.values()
if self.check_component_enabled(c, session_id, is_group_chat, group_id, platform)
]
return list(type_dict.values())
def get_components_by_plugin(
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[ComponentEntry]:
"""按插件查询组件。
Args:
plugin_id (str): 插件ID
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
components (List[ComponentEntry]): 组件条目列表
"""
comps = self._by_plugin.get(plugin_id, [])
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
def find_command_by_text(
self, text: str, session_id: Optional[str] = None
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
matched_groups 为正则命名捕获组 dict别名匹配时为空 dict。
Args:
text (str): 待匹配文本
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
"""
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
if not self.check_component_enabled(comp, session_id):
continue
if not isinstance(comp, CommandEntry):
continue
if comp.compiled_pattern:
if m := comp.compiled_pattern.search(text):
return comp, m.groupdict()
# 别名匹配
for alias in comp.aliases:
if text.startswith(alias):
return comp, {}
return None
def get_event_handlers(
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[EventHandlerEntry]:
"""查询指定事件类型的事件处理器组件。
Args:
event_type (str): 事件类型
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
"""
handlers: List[EventHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if not isinstance(comp, EventHandlerEntry):
continue
if comp.event_type == event_type:
handlers.append(comp)
handlers.sort(key=lambda c: c.weight, reverse=True)
return handlers
def get_hook_handlers(
self, hook_name: str, *, enabled_only: bool = True, session_id: Optional[str] = None
) -> List[HookHandlerEntry]:
"""获取订阅指定命名 Hook 的全部处理器。
Args:
hook_name: 目标 Hook 名称。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID若提供则考虑会话禁用状态。
Returns:
List[HookHandlerEntry]: 符合条件的 HookHandler 组件列表。
"""
handlers: List[HookHandlerEntry] = []
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if not isinstance(comp, HookHandlerEntry):
continue
if comp.hook == hook_name:
handlers.append(comp)
handlers.sort(key=lambda comp: (self._get_hook_mode_rank(comp.mode), self._get_hook_order_rank(comp.order), comp.plugin_id, comp.name))
return handlers
@staticmethod
def _get_hook_mode_rank(mode: str) -> int:
"""返回 Hook 模式的排序权重。
Args:
mode: Hook 模式字符串。
Returns:
int: 越小表示越靠前。
"""
return {"blocking": 0, "observe": 1}.get(mode, 99)
@staticmethod
def _get_hook_order_rank(order: str) -> int:
"""返回 Hook 顺序槽位的排序权重。
Args:
order: Hook 顺序槽位字符串。
Returns:
int: 越小表示越靠前。
"""
return {"early": 0, "normal": 1, "late": 2}.get(order, 99)
def get_message_gateway(
self,
plugin_id: str,
name: str,
*,
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> Optional[MessageGatewayEntry]:
"""按插件和组件名获取单个消息网关。
Args:
plugin_id: 插件 ID。
name: 网关组件名称。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID。
Returns:
Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
"""
component = self._components.get(f"{plugin_id}.{name}")
if not isinstance(component, MessageGatewayEntry):
return None
if enabled_only and not self.check_component_enabled(component, session_id):
return None
return component
def get_message_gateways(
self,
*,
plugin_id: Optional[str] = None,
platform: str = "",
route_type: str = "",
enabled_only: bool = True,
session_id: Optional[str] = None,
) -> List[MessageGatewayEntry]:
"""查询消息网关组件列表。
Args:
plugin_id: 可选的插件 ID 过滤条件。
platform: 可选的平台过滤条件。
route_type: 可选的路由类型过滤条件。
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID。
Returns:
List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
"""
normalized_platform = str(platform or "").strip()
normalized_route_type = str(route_type or "").strip().lower()
gateways: List[MessageGatewayEntry] = []
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
if not isinstance(comp, MessageGatewayEntry):
continue
if plugin_id and comp.plugin_id != plugin_id:
continue
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if normalized_platform and comp.platform != normalized_platform:
continue
if normalized_route_type and comp.route_type != normalized_route_type:
continue
gateways.append(comp)
return gateways
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
"""查询所有工具组件。
Args:
enabled_only (bool): 是否仅返回启用的组件
session_id (Optional[str]): 可选的会话ID若提供则考虑会话禁用状态
Returns:
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
"""
tools: List[ToolEntry] = []
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
if enabled_only and not self.check_component_enabled(comp, session_id):
continue
if isinstance(comp, ToolEntry):
tools.append(comp)
return tools
def get_tools_for_llm(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""兼容旧接口,返回可供 LLM 使用的工具条目列表。
Args:
enabled_only: 是否仅返回启用的组件。
session_id: 可选的会话 ID若提供则考虑会话禁用状态。
Returns:
List[Dict[str, Any]]: 兼容旧结构的工具组件字典列表。
"""
return [
{
"name": tool.full_name,
"description": tool.description,
"parameters": (
dict(tool.parameters_raw)
if isinstance(tool.parameters_raw, dict) and tool.parameters_raw
else tool._get_parameters_schema() or {}
),
"parameters_raw": tool.parameters_raw,
"enabled": tool.enabled,
"plugin_id": tool.plugin_id,
}
for tool in self.get_tools(enabled_only=enabled_only, session_id=session_id)
if not self._is_legacy_action_component(tool)
]
# ====== 统计信息 ======
def get_stats(self) -> StatusDict:
"""获取注册统计。
Returns:
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
"""
return StatusDict(
total=len(self._components),
action=len(
[
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if self._is_legacy_action_component(component)
]
),
command=len(self._by_type[ComponentTypes.COMMAND]),
tool=len(
[
component
for component in self._by_type.get(ComponentTypes.TOOL, {}).values()
if not self._is_legacy_action_component(component)
]
),
event_handler=len(self._by_type[ComponentTypes.EVENT_HANDLER]),
hook_handler=len(self._by_type[ComponentTypes.HOOK_HANDLER]),
message_gateway=len(self._by_type[ComponentTypes.MESSAGE_GATEWAY]),
plugins=len(self._by_plugin),
)