- Changed the message sending method to return DeliveryBatch instead of DeliveryReceipt in integration.py. - Removed AdapterDeclarationPayload and related references from envelope.py, replacing them with MessageGatewayStateUpdatePayload and MessageGatewayStateUpdateResultPayload. - Updated runner_main.py to remove adapter-related logic and methods, focusing on message gateway functionality. - Added tests for message gateway runtime state synchronization and action bridge functionality in test files.
542 lines
21 KiB
Python
542 lines
21 KiB
Python
"""Host-side ComponentRegistry
|
||
|
||
对齐旧系统 component_registry.py 的核心能力:
|
||
- 按类型注册组件(action / command / tool / event_handler / workflow_handler / message_gateway)
|
||
- 命名空间 (plugin_id.component_name)
|
||
- 命令正则匹配
|
||
- 组件启用/禁用
|
||
- 多维度查询(按名称、类型、插件)
|
||
- 注册统计
|
||
"""
|
||
|
||
from enum import Enum
|
||
from typing import Any, Dict, List, Optional, Set, TypedDict, Tuple
|
||
|
||
import contextlib
|
||
import re
|
||
|
||
from src.common.logger import get_logger
|
||
|
||
logger = get_logger("plugin_runtime.host.component_registry")
|
||
|
||
|
||
class ComponentTypes(str, Enum):
|
||
ACTION = "ACTION"
|
||
COMMAND = "COMMAND"
|
||
TOOL = "TOOL"
|
||
EVENT_HANDLER = "EVENT_HANDLER"
|
||
HOOK_HANDLER = "HOOK_HANDLER"
|
||
MESSAGE_GATEWAY = "MESSAGE_GATEWAY"
|
||
|
||
|
||
class StatusDict(TypedDict):
|
||
total: int
|
||
ACTION: int
|
||
COMMAND: int
|
||
TOOL: int
|
||
EVENT_HANDLER: int
|
||
HOOK_HANDLER: int
|
||
MESSAGE_GATEWAY: int
|
||
plugins: int
|
||
|
||
|
||
class ComponentEntry:
|
||
"""组件条目"""
|
||
|
||
__slots__ = (
|
||
"name",
|
||
"full_name",
|
||
"component_type",
|
||
"plugin_id",
|
||
"metadata",
|
||
"enabled",
|
||
"compiled_pattern",
|
||
"disabled_session",
|
||
)
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
self.name: str = name
|
||
self.full_name: str = f"{plugin_id}.{name}"
|
||
self.component_type: ComponentTypes = ComponentTypes(component_type)
|
||
self.plugin_id: str = plugin_id
|
||
self.metadata: Dict[str, Any] = metadata
|
||
self.enabled: bool = metadata.get("enabled", True)
|
||
self.disabled_session: Set[str] = set()
|
||
|
||
|
||
class ActionEntry(ComponentEntry):
|
||
"""Action 组件条目"""
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
super().__init__(name, component_type, plugin_id, metadata)
|
||
|
||
|
||
class CommandEntry(ComponentEntry):
|
||
"""Command 组件条目"""
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
self.compiled_pattern: Optional[re.Pattern] = None
|
||
self.aliases: List[str] = metadata.get("aliases", [])
|
||
if pattern := metadata.get("command_pattern", ""):
|
||
try:
|
||
self.compiled_pattern = re.compile(pattern)
|
||
except re.error as e:
|
||
logger.warning(f"命令 {self.full_name} 正则编译失败: {e}")
|
||
super().__init__(name, component_type, plugin_id, metadata)
|
||
|
||
|
||
class ToolEntry(ComponentEntry):
|
||
"""Tool 组件条目"""
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
self.description: str = metadata.get("description", "")
|
||
self.parameters: List[Dict[str, Any]] = metadata.get("parameters", [])
|
||
self.parameters_raw: List[Dict[str, Any]] = metadata.get("parameters_raw", [])
|
||
super().__init__(name, component_type, plugin_id, metadata)
|
||
|
||
|
||
class EventHandlerEntry(ComponentEntry):
|
||
"""EventHandler 组件条目"""
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
self.event_type: str = metadata.get("event_type", "")
|
||
self.weight: int = metadata.get("weight", 0)
|
||
self.intercept_message: bool = metadata.get("intercept_message", False)
|
||
super().__init__(name, component_type, plugin_id, metadata)
|
||
|
||
|
||
class HookHandlerEntry(ComponentEntry):
|
||
"""WorkflowHandler 组件条目"""
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
self.stage: str = metadata.get("stage", "")
|
||
self.priority: int = metadata.get("priority", 0)
|
||
self.blocking: bool = metadata.get("blocking", False)
|
||
super().__init__(name, component_type, plugin_id, metadata)
|
||
|
||
|
||
class MessageGatewayEntry(ComponentEntry):
|
||
"""MessageGateway 组件条目"""
|
||
|
||
def __init__(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> None:
|
||
self.route_type: str = self._normalize_route_type(metadata.get("route_type", ""))
|
||
self.platform: str = str(metadata.get("platform", "") or "").strip()
|
||
self.protocol: str = str(metadata.get("protocol", "") or "").strip()
|
||
self.account_id: str = str(metadata.get("account_id", "") or "").strip()
|
||
self.scope: str = str(metadata.get("scope", "") or "").strip()
|
||
super().__init__(name, component_type, plugin_id, metadata)
|
||
|
||
@staticmethod
|
||
def _normalize_route_type(raw_value: Any) -> str:
|
||
"""规范化消息网关路由类型。
|
||
|
||
Args:
|
||
raw_value: 原始路由类型值。
|
||
|
||
Returns:
|
||
str: 规范化后的路由类型。
|
||
|
||
Raises:
|
||
ValueError: 当路由类型不受支持时抛出。
|
||
"""
|
||
|
||
normalized_value = str(raw_value or "").strip().lower()
|
||
route_type_aliases = {
|
||
"send": "send",
|
||
"receive": "receive",
|
||
"recv": "receive",
|
||
"recive": "receive",
|
||
"duplex": "duplex",
|
||
}
|
||
route_type = route_type_aliases.get(normalized_value)
|
||
if route_type is None:
|
||
raise ValueError(f"MessageGateway 路由类型不合法: {raw_value}")
|
||
return route_type
|
||
|
||
@property
|
||
def supports_send(self) -> bool:
|
||
"""返回当前网关是否支持出站。"""
|
||
|
||
return self.route_type in {"send", "duplex"}
|
||
|
||
@property
|
||
def supports_receive(self) -> bool:
|
||
"""返回当前网关是否支持入站。"""
|
||
|
||
return self.route_type in {"receive", "duplex"}
|
||
|
||
|
||
class ComponentRegistry:
|
||
"""Host-side 组件注册表
|
||
|
||
由 Supervisor 在收到 plugin.register_components 时调用。
|
||
供业务层查询可用组件、匹配命令、调度 action/event 等。
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
# 全量索引
|
||
self._components: Dict[str, ComponentEntry] = {} # full_name -> comp
|
||
|
||
# 按类型索引
|
||
self._by_type: Dict[ComponentTypes, Dict[str, ComponentEntry]] = {
|
||
comp_type: {} for comp_type in ComponentTypes
|
||
} # component_type -> (full_name -> comp)
|
||
|
||
# 按插件索引
|
||
self._by_plugin: Dict[str, List[ComponentEntry]] = {}
|
||
|
||
def clear(self) -> None:
|
||
"""清空全部组件注册状态。"""
|
||
self._components.clear()
|
||
for type_dict in self._by_type.values():
|
||
type_dict.clear()
|
||
self._by_plugin.clear()
|
||
|
||
# ====== 注册 / 注销 ======
|
||
def register_component(self, name: str, component_type: str, plugin_id: str, metadata: Dict[str, Any]) -> bool:
|
||
"""注册单个组件
|
||
|
||
Args:
|
||
name: 组件名称(不含插件id前缀)
|
||
component_type: 组件类型(如 `ACTION`、`COMMAND` 等)
|
||
plugin_id: 插件id
|
||
metadata: 组件元数据
|
||
Returns:
|
||
success (bool): 是否成功注册(失败原因通常是组件类型无效)
|
||
"""
|
||
try:
|
||
if component_type == ComponentTypes.ACTION:
|
||
comp = ActionEntry(name, component_type, plugin_id, metadata)
|
||
elif component_type == ComponentTypes.COMMAND:
|
||
comp = CommandEntry(name, component_type, plugin_id, metadata)
|
||
elif component_type == ComponentTypes.TOOL:
|
||
comp = ToolEntry(name, component_type, plugin_id, metadata)
|
||
elif component_type == ComponentTypes.EVENT_HANDLER:
|
||
comp = EventHandlerEntry(name, component_type, plugin_id, metadata)
|
||
elif component_type == ComponentTypes.HOOK_HANDLER:
|
||
comp = HookHandlerEntry(name, component_type, plugin_id, metadata)
|
||
elif component_type == ComponentTypes.MESSAGE_GATEWAY:
|
||
comp = MessageGatewayEntry(name, component_type, plugin_id, metadata)
|
||
else:
|
||
raise ValueError(f"组件类型 {component_type} 不存在")
|
||
except ValueError:
|
||
logger.error(f"组件类型 {component_type} 不存在")
|
||
return False
|
||
|
||
if comp.full_name in self._components:
|
||
logger.warning(f"组件 {comp.full_name} 已存在,覆盖")
|
||
old_comp = self._components[comp.full_name]
|
||
# 从 _by_plugin 列表中移除旧条目,防止幽灵组件堆积
|
||
old_list = self._by_plugin.get(old_comp.plugin_id)
|
||
if old_list is not None:
|
||
with contextlib.suppress(ValueError):
|
||
old_list.remove(old_comp)
|
||
# 从旧类型索引中移除,防止类型变更时幽灵残留
|
||
if old_type_dict := self._by_type.get(old_comp.component_type):
|
||
old_type_dict.pop(comp.full_name, None)
|
||
|
||
self._components[comp.full_name] = comp
|
||
self._by_type[comp.component_type][comp.full_name] = comp
|
||
self._by_plugin.setdefault(plugin_id, []).append(comp)
|
||
|
||
return True
|
||
|
||
def register_plugin_components(self, plugin_id: str, components: List[Dict[str, Any]]) -> int:
|
||
"""批量注册一个插件的所有组件,返回成功注册数。
|
||
Args:
|
||
plugin_id (str): 插件id
|
||
components (List[Dict[str, Any]]): 组件字典列表,每个组件包含 name, component_type, metadata 等字段
|
||
Returns:
|
||
count (int): 成功注册的组件数量
|
||
"""
|
||
count = 0
|
||
for comp_data in components:
|
||
ok = self.register_component(
|
||
name=comp_data.get("name", ""),
|
||
component_type=comp_data.get("component_type", ""),
|
||
plugin_id=plugin_id,
|
||
metadata=comp_data.get("metadata", {}),
|
||
)
|
||
if ok:
|
||
count += 1
|
||
return count
|
||
|
||
def remove_components_by_plugin(self, plugin_id: str) -> int:
|
||
"""移除某个插件的所有组件,返回移除数量。
|
||
|
||
Args:
|
||
plugin_id (str): 插件id
|
||
Returns:
|
||
count (int): 移除的组件数量
|
||
"""
|
||
comps = self._by_plugin.pop(plugin_id, [])
|
||
for comp in comps:
|
||
self._components.pop(comp.full_name, None)
|
||
if type_dict := self._by_type.get(comp.component_type):
|
||
type_dict.pop(comp.full_name, None)
|
||
return len(comps)
|
||
|
||
# ====== 启用 / 禁用 ======
|
||
def check_component_enabled(self, component: ComponentEntry, session_id: Optional[str] = None):
|
||
if session_id and session_id in component.disabled_session:
|
||
return False
|
||
return component.enabled
|
||
|
||
def toggle_component_status(self, full_name: str, enabled: bool, session_id: Optional[str] = None) -> bool:
|
||
"""启用或禁用指定组件。
|
||
|
||
Args:
|
||
full_name (str): 组件全名
|
||
enabled (bool): 使能情况
|
||
session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
|
||
Returns:
|
||
success (bool): 是否成功设置(失败原因通常是组件不存在)
|
||
"""
|
||
comp = self._components.get(full_name)
|
||
if comp is None:
|
||
return False
|
||
if session_id:
|
||
if enabled:
|
||
comp.disabled_session.discard(session_id)
|
||
else:
|
||
comp.disabled_session.add(session_id)
|
||
else:
|
||
comp.enabled = enabled
|
||
return True
|
||
|
||
def toggle_plugin_status(self, plugin_id: str, enabled: bool, session_id: Optional[str] = None) -> int:
|
||
"""批量启用或禁用某插件的所有组件。
|
||
|
||
Args:
|
||
plugin_id (str): 插件id
|
||
enabled (bool): 使能情况
|
||
session_id (Optional[str]): 可选的会话ID,仅对该会话禁用(如果提供)
|
||
Returns:
|
||
count (int): 成功设置的组件数量(失败原因通常是插件不存在)
|
||
"""
|
||
comps = self._by_plugin.get(plugin_id, [])
|
||
for comp in comps:
|
||
if session_id:
|
||
if enabled:
|
||
comp.disabled_session.discard(session_id)
|
||
else:
|
||
comp.disabled_session.add(session_id)
|
||
else:
|
||
comp.enabled = enabled
|
||
return len(comps)
|
||
|
||
def get_component(self, full_name: str) -> Optional[ComponentEntry]:
|
||
"""按全名查询。
|
||
|
||
Args:
|
||
full_name (str): 组件全名
|
||
Returns:
|
||
component (Optional[ComponentEntry]): 组件条目,未找到时为 None
|
||
"""
|
||
return self._components.get(full_name)
|
||
|
||
def get_components_by_type(
|
||
self, component_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||
) -> List[ComponentEntry]:
|
||
"""按类型查询组件
|
||
|
||
Args:
|
||
component_type (str): 组件类型(如 `ACTION`、`COMMAND` 等)
|
||
enabled_only (bool): 是否仅返回启用的组件
|
||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||
Returns:
|
||
components (List[ComponentEntry]): 组件条目列表
|
||
"""
|
||
try:
|
||
comp_type = ComponentTypes(component_type)
|
||
except ValueError:
|
||
logger.error(f"组件类型 {component_type} 不存在")
|
||
raise
|
||
type_dict = self._by_type.get(comp_type, {})
|
||
if enabled_only:
|
||
return [c for c in type_dict.values() if self.check_component_enabled(c, session_id)]
|
||
return list(type_dict.values())
|
||
|
||
def get_components_by_plugin(
|
||
self, plugin_id: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||
) -> List[ComponentEntry]:
|
||
"""按插件查询组件。
|
||
|
||
Args:
|
||
plugin_id (str): 插件ID
|
||
enabled_only (bool): 是否仅返回启用的组件
|
||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||
Returns:
|
||
components (List[ComponentEntry]): 组件条目列表
|
||
"""
|
||
comps = self._by_plugin.get(plugin_id, [])
|
||
return [c for c in comps if self.check_component_enabled(c, session_id)] if enabled_only else list(comps)
|
||
|
||
def find_command_by_text(
|
||
self, text: str, session_id: Optional[str] = None
|
||
) -> Optional[Tuple[ComponentEntry, Dict[str, Any]]]:
|
||
"""通过文本匹配命令正则,返回 (组件, matched_groups) 元组。
|
||
|
||
matched_groups 为正则命名捕获组 dict,别名匹配时为空 dict。
|
||
Args:
|
||
text (str): 待匹配文本
|
||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||
Returns:
|
||
result (Optional[tuple[ComponentEntry, Dict[str, Any]]]): 匹配到的组件及正则捕获组,未找到时为 None
|
||
"""
|
||
for comp in self._by_type.get(ComponentTypes.COMMAND, {}).values():
|
||
if not self.check_component_enabled(comp, session_id):
|
||
continue
|
||
if not isinstance(comp, CommandEntry):
|
||
continue
|
||
if comp.compiled_pattern:
|
||
if m := comp.compiled_pattern.search(text):
|
||
return comp, m.groupdict()
|
||
# 别名匹配
|
||
for alias in comp.aliases:
|
||
if text.startswith(alias):
|
||
return comp, {}
|
||
return None
|
||
|
||
def get_event_handlers(
|
||
self, event_type: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||
) -> List[EventHandlerEntry]:
|
||
"""查询指定事件类型的事件处理器组件。
|
||
|
||
Args:
|
||
event_type (str): 事件类型
|
||
enabled_only (bool): 是否仅返回启用的组件
|
||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||
Returns:
|
||
handlers (List[EventHandlerEntry]): 符合条件的 EventHandler 组件列表,按 weight 降序排序
|
||
"""
|
||
handlers: List[EventHandlerEntry] = []
|
||
for comp in self._by_type.get(ComponentTypes.EVENT_HANDLER, {}).values():
|
||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||
continue
|
||
if not isinstance(comp, EventHandlerEntry):
|
||
continue
|
||
if comp.event_type == event_type:
|
||
handlers.append(comp)
|
||
handlers.sort(key=lambda c: c.weight, reverse=True)
|
||
return handlers
|
||
|
||
def get_hook_handlers(
|
||
self, stage: str, *, enabled_only: bool = True, session_id: Optional[str] = None
|
||
) -> List[HookHandlerEntry]:
|
||
"""获取特定 hook 阶段的所有步骤,按 priority 降序。
|
||
|
||
Args:
|
||
stage: hook 名称
|
||
enabled_only: 是否仅返回启用的组件
|
||
session_id: 可选的会话ID,若提供则考虑会话禁用状态
|
||
Returns:
|
||
handlers (List[HookHandlerEntry]): 符合条件的 HookHandler 组件列表,按 priority 降序排序
|
||
"""
|
||
handlers: List[HookHandlerEntry] = []
|
||
for comp in self._by_type.get(ComponentTypes.HOOK_HANDLER, {}).values():
|
||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||
continue
|
||
if not isinstance(comp, HookHandlerEntry):
|
||
continue
|
||
if comp.stage == stage:
|
||
handlers.append(comp)
|
||
handlers.sort(key=lambda c: c.priority, reverse=True)
|
||
return handlers
|
||
|
||
def get_message_gateway(
|
||
self,
|
||
plugin_id: str,
|
||
name: str,
|
||
*,
|
||
enabled_only: bool = True,
|
||
session_id: Optional[str] = None,
|
||
) -> Optional[MessageGatewayEntry]:
|
||
"""按插件和组件名获取单个消息网关。
|
||
|
||
Args:
|
||
plugin_id: 插件 ID。
|
||
name: 网关组件名称。
|
||
enabled_only: 是否仅返回启用的组件。
|
||
session_id: 可选的会话 ID。
|
||
|
||
Returns:
|
||
Optional[MessageGatewayEntry]: 若存在则返回消息网关条目。
|
||
"""
|
||
|
||
component = self._components.get(f"{plugin_id}.{name}")
|
||
if not isinstance(component, MessageGatewayEntry):
|
||
return None
|
||
if enabled_only and not self.check_component_enabled(component, session_id):
|
||
return None
|
||
return component
|
||
|
||
def get_message_gateways(
|
||
self,
|
||
*,
|
||
plugin_id: Optional[str] = None,
|
||
platform: str = "",
|
||
route_type: str = "",
|
||
enabled_only: bool = True,
|
||
session_id: Optional[str] = None,
|
||
) -> List[MessageGatewayEntry]:
|
||
"""查询消息网关组件列表。
|
||
|
||
Args:
|
||
plugin_id: 可选的插件 ID 过滤条件。
|
||
platform: 可选的平台过滤条件。
|
||
route_type: 可选的路由类型过滤条件。
|
||
enabled_only: 是否仅返回启用的组件。
|
||
session_id: 可选的会话 ID。
|
||
|
||
Returns:
|
||
List[MessageGatewayEntry]: 符合条件的消息网关组件列表。
|
||
"""
|
||
|
||
normalized_platform = str(platform or "").strip()
|
||
normalized_route_type = str(route_type or "").strip().lower()
|
||
gateways: List[MessageGatewayEntry] = []
|
||
for comp in self._by_type.get(ComponentTypes.MESSAGE_GATEWAY, {}).values():
|
||
if not isinstance(comp, MessageGatewayEntry):
|
||
continue
|
||
if plugin_id and comp.plugin_id != plugin_id:
|
||
continue
|
||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||
continue
|
||
if normalized_platform and comp.platform != normalized_platform:
|
||
continue
|
||
if normalized_route_type and comp.route_type != normalized_route_type:
|
||
continue
|
||
gateways.append(comp)
|
||
return gateways
|
||
|
||
def get_tools(self, *, enabled_only: bool = True, session_id: Optional[str] = None) -> List[ToolEntry]:
|
||
"""查询所有工具组件。
|
||
|
||
Args:
|
||
enabled_only (bool): 是否仅返回启用的组件
|
||
session_id (Optional[str]): 可选的会话ID,若提供则考虑会话禁用状态
|
||
Returns:
|
||
tools (List[ToolEntry]): 符合条件的 Tool 组件列表
|
||
"""
|
||
tools: List[ToolEntry] = []
|
||
for comp in self._by_type.get(ComponentTypes.TOOL, {}).values():
|
||
if enabled_only and not self.check_component_enabled(comp, session_id):
|
||
continue
|
||
if isinstance(comp, ToolEntry):
|
||
tools.append(comp)
|
||
return tools
|
||
|
||
# ====== 统计信息 ======
|
||
def get_stats(self) -> StatusDict:
|
||
"""获取注册统计。
|
||
|
||
Returns:
|
||
stats (StatusDict): 组件统计信息,包括总数、各类型数量、插件数量等
|
||
"""
|
||
stats: StatusDict = {"total": len(self._components)} # type: ignore
|
||
for comp_type, type_dict in self._by_type.items():
|
||
stats[comp_type.value] = len(type_dict)
|
||
stats["plugins"] = len(self._by_plugin)
|
||
return stats
|